Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled

Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-09 00:55:26 +00:00
parent 2b70ab9ad0
commit f09734b0ee
2274 changed files with 748556 additions and 3 deletions

View File

@@ -0,0 +1,270 @@
# File created using '.gitignore Generator' for Visual Studio Code: https://bit.ly/vscode-gig
# Created by https://www.toptal.com/developers/gitignore/api/windows,visualstudiocode,macos,python,venv
# Edit at https://www.toptal.com/developers/gitignore?templates=windows,visualstudiocode,macos,python,venv
### macOS ###
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
### macOS Patch ###
# iCloud generated files
*.icloud
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# LSP config files
pyrightconfig.json
### venv ###
# Virtualenv
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
[Bb]in
[Ii]nclude
[Ll]ib
[Ll]ib64
[Ll]ocal
[Ss]cripts
pyvenv.cfg
pip-selfcheck.json
### VisualStudioCode ###
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets
# Local History for Visual Studio Code
.history/
# Built Visual Studio Code Extensions
*.vsix
### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide
### Windows ###
# Windows thumbnail cache files
Thumbs.db
Thumbs.db:encryptable
ehthumbs.db
ehthumbs_vista.db
# Dump file
*.stackdump
# Folder config file
[Dd]esktop.ini
# Recycle Bin used on file shares
$RECYCLE.BIN/
# Windows Installer files
*.cab
*.msi
*.msix
*.msm
*.msp
# Windows shortcuts
*.lnk
# End of https://www.toptal.com/developers/gitignore/api/windows,visualstudiocode,macos,python,venv
# Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option)
ref/

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 pamparamm
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,86 @@
# Various Guidance implementations for ComfyUI / SD WebUI (reForge)
Implementation of
- Perturbed-Attention Guidance from [Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance (D. Ahn et al.)](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/)
- [Smoothed Energy Guidance: Guiding Diffusion Models with Reduced Energy Curvature of Attention (Susung Hong)](https://arxiv.org/abs/2408.00760)
- Sliding Window Guidance from [The Unreasonable Effectiveness of Guidance for Diffusion Models (Kaiser et al.)](https://arxiv.org/abs/2411.10257)
- [PLADIS: Pushing the Limits of Attention in Diffusion Models at Inference Time by Leveraging Sparsity](https://cubeyoung.github.io/pladis-proejct/) (ComfyUI-only)
- [Normalized Attention Guidance: Universal Negative Guidance for Diffusion Models](https://arxiv.org/abs/2505.21179) (ComfyUI-only, has a description inside ComfyUI)
- [Token Perturbation Guidance for Diffusion Models](https://arxiv.org/abs/2506.10036) (ComfyUI-only)
as an extension for [ComfyUI](https://github.com/comfyanonymous/ComfyUI) and [SD WebUI (reForge)](https://github.com/Panchovix/stable-diffusion-webui-reForge).
Works with SD1.5 and SDXL.
## Installation
### ComfyUI
You can either:
- `git clone https://github.com/pamparamm/sd-perturbed-attention.git` into `ComfyUI/custom-nodes/` folder.
- Install it via [ComfyUI Manager](https://github.com/ltdrdata/ComfyUI-Manager) (search for custom node named "Perturbed-Attention Guidance").
- Install it via [comfy-cli](https://comfydocs.org/comfy-cli/getting-started) with `comfy node registry-install sd-perturbed-attention`
### SD WebUI (reForge)
`git clone https://github.com/pamparamm/sd-perturbed-attention.git` into `stable-diffusion-webui-forge/extensions/` folder.
### SD WebUI (Auto1111)
As an alternative for A1111 WebUI you can use PAG implementation from [sd-webui-incantations](https://github.com/v0xie/sd-webui-incantations) extension.
## Guidance Nodes/Scripts
### ComfyUI
![comfyui-node-pag-basic](res/comfyui-node-pag-basic.png)
![comfyui-node-pag-advanced](res/comfyui-node-pag-advanced.png)
![comfyui-node-seg](res/comfyui-node-seg.png)
### SD WebUI (reForge)
![forge-pag](res/forge-pag.png)
![forge-seg](res/forge-seg.png)
> [!NOTE]
> You can override `CFG Scale` and `PAG Scale`/`SEG Scale` for Hires. fix by opening/enabling `Override for Hires. fix` tab.
> To disable PAG during Hires. fix, you can set `PAG Scale` under Override to 0.
### Inputs
- `scale`: Guidance scale, higher values can both increase structural coherence of an image and oversaturate/fry it entirely.
- `adaptive_scale` (PAG only): PAG dampening factor, it penalizes PAG during late denoising stages, resulting in overall speedup: 0.0 means no penalty and 1.0 completely removes PAG.
- `blur_sigma` (SEG only): Normal deviation of Gaussian blur, higher values increase "clarity" of an image. Negative values set `blur_sigma` to infinity.
- `unet_block`: Part of U-Net to which Guidance is applied, original paper suggests to use `middle`.
- `unet_block_id`: Id of U-Net layer in a selected block to which Guidance is applied. Guidance can be applied only to layers containing Self-attention blocks.
- `sigma_start` / `sigma_end`: Guidance will be active only between `sigma_start` and `sigma_end`. Set both values to negative to disable this feature.
- `rescale`: Acts similar to RescaleCFG node - it prevents over-exposure on high `scale` values. Based on Algorithm 2 from [Common Diffusion Noise Schedules and Sample Steps are Flawed (Lin et al.)](https://arxiv.org/abs/2305.08891). Set to 0 to disable this feature.
- `rescale_mode`:
- `full` - takes into account both CFG and Guidance.
- `partial` - depends only on Guidance.
- `snf` - Saliency-adaptive Noise Fusion from [High-fidelity Person-centric Subject-to-Image Synthesis (Wang et al.)](https://arxiv.org/abs/2311.10329). Should increase image quality on high guidance scales. Ignores `rescale` value.
- `unet_block_list`: Optional input, replaces both `unet_block` and `unet_block_id` and allows you to select multiple U-Net layers separated with commas. SDXL U-Net has multiple indices for layers, you can specify them by using dot symbol (if not specified, Guidance will be applied to the whole layer). Example value: `m0,u0.4` (it applies Guidance to middle block 0 and to output block 0 with index 4)
- In terms of U-Net `d` means `input`, `m` means `middle` and `u` means `output`.
- SD1.5 U-Net has layers `d0`-`d5`, `m0`, `u0`-`u8`.
- SDXL U-Net has layers `d0`-`d3`, `m0`, `u0`-`u5`. In addition, each block except `d0` and `d1` has `0-9` index values (like `m0.7` or `u0.4`). `d0` and `d1` have `0-1` index values.
- Supports block ranges (`d0-d3` corresponds to `d0,d1,d2,d3`) and index value ranges (`d2.2-9` corresponds to all index values of `d2` with the exclusion of `d2.0` and `d2.1`).
## ComfyUI TensorRT PAG (Experimental)
To use PAG together with [ComfyUI_TensorRT](https://github.com/comfyanonymous/ComfyUI_TensorRT), you'll need to:
0. Have 24GB of VRAM.
1. Build static/dynamic TRT engine of a desired model.
2. Build static/dynamic TRT engine of the same model with the same TRT parameters, but with fixed PAG injection in selected UNET blocks (`TensorRT Attach PAG` node).
3. Use `TensorRT Perturbed-Attention Guidance` node with two model inputs: one for base engine and one for PAG engine.
![trt-engines](res/trt-engines.png)
![trt-inference](res/trt-inference.png)

View File

@@ -0,0 +1,25 @@
from . import nag_nodes, tpg_nodes, pladis_nodes
from .pag_nodes import PerturbedAttention, SlidingWindowGuidanceAdvanced, SmoothedEnergyGuidanceAdvanced
from .pag_trt_nodes import TRTAttachPag, TRTPerturbedAttention
NODE_CLASS_MAPPINGS = {
"PerturbedAttention": PerturbedAttention,
"SmoothedEnergyGuidanceAdvanced": SmoothedEnergyGuidanceAdvanced,
"SlidingWindowGuidanceAdvanced": SlidingWindowGuidanceAdvanced,
"TRTAttachPag": TRTAttachPag,
"TRTPerturbedAttention": TRTPerturbedAttention,
**nag_nodes.NODE_CLASS_MAPPINGS,
**tpg_nodes.NODE_CLASS_MAPPINGS,
**pladis_nodes.NODE_CLASS_MAPPINGS,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PerturbedAttention": "Perturbed-Attention Guidance (Advanced)",
"SmoothedEnergyGuidanceAdvanced": "Smoothed Energy Guidance (Advanced)",
"SlidingWindowGuidanceAdvanced": "Sliding Window Guidance (Advanced)",
"TRTAttachPag": "TensorRT Attach PAG",
"TRTPerturbedAttention": "TensorRT Perturbed-Attention Guidance",
**nag_nodes.NODE_DISPLAY_NAME_MAPPINGS,
**tpg_nodes.NODE_DISPLAY_NAME_MAPPINGS,
**pladis_nodes.NODE_DISPLAY_NAME_MAPPINGS,
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

View File

@@ -0,0 +1,625 @@
{
"id": "319b510b-b5ec-46d6-8605-a6a5fd7d6c6c",
"revision": 0,
"last_node_id": 25,
"last_link_id": 54,
"nodes": [
{
"id": 3,
"type": "KSampler",
"pos": [
1100,
620
],
"size": [
210,
474
],
"flags": {},
"order": 6,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 41
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 35
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 36
},
{
"name": "latent_image",
"type": "LATENT",
"link": 2
}
],
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"slot_index": 0,
"links": [
7
]
}
],
"properties": {
"Node name for S&R": "KSampler"
},
"widgets_values": [
0,
"fixed",
25,
7,
"euler",
"sgm_uniform",
1
]
},
{
"id": 5,
"type": "EmptyLatentImage",
"pos": [
580,
810
],
"size": [
210,
106
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"slot_index": 0,
"links": [
2,
47
]
}
],
"properties": {
"Node name for S&R": "EmptyLatentImage"
},
"widgets_values": [
1024,
1024,
1
]
},
{
"id": 8,
"type": "VAEDecode",
"pos": [
1320,
620
],
"size": [
140,
46
],
"flags": {
"collapsed": false
},
"order": 8,
"mode": 0,
"inputs": [
{
"name": "samples",
"type": "LATENT",
"link": 7
},
{
"name": "vae",
"type": "VAE",
"link": 51
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"slot_index": 0,
"links": [
12
]
}
],
"properties": {
"Node name for S&R": "VAEDecode"
},
"widgets_values": []
},
{
"id": 13,
"type": "PreviewImage",
"pos": [
1320,
700
],
"size": [
440,
480
],
"flags": {},
"order": 10,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 12
}
],
"outputs": [],
"properties": {
"Node name for S&R": "PreviewImage"
},
"widgets_values": []
},
{
"id": 18,
"type": "NormalizedAttentionGuidance",
"pos": [
850,
620
],
"size": [
233.67147827148438,
198
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 53
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 40
}
],
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"links": [
41
]
}
],
"properties": {
"Node name for S&R": "NormalizedAttentionGuidance"
},
"widgets_values": [
4,
0.5,
1,
-1,
10.000000000000002,
""
]
},
{
"id": 19,
"type": "CLIPTextEncode",
"pos": [
400,
530
],
"size": [
390,
100
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 49
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
35,
45
]
}
],
"properties": {
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"elsa \\(frozen\\), portrait,"
]
},
{
"id": 20,
"type": "CLIPTextEncode",
"pos": [
400,
670
],
"size": [
390,
100
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 50
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
36,
40,
46
]
}
],
"properties": {
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"ugly, sketch, blurry, collage, blonde hair, blue eyes,"
]
},
{
"id": 22,
"type": "KSampler",
"pos": [
1100,
10
],
"size": [
210,
474
],
"flags": {},
"order": 5,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 54
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 45
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 46
},
{
"name": "latent_image",
"type": "LATENT",
"link": 47
}
],
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"slot_index": 0,
"links": [
43
]
}
],
"properties": {
"Node name for S&R": "KSampler"
},
"widgets_values": [
0,
"fixed",
25,
7,
"euler",
"sgm_uniform",
1
]
},
{
"id": 23,
"type": "PreviewImage",
"pos": [
1320,
90
],
"size": [
440,
480
],
"flags": {},
"order": 9,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 42
}
],
"outputs": [],
"properties": {
"Node name for S&R": "PreviewImage"
},
"widgets_values": []
},
{
"id": 24,
"type": "VAEDecode",
"pos": [
1320,
10
],
"size": [
140,
46
],
"flags": {
"collapsed": false
},
"order": 7,
"mode": 0,
"inputs": [
{
"name": "samples",
"type": "LATENT",
"link": 43
},
{
"name": "vae",
"type": "VAE",
"link": 52
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"slot_index": 0,
"links": [
42
]
}
],
"properties": {
"Node name for S&R": "VAEDecode"
},
"widgets_values": []
},
{
"id": 25,
"type": "CheckpointLoaderSimple",
"pos": [
400,
390
],
"size": [
390,
98
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"links": [
53,
54
]
},
{
"name": "CLIP",
"type": "CLIP",
"links": [
49,
50
]
},
{
"name": "VAE",
"type": "VAE",
"links": [
51,
52
]
}
],
"properties": {
"Node name for S&R": "CheckpointLoaderSimple"
},
"widgets_values": [
"sdxl\\base\\sd_xl_base_1.0.safetensors"
]
}
],
"links": [
[
2,
5,
0,
3,
3,
"LATENT"
],
[
7,
3,
0,
8,
0,
"LATENT"
],
[
12,
8,
0,
13,
0,
"IMAGE"
],
[
35,
19,
0,
3,
1,
"CONDITIONING"
],
[
36,
20,
0,
3,
2,
"CONDITIONING"
],
[
40,
20,
0,
18,
1,
"CONDITIONING"
],
[
41,
18,
0,
3,
0,
"MODEL"
],
[
42,
24,
0,
23,
0,
"IMAGE"
],
[
43,
22,
0,
24,
0,
"LATENT"
],
[
45,
19,
0,
22,
1,
"CONDITIONING"
],
[
46,
20,
0,
22,
2,
"CONDITIONING"
],
[
47,
5,
0,
22,
3,
"LATENT"
],
[
49,
25,
1,
19,
0,
"CLIP"
],
[
50,
25,
1,
20,
0,
"CLIP"
],
[
51,
25,
2,
8,
1,
"VAE"
],
[
52,
25,
2,
24,
1,
"VAE"
],
[
53,
25,
0,
18,
0,
"MODEL"
],
[
54,
25,
0,
22,
0,
"MODEL"
]
],
"groups": [],
"config": {},
"extra": {
"frontendVersion": "1.23.0"
},
"version": 0.4
}

