Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,144 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from ..modeling import Sam
|
||||
from .amg import calculate_stability_score
|
||||
|
||||
|
||||
class SamOnnxModel(nn.Module):
|
||||
"""
|
||||
This model should not be called directly, but is used in ONNX export.
|
||||
It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
|
||||
with some functions modified to enable model tracing. Also supports extra
|
||||
options controlling what information. See the ONNX export script for details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Sam,
|
||||
return_single_mask: bool,
|
||||
use_stability_score: bool = False,
|
||||
return_extra_metrics: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.mask_decoder = model.mask_decoder
|
||||
self.model = model
|
||||
self.img_size = model.image_encoder.img_size
|
||||
self.return_single_mask = return_single_mask
|
||||
self.use_stability_score = use_stability_score
|
||||
self.stability_score_offset = 1.0
|
||||
self.return_extra_metrics = return_extra_metrics
|
||||
|
||||
@staticmethod
|
||||
def resize_longest_image_size(
|
||||
input_image_size: torch.Tensor, longest_side: int
|
||||
) -> torch.Tensor:
|
||||
input_image_size = input_image_size.to(torch.float32)
|
||||
scale = longest_side / torch.max(input_image_size)
|
||||
transformed_size = scale * input_image_size
|
||||
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
|
||||
return transformed_size
|
||||
|
||||
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
|
||||
point_coords = point_coords + 0.5
|
||||
point_coords = point_coords / self.img_size
|
||||
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
|
||||
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
|
||||
|
||||
point_embedding = point_embedding * (point_labels != -1)
|
||||
point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
|
||||
point_labels == -1
|
||||
)
|
||||
|
||||
for i in range(self.model.prompt_encoder.num_point_embeddings):
|
||||
point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
|
||||
i
|
||||
].weight * (point_labels == i)
|
||||
|
||||
return point_embedding
|
||||
|
||||
def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
|
||||
mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
|
||||
mask_embedding = mask_embedding + (
|
||||
1 - has_mask_input
|
||||
) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
|
||||
return mask_embedding
|
||||
|
||||
def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
|
||||
masks = F.interpolate(
|
||||
masks,
|
||||
size=(self.img_size, self.img_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
|
||||
masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
|
||||
|
||||
orig_im_size = orig_im_size.to(torch.int64)
|
||||
h, w = orig_im_size[0], orig_im_size[1]
|
||||
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
|
||||
return masks
|
||||
|
||||
def select_masks(
|
||||
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Determine if we should return the multiclick mask or not from the number of points.
|
||||
# The reweighting is used to avoid control flow.
|
||||
score_reweight = torch.tensor(
|
||||
[[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
|
||||
).to(iou_preds.device)
|
||||
score = iou_preds + (num_points - 2.5) * score_reweight
|
||||
best_idx = torch.argmax(score, dim=1)
|
||||
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
|
||||
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
|
||||
|
||||
return masks, iou_preds
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
point_coords: torch.Tensor,
|
||||
point_labels: torch.Tensor,
|
||||
mask_input: torch.Tensor,
|
||||
has_mask_input: torch.Tensor,
|
||||
orig_im_size: torch.Tensor,
|
||||
):
|
||||
sparse_embedding = self._embed_points(point_coords, point_labels)
|
||||
dense_embedding = self._embed_masks(mask_input, has_mask_input)
|
||||
|
||||
masks, scores = self.model.mask_decoder.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embedding,
|
||||
dense_prompt_embeddings=dense_embedding,
|
||||
)
|
||||
|
||||
if self.use_stability_score:
|
||||
scores = calculate_stability_score(
|
||||
masks, self.model.mask_threshold, self.stability_score_offset
|
||||
)
|
||||
|
||||
if self.return_single_mask:
|
||||
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
|
||||
|
||||
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
|
||||
|
||||
if self.return_extra_metrics:
|
||||
stability_scores = calculate_stability_score(
|
||||
upscaled_masks, self.model.mask_threshold, self.stability_score_offset
|
||||
)
|
||||
areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
|
||||
return upscaled_masks, scores, stability_scores, areas, masks
|
||||
|
||||
return upscaled_masks, scores, masks
|
||||
Reference in New Issue
Block a user