Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
194
custom_nodes/ComfyUI-segment-anything-2/load_model.py
Normal file
194
custom_nodes/ComfyUI-segment-anything-2/load_model.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import yaml
|
||||
from .sam2.modeling.sam2_base import SAM2Base
|
||||
from .sam2.modeling.backbones.image_encoder import ImageEncoder
|
||||
from .sam2.modeling.backbones.hieradet import Hiera
|
||||
from .sam2.modeling.backbones.image_encoder import FpnNeck
|
||||
from .sam2.modeling.position_encoding import PositionEmbeddingSine
|
||||
from .sam2.modeling.memory_attention import MemoryAttention, MemoryAttentionLayer
|
||||
from .sam2.modeling.sam.transformer import RoPEAttention
|
||||
from .sam2.modeling.memory_encoder import MemoryEncoder, MaskDownSampler, Fuser, CXBlock
|
||||
|
||||
from .sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from .sam2.sam2_video_predictor import SAM2VideoPredictor
|
||||
from .sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
||||
from comfy.utils import load_torch_file
|
||||
|
||||
def load_model(model_path, model_cfg_path, segmentor, dtype, device):
|
||||
# Load the YAML configuration
|
||||
with open(model_cfg_path, 'r') as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
# Extract the model configuration
|
||||
model_config = config['model']
|
||||
|
||||
# Instantiate the image encoder components
|
||||
trunk_config = model_config['image_encoder']['trunk']
|
||||
neck_config = model_config['image_encoder']['neck']
|
||||
position_encoding_config = neck_config['position_encoding']
|
||||
|
||||
position_encoding = PositionEmbeddingSine(
|
||||
num_pos_feats=position_encoding_config['num_pos_feats'],
|
||||
normalize=position_encoding_config['normalize'],
|
||||
scale=position_encoding_config['scale'],
|
||||
temperature=position_encoding_config['temperature']
|
||||
)
|
||||
|
||||
neck = FpnNeck(
|
||||
position_encoding=position_encoding,
|
||||
d_model=neck_config['d_model'],
|
||||
backbone_channel_list=neck_config['backbone_channel_list'],
|
||||
fpn_top_down_levels=neck_config['fpn_top_down_levels'],
|
||||
fpn_interp_model=neck_config['fpn_interp_model']
|
||||
)
|
||||
|
||||
keys_to_include = ['embed_dim', 'num_heads', 'global_att_blocks', 'window_pos_embed_bkg_spatial_size', 'stages']
|
||||
trunk_kwargs = {key: trunk_config[key] for key in keys_to_include if key in trunk_config}
|
||||
trunk = Hiera(**trunk_kwargs)
|
||||
|
||||
image_encoder = ImageEncoder(
|
||||
scalp=model_config['image_encoder']['scalp'],
|
||||
trunk=trunk,
|
||||
neck=neck
|
||||
)
|
||||
# Instantiate the memory attention components
|
||||
memory_attention_layer_config = config['model']['memory_attention']['layer']
|
||||
self_attention_config = memory_attention_layer_config['self_attention']
|
||||
cross_attention_config = memory_attention_layer_config['cross_attention']
|
||||
|
||||
self_attention = RoPEAttention(
|
||||
rope_theta=self_attention_config['rope_theta'],
|
||||
feat_sizes=self_attention_config['feat_sizes'],
|
||||
embedding_dim=self_attention_config['embedding_dim'],
|
||||
num_heads=self_attention_config['num_heads'],
|
||||
downsample_rate=self_attention_config['downsample_rate'],
|
||||
dropout=self_attention_config['dropout']
|
||||
)
|
||||
|
||||
cross_attention = RoPEAttention(
|
||||
rope_theta=cross_attention_config['rope_theta'],
|
||||
feat_sizes=cross_attention_config['feat_sizes'],
|
||||
rope_k_repeat=cross_attention_config['rope_k_repeat'],
|
||||
embedding_dim=cross_attention_config['embedding_dim'],
|
||||
num_heads=cross_attention_config['num_heads'],
|
||||
downsample_rate=cross_attention_config['downsample_rate'],
|
||||
dropout=cross_attention_config['dropout'],
|
||||
kv_in_dim=cross_attention_config['kv_in_dim']
|
||||
)
|
||||
|
||||
memory_attention_layer = MemoryAttentionLayer(
|
||||
activation=memory_attention_layer_config['activation'],
|
||||
dim_feedforward=memory_attention_layer_config['dim_feedforward'],
|
||||
dropout=memory_attention_layer_config['dropout'],
|
||||
pos_enc_at_attn=memory_attention_layer_config['pos_enc_at_attn'],
|
||||
self_attention=self_attention,
|
||||
d_model=memory_attention_layer_config['d_model'],
|
||||
pos_enc_at_cross_attn_keys=memory_attention_layer_config['pos_enc_at_cross_attn_keys'],
|
||||
pos_enc_at_cross_attn_queries=memory_attention_layer_config['pos_enc_at_cross_attn_queries'],
|
||||
cross_attention=cross_attention
|
||||
)
|
||||
|
||||
memory_attention = MemoryAttention(
|
||||
d_model=config['model']['memory_attention']['d_model'],
|
||||
pos_enc_at_input=config['model']['memory_attention']['pos_enc_at_input'],
|
||||
layer=memory_attention_layer,
|
||||
num_layers=config['model']['memory_attention']['num_layers']
|
||||
)
|
||||
|
||||
# Instantiate the memory encoder components
|
||||
memory_encoder_config = config['model']['memory_encoder']
|
||||
position_encoding_mem_enc_config = memory_encoder_config['position_encoding']
|
||||
mask_downsampler_config = memory_encoder_config['mask_downsampler']
|
||||
fuser_layer_config = memory_encoder_config['fuser']['layer']
|
||||
|
||||
position_encoding_mem_enc = PositionEmbeddingSine(
|
||||
num_pos_feats=position_encoding_mem_enc_config['num_pos_feats'],
|
||||
normalize=position_encoding_mem_enc_config['normalize'],
|
||||
scale=position_encoding_mem_enc_config['scale'],
|
||||
temperature=position_encoding_mem_enc_config['temperature']
|
||||
)
|
||||
|
||||
mask_downsampler = MaskDownSampler(
|
||||
kernel_size=mask_downsampler_config['kernel_size'],
|
||||
stride=mask_downsampler_config['stride'],
|
||||
padding=mask_downsampler_config['padding']
|
||||
)
|
||||
|
||||
fuser_layer = CXBlock(
|
||||
dim=fuser_layer_config['dim'],
|
||||
kernel_size=fuser_layer_config['kernel_size'],
|
||||
padding=fuser_layer_config['padding'],
|
||||
layer_scale_init_value=float(fuser_layer_config['layer_scale_init_value'])
|
||||
)
|
||||
fuser = Fuser(
|
||||
num_layers=memory_encoder_config['fuser']['num_layers'],
|
||||
layer=fuser_layer
|
||||
)
|
||||
|
||||
memory_encoder = MemoryEncoder(
|
||||
position_encoding=position_encoding_mem_enc,
|
||||
mask_downsampler=mask_downsampler,
|
||||
fuser=fuser,
|
||||
out_dim=memory_encoder_config['out_dim']
|
||||
)
|
||||
|
||||
sam_mask_decoder_extra_args = {
|
||||
"dynamic_multimask_via_stability": True,
|
||||
"dynamic_multimask_stability_delta": 0.05,
|
||||
"dynamic_multimask_stability_thresh": 0.98,
|
||||
}
|
||||
|
||||
def initialize_model(model_class, model_config, segmentor, image_encoder, memory_attention, memory_encoder, sam_mask_decoder_extra_args, dtype, device):
|
||||
return model_class(
|
||||
image_encoder=image_encoder,
|
||||
memory_attention=memory_attention,
|
||||
memory_encoder=memory_encoder,
|
||||
sam_mask_decoder_extra_args=sam_mask_decoder_extra_args,
|
||||
num_maskmem=model_config['num_maskmem'],
|
||||
image_size=model_config['image_size'],
|
||||
sigmoid_scale_for_mem_enc=model_config['sigmoid_scale_for_mem_enc'],
|
||||
sigmoid_bias_for_mem_enc=model_config['sigmoid_bias_for_mem_enc'],
|
||||
use_mask_input_as_output_without_sam=model_config['use_mask_input_as_output_without_sam'],
|
||||
directly_add_no_mem_embed=model_config['directly_add_no_mem_embed'],
|
||||
use_high_res_features_in_sam=model_config['use_high_res_features_in_sam'],
|
||||
multimask_output_in_sam=model_config['multimask_output_in_sam'],
|
||||
iou_prediction_use_sigmoid=model_config['iou_prediction_use_sigmoid'],
|
||||
use_obj_ptrs_in_encoder=model_config['use_obj_ptrs_in_encoder'],
|
||||
add_tpos_enc_to_obj_ptrs=model_config['add_tpos_enc_to_obj_ptrs'],
|
||||
only_obj_ptrs_in_the_past_for_eval=model_config['only_obj_ptrs_in_the_past_for_eval'],
|
||||
pred_obj_scores=model_config['pred_obj_scores'],
|
||||
pred_obj_scores_mlp=model_config['pred_obj_scores_mlp'],
|
||||
fixed_no_obj_ptr=model_config['fixed_no_obj_ptr'],
|
||||
multimask_output_for_tracking=model_config['multimask_output_for_tracking'],
|
||||
use_multimask_token_for_obj_ptr=model_config['use_multimask_token_for_obj_ptr'],
|
||||
compile_image_encoder=model_config['compile_image_encoder'],
|
||||
multimask_min_pt_num=model_config['multimask_min_pt_num'],
|
||||
multimask_max_pt_num=model_config['multimask_max_pt_num'],
|
||||
use_mlp_for_obj_ptr_proj=model_config['use_mlp_for_obj_ptr_proj'],
|
||||
proj_tpos_enc_in_obj_ptrs=model_config['proj_tpos_enc_in_obj_ptrs'],
|
||||
no_obj_embed_spatial=model_config['no_obj_embed_spatial'],
|
||||
use_signed_tpos_enc_to_obj_ptrs=model_config['use_signed_tpos_enc_to_obj_ptrs'],
|
||||
binarize_mask_from_pts_for_mem_enc=True if segmentor == 'video' else False,
|
||||
).to(dtype).to(device).eval()
|
||||
|
||||
# Load the state dictionary
|
||||
sd = load_torch_file(model_path)
|
||||
|
||||
# Initialize model based on segmentor type
|
||||
if segmentor == 'single_image':
|
||||
model_class = SAM2Base
|
||||
model = initialize_model(model_class, model_config, segmentor, image_encoder, memory_attention, memory_encoder, sam_mask_decoder_extra_args, dtype, device)
|
||||
model.load_state_dict(sd)
|
||||
model = SAM2ImagePredictor(model)
|
||||
elif segmentor == 'video':
|
||||
model_class = SAM2VideoPredictor
|
||||
model = initialize_model(model_class, model_config, segmentor, image_encoder, memory_attention, memory_encoder, sam_mask_decoder_extra_args, dtype, device)
|
||||
model.load_state_dict(sd)
|
||||
elif segmentor == 'automaskgenerator':
|
||||
model_class = SAM2Base
|
||||
model = initialize_model(model_class, model_config, segmentor, image_encoder, memory_attention, memory_encoder, sam_mask_decoder_extra_args, dtype, device)
|
||||
model.load_state_dict(sd)
|
||||
model = SAM2AutomaticMaskGenerator(model)
|
||||
else:
|
||||
raise ValueError(f"Segmentor {segmentor} not supported")
|
||||
|
||||
return model
|
||||
Reference in New Issue
Block a user