View File

@@ -0,0 +1,241 @@
import math
from itertools import groupby
from typing import Any, Callable, Literal
import torch
import torch.nn.functional as F
def parse_unet_blocks(model, unet_block_list: str, attn: Literal["attn1", "attn2"] | None):
output: list[tuple[str, int, int | None]] = []
names: list[str] = []
# Get all Self-attention blocks
input_blocks: list[tuple[int, str]] = []
middle_blocks: list[tuple[int, str]] = []
output_blocks: list[tuple[int, str]] = []
for name, module in model.model.diffusion_model.named_modules():
if module.__class__.__name__ == "BasicTransformerBlock" and (attn is None or hasattr(module, attn)):
parts = name.split(".")
unet_part = parts[0]
block_id = int(parts[1])
if unet_part.startswith("input"):
input_blocks.append((block_id, name))
elif unet_part.startswith("middle"):
middle_blocks.append((block_id - 1, name))
elif unet_part.startswith("output"):
output_blocks.append((block_id, name))
def group_blocks(blocks: list[tuple[int, str]]):
grouped_blocks = [(i, list(gr)) for i, gr in groupby(blocks, lambda b: b[0])]
return [(i, len(gr), list(idx[1] for idx in gr)) for i, gr in grouped_blocks]
input_blocks_gr, middle_blocks_gr, output_blocks_gr = (
group_blocks(input_blocks),
group_blocks(middle_blocks),
group_blocks(output_blocks),
)
user_inputs = [b.strip() for b in unet_block_list.split(",")]
for user_input in user_inputs:
unet_part_s, indices = user_input[0], user_input[1:].split(".")
match unet_part_s:
case "d":
unet_part, unet_group = "input", input_blocks_gr
case "m":
unet_part, unet_group = "middle", middle_blocks_gr
case "u":
unet_part, unet_group = "output", output_blocks_gr
case _:
raise ValueError(f"Block {user_input}: Unknown block prefix {unet_part_s}")
block_index_range = [int(b.strip()) for b in indices[0].split("-")]
block_index_range_start = block_index_range[0]
block_index_range_end = block_index_range[0] if len(block_index_range) != 2 else block_index_range[1]
for block_index in range(block_index_range_start, block_index_range_end + 1):
if block_index < 0 or block_index >= len(unet_group):
raise ValueError(
f"Block {user_input}: Block index in out of range 0 <= {block_index} < {len(unet_group)}"
)
block_group = unet_group[block_index]
block_index_real = block_group[0]
if len(indices) == 1:
output.append((unet_part, block_index_real, None))
names.extend(block_group[2])
else:
transformer_index_range = [int(b.strip()) for b in indices[1].split("-")]
transformer_index_range_start = transformer_index_range[0]
transformer_index_range_end = (
transformer_index_range[0] if len(transformer_index_range) != 2 else transformer_index_range[1]
)
for transformer_index in range(transformer_index_range_start, transformer_index_range_end + 1):
if transformer_index is not None and (transformer_index < 0 or transformer_index >= block_group[1]):
raise ValueError(
f"Block {user_input}: Transformer index in out of range 0 <= {transformer_index} < {block_group[1]}"
)
output.append((unet_part, block_index_real, transformer_index))
names.append(block_group[2][transformer_index])
return output, names
# Copied from https://github.com/comfyanonymous/ComfyUI/blob/719fb2c81d716ce8edd7f1bdc7804ae160a71d3a/comfy/model_patcher.py#L21 for backward compatibility
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
if "patches_replace" not in to:
to["patches_replace"] = {}
else:
to["patches_replace"] = to["patches_replace"].copy()
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
else:
to["patches_replace"][name] = to["patches_replace"][name].copy()
if transformer_index is not None:
block = (block_name, number, transformer_index)
else:
block = (block_name, number)
to["patches_replace"][name][block] = patch
model_options["transformer_options"] = to
return model_options
def set_model_options_value(model_options, key: str, value: Any):
to = model_options["transformer_options"].copy()
to[key] = value
model_options["transformer_options"] = to
return model_options
def perturbed_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, extra_options, mask=None):
"""Perturbed self-attention"""
return v
# Modified 'Algorithm 2 Classifier-Free Guidance with Rescale' from Common Diffusion Noise Schedules and Sample Steps are Flawed (Lin et al.).
def rescale_guidance(
guidance: torch.Tensor, cond_pred: torch.Tensor, cfg_result: torch.Tensor, rescale=0.0, rescale_mode="full"
):
if rescale == 0.0:
return guidance
match rescale_mode:
case "full":
guidance_result = cfg_result + guidance
case _:
guidance_result = cond_pred + guidance
std_cond = torch.std(cond_pred, dim=(1, 2, 3), keepdim=True)
std_guidance = torch.std(guidance_result, dim=(1, 2, 3), keepdim=True)
factor = std_cond / std_guidance
factor = rescale * factor + (1.0 - rescale)
return guidance * factor
# Gaussian blur
def gaussian_blur_2d(img, kernel_size, sigma):
height = img.shape[-1]
kernel_size = min(kernel_size, height - (height % 2 - 1))
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
x_kernel = pdf / pdf.sum()
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
img = F.pad(img, padding, mode="reflect")
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
return img
def seg_attention_wrapper(attention, blur_sigma=1.0):
def seg_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, extra_options, mask=None):
"""Smoothed Energy Guidance self-attention"""
heads = extra_options["n_heads"]
bs, area, inner_dim = q.shape
height_orig, width_orig = extra_options["original_shape"][2:4]
aspect_ratio = width_orig / height_orig
if aspect_ratio >= 1.0:
height = round((area / aspect_ratio) ** 0.5)
q = q.permute(0, 2, 1).reshape(bs, inner_dim, height, -1)
else:
width = round((area * aspect_ratio) ** 0.5)
q = q.permute(0, 2, 1).reshape(bs, inner_dim, -1, width)
if blur_sigma >= 0:
kernel_size = math.ceil(6 * blur_sigma) + 1 - math.ceil(6 * blur_sigma) % 2
q = gaussian_blur_2d(q, kernel_size, blur_sigma)
else:
q[:] = q.mean(dim=(-2, -1), keepdim=True)
q = q.reshape(bs, inner_dim, -1).permute(0, 2, 1)
return attention(q, k, v, heads=heads)
return seg_attention
# Modified algorithm from 2411.10257 'The Unreasonable Effectiveness of Guidance for Diffusion Models' (Figure 6.)
def swg_pred_calc(
x: torch.Tensor, tile_width: int, tile_height: int, tile_overlap: int, calc_func: Callable[..., tuple[torch.Tensor]]
):
b, c, h, w = x.shape
swg_pred = torch.zeros_like(x)
overlap = torch.zeros_like(x)
tiles_w = math.ceil(w / (tile_width - tile_overlap))
tiles_h = math.ceil(h / (tile_height - tile_overlap))
for w_i in range(tiles_w):
for h_i in range(tiles_h):
left, right = tile_width * w_i, tile_width * (w_i + 1) + tile_overlap
top, bottom = tile_height * h_i, tile_height * (h_i + 1) + tile_overlap
x_window = x[:, :, top:bottom, left:right]
if x_window.shape[-1] == 0 or x_window.shape[-2] == 0:
continue
swg_pred_window = calc_func(x_in=x_window)[0]
swg_pred[:, :, top:bottom, left:right] += swg_pred_window
overlap_window = torch.ones_like(swg_pred_window)
overlap[:, :, top:bottom, left:right] += overlap_window
swg_pred = swg_pred / overlap
return swg_pred
# Saliency-adaptive Noise Fusion based on High-fidelity Person-centric Subject-to-Image Synthesis (Wang et al.)
# https://github.com/CodeGoat24/Face-diffuser/blob/edff1a5178ac9984879d9f5e542c1d0f0059ca5f/facediffuser/pipeline.py#L535-L562
def snf_guidance(t_guidance: torch.Tensor, s_guidance: torch.Tensor):
b, c, h, w = t_guidance.shape
t_omega = gaussian_blur_2d(torch.abs(t_guidance), 3, 1)
s_omega = gaussian_blur_2d(torch.abs(s_guidance), 3, 1)
t_softmax = torch.softmax(t_omega.reshape(b * c, h * w), dim=1).reshape(b, c, h, w)
s_softmax = torch.softmax(s_omega.reshape(b * c, h * w), dim=1).reshape(b, c, h, w)
guidance_stacked = torch.stack([t_guidance, s_guidance], dim=0)
ts_softmax = torch.stack([t_softmax, s_softmax], dim=0)
argeps = torch.argmax(ts_softmax, dim=0, keepdim=True)
snf = torch.gather(guidance_stacked, dim=0, index=argeps).squeeze(0)
return snf

View File

@@ -0,0 +1,235 @@
from contextlib import suppress
from typing import Callable
import torch
import comfy.model_management
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
from comfy.ldm.modules.attention import BasicTransformerBlock, CrossAttention, optimized_attention
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from .guidance_utils import parse_unet_blocks
COND = 0
UNCOND = 1
def nag_attn2_replace_wrapper(
nag_scale: float,
tau: float,
alpha: float,
sigma_start: float,
sigma_end: float,
k_neg: torch.Tensor,
v_neg: torch.Tensor,
prev_attn2_replace: Callable | None = None,
):
# Modified Algorithm 1 from 2505.21179 'Normalized Attention Guidance: Universal Negative Guidance for Diffusion Models'
def nag_attn2_replace(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, extra_options: dict):
heads = extra_options["n_heads"]
attn_precision = extra_options.get("attn_precision")
sigma = extra_options["sigmas"]
cond_or_uncond: list[int] = extra_options.get("cond_or_uncond") # type: ignore
# Perform batched CA
z = (
optimized_attention(q, k, v, heads, attn_precision)
if prev_attn2_replace is None
else prev_attn2_replace(q, k, v, extra_options)
)
if nag_scale == 0 or not (sigma_end < sigma[0] <= sigma_start) or COND not in cond_or_uncond:
return z
bs = q.shape[0] // len(cond_or_uncond) * cond_or_uncond.count(COND)
k_neg_, v_neg_ = k_neg.repeat_interleave(bs, dim=0), v_neg.repeat_interleave(bs, dim=0)
# Get conditional queries for NAG
# Assume that cond_or_uncond has a layout [1, 1..., 0, 0...]
q_chunked = q.chunk(len(cond_or_uncond))
q_pos = torch.cat(q_chunked[cond_or_uncond.index(COND) :])
# Apply NAG only to conditional parts of batched CA
z_chunked = z.chunk(len(cond_or_uncond))
z_pos = torch.cat(z_chunked[cond_or_uncond.index(COND) :])
z_neg = optimized_attention(q_pos, k_neg_, v_neg_, heads, attn_precision)
z_tilde = z_pos + nag_scale * (z_pos - z_neg)
norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True)
norm_tilde = torch.norm(z_tilde, p=1, dim=-1, keepdim=True)
ratio = norm_tilde / norm_pos
z_hat = torch.where(ratio > tau, tau, ratio) / ratio * z_tilde
z_nag = alpha * z_hat + (1 - alpha) * z_pos
# Prepend unconditional CA result to NAG result
if UNCOND in cond_or_uncond:
z_nag = torch.cat(z_chunked[cond_or_uncond.index(UNCOND) : cond_or_uncond.index(COND)] + (z_nag,))
return z_nag
return nag_attn2_replace
class NormalizedAttentionGuidance(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"model": (
IO.MODEL,
{
"tooltip": (
"The diffusion model.\n"
"If you are using any other attn2 replacer (such as `IPAdapter`), you should place this node after it."
)
},
),
"negative": (
IO.CONDITIONING,
{"tooltip": "Negative conditioning: either the one you use for CFG or a completely different one."},
),
"scale": (
IO.FLOAT,
{
"default": 2.0,
"min": 0.0,
"max": 100.0,
"step": 0.1,
"round": 0.01,
"tooltip": "Scale of NAG, does nothing when `tau=0`.",
},
),
"tau": (
IO.FLOAT,
{
"default": 2.5,
"min": 0.0,
"max": 100.0,
"step": 0.1,
"round": 0.01,
"tooltip": "Normalization threshold, larger value should increase `scale` impact.",
},
),
"alpha": (
IO.FLOAT,
{
"default": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.001,
"round": 0.001,
"tooltip": "Linear interpolation between original (at `alpha=0`) and NAG (at `alpha=1`) results.",
},
),
"sigma_start": (IO.FLOAT, {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
"sigma_end": (IO.FLOAT, {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
},
"optional": {
"unet_block_list": (
IO.STRING,
{
"default": "",
"tooltip": (
"Comma-separated blocks to which NAG is being applied to. When the list is empty, NAG is being applied to all block.\n"
"Read README from sd-perturbed-attention for more details."
),
},
),
},
}
RETURN_TYPES = (IO.MODEL,)
FUNCTION = "patch"
DESCRIPTION = (
"An additional way to apply negative prompts to the image.\n"
"It's compatible with CFG, PAG, and other guidances, and can be used with guidance- and step-distilled models as well.\n"
"It's also compatible with other attn2 replacers (such as `IPAdapter`) - but make sure to place NAG node **after** other model patches!"
)
CATEGORY = "model_patches/unet"
def patch(
self,
model: ModelPatcher,
negative,
scale=2.0,
tau=2.5,
alpha=0.5,
sigma_start: float = -1.0,
sigma_end: float = -1.0,
unet_block_list="",
):
m = model.clone()
inner_model: BaseModel = m.model
dtype = inner_model.get_dtype()
if inner_model.manual_cast_dtype is not None:
dtype = inner_model.manual_cast_dtype
device_model = inner_model.device
device_infer = comfy.model_management.get_torch_device()
sigma_start = float("inf") if sigma_start < 0 else sigma_start
negative_cond = negative[0][0].to(device_model, dtype=dtype)
blocks, block_names = parse_unet_blocks(m, unet_block_list, "attn2") if unet_block_list else (None, None)
# Apply NAG only to transformer blocks with cross-attention (attn2)
for name, module in (
(n, m)
for n, m in inner_model.diffusion_model.named_modules()
if isinstance(m, BasicTransformerBlock) and getattr(m, "attn2", None)
):
attn2: CrossAttention = module.attn2 # type: ignore
parts: list[str] = name.split(".")
block_name: str = parts[0].split("_")[0]
block_id = int(parts[1])
if block_name == "middle":
block_id = block_id - 1
t_idx = None
if "transformer_blocks" in parts:
t_pos = parts.index("transformer_blocks") + 1
t_idx = int(parts[t_pos])
if not blocks or (block_name, block_id, t_idx) in blocks or (block_name, block_id, None) in blocks:
k_neg, v_neg = attn2.to_k(negative_cond), attn2.to_v(negative_cond)
# Compatibility with other attn2 replaces (such as IPAdapter)
prev_attn2_replace = None
with suppress(KeyError):
block = (block_name, block_id, t_idx)
block_full = (block_name, block_id)
attn2_patches = m.model_options["transformer_options"]["patches_replace"]["attn2"]
if block_full in attn2_patches:
prev_attn2_replace = attn2_patches[block_full]
elif block in attn2_patches:
prev_attn2_replace = attn2_patches[block]
nag_attn2_replace = nag_attn2_replace_wrapper(
scale,
tau,
alpha,
sigma_start,
sigma_end,
k_neg.to(device_infer, dtype=dtype),
v_neg.to(device_infer, dtype=dtype),
prev_attn2_replace,
)
m.set_model_attn2_replace(nag_attn2_replace, block_name, block_id, t_idx)
return (m,)
NODE_CLASS_MAPPINGS = {
"NormalizedAttentionGuidance": NormalizedAttentionGuidance,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"NormalizedAttentionGuidance": "Normalized Attention Guidance",
}

View File

@@ -0,0 +1,315 @@
from functools import partial
BACKEND = None
try:
from comfy.ldm.modules.attention import optimized_attention
from comfy.model_patcher import ModelPatcher
from comfy.samplers import calc_cond_batch
from .guidance_utils import (
parse_unet_blocks,
perturbed_attention,
rescale_guidance,
seg_attention_wrapper,
snf_guidance,
swg_pred_calc,
)
try:
from comfy.model_patcher import set_model_options_patch_replace
except ImportError:
from .guidance_utils import set_model_options_patch_replace
BACKEND = "ComfyUI"
except ImportError:
from guidance_utils import (
parse_unet_blocks,
perturbed_attention,
rescale_guidance,
seg_attention_wrapper,
set_model_options_patch_replace,
snf_guidance,
swg_pred_calc,
)
try:
from ldm_patched.ldm.modules.attention import optimized_attention
from ldm_patched.modules.model_patcher import ModelPatcher
from ldm_patched.modules.samplers import calc_cond_uncond_batch
BACKEND = "reForge"
except ImportError:
from backend.attention import attention_function as optimized_attention
from backend.patcher.base import ModelPatcher
from backend.sampling.sampling_function import calc_cond_uncond_batch
BACKEND = "Forge"
class PerturbedAttention:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
"adaptive_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "round": 0.0001}),
"unet_block": (["input", "middle", "output"], {"default": "middle"}),
"unet_block_id": ("INT", {"default": 0}),
"sigma_start": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
"sigma_end": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
"rescale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"rescale_mode": (["full", "partial", "snf"], {"default": "full"}),
},
"optional": {
"unet_block_list": ("STRING", {"default": ""}),
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "model_patches/unet"
def patch(
self,
model: ModelPatcher,
scale: float = 3.0,
adaptive_scale: float = 0.0,
unet_block: str = "middle",
unet_block_id: int = 0,
sigma_start: float = -1.0,
sigma_end: float = -1.0,
rescale: float = 0.0,
rescale_mode: str = "full",
unet_block_list: str = "",
):
m = model.clone()
sigma_start = float("inf") if sigma_start < 0 else sigma_start
single_block = (unet_block, unet_block_id, None)
blocks, block_names = (
parse_unet_blocks(model, unet_block_list, "attn1") if unet_block_list else ([single_block], None)
)
def post_cfg_function(args):
"""CFG+PAG"""
model = args["model"]
cond_pred = args["cond_denoised"]
uncond_pred = args["uncond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
model_options = args["model_options"].copy()
x = args["input"]
signal_scale = scale
if adaptive_scale > 0:
t = 0
if hasattr(model, "model_sampling"):
t = model.model_sampling.timestep(sigma)[0].item()
else:
ts = model.predictor.timestep(sigma)
t = ts[0].item()
signal_scale -= scale * (adaptive_scale**4) * (1000 - t)
if signal_scale < 0:
signal_scale = 0
if signal_scale == 0 or not (sigma_end < sigma[0] <= sigma_start):
return cfg_result
# Replace Self-attention with PAG
for block in blocks:
layer, number, index = block
model_options = set_model_options_patch_replace(
model_options, perturbed_attention, "attn1", layer, number, index
)
if BACKEND == "ComfyUI":
(pag_cond_pred,) = calc_cond_batch(model, [cond], x, sigma, model_options)
if BACKEND in {"Forge", "reForge"}:
(pag_cond_pred, _) = calc_cond_uncond_batch(model, cond, None, x, sigma, model_options)
pag = (cond_pred - pag_cond_pred) * signal_scale
if rescale_mode == "snf":
if uncond_pred.any():
return uncond_pred + snf_guidance(cfg_result - uncond_pred, pag)
return cfg_result + pag
return cfg_result + rescale_guidance(pag, cond_pred, cfg_result, rescale, rescale_mode)
m.set_model_sampler_post_cfg_function(post_cfg_function, rescale_mode == "snf")
return (m,)
class SmoothedEnergyGuidanceAdvanced:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
"blur_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 9999.0, "step": 0.01, "round": 0.001}),
"unet_block": (["input", "middle", "output"], {"default": "middle"}),
"unet_block_id": ("INT", {"default": 0}),
"sigma_start": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
"sigma_end": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
"rescale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"rescale_mode": (["full", "partial", "snf"], {"default": "full"}),
},
"optional": {
"unet_block_list": ("STRING", {"default": ""}),
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "model_patches/unet"
def patch(
self,
model: ModelPatcher,
scale: float = 3.0,
blur_sigma: float = -1.0,
unet_block: str = "middle",
unet_block_id: int = 0,
sigma_start: float = -1.0,
sigma_end: float = -1.0,
rescale: float = 0.0,
rescale_mode: str = "full",
unet_block_list: str = "",
):
m = model.clone()
sigma_start = float("inf") if sigma_start < 0 else sigma_start
single_block = (unet_block, unet_block_id, None)
blocks, block_names = (
parse_unet_blocks(model, unet_block_list, "attn1") if unet_block_list else ([single_block], None)
)
def post_cfg_function(args):
"""CFG+SEG"""
model = args["model"]
cond_pred = args["cond_denoised"]
uncond_pred = args["uncond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
model_options = args["model_options"].copy()
x = args["input"]
signal_scale = scale
if signal_scale == 0 or not (sigma_end < sigma[0] <= sigma_start):
return cfg_result
seg_attention = seg_attention_wrapper(optimized_attention, blur_sigma)
# Replace Self-attention with SEG attention
for block in blocks:
layer, number, index = block
model_options = set_model_options_patch_replace(
model_options, seg_attention, "attn1", layer, number, index
)
if BACKEND == "ComfyUI":
(seg_cond_pred,) = calc_cond_batch(model, [cond], x, sigma, model_options)
if BACKEND in {"Forge", "reForge"}:
(seg_cond_pred, _) = calc_cond_uncond_batch(model, cond, None, x, sigma, model_options)
seg = (cond_pred - seg_cond_pred) * signal_scale
if rescale_mode == "snf":
if uncond_pred.any():
return uncond_pred + snf_guidance(cfg_result - uncond_pred, seg)
return cfg_result + seg
return cfg_result + rescale_guidance(seg, cond_pred, cfg_result, rescale, rescale_mode)
m.set_model_sampler_post_cfg_function(post_cfg_function, rescale_mode == "snf")
return (m,)
class SlidingWindowGuidanceAdvanced:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scale": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
"tile_width": ("INT", {"default": 768, "min": 16, "max": 16384, "step": 8}),
"tile_height": ("INT", {"default": 768, "min": 16, "max": 16384, "step": 8}),
"tile_overlap": ("INT", {"default": 256, "min": 16, "max": 16384, "step": 8}),
"sigma_start": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
"sigma_end": ("FLOAT", {"default": 5.42, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "model_patches/unet"
def patch(
self,
model: ModelPatcher,
scale: float = 5.0,
tile_width: int = 768,
tile_height: int = 768,
tile_overlap: int = 256,
sigma_start: float = -1.0,
sigma_end: float = 5.42,
):
m = model.clone()
sigma_start = float("inf") if sigma_start < 0 else sigma_start
tile_width, tile_height, tile_overlap = tile_width // 8, tile_height // 8, tile_overlap // 8
def post_cfg_function(args):
"""CFG+SWG"""
model = args["model"]
cond_pred = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
model_options = args["model_options"].copy()
x = args["input"]
signal_scale = scale
if signal_scale == 0 or not (sigma_end < sigma[0] <= sigma_start):
return cfg_result
calc_func = None
if BACKEND == "ComfyUI":
calc_func = partial(
calc_cond_batch,
model=model,
conds=[cond],
timestep=sigma,
model_options=model_options,
)
if BACKEND in {"Forge", "reForge"}:
calc_func = partial(
calc_cond_uncond_batch,
model=model,
cond=cond,
uncond=None,
timestep=sigma,
model_options=model_options,
)
swg_pred = swg_pred_calc(x, tile_width, tile_height, tile_overlap, calc_func)
swg = (cond_pred - swg_pred) * signal_scale
return cfg_result + swg
m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m,)

View File

@@ -0,0 +1,111 @@
from comfy.model_patcher import ModelPatcher
from comfy.samplers import calc_cond_batch
from .guidance_utils import parse_unet_blocks, perturbed_attention, rescale_guidance
class TRTAttachPag:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"unet_block": (["input", "middle", "output"], {"default": "middle"}),
"unet_block_id": ("INT", {"default": 0}),
},
"optional": {
"unet_block_list": ("STRING", {"default": ""}),
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "attach"
CATEGORY = "TensorRT"
def attach(
self,
model: ModelPatcher,
unet_block: str = "middle",
unet_block_id: int = 0,
unet_block_list: str = "",
):
m = model.clone()
single_block = (unet_block, unet_block_id, None)
blocks, block_names = (
parse_unet_blocks(model, unet_block_list, "attn1") if unet_block_list else ([single_block], None)
)
# Replace Self-attention with PAG
for block in blocks:
layer, number, index = block
m.set_model_attn1_replace(perturbed_attention, layer, number, index)
return (m,)
class TRTPerturbedAttention:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_base": ("MODEL",),
"model_pag": ("MODEL",),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
"adaptive_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "round": 0.0001}),
"sigma_start": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
"sigma_end": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
"rescale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"rescale_mode": (["full", "partial"], {"default": "full"}),
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "TensorRT"
def patch(
self,
model_base: ModelPatcher,
model_pag: ModelPatcher,
scale: float = 3.0,
adaptive_scale: float = 0.0,
sigma_start: float = -1.0,
sigma_end: float = -1.0,
rescale: float = 0.0,
rescale_mode: str = "full",
):
m = model_base.clone()
sigma_start = float("inf") if sigma_start < 0 else sigma_start
def post_cfg_function(args):
"""CFG+PAG"""
model = args["model"]
cond_pred = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
x = args["input"]
signal_scale = scale
if adaptive_scale > 0:
t = model.model_sampling.timestep(sigma)[0].item()
signal_scale -= scale * (adaptive_scale**4) * (1000 - t)
if signal_scale < 0:
signal_scale = 0
if signal_scale == 0 or not (sigma_end < sigma[0] <= sigma_start):
return cfg_result
(pag_cond_pred,) = calc_cond_batch(model_pag.model, [cond], x, sigma, model_pag.model_options)
pag = (cond_pred - pag_cond_pred) * signal_scale
return cfg_result + rescale_guidance(pag, cond_pred, cfg_result, rescale, rescale_mode)
m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m,)

View File

@@ -0,0 +1,83 @@
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
from comfy.ldm.modules.attention import BasicTransformerBlock
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from .guidance_utils import parse_unet_blocks
from .pladis_utils import SPARSE_FUNCTIONS, pladis_attention_wrapper
class Pladis(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"model": (IO.MODEL, {}),
"scale": (IO.FLOAT, {"default": 2.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
"sparse_func": (IO.COMBO, {"default": SPARSE_FUNCTIONS[0], "options": SPARSE_FUNCTIONS}),
},
"optional": {
"unet_block_list": (
IO.STRING,
{
"default": "",
"tooltip": (
"Comma-separated blocks to which Pladis is being applied to. When the list is empty, PLADIS is being applied to all `u` and `d` blocks.\n"
"Read README from sd-perturbed-attention for more details."
),
},
),
},
}
RETURN_TYPES = (IO.MODEL,)
FUNCTION = "patch"
CATEGORY = "model_patches/unet"
EXPERIMENTAL = True
def patch(
self,
model: ModelPatcher,
scale=2.0,
sparse_func=SPARSE_FUNCTIONS[0],
unet_block_list="",
):
m = model.clone()
inner_model: BaseModel = m.model
pladis_attention = pladis_attention_wrapper(scale, sparse_func)
blocks, block_names = parse_unet_blocks(m, unet_block_list, "attn2") if unet_block_list else (None, None)
# Apply PLADIS only to transformer blocks with cross-attention (attn2)
for name, module in (
(n, m)
for n, m in inner_model.diffusion_model.named_modules()
if isinstance(m, BasicTransformerBlock) and getattr(m, "attn2", None)
):
parts = name.split(".")
block_name: str = parts[0].split("_")[0]
block_id = int(parts[1])
if block_name == "middle":
block_id = block_id - 1
if not blocks:
continue
t_idx = None
if "transformer_blocks" in parts:
t_pos = parts.index("transformer_blocks") + 1
t_idx = int(parts[t_pos])
if not blocks or (block_name, block_id, t_idx) in blocks or (block_name, block_id, None) in blocks:
m.set_model_attn2_replace(pladis_attention, block_name, block_id, t_idx)
return (m,)
NODE_CLASS_MAPPINGS = {
"PLADIS": Pladis,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PLADIS": "PLADIS",
}

View File

@@ -0,0 +1,166 @@
from typing import Optional
import torch
ENTMAX15_FUNC = "entmax1.5" # sparse attention with alpha=1.5
SPARSEMAX_FUNC = "sparsemax" # sparse attention with alpha=2
SPARSE_FUNCTIONS: list = [ENTMAX15_FUNC, SPARSEMAX_FUNC]
def pladis_attention_wrapper(pladis_scale=2.0, sparse_func=SPARSE_FUNCTIONS[0]):
# Simplified attention_basic with sparse functions instead of a softmax
def _pladis_sparse_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
extra_options: dict,
):
heads = extra_options["n_heads"]
attn_precision = extra_options.get("attn_precision")
b, _, dim_head = q.shape
dim_head //= heads
scale: int = dim_head**-0.5
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
sim = q @ k.transpose(-2, -1) * scale
del q, k
dense_sim = torch.softmax(sim, dim=-1)
if sparse_func == ENTMAX15_FUNC:
sparse_sim = Entmax.entmax15(sim, dim=-1)
elif sparse_func == SPARSEMAX_FUNC:
sparse_sim = Entmax.sparsemax(sim, dim=-1)
else: # fallback to the default from paper
sparse_sim = Entmax.entmax15(sim, dim=-1)
pladis_sim = pladis_scale * sparse_sim + (1 - pladis_scale) * dense_sim
out = pladis_sim.to(v.dtype) @ v
out = out.unsqueeze(0).reshape(b, heads, -1, dim_head).permute(0, 2, 1, 3).reshape(b, -1, heads * dim_head)
return out
return _pladis_sparse_attention
class Entmax:
"""
Activations from `entmax` module converted to a static class.
Both sparsemax and entmax15, and all their inner function implementations
are taken from https://github.com/deep-spin/entmax/blob/c2bec6d5e7d649cba7766c2172d89123ec2a6d70/entmax/activations.py
(as recommended by PLADIS paper).
Author: Ben Peters
Author: Vlad Niculae <vlad@vene.ro>
License: MIT
"""
@staticmethod
def entmax15(X: torch.Tensor, dim=-1, k: Optional[int] = None):
max_val, _ = X.max(dim=dim, keepdim=True)
X = X - max_val # same numerical stability trick as for softmax
X = X / 2 # divide by 2 to solve actual Entmax
tau_star, _ = Entmax._entmax_threshold_and_support(X, dim=dim, k=k)
Y = torch.clamp(X - tau_star, min=0) ** 2
return Y
@staticmethod
def sparsemax(X: torch.Tensor, dim=-1, k: Optional[int] = None):
max_val, _ = X.max(dim=dim, keepdim=True)
X = X - max_val # same numerical stability trick as softmax
tau, _ = Entmax._sparsemax_threshold_and_support(X, dim=dim, k=k)
output = torch.clamp(X - tau, min=0)
return output
@staticmethod
def _entmax_threshold_and_support(X, dim=-1, k=None):
if k is None or k >= X.shape[dim]: # do full sort
Xsrt, _ = torch.sort(X, dim=dim, descending=True)
else:
Xsrt, _ = torch.topk(X, k=k, dim=dim)
rho = Entmax._make_ix_like(Xsrt, dim)
mean = Xsrt.cumsum(dim) / rho
mean_sq = (Xsrt**2).cumsum(dim) / rho
ss = rho * (mean_sq - mean**2)
delta = (1 - ss) / rho
delta_nz = torch.clamp(delta, 0)
tau = mean - torch.sqrt(delta_nz)
support_size = (tau <= Xsrt).sum(dim).unsqueeze(dim)
tau_star = tau.gather(dim, support_size - 1)
if k is not None and k < X.shape[dim]:
unsolved = (support_size == k).squeeze(dim)
if torch.any(unsolved):
X_ = Entmax._roll_last(X, dim)[unsolved]
tau_, ss_ = Entmax._entmax_threshold_and_support(X_, dim=-1, k=2 * k)
Entmax._roll_last(tau_star, dim)[unsolved] = tau_
Entmax._roll_last(support_size, dim)[unsolved] = ss_
return tau_star, support_size
@staticmethod
def _sparsemax_threshold_and_support(X: torch.Tensor, dim=-1, k=None):
if k is None or k >= X.shape[dim]: # do full sort
topk, _ = torch.sort(X, dim=dim, descending=True)
else:
topk, _ = torch.topk(X, k=k, dim=dim)
topk_cumsum = topk.cumsum(dim) - 1
rhos = Entmax._make_ix_like(topk, dim)
support = rhos * topk > topk_cumsum
support_size = support.sum(dim=dim).unsqueeze(dim)
tau = topk_cumsum.gather(dim, support_size - 1)
tau /= support_size.to(X.dtype)
if k is not None and k < X.shape[dim]:
unsolved = (support_size == k).squeeze(dim)
if torch.any(unsolved):
in_ = Entmax._roll_last(X, dim)[unsolved]
tau_, ss_ = Entmax._sparsemax_threshold_and_support(in_, dim=-1, k=2 * k)
Entmax._roll_last(tau, dim)[unsolved] = tau_
Entmax._roll_last(support_size, dim)[unsolved] = ss_
return tau, support_size
@staticmethod
def _make_ix_like(X: torch.Tensor, dim=-1):
d = X.size(dim)
rho = torch.arange(1, d + 1, device=X.device, dtype=X.dtype)
view = [1] * X.dim()
view[0] = -1
return rho.view(view).transpose(0, dim)
@staticmethod
def _roll_last(X: torch.Tensor, dim=-1):
if dim == -1:
return X
elif dim < 0:
dim = X.dim() - dim
perm = [i for i in range(X.dim()) if i != dim] + [dim]
return X.permute(perm)

View File

@@ -0,0 +1,14 @@
[project]
name = "sd-perturbed-attention"
description = "Perturbed-Attention Guidance (PAG), Smoothed Energy Guidance (SEG), Sliding Window Guidance (SWG), PLADIS, Normalized Attention Guidance (NAG), Token Perturbation Guidance (TPG) for ComfyUI and SD reForge."
version = "1.2.15"
license = { text = "MIT License" }
[project.urls]
Repository = "https://github.com/pamparamm/sd-perturbed-attention"
# Used by Comfy Registry https://comfyregistry.org
[tool.comfy]
PublisherId = "pamparamm"
DisplayName = "sd-perturbed-attention"
Icon = ""

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 162 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

View File

@@ -0,0 +1,127 @@
from typing import Any
import torch
from torch import nn
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
from comfy.ldm.modules.attention import BasicTransformerBlock
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from comfy.samplers import calc_cond_batch
from .guidance_utils import parse_unet_blocks, rescale_guidance, set_model_options_value, snf_guidance
TPG_OPTION = "tpg"
# Implementation of 2506.10036 'Token Perturbation Guidance for Diffusion Models'
class TPGTransformerWrapper(nn.Module):
def __init__(self, transformer_block: BasicTransformerBlock) -> None:
super().__init__()
self.wrapped_block = transformer_block
def shuffle_tokens(self, x: torch.Tensor):
# ComfyUI's torch.manual_seed generator should produce the same results here.
permutation = torch.randperm(x.shape[1], device=x.device)
return x[:, permutation]
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None, transformer_options: dict[str, Any] = {}):
is_tpg = transformer_options.get(TPG_OPTION, False)
x_ = self.shuffle_tokens(x) if is_tpg else x
return self.wrapped_block(x_, context=context, transformer_options=transformer_options)
class TokenPerturbationGuidance(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"model": (IO.MODEL, {}),
"scale": (IO.FLOAT, {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
"sigma_start": (IO.FLOAT, {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
"sigma_end": (IO.FLOAT, {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
"rescale": (IO.FLOAT, {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"rescale_mode": (IO.COMBO, {"options": ["full", "partial", "snf"], "default": "full"}),
},
"optional": {
"unet_block_list": (IO.STRING, {"default": "d2.2-9,d3", "tooltip": "Blocks to which TPG is applied. "}),
},
}
RETURN_TYPES = (IO.MODEL,)
FUNCTION = "patch"
CATEGORY = "model_patches/unet"
def patch(
self,
model: ModelPatcher,
scale: float = 3.0,
sigma_start: float = -1.0,
sigma_end: float = -1.0,
rescale: float = 0.0,
rescale_mode: str = "full",
unet_block_list: str = "",
):
m = model.clone()
inner_model: BaseModel = m.model
sigma_start = float("inf") if sigma_start < 0 else sigma_start
blocks, block_names = parse_unet_blocks(model, unet_block_list, None) if unet_block_list else (None, None)
# Patch transformer blocks with TPG wrapper
for name, module in inner_model.diffusion_model.named_modules():
if (
isinstance(module, BasicTransformerBlock)
and not "wrapped_block" in name
and (block_names is None or name in block_names)
):
# Potential memory leak?
wrapper = TPGTransformerWrapper(module)
m.add_object_patch(f"diffusion_model.{name}", wrapper)
def post_cfg_function(args):
"""CFG+TPG"""
model: BaseModel = args["model"]
cond_pred = args["cond_denoised"]
uncond_pred = args["uncond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
model_options = args["model_options"].copy()
x = args["input"]
signal_scale = scale
if signal_scale == 0 or not (sigma_end < sigma[0] <= sigma_start):
return cfg_result
# Enable TPG in patched transformer blocks
for name, module in model.diffusion_model.named_modules():
if isinstance(module, TPGTransformerWrapper):
set_model_options_value(model_options, TPG_OPTION, True)
(tpg_cond_pred,) = calc_cond_batch(model, [cond], x, sigma, model_options)
tpg = (cond_pred - tpg_cond_pred) * signal_scale
if rescale_mode == "snf":
if uncond_pred.any():
return uncond_pred + snf_guidance(cfg_result - uncond_pred, tpg)
return cfg_result + tpg
return cfg_result + rescale_guidance(tpg, cond_pred, cfg_result, rescale, rescale_mode)
m.set_model_sampler_post_cfg_function(post_cfg_function, rescale_mode == "snf")
return (m,)
NODE_CLASS_MAPPINGS = {
"TokenPerturbationGuidance": TokenPerturbationGuidance,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"TokenPerturbationGuidance": "Token Perturbation Guidance",
}