Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled

Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-09 00:55:26 +00:00
parent 2b70ab9ad0
commit f09734b0ee
2274 changed files with 748556 additions and 3 deletions

View File

@@ -0,0 +1,183 @@
# Initially taken from Github's Python gitignore file
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# tests and logs
tests/fixtures/cached_*_text.txt
logs/
lightning_logs/
lang_code_data/
tests/outputs
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# vscode
.vs
.vscode
# Pycharm
.idea
# TF code
tensorflow_code
# Models
proc_data
# examples
runs
/runs_old
/wandb
/examples/runs
/examples/**/*.args
/examples/rag/sweep
# data
/data
serialization_dir
# emacs
*.*~
debug.env
# vim
.*.swp
#ctags
tags
# pre-commit
.pre-commit*
# .lock
*.lock
# DS_Store (MacOS)
.DS_Store
# RL pipelines may produce mp4 outputs
*.mp4
# dependencies
/transformers
# ruff
.ruff_cache
wandb
ckpts/
test.ipynb
config.yaml
test.ipynb

View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -0,0 +1,252 @@
# ComfyUI's ControlNet Auxiliary Preprocessors
Plug-and-play [ComfyUI](https://github.com/comfyanonymous/ComfyUI) node sets for making [ControlNet](https://github.com/lllyasviel/ControlNet/) hint images
"anime style, a protest in the street, cyberpunk city, a woman with pink hair and golden eyes (looking at the viewer) is holding a sign with the text "ComfyUI ControlNet Aux" in bold, neon pink" on Flux.1 Dev
![](./examples/CNAuxBanner.jpg)
The code is copy-pasted from the respective folders in https://github.com/lllyasviel/ControlNet/tree/main/annotator and connected to [the 🤗 Hub](https://huggingface.co/lllyasviel/Annotators).
All credit & copyright goes to https://github.com/lllyasviel.
# Updates
Go to [Update page](./UPDATES.md) to follow updates
# Installation:
## Using ComfyUI Manager (recommended):
Install [ComfyUI Manager](https://github.com/ltdrdata/ComfyUI-Manager) and do steps introduced there to install this repo.
## Alternative:
If you're running on Linux, or non-admin account on windows you'll want to ensure `/ComfyUI/custom_nodes` and `comfyui_controlnet_aux` has write permissions.
There is now a **install.bat** you can run to install to portable if detected. Otherwise it will default to system and assume you followed ConfyUI's manual installation steps.
If you can't run **install.bat** (e.g. you are a Linux user). Open the CMD/Shell and do the following:
- Navigate to your `/ComfyUI/custom_nodes/` folder
- Run `git clone https://github.com/Fannovel16/comfyui_controlnet_aux/`
- Navigate to your `comfyui_controlnet_aux` folder
- Portable/venv:
- Run `path/to/ComfUI/python_embeded/python.exe -s -m pip install -r requirements.txt`
- With system python
- Run `pip install -r requirements.txt`
- Start ComfyUI
# Nodes
Please note that this repo only supports preprocessors making hint images (e.g. stickman, canny edge, etc).
All preprocessors except Inpaint are intergrated into `AIO Aux Preprocessor` node.
This node allow you to quickly get the preprocessor but a preprocessor's own threshold parameters won't be able to set.
You need to use its node directly to set thresholds.
# Nodes (sections are categories in Comfy menu)
## Line Extractors
| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
|-----------------------------|---------------------------|-------------------------------------------|
| Binary Lines | binary | control_scribble |
| Canny Edge | canny | control_v11p_sd15_canny <br> control_canny <br> t2iadapter_canny |
| HED Soft-Edge Lines | hed | control_v11p_sd15_softedge <br> control_hed |
| Standard Lineart | standard_lineart | control_v11p_sd15_lineart |
| Realistic Lineart | lineart (or `lineart_coarse` if `coarse` is enabled) | control_v11p_sd15_lineart |
| Anime Lineart | lineart_anime | control_v11p_sd15s2_lineart_anime |
| Manga Lineart | lineart_anime_denoise | control_v11p_sd15s2_lineart_anime |
| M-LSD Lines | mlsd | control_v11p_sd15_mlsd <br> control_mlsd |
| PiDiNet Soft-Edge Lines | pidinet | control_v11p_sd15_softedge <br> control_scribble |
| Scribble Lines | scribble | control_v11p_sd15_scribble <br> control_scribble |
| Scribble XDoG Lines | scribble_xdog | control_v11p_sd15_scribble <br> control_scribble |
| Fake Scribble Lines | scribble_hed | control_v11p_sd15_scribble <br> control_scribble |
| TEED Soft-Edge Lines | teed | [controlnet-sd-xl-1.0-softedge-dexined](https://huggingface.co/SargeZT/controlnet-sd-xl-1.0-softedge-dexined/blob/main/controlnet-sd-xl-1.0-softedge-dexined.safetensors) <br> control_v11p_sd15_softedge (Theoretically)
| Scribble PiDiNet Lines | scribble_pidinet | control_v11p_sd15_scribble <br> control_scribble |
| AnyLine Lineart | | mistoLine_fp16.safetensors <br> mistoLine_rank256 <br> control_v11p_sd15s2_lineart_anime <br> control_v11p_sd15_lineart |
## Normal and Depth Estimators
| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
|-----------------------------|---------------------------|-------------------------------------------|
| MiDaS Depth Map | (normal) depth | control_v11f1p_sd15_depth <br> control_depth <br> t2iadapter_depth |
| LeReS Depth Map | depth_leres | control_v11f1p_sd15_depth <br> control_depth <br> t2iadapter_depth |
| Zoe Depth Map | depth_zoe | control_v11f1p_sd15_depth <br> control_depth <br> t2iadapter_depth |
| MiDaS Normal Map | normal_map | control_normal |
| BAE Normal Map | normal_bae | control_v11p_sd15_normalbae |
| MeshGraphormer Hand Refiner ([HandRefinder](https://github.com/wenquanlu/HandRefiner)) | depth_hand_refiner | [control_sd15_inpaint_depth_hand_fp16](https://huggingface.co/hr16/ControlNet-HandRefiner-pruned/blob/main/control_sd15_inpaint_depth_hand_fp16.safetensors) |
| Depth Anything | depth_anything | [Depth-Anything](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints_controlnet/diffusion_pytorch_model.safetensors) |
| Zoe Depth Anything <br> (Basically Zoe but the encoder is replaced with DepthAnything) | depth_anything | [Depth-Anything](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints_controlnet/diffusion_pytorch_model.safetensors) |
| Normal DSINE | | control_normal/control_v11p_sd15_normalbae |
| Metric3D Depth | | control_v11f1p_sd15_depth <br> control_depth <br> t2iadapter_depth |
| Metric3D Normal | | control_v11p_sd15_normalbae |
| Depth Anything V2 | | [Depth-Anything](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints_controlnet/diffusion_pytorch_model.safetensors) |
## Faces and Poses Estimators
| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
|-----------------------------|---------------------------|-------------------------------------------|
| DWPose Estimator | dw_openpose_full | control_v11p_sd15_openpose <br> control_openpose <br> t2iadapter_openpose |
| OpenPose Estimator | openpose (detect_body) <br> openpose_hand (detect_body + detect_hand) <br> openpose_faceonly (detect_face) <br> openpose_full (detect_hand + detect_body + detect_face) | control_v11p_sd15_openpose <br> control_openpose <br> t2iadapter_openpose |
| MediaPipe Face Mesh | mediapipe_face | controlnet_sd21_laion_face_v2 |
| Animal Estimator | animal_openpose | [control_sd15_animal_openpose_fp16](https://huggingface.co/huchenlei/animal_openpose/blob/main/control_sd15_animal_openpose_fp16.pth) |
## Optical Flow Estimators
| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
|-----------------------------|---------------------------|-------------------------------------------|
| Unimatch Optical Flow | | [DragNUWA](https://github.com/ProjectNUWA/DragNUWA) |
### How to get OpenPose-format JSON?
#### User-side
This workflow will save images to ComfyUI's output folder (the same location as output images). If you haven't found `Save Pose Keypoints` node, update this extension
![](./examples/example_save_kps.png)
#### Dev-side
An array of [OpenPose-format JSON](https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/02_output.md#json-output-format) corresponsding to each frame in an IMAGE batch can be gotten from DWPose and OpenPose using `app.nodeOutputs` on the UI or `/history` API endpoint. JSON output from AnimalPose uses a kinda similar format to OpenPose JSON:
```
[
{
"version": "ap10k",
"animals": [
[[x1, y1, 1], [x2, y2, 1],..., [x17, y17, 1]],
[[x1, y1, 1], [x2, y2, 1],..., [x17, y17, 1]],
...
],
"canvas_height": 512,
"canvas_width": 768
},
...
]
```
For extension developers (e.g. Openpose editor):
```js
const poseNodes = app.graph._nodes.filter(node => ["OpenposePreprocessor", "DWPreprocessor", "AnimalPosePreprocessor"].includes(node.type))
for (const poseNode of poseNodes) {
const openposeResults = JSON.parse(app.nodeOutputs[poseNode.id].openpose_json[0])
console.log(openposeResults) //An array containing Openpose JSON for each frame
}
```
For API users:
Javascript
```js
import fetch from "node-fetch" //Remember to add "type": "module" to "package.json"
async function main() {
const promptId = '792c1905-ecfe-41f4-8114-83e6a4a09a9f' //Too lazy to POST /queue
let history = await fetch(`http://127.0.0.1:8188/history/${promptId}`).then(re => re.json())
history = history[promptId]
const nodeOutputs = Object.values(history.outputs).filter(output => output.openpose_json)
for (const nodeOutput of nodeOutputs) {
const openposeResults = JSON.parse(nodeOutput.openpose_json[0])
console.log(openposeResults) //An array containing Openpose JSON for each frame
}
}
main()
```
Python
```py
import json, urllib.request
server_address = "127.0.0.1:8188"
prompt_id = '' #Too lazy to POST /queue
def get_history(prompt_id):
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
return json.loads(response.read())
history = get_history(prompt_id)[prompt_id]
for o in history['outputs']:
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
if 'openpose_json' in node_output:
print(json.loads(node_output['openpose_json'][0])) #An list containing Openpose JSON for each frame
```
## Semantic Segmentation
| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
|-----------------------------|---------------------------|-------------------------------------------|
| OneFormer ADE20K Segmentor | oneformer_ade20k | control_v11p_sd15_seg |
| OneFormer COCO Segmentor | oneformer_coco | control_v11p_sd15_seg |
| UniFormer Segmentor | segmentation |control_sd15_seg <br> control_v11p_sd15_seg|
## T2IAdapter-only
| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
|-----------------------------|---------------------------|-------------------------------------------|
| Color Pallete | color | t2iadapter_color |
| Content Shuffle | shuffle | t2iadapter_style |
## Recolor
| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
|-----------------------------|---------------------------|-------------------------------------------|
| Image Luminance | recolor_luminance | [ioclab_sd15_recolor](https://huggingface.co/lllyasviel/sd_control_collection/resolve/main/ioclab_sd15_recolor.safetensors) <br> [sai_xl_recolor_256lora](https://huggingface.co/lllyasviel/sd_control_collection/resolve/main/sai_xl_recolor_256lora.safetensors) <br> [bdsqlsz_controlllite_xl_recolor_luminance](https://huggingface.co/bdsqlsz/qinglong_controlnet-lllite/resolve/main/bdsqlsz_controlllite_xl_recolor_luminance.safetensors) |
| Image Intensity | recolor_intensity | Idk. Maybe same as above? |
# Examples
> A picture is worth a thousand words
![](./examples/ExecuteAll1.jpg)
![](./examples/ExecuteAll2.jpg)
# Testing workflow
https://github.com/Fannovel16/comfyui_controlnet_aux/blob/main/examples/ExecuteAll.png
Input image: https://github.com/Fannovel16/comfyui_controlnet_aux/blob/main/examples/comfyui-controlnet-aux-logo.png
# Q&A:
## Why some nodes doesn't appear after I installed this repo?
This repo has a new mechanism which will skip any custom node can't be imported. If you meet this case, please create a issue on [Issues tab](https://github.com/Fannovel16/comfyui_controlnet_aux/issues) with the log from the command line.
## DWPose/AnimalPose only uses CPU so it's so slow. How can I make it use GPU?
There are two ways to speed-up DWPose: using TorchScript checkpoints (.torchscript.pt) checkpoints or ONNXRuntime (.onnx). TorchScript way is little bit slower than ONNXRuntime but doesn't require any additional library and still way way faster than CPU.
A torchscript bbox detector is compatiable with an onnx pose estimator and vice versa.
### TorchScript
Set `bbox_detector` and `pose_estimator` according to this picture. You can try other bbox detector endings with `.torchscript.pt` to reduce bbox detection time if input images are ideal.
![](./examples/example_torchscript.png)
### ONNXRuntime
If onnxruntime is installed successfully and the checkpoint used endings with `.onnx`, it will replace default cv2 backend to take advantage of GPU. Note that if you are using NVidia card, this method currently can only works on CUDA 11.8 (ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z) unless you compile onnxruntime yourself.
1. Know your onnxruntime build:
* * NVidia CUDA 11.x or bellow/AMD GPU: `onnxruntime-gpu`
* * NVidia CUDA 12.x: `onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/`
* * DirectML: `onnxruntime-directml`
* * OpenVINO: `onnxruntime-openvino`
Note that if this is your first time using ComfyUI, please test if it can run on your device before doing next steps.
2. Add it into `requirements.txt`
3. Run `install.bat` or pip command mentioned in Installation
![](./examples/example_onnx.png)
# Assets files of preprocessors
* anime_face_segment: [bdsqlsz/qinglong_controlnet-lllite/Annotators/UNet.pth](https://huggingface.co/bdsqlsz/qinglong_controlnet-lllite/blob/main/Annotators/UNet.pth), [anime-seg/isnetis.ckpt](https://huggingface.co/skytnt/anime-seg/blob/main/isnetis.ckpt)
* densepose: [LayerNorm/DensePose-TorchScript-with-hint-image/densepose_r50_fpn_dl.torchscript](https://huggingface.co/LayerNorm/DensePose-TorchScript-with-hint-image/blob/main/densepose_r50_fpn_dl.torchscript)
* dwpose:
* * bbox_detector: Either [yzd-v/DWPose/yolox_l.onnx](https://huggingface.co/yzd-v/DWPose/blob/main/yolox_l.onnx), [hr16/yolox-onnx/yolox_l.torchscript.pt](https://huggingface.co/hr16/yolox-onnx/blob/main/yolox_l.torchscript.pt), [hr16/yolo-nas-fp16/yolo_nas_l_fp16.onnx](https://huggingface.co/hr16/yolo-nas-fp16/blob/main/yolo_nas_l_fp16.onnx), [hr16/yolo-nas-fp16/yolo_nas_m_fp16.onnx](https://huggingface.co/hr16/yolo-nas-fp16/blob/main/yolo_nas_m_fp16.onnx), [hr16/yolo-nas-fp16/yolo_nas_s_fp16.onnx](https://huggingface.co/hr16/yolo-nas-fp16/blob/main/yolo_nas_s_fp16.onnx)
* * pose_estimator: Either [hr16/DWPose-TorchScript-BatchSize5/dw-ll_ucoco_384_bs5.torchscript.pt](https://huggingface.co/hr16/DWPose-TorchScript-BatchSize5/blob/main/dw-ll_ucoco_384_bs5.torchscript.pt), [yzd-v/DWPose/dw-ll_ucoco_384.onnx](https://huggingface.co/yzd-v/DWPose/blob/main/dw-ll_ucoco_384.onnx)
* animal_pose (ap10k):
* * bbox_detector: Either [yzd-v/DWPose/yolox_l.onnx](https://huggingface.co/yzd-v/DWPose/blob/main/yolox_l.onnx), [hr16/yolox-onnx/yolox_l.torchscript.pt](https://huggingface.co/hr16/yolox-onnx/blob/main/yolox_l.torchscript.pt), [hr16/yolo-nas-fp16/yolo_nas_l_fp16.onnx](https://huggingface.co/hr16/yolo-nas-fp16/blob/main/yolo_nas_l_fp16.onnx), [hr16/yolo-nas-fp16/yolo_nas_m_fp16.onnx](https://huggingface.co/hr16/yolo-nas-fp16/blob/main/yolo_nas_m_fp16.onnx), [hr16/yolo-nas-fp16/yolo_nas_s_fp16.onnx](https://huggingface.co/hr16/yolo-nas-fp16/blob/main/yolo_nas_s_fp16.onnx)
* * pose_estimator: Either [hr16/DWPose-TorchScript-BatchSize5/rtmpose-m_ap10k_256_bs5.torchscript.pt](https://huggingface.co/hr16/DWPose-TorchScript-BatchSize5/blob/main/rtmpose-m_ap10k_256_bs5.torchscript.pt), [hr16/UnJIT-DWPose/rtmpose-m_ap10k_256.onnx](https://huggingface.co/hr16/UnJIT-DWPose/blob/main/rtmpose-m_ap10k_256.onnx)
* hed: [lllyasviel/Annotators/ControlNetHED.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/ControlNetHED.pth)
* leres: [lllyasviel/Annotators/res101.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/res101.pth), [lllyasviel/Annotators/latest_net_G.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/latest_net_G.pth)
* lineart: [lllyasviel/Annotators/sk_model.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/sk_model.pth), [lllyasviel/Annotators/sk_model2.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/sk_model2.pth)
* lineart_anime: [lllyasviel/Annotators/netG.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/netG.pth)
* manga_line: [lllyasviel/Annotators/erika.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/erika.pth)
* mesh_graphormer: [hr16/ControlNet-HandRefiner-pruned/graphormer_hand_state_dict.bin](https://huggingface.co/hr16/ControlNet-HandRefiner-pruned/blob/main/graphormer_hand_state_dict.bin), [hr16/ControlNet-HandRefiner-pruned/hrnetv2_w64_imagenet_pretrained.pth](https://huggingface.co/hr16/ControlNet-HandRefiner-pruned/blob/main/hrnetv2_w64_imagenet_pretrained.pth)
* midas: [lllyasviel/Annotators/dpt_hybrid-midas-501f0c75.pt](https://huggingface.co/lllyasviel/Annotators/blob/main/dpt_hybrid-midas-501f0c75.pt)
* mlsd: [lllyasviel/Annotators/mlsd_large_512_fp32.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/mlsd_large_512_fp32.pth)
* normalbae: [lllyasviel/Annotators/scannet.pt](https://huggingface.co/lllyasviel/Annotators/blob/main/scannet.pt)
* oneformer: [lllyasviel/Annotators/250_16_swin_l_oneformer_ade20k_160k.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/250_16_swin_l_oneformer_ade20k_160k.pth)
* open_pose: [lllyasviel/Annotators/body_pose_model.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/body_pose_model.pth), [lllyasviel/Annotators/hand_pose_model.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/hand_pose_model.pth), [lllyasviel/Annotators/facenet.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/facenet.pth)
* pidi: [lllyasviel/Annotators/table5_pidinet.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/table5_pidinet.pth)
* sam: [dhkim2810/MobileSAM/mobile_sam.pt](https://huggingface.co/dhkim2810/MobileSAM/blob/main/mobile_sam.pt)
* uniformer: [lllyasviel/Annotators/upernet_global_small.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/upernet_global_small.pth)
* zoe: [lllyasviel/Annotators/ZoeD_M12_N.pt](https://huggingface.co/lllyasviel/Annotators/blob/main/ZoeD_M12_N.pt)
* teed: [bdsqlsz/qinglong_controlnet-lllite/7_model.pth](https://huggingface.co/bdsqlsz/qinglong_controlnet-lllite/blob/main/Annotators/7_model.pth)
* depth_anything: Either [LiheYoung/Depth-Anything/checkpoints/depth_anything_vitl14.pth](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints/depth_anything_vitl14.pth), [LiheYoung/Depth-Anything/checkpoints/depth_anything_vitb14.pth](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints/depth_anything_vitb14.pth) or [LiheYoung/Depth-Anything/checkpoints/depth_anything_vits14.pth](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints/depth_anything_vits14.pth)
* diffusion_edge: Either [hr16/Diffusion-Edge/diffusion_edge_indoor.pt](https://huggingface.co/hr16/Diffusion-Edge/blob/main/diffusion_edge_indoor.pt), [hr16/Diffusion-Edge/diffusion_edge_urban.pt](https://huggingface.co/hr16/Diffusion-Edge/blob/main/diffusion_edge_urban.pt) or [hr16/Diffusion-Edge/diffusion_edge_natrual.pt](https://huggingface.co/hr16/Diffusion-Edge/blob/main/diffusion_edge_natrual.pt)
* unimatch: Either [hr16/Unimatch/gmflow-scale2-regrefine6-mixdata.pth](https://huggingface.co/hr16/Unimatch/blob/main/gmflow-scale2-regrefine6-mixdata.pth), [hr16/Unimatch/gmflow-scale2-mixdata.pth](https://huggingface.co/hr16/Unimatch/blob/main/gmflow-scale2-mixdata.pth) or [hr16/Unimatch/gmflow-scale1-mixdata.pth](https://huggingface.co/hr16/Unimatch/blob/main/gmflow-scale1-mixdata.pth)
* zoe_depth_anything: Either [LiheYoung/Depth-Anything/checkpoints_metric_depth/depth_anything_metric_depth_indoor.pt](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints_metric_depth/depth_anything_metric_depth_indoor.pt) or [LiheYoung/Depth-Anything/checkpoints_metric_depth/depth_anything_metric_depth_outdoor.pt](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints_metric_depth/depth_anything_metric_depth_outdoor.pt)
# 2000 Stars 😄
<a href="https://star-history.com/#Fannovel16/comfyui_controlnet_aux&Date">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Fannovel16/comfyui_controlnet_aux&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Fannovel16/comfyui_controlnet_aux&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Fannovel16/comfyui_controlnet_aux&type=Date" />
</picture>
</a>
Thanks for yalls supports. I never thought the graph for stars would be linear lol.

View File

@@ -0,0 +1,45 @@
* `AIO Aux Preprocessor` intergrating all loadable aux preprocessors as dropdown options. Easy to copy, paste and get the preprocessor faster.
* Added OpenPose-format JSON output from OpenPose Preprocessor and DWPose Preprocessor. Checks [here](#faces-and-poses).
* Fixed wrong model path when downloading DWPose.
* Make hint images less blurry.
* Added `resolution` option, `PixelPerfectResolution` and `HintImageEnchance` nodes (TODO: Documentation).
* Added `RAFT Optical Flow Embedder` for TemporalNet2 (TODO: Workflow example).
* Fixed opencv's conflicts between this extension, [ReActor](https://github.com/Gourieff/comfyui-reactor-node) and Roop. Thanks `Gourieff` for [the solution](https://github.com/Fannovel16/comfyui_controlnet_aux/issues/7#issuecomment-1734319075)!
* RAFT is removed as the code behind it doesn't match what what the original code does
* Changed `lineart`'s display name from `Normal Lineart` to `Realistic Lineart`. This change won't affect old workflows
* Added support for `onnxruntime` to speed-up DWPose (see the Q&A)
* Fixed TypeError: expected size to be one of int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], but got size with types [<class 'numpy.int64'>, <class 'numpy.int64'>]: [Issue](https://github.com/Fannovel16/comfyui_controlnet_aux/issues/2), [PR](https://github.com/Fannovel16/comfyui_controlnet_aux/pull/71))
* Fixed ImageGenResolutionFromImage mishape (https://github.com/Fannovel16/comfyui_controlnet_aux/pull/74)
* Fixed LeRes and MiDaS's incomatipility with MPS device
* Fixed checking DWPose onnxruntime session multiple times: https://github.com/Fannovel16/comfyui_controlnet_aux/issues/89)
* Added `Anime Face Segmentor` (in `ControlNet Preprocessors/Semantic Segmentation`) for [ControlNet AnimeFaceSegmentV2](https://huggingface.co/bdsqlsz/qinglong_controlnet-lllite#animefacesegmentv2). Checks [here](#anime-face-segmentor)
* Change download functions and fix [download error](https://github.com/Fannovel16/comfyui_controlnet_aux/issues/39): [PR](https://github.com/Fannovel16/comfyui_controlnet_aux/pull/96)
* Caching DWPose Onnxruntime during the first use of DWPose node instead of ComfyUI startup
* Added alternative YOLOX models for faster speed when using DWPose
* Added alternative DWPose models
* Implemented the preprocessor for [AnimalPose ControlNet](https://github.com/abehonest/ControlNet_AnimalPose/tree/main). Check [Animal Pose AP-10K](#animal-pose-ap-10k)
* Added YOLO-NAS models which are drop-in replacements of YOLOX
* Fixed Openpose Face/Hands no longer detecting: https://github.com/Fannovel16/comfyui_controlnet_aux/issues/54
* Added TorchScript implementation of DWPose and AnimalPose
* Added TorchScript implementation of DensePose from [Colab notebook](https://colab.research.google.com/drive/16hcaaKs210ivpxjoyGNuvEXZD4eqOOSQ) which doesn't require detectron2. [Example](#densepose). Thanks [@LayerNome](https://github.com/Layer-norm) for fixing bugs related.
* Added Standard Lineart Preprocessor
* Fixed OpenPose misplacements in some cases
* Added Mesh Graphormer - Hand Depth Map & Mask
* Misaligned hands bug from MeshGraphormer was fixed
* Added more mask options for MeshGraphormer
* Added Save Pose Keypoint node for editing
* Added Unimatch Optical Flow
* Added Depth Anything & Zoe Depth Anything
* Removed resolution field from Unimatch Optical Flow as that interpolating optical flow seems unstable
* Added TEED Soft-Edge Preprocessor
* Added DiffusionEdge
* Added Image Luminance and Image Intensity
* Added Normal DSINE
* Added TTPlanet Tile (09/05/2024, DD/MM/YYYY)
* Added AnyLine, Metric3D (18/05/2024)
* Added Depth Anything V2 (16/06/2024)
* Added Union model of ControlNet and preprocessors
![345832280-edf41dab-7619-494c-9f60-60ec1f8789cb](https://github.com/user-attachments/assets/aa55f57c-cad7-48e6-84d3-8f506d847989)
* Refactor INPUT_TYPES and add Execute All node during the process of learning [Execution Model Inversion](https://github.com/comfyanonymous/ComfyUI/pull/2666)
* Added scale_stick_for_xinsr_cn (https://github.com/Fannovel16/comfyui_controlnet_aux/issues/447) (09/04/2024)
* PyTorch 2.7 compatibility fixes - eliminated custom_timm, custom_detectron2, and custom_midas_repo dependencies causing hanging issues. Refactored 7 major preprocessors including OneFormer (now using HuggingFace transformers), ZOE, DSINE, MiDaS, BAE, Metric3D, and Uniformer. Resolved ~59 GitHub issues related to import failures, hanging, and extension conflicts. Full modernization to actively maintained packages.

View File

@@ -0,0 +1,224 @@
import sys, os
# Disable NPU device initialization and problematic MMCV ops to prevent RuntimeError
# Must be set BEFORE any MMCV imports happen anywhere in ComfyUI
os.environ['NPU_DEVICE_COUNT'] = '0'
os.environ['MMCV_WITH_OPS'] = '0'
from .utils import here, define_preprocessor_inputs, INPUT
from pathlib import Path
import traceback
import importlib
from .log import log, blue_text, cyan_text, get_summary, get_label
from .hint_image_enchance import NODE_CLASS_MAPPINGS as HIE_NODE_CLASS_MAPPINGS
from .hint_image_enchance import NODE_DISPLAY_NAME_MAPPINGS as HIE_NODE_DISPLAY_NAME_MAPPINGS
#Ref: https://github.com/comfyanonymous/ComfyUI/blob/76d53c4622fc06372975ed2a43ad345935b8a551/nodes.py#L17
sys.path.insert(0, str(Path(here, "src").resolve()))
for pkg_name in ["custom_controlnet_aux", "custom_mmpkg"]:
sys.path.append(str(Path(here, "src", pkg_name).resolve()))
#Enable CPU fallback for ops not being supported by MPS like upsample_bicubic2d.out
#https://github.com/pytorch/pytorch/issues/77764
#https://github.com/Fannovel16/comfyui_controlnet_aux/issues/2#issuecomment-1763579485
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = os.getenv("PYTORCH_ENABLE_MPS_FALLBACK", '1')
def load_nodes():
shorted_errors = []
full_error_messages = []
node_class_mappings = {}
node_display_name_mappings = {}
for filename in (here / "node_wrappers").iterdir():
module_name = filename.stem
if module_name.startswith('.'): continue #Skip hidden files created by the OS (e.g. [.DS_Store](https://en.wikipedia.org/wiki/.DS_Store))
try:
module = importlib.import_module(
f".node_wrappers.{module_name}", package=__package__
)
node_class_mappings.update(getattr(module, "NODE_CLASS_MAPPINGS"))
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS"):
node_display_name_mappings.update(getattr(module, "NODE_DISPLAY_NAME_MAPPINGS"))
log.debug(f"Imported {module_name} nodes")
except AttributeError:
pass # wip nodes
except Exception:
error_message = traceback.format_exc()
full_error_messages.append(error_message)
error_message = error_message.splitlines()[-1]
shorted_errors.append(
f"Failed to import module {module_name} because {error_message}"
)
if len(shorted_errors) > 0:
full_err_log = '\n\n'.join(full_error_messages)
print(f"\n\nFull error log from comfyui_controlnet_aux: \n{full_err_log}\n\n")
log.info(
f"Some nodes failed to load:\n\t"
+ "\n\t".join(shorted_errors)
+ "\n\n"
+ "Check that you properly installed the dependencies.\n"
+ "If you think this is a bug, please report it on the github page (https://github.com/Fannovel16/comfyui_controlnet_aux/issues)"
)
return node_class_mappings, node_display_name_mappings
AUX_NODE_MAPPINGS, AUX_DISPLAY_NAME_MAPPINGS = load_nodes()
#For nodes not mapping image to image or has special requirements
AIO_NOT_SUPPORTED = ["InpaintPreprocessor", "MeshGraphormer+ImpactDetector-DepthMapPreprocessor", "DiffusionEdge_Preprocessor"]
AIO_NOT_SUPPORTED += ["SavePoseKpsAsJsonFile", "FacialPartColoringFromPoseKps", "UpperBodyTrackingFromPoseKps", "RenderPeopleKps", "RenderAnimalKps"]
AIO_NOT_SUPPORTED += ["Unimatch_OptFlowPreprocessor", "MaskOptFlow"]
def preprocessor_options():
auxs = list(AUX_NODE_MAPPINGS.keys())
auxs.insert(0, "none")
for name in AIO_NOT_SUPPORTED:
if name in auxs:
auxs.remove(name)
return auxs
PREPROCESSOR_OPTIONS = preprocessor_options()
class AIO_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
preprocessor=INPUT.COMBO(PREPROCESSOR_OPTIONS, default="none"),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors"
def execute(self, preprocessor, image, resolution=512):
if preprocessor == "none":
return (image, )
else:
aux_class = AUX_NODE_MAPPINGS[preprocessor]
input_types = aux_class.INPUT_TYPES()
input_types = {
**input_types["required"],
**(input_types["optional"] if "optional" in input_types else {})
}
params = {}
for name, input_type in input_types.items():
if name == "image":
params[name] = image
continue
if name == "resolution":
params[name] = resolution
continue
if len(input_type) == 2 and ("default" in input_type[1]):
params[name] = input_type[1]["default"]
continue
default_values = { "INT": 0, "FLOAT": 0.0 }
if type(input_type[0]) is list:
for input_type_value in input_type[0]:
if input_type_value in default_values:
params[name] = default_values[input_type[0]]
else:
if input_type[0] in default_values:
params[name] = default_values[input_type[0]]
return getattr(aux_class(), aux_class.FUNCTION)(**params)
class ControlNetAuxSimpleAddText:
@classmethod
def INPUT_TYPES(s):
return dict(
required=dict(image=INPUT.IMAGE(), text=INPUT.STRING())
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors"
def execute(self, image, text):
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import torch
font = ImageFont.truetype(str((here / "NotoSans-Regular.ttf").resolve()), 40)
img = Image.fromarray(image[0].cpu().numpy().__mul__(255.).astype(np.uint8))
ImageDraw.Draw(img).text((0,0), text, fill=(0,255,0), font=font)
return (torch.from_numpy(np.array(img)).unsqueeze(0) / 255.,)
class ExecuteAllControlNetPreprocessors:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors"
def execute(self, image, resolution=512):
try:
from comfy_execution.graph_utils import GraphBuilder
except:
raise RuntimeError("ExecuteAllControlNetPreprocessor requries [Execution Model Inversion](https://github.com/comfyanonymous/ComfyUI/commit/5cfe38). Update ComfyUI/SwarmUI to get this feature")
graph = GraphBuilder()
curr_outputs = []
for preprocc in PREPROCESSOR_OPTIONS:
preprocc_node = graph.node("AIO_Preprocessor", preprocessor=preprocc, image=image, resolution=resolution)
hint_img = preprocc_node.out(0)
add_text_node = graph.node("ControlNetAuxSimpleAddText", image=hint_img, text=preprocc)
curr_outputs.append(add_text_node.out(0))
while len(curr_outputs) > 1:
_outputs = []
for i in range(0, len(curr_outputs), 2):
if i+1 < len(curr_outputs):
image_batch = graph.node("ImageBatch", image1=curr_outputs[i], image2=curr_outputs[i+1])
_outputs.append(image_batch.out(0))
else:
_outputs.append(curr_outputs[i])
curr_outputs = _outputs
return {
"result": (curr_outputs[0],),
"expand": graph.finalize(),
}
class ControlNetPreprocessorSelector:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"preprocessor": (PREPROCESSOR_OPTIONS,),
}
}
RETURN_TYPES = (PREPROCESSOR_OPTIONS,)
RETURN_NAMES = ("preprocessor",)
FUNCTION = "get_preprocessor"
CATEGORY = "ControlNet Preprocessors"
def get_preprocessor(self, preprocessor: str):
return (preprocessor,)
NODE_CLASS_MAPPINGS = {
**AUX_NODE_MAPPINGS,
"AIO_Preprocessor": AIO_Preprocessor,
"ControlNetPreprocessorSelector": ControlNetPreprocessorSelector,
**HIE_NODE_CLASS_MAPPINGS,
"ExecuteAllControlNetPreprocessors": ExecuteAllControlNetPreprocessors,
"ControlNetAuxSimpleAddText": ControlNetAuxSimpleAddText
}
NODE_DISPLAY_NAME_MAPPINGS = {
**AUX_DISPLAY_NAME_MAPPINGS,
"AIO_Preprocessor": "AIO Aux Preprocessor",
"ControlNetPreprocessorSelector": "Preprocessor Selector",
**HIE_NODE_DISPLAY_NAME_MAPPINGS,
"ExecuteAllControlNetPreprocessors": "Execute All ControlNet Preprocessors"
}

View File

@@ -0,0 +1,20 @@
# this is an example for config.yaml file, you can rename it to config.yaml if you want to use it
# ###############################################################################################
# This path is for custom pressesor models base folder. default is "./ckpts"
# you can also use absolute paths like: "/root/ComfyUI/custom_nodes/comfyui_controlnet_aux/ckpts" or "D:\\ComfyUI\\custom_nodes\\comfyui_controlnet_aux\\ckpts"
annotator_ckpts_path: "./ckpts"
# ###############################################################################################
# This path is for downloading temporary files.
# You SHOULD use absolute path for this like"D:\\temp", DO NOT use relative paths. Empty for default.
custom_temp_path:
# ###############################################################################################
# if you already have downloaded ckpts via huggingface hub into default cache path like: ~/.cache/huggingface/hub, you can set this True to use symlinks to save space
USE_SYMLINKS: False
# ###############################################################################################
# EP_list is a list of execution providers for onnxruntime, if one of them is not available or not working well, you can delete that provider from here(config.yaml)
# you can find all available providers here: https://onnxruntime.ai/docs/execution-providers
# for example, if you have CUDA installed, you can set it to: ["CUDAExecutionProvider", "CPUExecutionProvider"]
# empty list or only keep ["CPUExecutionProvider"] means you use cv2.dnn.readNetFromONNX to load onnx models
# if your onnx models can only run on the CPU or have other issues, we recommend using pt model instead.
# default value is ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CPUExecutionProvider"]
EP_list: ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CPUExecutionProvider"]

View File

@@ -0,0 +1,6 @@
from pathlib import Path
from utils import here
import sys
sys.path.append(str(Path(here, "src")))
from custom_controlnet_aux import *

Binary file not shown.

After

Width:  |  Height:  |  Size: 576 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.5 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 998 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 694 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 706 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 472 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 371 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 271 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 438 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 180 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 636 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 646 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 294 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 244 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 93 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 713 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 400 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 519 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 244 KiB

View File

@@ -0,0 +1,233 @@
from .log import log
from .utils import ResizeMode, safe_numpy
import numpy as np
import torch
import cv2
from .utils import get_unique_axis0
from .lvminthin import nake_nms, lvmin_thin
MAX_IMAGEGEN_RESOLUTION = 8192 #https://github.com/comfyanonymous/ComfyUI/blob/c910b4a01ca58b04e5d4ab4c747680b996ada02b/nodes.py#L42
RESIZE_MODES = [ResizeMode.RESIZE.value, ResizeMode.INNER_FIT.value, ResizeMode.OUTER_FIT.value]
#Port from https://github.com/Mikubill/sd-webui-controlnet/blob/e67e017731aad05796b9615dc6eadce911298ea1/internal_controlnet/external_code.py#L89
class PixelPerfectResolution:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"original_image": ("IMAGE", ),
"image_gen_width": ("INT", {"default": 512, "min": 64, "max": MAX_IMAGEGEN_RESOLUTION, "step": 8}),
"image_gen_height": ("INT", {"default": 512, "min": 64, "max": MAX_IMAGEGEN_RESOLUTION, "step": 8}),
#https://github.com/comfyanonymous/ComfyUI/blob/c910b4a01ca58b04e5d4ab4c747680b996ada02b/nodes.py#L854
"resize_mode": (RESIZE_MODES, {"default": ResizeMode.RESIZE.value})
}
}
RETURN_TYPES = ("INT",)
RETURN_NAMES = ("RESOLUTION (INT)", )
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors"
def execute(self, original_image, image_gen_width, image_gen_height, resize_mode):
_, raw_H, raw_W, _ = original_image.shape
k0 = float(image_gen_height) / float(raw_H)
k1 = float(image_gen_width) / float(raw_W)
if resize_mode == ResizeMode.OUTER_FIT.value:
estimation = min(k0, k1) * float(min(raw_H, raw_W))
else:
estimation = max(k0, k1) * float(min(raw_H, raw_W))
log.debug(f"Pixel Perfect Computation:")
log.debug(f"resize_mode = {resize_mode}")
log.debug(f"raw_H = {raw_H}")
log.debug(f"raw_W = {raw_W}")
log.debug(f"target_H = {image_gen_height}")
log.debug(f"target_W = {image_gen_width}")
log.debug(f"estimation = {estimation}")
return (int(np.round(estimation)), )
class HintImageEnchance:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"hint_image": ("IMAGE", ),
"image_gen_width": ("INT", {"default": 512, "min": 64, "max": MAX_IMAGEGEN_RESOLUTION, "step": 8}),
"image_gen_height": ("INT", {"default": 512, "min": 64, "max": MAX_IMAGEGEN_RESOLUTION, "step": 8}),
#https://github.com/comfyanonymous/ComfyUI/blob/c910b4a01ca58b04e5d4ab4c747680b996ada02b/nodes.py#L854
"resize_mode": (RESIZE_MODES, {"default": ResizeMode.RESIZE.value})
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors"
def execute(self, hint_image, image_gen_width, image_gen_height, resize_mode):
outs = []
for single_hint_image in hint_image:
np_hint_image = np.asarray(single_hint_image * 255., dtype=np.uint8)
if resize_mode == ResizeMode.RESIZE.value:
np_hint_image = self.execute_resize(np_hint_image, image_gen_width, image_gen_height)
elif resize_mode == ResizeMode.OUTER_FIT.value:
np_hint_image = self.execute_outer_fit(np_hint_image, image_gen_width, image_gen_height)
else:
np_hint_image = self.execute_inner_fit(np_hint_image, image_gen_width, image_gen_height)
outs.append(torch.from_numpy(np_hint_image.astype(np.float32) / 255.0))
return (torch.stack(outs, dim=0),)
def execute_resize(self, detected_map, w, h):
detected_map = self.high_quality_resize(detected_map, (w, h))
detected_map = safe_numpy(detected_map)
return detected_map
def execute_outer_fit(self, detected_map, w, h):
old_h, old_w, _ = detected_map.shape
old_w = float(old_w)
old_h = float(old_h)
k0 = float(h) / old_h
k1 = float(w) / old_w
safeint = lambda x: int(np.round(x))
k = min(k0, k1)
borders = np.concatenate([detected_map[0, :, :], detected_map[-1, :, :], detected_map[:, 0, :], detected_map[:, -1, :]], axis=0)
high_quality_border_color = np.median(borders, axis=0).astype(detected_map.dtype)
if len(high_quality_border_color) == 4:
# Inpaint hijack
high_quality_border_color[3] = 255
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
detected_map = self.high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = detected_map.shape
pad_h = max(0, (h - new_h) // 2)
pad_w = max(0, (w - new_w) // 2)
high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = detected_map
detected_map = high_quality_background
detected_map = safe_numpy(detected_map)
return detected_map
def execute_inner_fit(self, detected_map, w, h):
old_h, old_w, _ = detected_map.shape
old_w = float(old_w)
old_h = float(old_h)
k0 = float(h) / old_h
k1 = float(w) / old_w
safeint = lambda x: int(np.round(x))
k = max(k0, k1)
detected_map = self.high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = detected_map.shape
pad_h = max(0, (new_h - h) // 2)
pad_w = max(0, (new_w - w) // 2)
detected_map = detected_map[pad_h:pad_h+h, pad_w:pad_w+w]
detected_map = safe_numpy(detected_map)
return detected_map
def high_quality_resize(self, x, size):
# Written by lvmin
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
inpaint_mask = None
if x.ndim == 3 and x.shape[2] == 4:
inpaint_mask = x[:, :, 3]
x = x[:, :, 0:3]
if x.shape[0] != size[1] or x.shape[1] != size[0]:
new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
unique_color_count = len(get_unique_axis0(x.reshape(-1, x.shape[2])))
is_one_pixel_edge = False
is_binary = False
if unique_color_count == 2:
is_binary = np.min(x) < 16 and np.max(x) > 240
if is_binary:
xc = x
xc = cv2.erode(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
xc = cv2.dilate(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
one_pixel_edge_count = np.where(xc < x)[0].shape[0]
all_edge_count = np.where(x > 127)[0].shape[0]
is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count
if 2 < unique_color_count < 200:
interpolation = cv2.INTER_NEAREST
elif new_size_is_smaller:
interpolation = cv2.INTER_AREA
else:
interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
y = cv2.resize(x, size, interpolation=interpolation)
if inpaint_mask is not None:
inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
if is_binary:
y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
if is_one_pixel_edge:
y = nake_nms(y)
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
y = lvmin_thin(y, prunings=new_size_is_bigger)
else:
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
y = np.stack([y] * 3, axis=2)
else:
y = x
if inpaint_mask is not None:
inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
y = np.concatenate([y, inpaint_mask], axis=2)
return y
class ImageGenResolutionFromLatent:
@classmethod
def INPUT_TYPES(s):
return {
"required": { "latent": ("LATENT", ) }
}
RETURN_TYPES = ("INT", "INT")
RETURN_NAMES = ("IMAGE_GEN_WIDTH (INT)", "IMAGE_GEN_HEIGHT (INT)")
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors"
def execute(self, latent):
_, _, H, W = latent["samples"].shape
return (W * 8, H * 8)
class ImageGenResolutionFromImage:
@classmethod
def INPUT_TYPES(s):
return {
"required": { "image": ("IMAGE", ) }
}
RETURN_TYPES = ("INT", "INT")
RETURN_NAMES = ("IMAGE_GEN_WIDTH (INT)", "IMAGE_GEN_HEIGHT (INT)")
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors"
def execute(self, image):
_, H, W, _ = image.shape
return (W, H)
NODE_CLASS_MAPPINGS = {
"PixelPerfectResolution": PixelPerfectResolution,
"ImageGenResolutionFromImage": ImageGenResolutionFromImage,
"ImageGenResolutionFromLatent": ImageGenResolutionFromLatent,
"HintImageEnchance": HintImageEnchance
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PixelPerfectResolution": "Pixel Perfect Resolution",
"ImageGenResolutionFromImage": "Generation Resolution From Image",
"ImageGenResolutionFromLatent": "Generation Resolution From Latent",
"HintImageEnchance": "Enchance And Resize Hint Images"
}

View File

@@ -0,0 +1,20 @@
@echo off
set "requirements_txt=%~dp0\requirements.txt"
set "python_exec=..\..\..\python_embeded\python.exe"
echo Installing ComfyUI's ControlNet Auxiliary Preprocessors..
if exist "%python_exec%" (
echo Installing with ComfyUI Portable
for /f "delims=" %%i in (%requirements_txt%) do (
%python_exec% -s -m pip install "%%i"
)
) else (
echo Installing with system Python
for /f "delims=" %%i in (%requirements_txt%) do (
pip install "%%i"
)
)
pause

View File

@@ -0,0 +1,80 @@
#Cre: https://github.com/melMass/comfy_mtb/blob/main/log.py
import logging
import re
import os
base_log_level = logging.INFO
# Custom object that discards the output
class NullWriter:
def write(self, text):
pass
class Formatter(logging.Formatter):
grey = "\x1b[38;20m"
cyan = "\x1b[36;20m"
purple = "\x1b[35;20m"
yellow = "\x1b[33;20m"
red = "\x1b[31;20m"
bold_red = "\x1b[31;1m"
reset = "\x1b[0m"
# format = "%(asctime)s - [%(name)s] - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
format = "[%(name)s] | %(levelname)s -> %(message)s"
FORMATS = {
logging.DEBUG: purple + format + reset,
logging.INFO: cyan + format + reset,
logging.WARNING: yellow + format + reset,
logging.ERROR: red + format + reset,
logging.CRITICAL: bold_red + format + reset,
}
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt)
return formatter.format(record)
def mklog(name, level=base_log_level):
logger = logging.getLogger(name)
logger.setLevel(level)
for handler in logger.handlers:
logger.removeHandler(handler)
ch = logging.StreamHandler()
ch.setLevel(level)
ch.setFormatter(Formatter())
logger.addHandler(ch)
# Disable log propagation
logger.propagate = False
return logger
# - The main app logger
log = mklog(__package__, base_log_level)
def log_user(arg):
print("\033[34mComfyUI ControlNet AUX:\033[0m {arg}")
def get_summary(docstring):
return docstring.strip().split("\n\n", 1)[0]
def blue_text(text):
return f"\033[94m{text}\033[0m"
def cyan_text(text):
return f"\033[96m{text}\033[0m"
def get_label(label):
words = re.findall(r"(?:^|[A-Z])[a-z]*", label)
return " ".join(words).strip()

View File

@@ -0,0 +1,87 @@
# High Quality Edge Thinning using Pure Python
# Written by Lvmin Zhang
# 2023 April
# Stanford University
# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
import cv2
import numpy as np
lvmin_kernels_raw = [
np.array([
[-1, -1, -1],
[0, 1, 0],
[1, 1, 1]
], dtype=np.int32),
np.array([
[0, -1, -1],
[1, 1, -1],
[0, 1, 0]
], dtype=np.int32)
]
lvmin_kernels = []
lvmin_kernels += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_kernels += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_prunings_raw = [
np.array([
[-1, -1, -1],
[-1, 1, -1],
[0, 0, -1]
], dtype=np.int32),
np.array([
[-1, -1, -1],
[-1, 1, -1],
[-1, 0, 0]
], dtype=np.int32)
]
lvmin_prunings = []
lvmin_prunings += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_prunings_raw]
lvmin_prunings += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_prunings_raw]
lvmin_prunings += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_prunings_raw]
lvmin_prunings += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_prunings_raw]
def remove_pattern(x, kernel):
objects = cv2.morphologyEx(x, cv2.MORPH_HITMISS, kernel)
objects = np.where(objects > 127)
x[objects] = 0
return x, objects[0].shape[0] > 0
def thin_one_time(x, kernels):
y = x
is_done = True
for k in kernels:
y, has_update = remove_pattern(y, k)
if has_update:
is_done = False
return y, is_done
def lvmin_thin(x, prunings=True):
y = x
for i in range(32):
y, is_done = thin_one_time(y, lvmin_kernels)
if is_done:
break
if prunings:
y, _ = thin_one_time(y, lvmin_prunings)
return y
def nake_nms(x):
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(x)
for f in [f1, f2, f3, f4]:
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
return y

View File

@@ -0,0 +1,43 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
import torch
from einops import rearrange
class AnimeFace_SemSegPreprocessor:
@classmethod
def INPUT_TYPES(s):
#This preprocessor is only trained on 512x resolution
#https://github.com/siyeong0/Anime-Face-Segmentation/blob/main/predict.py#L25
return define_preprocessor_inputs(
remove_background_using_abg=INPUT.BOOLEAN(True),
resolution=INPUT.RESOLUTION(default=512, min=512, max=512)
)
RETURN_TYPES = ("IMAGE", "MASK")
RETURN_NAMES = ("IMAGE", "ABG_CHARACTER_MASK (MASK)")
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Semantic Segmentation"
def execute(self, image, remove_background_using_abg=True, resolution=512, **kwargs):
from custom_controlnet_aux.anime_face_segment import AnimeFaceSegmentor
model = AnimeFaceSegmentor.from_pretrained().to(model_management.get_torch_device())
if remove_background_using_abg:
out_image_with_mask = common_annotator_call(model, image, resolution=resolution, remove_background=True)
out_image = out_image_with_mask[..., :3]
mask = out_image_with_mask[..., 3:]
mask = rearrange(mask, "n h w c -> n c h w")
else:
out_image = common_annotator_call(model, image, resolution=resolution, remove_background=False)
N, H, W, C = out_image.shape
mask = torch.ones(N, C, H, W)
del model
return (out_image, mask)
NODE_CLASS_MAPPINGS = {
"AnimeFace_SemSegPreprocessor": AnimeFace_SemSegPreprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"AnimeFace_SemSegPreprocessor": "Anime Face Segmentor"
}

View File

@@ -0,0 +1,87 @@
import torch
import numpy as np
import comfy.model_management as model_management
import comfy.utils
# Requires comfyui_controlnet_aux funcsions and classes
from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
def get_intensity_mask(image_array, lower_bound, upper_bound):
mask = image_array[:, :, 0]
mask = np.where((mask >= lower_bound) & (mask <= upper_bound), mask, 0)
mask = np.expand_dims(mask, 2).repeat(3, axis=2)
return mask
def combine_layers(base_layer, top_layer):
mask = top_layer.astype(bool)
temp = 1 - (1 - top_layer) * (1 - base_layer)
result = base_layer * (~mask) + temp * mask
return result
class AnyLinePreprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
merge_with_lineart=INPUT.COMBO(["lineart_standard", "lineart_realisitic", "lineart_anime", "manga_line"], default="lineart_standard"),
resolution=INPUT.RESOLUTION(default=1280, step=8),
lineart_lower_bound=INPUT.FLOAT(default=0),
lineart_upper_bound=INPUT.FLOAT(default=1),
object_min_size=INPUT.INT(default=36, min=1),
object_connectivity=INPUT.INT(default=1, min=1)
)
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "get_anyline"
CATEGORY = "ControlNet Preprocessors/Line Extractors"
def __init__(self):
self.device = model_management.get_torch_device()
def get_anyline(self, image, merge_with_lineart="lineart_standard", resolution=512, lineart_lower_bound=0, lineart_upper_bound=1, object_min_size=36, object_connectivity=1):
from custom_controlnet_aux.teed import TEDDetector
from skimage import morphology
pbar = comfy.utils.ProgressBar(3)
# Process the image with MTEED model
mteed_model = TEDDetector.from_pretrained("TheMistoAI/MistoLine", "MTEED.pth", subfolder="Anyline").to(self.device)
mteed_result = common_annotator_call(mteed_model, image, resolution=resolution, show_pbar=False)
mteed_result = mteed_result.numpy()
del mteed_model
pbar.update(1)
# Process the image with the lineart standard preprocessor
if merge_with_lineart == "lineart_standard":
from custom_controlnet_aux.lineart_standard import LineartStandardDetector
lineart_standard_detector = LineartStandardDetector()
lineart_result = common_annotator_call(lineart_standard_detector, image, guassian_sigma=2, intensity_threshold=3, resolution=resolution, show_pbar=False).numpy()
del lineart_standard_detector
else:
from custom_controlnet_aux.lineart import LineartDetector
from custom_controlnet_aux.lineart_anime import LineartAnimeDetector
from custom_controlnet_aux.manga_line import LineartMangaDetector
lineart_detector = dict(lineart_realisitic=LineartDetector, lineart_anime=LineartAnimeDetector, manga_line=LineartMangaDetector)[merge_with_lineart]
lineart_detector = lineart_detector.from_pretrained().to(self.device)
lineart_result = common_annotator_call(lineart_detector, image, resolution=resolution, show_pbar=False).numpy()
del lineart_detector
pbar.update(1)
final_result = []
for i in range(len(image)):
_lineart_result = get_intensity_mask(lineart_result[i], lower_bound=lineart_lower_bound, upper_bound=lineart_upper_bound)
_cleaned = morphology.remove_small_objects(_lineart_result.astype(bool), min_size=object_min_size, connectivity=object_connectivity)
_lineart_result = _lineart_result * _cleaned
_mteed_result = mteed_result[i]
# Combine the results
final_result.append(torch.from_numpy(combine_layers(_mteed_result, _lineart_result)))
pbar.update(1)
return (torch.stack(final_result),)
NODE_CLASS_MAPPINGS = {
"AnyLineArtPreprocessor_aux": AnyLinePreprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"AnyLineArtPreprocessor_aux": "AnyLine Lineart"
}

View File

@@ -0,0 +1,29 @@
from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
import comfy.model_management as model_management
class Binary_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
bin_threshold=INPUT.INT(default=100, max=255),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Line Extractors"
def execute(self, image, bin_threshold=100, resolution=512, **kwargs):
from custom_controlnet_aux.binary import BinaryDetector
return (common_annotator_call(BinaryDetector(), image, bin_threshold=bin_threshold, resolution=resolution), )
NODE_CLASS_MAPPINGS = {
"BinaryPreprocessor": Binary_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"BinaryPreprocessor": "Binary Lines"
}

View File

@@ -0,0 +1,30 @@
from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
import comfy.model_management as model_management
class Canny_Edge_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
low_threshold=INPUT.INT(default=100, max=255),
high_threshold=INPUT.INT(default=200, max=255),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Line Extractors"
def execute(self, image, low_threshold=100, high_threshold=200, resolution=512, **kwargs):
from custom_controlnet_aux.canny import CannyDetector
return (common_annotator_call(CannyDetector(), image, low_threshold=low_threshold, high_threshold=high_threshold, resolution=resolution), )
NODE_CLASS_MAPPINGS = {
"CannyEdgePreprocessor": Canny_Edge_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CannyEdgePreprocessor": "Canny Edge"
}

View File

@@ -0,0 +1,26 @@
from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
import comfy.model_management as model_management
class Color_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/T2IAdapter-only"
def execute(self, image, resolution=512, **kwargs):
from custom_controlnet_aux.color import ColorDetector
return (common_annotator_call(ColorDetector(), image, resolution=resolution), )
NODE_CLASS_MAPPINGS = {
"ColorPreprocessor": Color_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ColorPreprocessor": "Color Pallete"
}

View File

@@ -0,0 +1,31 @@
from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
import comfy.model_management as model_management
class DensePose_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
model=INPUT.COMBO(["densepose_r50_fpn_dl.torchscript", "densepose_r101_fpn_dl.torchscript"]),
cmap=INPUT.COMBO(["Viridis (MagicAnimate)", "Parula (CivitAI)"]),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Faces and Poses Estimators"
def execute(self, image, model="densepose_r50_fpn_dl.torchscript", cmap="Viridis (MagicAnimate)", resolution=512):
from custom_controlnet_aux.densepose import DenseposeDetector
model = DenseposeDetector \
.from_pretrained(filename=model) \
.to(model_management.get_torch_device())
return (common_annotator_call(model, image, cmap="viridis" if "Viridis" in cmap else "parula", resolution=resolution), )
NODE_CLASS_MAPPINGS = {
"DensePosePreprocessor": DensePose_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DensePosePreprocessor": "DensePose Estimator"
}

View File

@@ -0,0 +1,55 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
class Depth_Anything_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
ckpt_name=INPUT.COMBO(
["depth_anything_vitl14.pth", "depth_anything_vitb14.pth", "depth_anything_vits14.pth"]
),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
def execute(self, image, ckpt_name="depth_anything_vitl14.pth", resolution=512, **kwargs):
from custom_controlnet_aux.depth_anything import DepthAnythingDetector
model = DepthAnythingDetector.from_pretrained(filename=ckpt_name).to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution)
del model
return (out, )
class Zoe_Depth_Anything_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
environment=INPUT.COMBO(["indoor", "outdoor"]),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
def execute(self, image, environment="indoor", resolution=512, **kwargs):
from custom_controlnet_aux.zoe import ZoeDepthAnythingDetector
ckpt_name = "depth_anything_metric_depth_indoor.pt" if environment == "indoor" else "depth_anything_metric_depth_outdoor.pt"
model = ZoeDepthAnythingDetector.from_pretrained(filename=ckpt_name).to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution)
del model
return (out, )
NODE_CLASS_MAPPINGS = {
"DepthAnythingPreprocessor": Depth_Anything_Preprocessor,
"Zoe_DepthAnythingPreprocessor": Zoe_Depth_Anything_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DepthAnythingPreprocessor": "Depth Anything",
"Zoe_DepthAnythingPreprocessor": "Zoe Depth Anything"
}

View File

@@ -0,0 +1,56 @@
from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
import comfy.model_management as model_management
class Depth_Anything_V2_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
ckpt_name=INPUT.COMBO(
["depth_anything_v2_vitg.pth", "depth_anything_v2_vitl.pth", "depth_anything_v2_vitb.pth", "depth_anything_v2_vits.pth"],
default="depth_anything_v2_vitl.pth"
),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
def execute(self, image, ckpt_name="depth_anything_v2_vitl.pth", resolution=512, **kwargs):
from custom_controlnet_aux.depth_anything_v2 import DepthAnythingV2Detector
model = DepthAnythingV2Detector.from_pretrained(filename=ckpt_name).to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution, max_depth=1)
del model
return (out, )
""" class Depth_Anything_Metric_V2_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return create_node_input_types(
environment=(["indoor", "outdoor"], {"default": "indoor"}),
max_depth=("FLOAT", {"min": 0, "max": 100, "default": 20.0, "step": 0.01})
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
def execute(self, image, environment, resolution=512, max_depth=20.0, **kwargs):
from custom_controlnet_aux.depth_anything_v2 import DepthAnythingV2Detector
filename = dict(indoor="depth_anything_v2_metric_hypersim_vitl.pth", outdoor="depth_anything_v2_metric_vkitti_vitl.pth")[environment]
model = DepthAnythingV2Detector.from_pretrained(filename=filename).to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution, max_depth=max_depth)
del model
return (out, ) """
NODE_CLASS_MAPPINGS = {
"DepthAnythingV2Preprocessor": Depth_Anything_V2_Preprocessor,
#"Metric_DepthAnythingV2Preprocessor": Depth_Anything_Metric_V2_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DepthAnythingV2Preprocessor": "Depth Anything V2 - Relative",
#"Metric_DepthAnythingV2Preprocessor": "Depth Anything V2 - Metric"
}

View File

@@ -0,0 +1,41 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, run_script
import comfy.model_management as model_management
import sys
def install_deps():
try:
import sklearn
except:
run_script([sys.executable, '-s', '-m', 'pip', 'install', 'scikit-learn'])
class DiffusionEdge_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
environment=INPUT.COMBO(["indoor", "urban", "natrual"]),
patch_batch_size=INPUT.INT(default=4, min=1, max=16),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Line Extractors"
def execute(self, image, environment="indoor", patch_batch_size=4, resolution=512, **kwargs):
install_deps()
from custom_controlnet_aux.diffusion_edge import DiffusionEdgeDetector
model = DiffusionEdgeDetector \
.from_pretrained(filename = f"diffusion_edge_{environment}.pt") \
.to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution, patch_batch_size=patch_batch_size)
del model
return (out, )
NODE_CLASS_MAPPINGS = {
"DiffusionEdge_Preprocessor": DiffusionEdge_Preprocessor,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DiffusionEdge_Preprocessor": "Diffusion Edge (batch size ↑ => speed ↑, VRAM ↑)",
}

View File

@@ -0,0 +1,31 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
class DSINE_Normal_Map_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
fov=INPUT.FLOAT(max=365.0, default=60.0),
iterations=INPUT.INT(min=1, max=20, default=5),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
def execute(self, image, fov=60.0, iterations=5, resolution=512, **kwargs):
from custom_controlnet_aux.dsine import DsineDetector
model = DsineDetector.from_pretrained().to(model_management.get_torch_device())
out = common_annotator_call(model, image, fov=fov, iterations=iterations, resolution=resolution)
del model
return (out,)
NODE_CLASS_MAPPINGS = {
"DSINE-NormalMapPreprocessor": DSINE_Normal_Map_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DSINE-NormalMapPreprocessor": "DSINE Normal Map"
}

View File

@@ -0,0 +1,166 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
import numpy as np
import warnings
from ..src.custom_controlnet_aux.dwpose import DwposeDetector, AnimalposeDetector
import os
import json
DWPOSE_MODEL_NAME = "yzd-v/DWPose"
#Trigger startup caching for onnxruntime
GPU_PROVIDERS = ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CoreMLExecutionProvider"]
def check_ort_gpu():
try:
import onnxruntime as ort
for provider in GPU_PROVIDERS:
if provider in ort.get_available_providers():
return True
return False
except:
return False
if not os.environ.get("DWPOSE_ONNXRT_CHECKED"):
if check_ort_gpu():
print("DWPose: Onnxruntime with acceleration providers detected")
else:
warnings.warn("DWPose: Onnxruntime not found or doesn't come with acceleration providers, switch to OpenCV with CPU device. DWPose might run very slowly")
os.environ['AUX_ORT_PROVIDERS'] = ''
os.environ["DWPOSE_ONNXRT_CHECKED"] = '1'
class DWPose_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
detect_hand=INPUT.COMBO(["enable", "disable"]),
detect_body=INPUT.COMBO(["enable", "disable"]),
detect_face=INPUT.COMBO(["enable", "disable"]),
resolution=INPUT.RESOLUTION(),
bbox_detector=INPUT.COMBO(
["None"] + ["yolox_l.torchscript.pt", "yolox_l.onnx", "yolo_nas_l_fp16.onnx", "yolo_nas_m_fp16.onnx", "yolo_nas_s_fp16.onnx"],
default="yolox_l.onnx"
),
pose_estimator=INPUT.COMBO(
["dw-ll_ucoco_384_bs5.torchscript.pt", "dw-ll_ucoco_384.onnx", "dw-ll_ucoco.onnx"],
default="dw-ll_ucoco_384_bs5.torchscript.pt"
),
scale_stick_for_xinsr_cn=INPUT.COMBO(["disable", "enable"])
)
RETURN_TYPES = ("IMAGE", "POSE_KEYPOINT")
FUNCTION = "estimate_pose"
CATEGORY = "ControlNet Preprocessors/Faces and Poses Estimators"
def estimate_pose(self, image, detect_hand="enable", detect_body="enable", detect_face="enable", resolution=512, bbox_detector="yolox_l.onnx", pose_estimator="dw-ll_ucoco_384.onnx", scale_stick_for_xinsr_cn="disable", **kwargs):
if bbox_detector == "None":
yolo_repo = DWPOSE_MODEL_NAME
elif bbox_detector == "yolox_l.onnx":
yolo_repo = DWPOSE_MODEL_NAME
elif "yolox" in bbox_detector:
yolo_repo = "hr16/yolox-onnx"
elif "yolo_nas" in bbox_detector:
yolo_repo = "hr16/yolo-nas-fp16"
else:
raise NotImplementedError(f"Download mechanism for {bbox_detector}")
if pose_estimator == "dw-ll_ucoco_384.onnx":
pose_repo = DWPOSE_MODEL_NAME
elif pose_estimator.endswith(".onnx"):
pose_repo = "hr16/UnJIT-DWPose"
elif pose_estimator.endswith(".torchscript.pt"):
pose_repo = "hr16/DWPose-TorchScript-BatchSize5"
else:
raise NotImplementedError(f"Download mechanism for {pose_estimator}")
model = DwposeDetector.from_pretrained(
pose_repo,
yolo_repo,
det_filename=(None if bbox_detector == "None" else bbox_detector), pose_filename=pose_estimator,
torchscript_device=model_management.get_torch_device()
)
detect_hand = detect_hand == "enable"
detect_body = detect_body == "enable"
detect_face = detect_face == "enable"
scale_stick_for_xinsr_cn = scale_stick_for_xinsr_cn == "enable"
self.openpose_dicts = []
def func(image, **kwargs):
pose_img, openpose_dict = model(image, **kwargs)
self.openpose_dicts.append(openpose_dict)
return pose_img
out = common_annotator_call(func, image, include_hand=detect_hand, include_face=detect_face, include_body=detect_body, image_and_json=True, resolution=resolution, xinsr_stick_scaling=scale_stick_for_xinsr_cn)
del model
return {
'ui': { "openpose_json": [json.dumps(self.openpose_dicts, indent=4)] },
"result": (out, self.openpose_dicts)
}
class AnimalPose_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
bbox_detector = INPUT.COMBO(
["None"] + ["yolox_l.torchscript.pt", "yolox_l.onnx", "yolo_nas_l_fp16.onnx", "yolo_nas_m_fp16.onnx", "yolo_nas_s_fp16.onnx"],
default="yolox_l.torchscript.pt"
),
pose_estimator = INPUT.COMBO(
["rtmpose-m_ap10k_256_bs5.torchscript.pt", "rtmpose-m_ap10k_256.onnx"],
default="rtmpose-m_ap10k_256_bs5.torchscript.pt"
),
resolution = INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE", "POSE_KEYPOINT")
FUNCTION = "estimate_pose"
CATEGORY = "ControlNet Preprocessors/Faces and Poses Estimators"
def estimate_pose(self, image, resolution=512, bbox_detector="yolox_l.onnx", pose_estimator="rtmpose-m_ap10k_256.onnx", **kwargs):
if bbox_detector == "None":
yolo_repo = DWPOSE_MODEL_NAME
elif bbox_detector == "yolox_l.onnx":
yolo_repo = DWPOSE_MODEL_NAME
elif "yolox" in bbox_detector:
yolo_repo = "hr16/yolox-onnx"
elif "yolo_nas" in bbox_detector:
yolo_repo = "hr16/yolo-nas-fp16"
else:
raise NotImplementedError(f"Download mechanism for {bbox_detector}")
if pose_estimator == "dw-ll_ucoco_384.onnx":
pose_repo = DWPOSE_MODEL_NAME
elif pose_estimator.endswith(".onnx"):
pose_repo = "hr16/UnJIT-DWPose"
elif pose_estimator.endswith(".torchscript.pt"):
pose_repo = "hr16/DWPose-TorchScript-BatchSize5"
else:
raise NotImplementedError(f"Download mechanism for {pose_estimator}")
model = AnimalposeDetector.from_pretrained(
pose_repo,
yolo_repo,
det_filename=(None if bbox_detector == "None" else bbox_detector), pose_filename=pose_estimator,
torchscript_device=model_management.get_torch_device()
)
self.openpose_dicts = []
def func(image, **kwargs):
pose_img, openpose_dict = model(image, **kwargs)
self.openpose_dicts.append(openpose_dict)
return pose_img
out = common_annotator_call(func, image, image_and_json=True, resolution=resolution)
del model
return {
'ui': { "openpose_json": [json.dumps(self.openpose_dicts, indent=4)] },
"result": (out, self.openpose_dicts)
}
NODE_CLASS_MAPPINGS = {
"DWPreprocessor": DWPose_Preprocessor,
"AnimalPosePreprocessor": AnimalPose_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DWPreprocessor": "DWPose Estimator",
"AnimalPosePreprocessor": "AnimalPose Estimator (AP10K)"
}

View File

@@ -0,0 +1,53 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
class HED_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
safe=INPUT.COMBO(["enable", "disable"]),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Line Extractors"
def execute(self, image, resolution=512, **kwargs):
from custom_controlnet_aux.hed import HEDdetector
model = HEDdetector.from_pretrained().to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution, safe = kwargs["safe"] == "enable")
del model
return (out, )
class Fake_Scribble_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
safe=INPUT.COMBO(["enable", "disable"]),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Line Extractors"
def execute(self, image, resolution=512, **kwargs):
from custom_controlnet_aux.hed import HEDdetector
model = HEDdetector.from_pretrained().to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution, scribble=True, safe=kwargs["safe"]=="enable")
del model
return (out, )
NODE_CLASS_MAPPINGS = {
"HEDPreprocessor": HED_Preprocessor,
"FakeScribblePreprocessor": Fake_Scribble_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"HEDPreprocessor": "HED Soft-Edge Lines",
"FakeScribblePreprocessor": "Fake Scribble Lines (aka scribble_hed)"
}

View File

@@ -0,0 +1,32 @@
import torch
from ..utils import INPUT
class InpaintPreprocessor:
@classmethod
def INPUT_TYPES(s):
return dict(
required=dict(image=INPUT.IMAGE(), mask=INPUT.MASK()),
optional=dict(black_pixel_for_xinsir_cn=INPUT.BOOLEAN(False))
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "preprocess"
CATEGORY = "ControlNet Preprocessors/others"
def preprocess(self, image, mask, black_pixel_for_xinsir_cn=False):
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(image.shape[1], image.shape[2]), mode="bilinear")
mask = mask.movedim(1,-1).expand((-1,-1,-1,3))
image = image.clone()
if black_pixel_for_xinsir_cn:
masked_pixel = 0.0
else:
masked_pixel = -1.0
image[mask > 0.5] = masked_pixel
return (image,)
NODE_CLASS_MAPPINGS = {
"InpaintPreprocessor": InpaintPreprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"InpaintPreprocessor": "Inpaint Preprocessor"
}

View File

@@ -0,0 +1,32 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
class LERES_Depth_Map_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
rm_nearest=INPUT.FLOAT(max=100.0),
rm_background=INPUT.FLOAT(max=100.0),
boost=INPUT.COMBO(["disable", "enable"]),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
def execute(self, image, rm_nearest=0, rm_background=0, resolution=512, boost="disable", **kwargs):
from custom_controlnet_aux.leres import LeresDetector
model = LeresDetector.from_pretrained().to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution, thr_a=rm_nearest, thr_b=rm_background, boost=boost == "enable")
del model
return (out, )
NODE_CLASS_MAPPINGS = {
"LeReS-DepthMapPreprocessor": LERES_Depth_Map_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LeReS-DepthMapPreprocessor": "LeReS Depth Map (enable boost for leres++)"
}

View File

@@ -0,0 +1,30 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
class LineArt_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
coarse=INPUT.COMBO((["disable", "enable"])),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Line Extractors"
def execute(self, image, resolution=512, **kwargs):
from custom_controlnet_aux.lineart import LineartDetector
model = LineartDetector.from_pretrained().to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution, coarse = kwargs["coarse"] == "enable")
del model
return (out, )
NODE_CLASS_MAPPINGS = {
"LineArtPreprocessor": LineArt_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LineArtPreprocessor": "Realistic Lineart"
}

View File

@@ -0,0 +1,27 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
class AnimeLineArt_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Line Extractors"
def execute(self, image, resolution=512, **kwargs):
from custom_controlnet_aux.lineart_anime import LineartAnimeDetector
model = LineartAnimeDetector.from_pretrained().to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution)
del model
return (out, )
NODE_CLASS_MAPPINGS = {
"AnimeLineArtPreprocessor": AnimeLineArt_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"AnimeLineArtPreprocessor": "Anime Lineart"
}

View File

@@ -0,0 +1,27 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
class Lineart_Standard_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
guassian_sigma=INPUT.FLOAT(default=6.0, max=100.0),
intensity_threshold=INPUT.INT(default=8, max=16),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Line Extractors"
def execute(self, image, guassian_sigma=6, intensity_threshold=8, resolution=512, **kwargs):
from custom_controlnet_aux.lineart_standard import LineartStandardDetector
return (common_annotator_call(LineartStandardDetector(), image, guassian_sigma=guassian_sigma, intensity_threshold=intensity_threshold, resolution=resolution), )
NODE_CLASS_MAPPINGS = {
"LineartStandardPreprocessor": Lineart_Standard_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LineartStandardPreprocessor": "Standard Lineart"
}

View File

@@ -0,0 +1,27 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
class Manga2Anime_LineArt_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Line Extractors"
def execute(self, image, resolution=512, **kwargs):
from custom_controlnet_aux.manga_line import LineartMangaDetector
model = LineartMangaDetector.from_pretrained().to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution)
del model
return (out, )
NODE_CLASS_MAPPINGS = {
"Manga2Anime_LineArt_Preprocessor": Manga2Anime_LineArt_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Manga2Anime_LineArt_Preprocessor": "Manga Lineart (aka lineart_anime_denoise)"
}

View File

@@ -0,0 +1,39 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, run_script
import comfy.model_management as model_management
import os, sys
import subprocess, threading
def install_deps():
try:
import mediapipe
except ImportError:
run_script([sys.executable, '-s', '-m', 'pip', 'install', 'mediapipe'])
run_script([sys.executable, '-s', '-m', 'pip', 'install', '--upgrade', 'protobuf'])
class Media_Pipe_Face_Mesh_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
max_faces=INPUT.INT(default=10, min=1, max=50), #Which image has more than 50 detectable faces?
min_confidence=INPUT.FLOAT(default=0.5, min=0.1),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "detect"
CATEGORY = "ControlNet Preprocessors/Faces and Poses Estimators"
def detect(self, image, max_faces=10, min_confidence=0.5, resolution=512):
#Ref: https://github.com/Fannovel16/comfy_controlnet_preprocessors/issues/70#issuecomment-1677967369
install_deps()
from custom_controlnet_aux.mediapipe_face import MediapipeFaceDetector
return (common_annotator_call(MediapipeFaceDetector(), image, max_faces=max_faces, min_confidence=min_confidence, resolution=resolution), )
NODE_CLASS_MAPPINGS = {
"MediaPipe-FaceMeshPreprocessor": Media_Pipe_Face_Mesh_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"MediaPipe-FaceMeshPreprocessor": "MediaPipe Face Mesh"
}

View File

@@ -0,0 +1,158 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, MAX_RESOLUTION, run_script
import comfy.model_management as model_management
import numpy as np
import torch
from einops import rearrange
import os, sys
import subprocess, threading
import scipy.ndimage
import cv2
import torch.nn.functional as F
def install_deps():
try:
import mediapipe
except ImportError:
run_script([sys.executable, '-s', '-m', 'pip', 'install', 'mediapipe'])
run_script([sys.executable, '-s', '-m', 'pip', 'install', '--upgrade', 'protobuf'])
try:
import trimesh
except ImportError:
run_script([sys.executable, '-s', '-m', 'pip', 'install', 'trimesh[easy]'])
#Sauce: https://github.com/comfyanonymous/ComfyUI/blob/8c6493578b3dda233e9b9a953feeaf1e6ca434ad/comfy_extras/nodes_mask.py#L309
def expand_mask(mask, expand, tapered_corners):
c = 0 if tapered_corners else 1
kernel = np.array([[c, 1, c],
[1, 1, 1],
[c, 1, c]])
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
out = []
for m in mask:
output = m.numpy()
for _ in range(abs(expand)):
if expand < 0:
output = scipy.ndimage.grey_erosion(output, footprint=kernel)
else:
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
output = torch.from_numpy(output)
out.append(output)
return torch.stack(out, dim=0)
class Mesh_Graphormer_Depth_Map_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
mask_bbox_padding=("INT", {"default": 30, "min": 0, "max": 100}),
resolution=INPUT.RESOLUTION(),
mask_type=INPUT.COMBO(["based_on_depth", "tight_bboxes", "original"]),
mask_expand=INPUT.INT(default=5, min=-MAX_RESOLUTION, max=MAX_RESOLUTION),
rand_seed=INPUT.INT(default=88, min=0, max=0xffffffffffffffff),
detect_thr=INPUT.FLOAT(default=0.6, min=0.1),
presence_thr=INPUT.FLOAT(default=0.6, min=0.1)
)
RETURN_TYPES = ("IMAGE", "MASK")
RETURN_NAMES = ("IMAGE", "INPAINTING_MASK")
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
def execute(self, image, mask_bbox_padding=30, mask_type="based_on_depth", mask_expand=5, resolution=512, rand_seed=88, detect_thr=0.6, presence_thr=0.6, **kwargs):
install_deps()
from custom_controlnet_aux.mesh_graphormer import MeshGraphormerDetector
model = kwargs["model"] if "model" in kwargs \
else MeshGraphormerDetector.from_pretrained(detect_thr=detect_thr, presence_thr=presence_thr).to(model_management.get_torch_device())
depth_map_list = []
mask_list = []
for single_image in image:
np_image = np.asarray(single_image.cpu() * 255., dtype=np.uint8)
depth_map, mask, info = model(np_image, output_type="np", detect_resolution=resolution, mask_bbox_padding=mask_bbox_padding, seed=rand_seed)
if mask_type == "based_on_depth":
H, W = mask.shape[:2]
mask = cv2.resize(depth_map.copy(), (W, H))
mask[mask > 0] = 255
elif mask_type == "tight_bboxes":
mask = np.zeros_like(mask)
hand_bboxes = (info or {}).get("abs_boxes") or []
for hand_bbox in hand_bboxes:
x_min, x_max, y_min, y_max = hand_bbox
mask[y_min:y_max+1, x_min:x_max+1, :] = 255 #HWC
mask = mask[:, :, :1]
depth_map_list.append(torch.from_numpy(depth_map.astype(np.float32) / 255.0))
mask_list.append(torch.from_numpy(mask.astype(np.float32) / 255.0))
depth_maps, masks = torch.stack(depth_map_list, dim=0), rearrange(torch.stack(mask_list, dim=0), "n h w 1 -> n 1 h w")
return depth_maps, expand_mask(masks, mask_expand, tapered_corners=True)
def normalize_size_base_64(w, h):
short_side = min(w, h)
remainder = short_side % 64
return short_side - remainder + (64 if remainder > 0 else 0)
class Mesh_Graphormer_With_ImpactDetector_Depth_Map_Preprocessor:
@classmethod
def INPUT_TYPES(s):
types = define_preprocessor_inputs(
# Impact pack
bbox_threshold=INPUT.FLOAT(default=0.5, min=0.1),
bbox_dilation=INPUT.INT(default=10, min=-512, max=512),
bbox_crop_factor=INPUT.FLOAT(default=3.0, min=1.0, max=10.0),
drop_size=INPUT.INT(default=10, min=1, max=MAX_RESOLUTION),
# Mesh Graphormer
mask_bbox_padding=INPUT.INT(default=30, min=0, max=100),
mask_type=INPUT.COMBO(["based_on_depth", "tight_bboxes", "original"]),
mask_expand=INPUT.INT(default=5, min=-MAX_RESOLUTION, max=MAX_RESOLUTION),
rand_seed=INPUT.INT(default=88, min=0, max=0xffffffffffffffff),
resolution=INPUT.RESOLUTION()
)
types["required"]["bbox_detector"] = ("BBOX_DETECTOR", )
return types
RETURN_TYPES = ("IMAGE", "MASK")
RETURN_NAMES = ("IMAGE", "INPAINTING_MASK")
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
def execute(self, image, bbox_detector, bbox_threshold=0.5, bbox_dilation=10, bbox_crop_factor=3.0, drop_size=10, resolution=512, **mesh_graphormer_kwargs):
install_deps()
from custom_controlnet_aux.mesh_graphormer import MeshGraphormerDetector
mesh_graphormer_node = Mesh_Graphormer_Depth_Map_Preprocessor()
model = MeshGraphormerDetector.from_pretrained(detect_thr=0.6, presence_thr=0.6).to(model_management.get_torch_device())
mesh_graphormer_kwargs["model"] = model
frames = image
depth_maps, masks = [], []
for idx in range(len(frames)):
frame = frames[idx:idx+1,...] #Impact Pack's BBOX_DETECTOR only supports single batch image
bbox_detector.setAux('face') # make default prompt as 'face' if empty prompt for CLIPSeg
_, segs = bbox_detector.detect(frame, bbox_threshold, bbox_dilation, bbox_crop_factor, drop_size)
bbox_detector.setAux(None)
n, h, w, _ = frame.shape
depth_map, mask = torch.zeros_like(frame), torch.zeros(n, 1, h, w)
for i, seg in enumerate(segs):
x1, y1, x2, y2 = seg.crop_region
cropped_image = frame[:, y1:y2, x1:x2, :] # Never use seg.cropped_image to handle overlapping area
mesh_graphormer_kwargs["resolution"] = 0 #Disable resizing
sub_depth_map, sub_mask = mesh_graphormer_node.execute(cropped_image, **mesh_graphormer_kwargs)
depth_map[:, y1:y2, x1:x2, :] = sub_depth_map
mask[:, :, y1:y2, x1:x2] = sub_mask
depth_maps.append(depth_map)
masks.append(mask)
return (torch.cat(depth_maps), torch.cat(masks))
NODE_CLASS_MAPPINGS = {
"MeshGraphormer-DepthMapPreprocessor": Mesh_Graphormer_Depth_Map_Preprocessor,
"MeshGraphormer+ImpactDetector-DepthMapPreprocessor": Mesh_Graphormer_With_ImpactDetector_Depth_Map_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"MeshGraphormer-DepthMapPreprocessor": "MeshGraphormer Hand Refiner",
"MeshGraphormer+ImpactDetector-DepthMapPreprocessor": "MeshGraphormer Hand Refiner With External Detector"
}

View File

@@ -0,0 +1,62 @@
import os
# Disable NPU device initialization and problematic MMCV ops to prevent RuntimeError
os.environ['NPU_DEVICE_COUNT'] = '0'
os.environ['MMCV_WITH_OPS'] = '0'
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, MAX_RESOLUTION
import comfy.model_management as model_management
class Metric3D_Depth_Map_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
backbone=INPUT.COMBO(["vit-small", "vit-large", "vit-giant2"]),
fx=INPUT.INT(default=1000, min=1, max=MAX_RESOLUTION),
fy=INPUT.INT(default=1000, min=1, max=MAX_RESOLUTION),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
def execute(self, image, backbone="vit-small", fx=1000, fy=1000, resolution=512):
from custom_controlnet_aux.metric3d import Metric3DDetector
model = Metric3DDetector.from_pretrained(filename=f"metric_depth_{backbone.replace('-', '_')}_800k.pth").to(model_management.get_torch_device())
cb = lambda image, **kwargs: model(image, **kwargs)[0]
out = common_annotator_call(cb, image, resolution=resolution, fx=fx, fy=fy, depth_and_normal=True)
del model
return (out, )
class Metric3D_Normal_Map_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
backbone=INPUT.COMBO(["vit-small", "vit-large", "vit-giant2"]),
fx=INPUT.INT(default=1000, min=1, max=MAX_RESOLUTION),
fy=INPUT.INT(default=1000, min=1, max=MAX_RESOLUTION),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
def execute(self, image, backbone="vit-small", fx=1000, fy=1000, resolution=512):
from custom_controlnet_aux.metric3d import Metric3DDetector
model = Metric3DDetector.from_pretrained(filename=f"metric_depth_{backbone.replace('-', '_')}_800k.pth").to(model_management.get_torch_device())
cb = lambda image, **kwargs: model(image, **kwargs)[1]
out = common_annotator_call(cb, image, resolution=resolution, fx=fx, fy=fy, depth_and_normal=True)
del model
return (out, )
NODE_CLASS_MAPPINGS = {
"Metric3D-DepthMapPreprocessor": Metric3D_Depth_Map_Preprocessor,
"Metric3D-NormalMapPreprocessor": Metric3D_Normal_Map_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Metric3D-DepthMapPreprocessor": "Metric3D Depth Map",
"Metric3D-NormalMapPreprocessor": "Metric3D Normal Map"
}

View File

@@ -0,0 +1,59 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
import numpy as np
class MIDAS_Normal_Map_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
a=INPUT.FLOAT(default=np.pi * 2.0, min=0.0, max=np.pi * 5.0),
bg_threshold=INPUT.FLOAT(default=0.1),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
def execute(self, image, a=np.pi * 2.0, bg_threshold=0.1, resolution=512, **kwargs):
from custom_controlnet_aux.midas import MidasDetector
model = MidasDetector.from_pretrained().to(model_management.get_torch_device())
#Dirty hack :))
cb = lambda image, **kargs: model(image, **kargs)[1]
out = common_annotator_call(cb, image, resolution=resolution, a=a, bg_th=bg_threshold, depth_and_normal=True)
del model
return (out, )
class MIDAS_Depth_Map_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
a=INPUT.FLOAT(default=np.pi * 2.0, min=0.0, max=np.pi * 5.0),
bg_threshold=INPUT.FLOAT(default=0.1),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
def execute(self, image, a=np.pi * 2.0, bg_threshold=0.1, resolution=512, **kwargs):
from custom_controlnet_aux.midas import MidasDetector
# Ref: https://github.com/lllyasviel/ControlNet/blob/main/gradio_depth2image.py
model = MidasDetector.from_pretrained().to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution, a=a, bg_th=bg_threshold)
del model
return (out, )
NODE_CLASS_MAPPINGS = {
"MiDaS-NormalMapPreprocessor": MIDAS_Normal_Map_Preprocessor,
"MiDaS-DepthMapPreprocessor": MIDAS_Depth_Map_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"MiDaS-NormalMapPreprocessor": "MiDaS Normal Map",
"MiDaS-DepthMapPreprocessor": "MiDaS Depth Map"
}

View File

@@ -0,0 +1,31 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
import numpy as np
class MLSD_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
score_threshold=INPUT.FLOAT(default=0.1, min=0.01, max=2.0),
dist_threshold=INPUT.FLOAT(default=0.1, min=0.01, max=20.0),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Line Extractors"
def execute(self, image, score_threshold, dist_threshold, resolution=512, **kwargs):
from custom_controlnet_aux.mlsd import MLSDdetector
model = MLSDdetector.from_pretrained().to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution, thr_v=score_threshold, thr_d=dist_threshold)
return (out, )
NODE_CLASS_MAPPINGS = {
"M-LSDPreprocessor": MLSD_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"M-LSDPreprocessor": "M-LSD Lines"
}

View File

@@ -0,0 +1,27 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
class BAE_Normal_Map_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
def execute(self, image, resolution=512, **kwargs):
from custom_controlnet_aux.normalbae import NormalBaeDetector
model = NormalBaeDetector.from_pretrained().to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution)
del model
return (out,)
NODE_CLASS_MAPPINGS = {
"BAE-NormalMapPreprocessor": BAE_Normal_Map_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"BAE-NormalMapPreprocessor": "BAE Normal Map"
}

View File

@@ -0,0 +1,50 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
class OneFormer_COCO_SemSegPreprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
RETURN_TYPES = ("IMAGE",)
FUNCTION = "semantic_segmentate"
CATEGORY = "ControlNet Preprocessors/Semantic Segmentation"
def semantic_segmentate(self, image, resolution=512):
from custom_controlnet_aux.oneformer import OneformerSegmentor
model = OneformerSegmentor.from_pretrained(filename="150_16_swin_l_oneformer_coco_100ep.pth")
model = model.to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution)
del model
return (out,)
class OneFormer_ADE20K_SemSegPreprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
RETURN_TYPES = ("IMAGE",)
FUNCTION = "semantic_segmentate"
CATEGORY = "ControlNet Preprocessors/Semantic Segmentation"
def semantic_segmentate(self, image, resolution=512):
from custom_controlnet_aux.oneformer import OneformerSegmentor
model = OneformerSegmentor.from_pretrained(filename="250_16_swin_l_oneformer_ade20k_160k.pth")
model = model.to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution)
del model
return (out,)
NODE_CLASS_MAPPINGS = {
"OneFormer-COCO-SemSegPreprocessor": OneFormer_COCO_SemSegPreprocessor,
"OneFormer-ADE20K-SemSegPreprocessor": OneFormer_ADE20K_SemSegPreprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"OneFormer-COCO-SemSegPreprocessor": "OneFormer COCO Segmentor",
"OneFormer-ADE20K-SemSegPreprocessor": "OneFormer ADE20K Segmentor"
}

View File

@@ -0,0 +1,48 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
import json
class OpenPose_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
detect_hand=INPUT.COMBO(["enable", "disable"]),
detect_body=INPUT.COMBO(["enable", "disable"]),
detect_face=INPUT.COMBO(["enable", "disable"]),
resolution=INPUT.RESOLUTION(),
scale_stick_for_xinsr_cn=INPUT.COMBO(["disable", "enable"])
)
RETURN_TYPES = ("IMAGE", "POSE_KEYPOINT")
FUNCTION = "estimate_pose"
CATEGORY = "ControlNet Preprocessors/Faces and Poses Estimators"
def estimate_pose(self, image, detect_hand="enable", detect_body="enable", detect_face="enable", scale_stick_for_xinsr_cn="disable", resolution=512, **kwargs):
from custom_controlnet_aux.open_pose import OpenposeDetector
detect_hand = detect_hand == "enable"
detect_body = detect_body == "enable"
detect_face = detect_face == "enable"
scale_stick_for_xinsr_cn = scale_stick_for_xinsr_cn == "enable"
model = OpenposeDetector.from_pretrained().to(model_management.get_torch_device())
self.openpose_dicts = []
def func(image, **kwargs):
pose_img, openpose_dict = model(image, **kwargs)
self.openpose_dicts.append(openpose_dict)
return pose_img
out = common_annotator_call(func, image, include_hand=detect_hand, include_face=detect_face, include_body=detect_body, image_and_json=True, xinsr_stick_scaling=scale_stick_for_xinsr_cn, resolution=resolution)
del model
return {
'ui': { "openpose_json": [json.dumps(self.openpose_dicts, indent=4)] },
"result": (out, self.openpose_dicts)
}
NODE_CLASS_MAPPINGS = {
"OpenposePreprocessor": OpenPose_Preprocessor,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"OpenposePreprocessor": "OpenPose Pose",
}

View File

@@ -0,0 +1,30 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
class PIDINET_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
safe=INPUT.COMBO(["enable", "disable"]),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Line Extractors"
def execute(self, image, safe, resolution=512, **kwargs):
from custom_controlnet_aux.pidi import PidiNetDetector
model = PidiNetDetector.from_pretrained().to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution, safe = safe == "enable")
del model
return (out, )
NODE_CLASS_MAPPINGS = {
"PiDiNetPreprocessor": PIDINET_Preprocessor,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PiDiNetPreprocessor": "PiDiNet Soft-Edge Lines"
}

View File

@@ -0,0 +1,340 @@
import folder_paths
import json
import os
import numpy as np
import cv2
from PIL import ImageColor
from einops import rearrange
import torch
import itertools
from ..src.custom_controlnet_aux.dwpose import draw_poses, draw_animalposes, decode_json_as_poses
"""
Format of POSE_KEYPOINT (AP10K keypoints):
[{
"version": "ap10k",
"animals": [
[[x1, y1, 1], [x2, y2, 1],..., [x17, y17, 1]],
[[x1, y1, 1], [x2, y2, 1],..., [x17, y17, 1]],
...
],
"canvas_height": 512,
"canvas_width": 768
},...]
Format of POSE_KEYPOINT (OpenPose keypoints):
[{
"people": [
{
'pose_keypoints_2d': [[x1, y1, 1], [x2, y2, 1],..., [x17, y17, 1]]
"face_keypoints_2d": [[x1, y1, 1], [x2, y2, 1],..., [x68, y68, 1]],
"hand_left_keypoints_2d": [[x1, y1, 1], [x2, y2, 1],..., [x21, y21, 1]],
"hand_right_keypoints_2d":[[x1, y1, 1], [x2, y2, 1],..., [x21, y21, 1]],
}
],
"canvas_height": canvas_height,
"canvas_width": canvas_width,
},...]
"""
class SavePoseKpsAsJsonFile:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"pose_kps": ("POSE_KEYPOINT",),
"filename_prefix": ("STRING", {"default": "PoseKeypoint"})
}
}
RETURN_TYPES = ()
FUNCTION = "save_pose_kps"
OUTPUT_NODE = True
CATEGORY = "ControlNet Preprocessors/Pose Keypoint Postprocess"
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
def save_pose_kps(self, pose_kps, filename_prefix):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = \
folder_paths.get_save_image_path(filename_prefix, self.output_dir, pose_kps[0]["canvas_width"], pose_kps[0]["canvas_height"])
file = f"{filename}_{counter:05}.json"
with open(os.path.join(full_output_folder, file), 'w') as f:
json.dump(pose_kps , f)
return {}
#COCO-Wholebody doesn't have eyebrows as it inherits 68 keypoints format
#Perhaps eyebrows can be estimated tho
FACIAL_PARTS = ["skin", "left_eye", "right_eye", "nose", "upper_lip", "inner_mouth", "lower_lip"]
LAPA_COLORS = dict(
skin="rgb(0, 153, 255)",
left_eye="rgb(0, 204, 153)",
right_eye="rgb(255, 153, 0)",
nose="rgb(255, 102, 255)",
upper_lip="rgb(102, 0, 51)",
inner_mouth="rgb(255, 204, 255)",
lower_lip="rgb(255, 0, 102)"
)
#One-based index
def kps_idxs(start, end):
step = -1 if start > end else 1
return list(range(start-1, end+1-1, step))
#Source: https://www.researchgate.net/profile/Fabrizio-Falchi/publication/338048224/figure/fig1/AS:837860722741255@1576772971540/68-facial-landmarks.jpg
FACIAL_PART_RANGES = dict(
skin=kps_idxs(1, 17) + kps_idxs(27, 18),
nose=kps_idxs(28, 36),
left_eye=kps_idxs(37, 42),
right_eye=kps_idxs(43, 48),
upper_lip=kps_idxs(49, 55) + kps_idxs(65, 61),
lower_lip=kps_idxs(61, 68),
inner_mouth=kps_idxs(61, 65) + kps_idxs(55, 49)
)
def is_normalized(keypoints) -> bool:
point_normalized = [
0 <= np.abs(k[0]) <= 1 and 0 <= np.abs(k[1]) <= 1
for k in keypoints
if k is not None
]
if not point_normalized:
return False
return np.all(point_normalized)
class FacialPartColoringFromPoseKps:
@classmethod
def INPUT_TYPES(s):
input_types = {
"required": {"pose_kps": ("POSE_KEYPOINT",), "mode": (["point", "polygon"], {"default": "polygon"})}
}
for facial_part in FACIAL_PARTS:
input_types["required"][facial_part] = ("STRING", {"default": LAPA_COLORS[facial_part], "multiline": False})
return input_types
RETURN_TYPES = ("IMAGE",)
FUNCTION = "colorize"
CATEGORY = "ControlNet Preprocessors/Pose Keypoint Postprocess"
def colorize(self, pose_kps, mode, **facial_part_colors):
pose_frames = pose_kps
np_frames = [self.draw_kps(pose_frame, mode, **facial_part_colors) for pose_frame in pose_frames]
np_frames = np.stack(np_frames, axis=0)
return (torch.from_numpy(np_frames).float() / 255.,)
def draw_kps(self, pose_frame, mode, **facial_part_colors):
width, height = pose_frame["canvas_width"], pose_frame["canvas_height"]
canvas = np.zeros((height, width, 3), dtype=np.uint8)
for person, part_name in itertools.product(pose_frame["people"], FACIAL_PARTS):
n = len(person["face_keypoints_2d"]) // 3
facial_kps = rearrange(np.array(person["face_keypoints_2d"]), "(n c) -> n c", n=n, c=3)[:, :2]
if is_normalized(facial_kps):
facial_kps *= (width, height)
facial_kps = facial_kps.astype(np.int32)
part_color = ImageColor.getrgb(facial_part_colors[part_name])[:3]
part_contours = facial_kps[FACIAL_PART_RANGES[part_name], :]
if mode == "point":
for pt in part_contours:
cv2.circle(canvas, pt, radius=2, color=part_color, thickness=-1)
else:
cv2.fillPoly(canvas, pts=[part_contours], color=part_color)
return canvas
# https://raw.githubusercontent.com/CMU-Perceptual-Computing-Lab/openpose/master/.github/media/keypoints_pose_18.png
BODY_PART_INDEXES = {
"Head": (16, 14, 0, 15, 17),
"Neck": (0, 1),
"Shoulder": (2, 5),
"Torso": (2, 5, 8, 11),
"RArm": (2, 3),
"RForearm": (3, 4),
"LArm": (5, 6),
"LForearm": (6, 7),
"RThigh": (8, 9),
"RLeg": (9, 10),
"LThigh": (11, 12),
"LLeg": (12, 13)
}
BODY_PART_DEFAULT_W_H = {
"Head": "256, 256",
"Neck": "100, 100",
"Shoulder": '',
"Torso": "350, 450",
"RArm": "128, 256",
"RForearm": "128, 256",
"LArm": "128, 256",
"LForearm": "128, 256",
"RThigh": "128, 256",
"RLeg": "128, 256",
"LThigh": "128, 256",
"LLeg": "128, 256"
}
class SinglePersonProcess:
@classmethod
def sort_and_get_max_people(s, pose_kps):
for idx in range(len(pose_kps)):
pose_kps[idx]["people"] = sorted(pose_kps[idx]["people"], key=lambda person:person["pose_keypoints_2d"][0])
return pose_kps, max(len(frame["people"]) for frame in pose_kps)
def __init__(self, pose_kps, person_idx=0) -> None:
self.width, self.height = pose_kps[0]["canvas_width"], pose_kps[0]["canvas_height"]
self.poses = [
self.normalize(pose_frame["people"][person_idx]["pose_keypoints_2d"])
if person_idx < len(pose_frame["people"])
else None
for pose_frame in pose_kps
]
def normalize(self, pose_kps_2d):
n = len(pose_kps_2d) // 3
pose_kps_2d = rearrange(np.array(pose_kps_2d), "(n c) -> n c", n=n, c=3)
pose_kps_2d[np.argwhere(pose_kps_2d[:,2]==0), :] = np.iinfo(np.int32).max // 2 #Safe large value
pose_kps_2d = pose_kps_2d[:, :2]
if is_normalized(pose_kps_2d):
pose_kps_2d *= (self.width, self.height)
return pose_kps_2d
def get_xyxy_bboxes(self, part_name, bbox_size=(128, 256)):
width, height = bbox_size
xyxy_bboxes = {}
for idx, pose in enumerate(self.poses):
if pose is None:
xyxy_bboxes[idx] = (np.iinfo(np.int32).max // 2,) * 4
continue
pts = pose[BODY_PART_INDEXES[part_name], :]
#top_left = np.min(pts[:,0]), np.min(pts[:,1])
#bottom_right = np.max(pts[:,0]), np.max(pts[:,1])
#pad_width = np.maximum(width - (bottom_right[0]-top_left[0]), 0) / 2
#pad_height = np.maximum(height - (bottom_right[1]-top_left[1]), 0) / 2
#xyxy_bboxes.append((
# top_left[0] - pad_width, top_left[1] - pad_height,
# bottom_right[0] + pad_width, bottom_right[1] + pad_height,
#))
x_mid, y_mid = np.mean(pts[:, 0]), np.mean(pts[:, 1])
xyxy_bboxes[idx] = (
x_mid - width/2, y_mid - height/2,
x_mid + width/2, y_mid + height/2
)
return xyxy_bboxes
class UpperBodyTrackingFromPoseKps:
PART_NAMES = ["Head", "Neck", "Shoulder", "Torso", "RArm", "RForearm", "LArm", "LForearm"]
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"pose_kps": ("POSE_KEYPOINT",),
"id_include": ("STRING", {"default": '', "multiline": False}),
**{part_name + "_width_height": ("STRING", {"default": BODY_PART_DEFAULT_W_H[part_name], "multiline": False}) for part_name in s.PART_NAMES}
}
}
RETURN_TYPES = ("TRACKING", "STRING")
RETURN_NAMES = ("tracking", "prompt")
FUNCTION = "convert"
CATEGORY = "ControlNet Preprocessors/Pose Keypoint Postprocess"
def convert(self, pose_kps, id_include, **parts_width_height):
parts_width_height = {part_name.replace("_width_height", ''): value for part_name, value in parts_width_height.items()}
enabled_part_names = [part_name for part_name in self.PART_NAMES if len(parts_width_height[part_name].strip())]
tracked = {part_name: {} for part_name in enabled_part_names}
id_include = id_include.strip()
id_include = list(map(int, id_include.split(','))) if len(id_include) else []
prompt_string = ''
pose_kps, max_people = SinglePersonProcess.sort_and_get_max_people(pose_kps)
for person_idx in range(max_people):
if len(id_include) and person_idx not in id_include:
continue
processor = SinglePersonProcess(pose_kps, person_idx)
for part_name in enabled_part_names:
bbox_size = tuple(map(int, parts_width_height[part_name].split(',')))
part_bboxes = processor.get_xyxy_bboxes(part_name, bbox_size)
id_coordinates = {idx: part_bbox+(processor.width, processor.height) for idx, part_bbox in part_bboxes.items()}
tracked[part_name][person_idx] = id_coordinates
for class_name, class_data in tracked.items():
for class_id in class_data.keys():
class_id_str = str(class_id)
# Use the incoming prompt for each class name and ID
_class_name = class_name.replace('L', '').replace('R', '').lower()
prompt_string += f'"{class_id_str}.{class_name}": "({_class_name})",\n'
return (tracked, prompt_string)
def numpy2torch(np_image: np.ndarray) -> torch.Tensor:
""" [H, W, C] => [B=1, H, W, C]"""
return torch.from_numpy(np_image.astype(np.float32) / 255).unsqueeze(0)
class RenderPeopleKps:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"kps": ("POSE_KEYPOINT",),
"render_body": ("BOOLEAN", {"default": True}),
"render_hand": ("BOOLEAN", {"default": True}),
"render_face": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "render"
CATEGORY = "ControlNet Preprocessors/Pose Keypoint Postprocess"
def render(self, kps, render_body, render_hand, render_face) -> tuple[np.ndarray]:
if isinstance(kps, list):
kps = kps[0]
poses, _, height, width = decode_json_as_poses(kps)
np_image = draw_poses(
poses,
height,
width,
render_body,
render_hand,
render_face,
)
return (numpy2torch(np_image),)
class RenderAnimalKps:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"kps": ("POSE_KEYPOINT",),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "render"
CATEGORY = "ControlNet Preprocessors/Pose Keypoint Postprocess"
def render(self, kps) -> tuple[np.ndarray]:
if isinstance(kps, list):
kps = kps[0]
_, poses, height, width = decode_json_as_poses(kps)
np_image = draw_animalposes(poses, height, width)
return (numpy2torch(np_image),)
NODE_CLASS_MAPPINGS = {
"SavePoseKpsAsJsonFile": SavePoseKpsAsJsonFile,
"FacialPartColoringFromPoseKps": FacialPartColoringFromPoseKps,
"UpperBodyTrackingFromPoseKps": UpperBodyTrackingFromPoseKps,
"RenderPeopleKps": RenderPeopleKps,
"RenderAnimalKps": RenderAnimalKps,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SavePoseKpsAsJsonFile": "Save Pose Keypoints",
"FacialPartColoringFromPoseKps": "Colorize Facial Parts from PoseKPS",
"UpperBodyTrackingFromPoseKps": "Upper Body Tracking From PoseKps (InstanceDiffusion)",
"RenderPeopleKps": "Render Pose JSON (Human)",
"RenderAnimalKps": "Render Pose JSON (Animal)",
}

View File

@@ -0,0 +1,30 @@
from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
import comfy.model_management as model_management
class PyraCanny_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
low_threshold=INPUT.INT(default=64, max=255),
high_threshold=INPUT.INT(default=128, max=255),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Line Extractors"
def execute(self, image, low_threshold=64, high_threshold=128, resolution=512, **kwargs):
from custom_controlnet_aux.pyracanny import PyraCannyDetector
return (common_annotator_call(PyraCannyDetector(), image, low_threshold=low_threshold, high_threshold=high_threshold, resolution=resolution), )
NODE_CLASS_MAPPINGS = {
"PyraCannyPreprocessor": PyraCanny_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PyraCannyPreprocessor": "PyraCanny"
}

View File

@@ -0,0 +1,46 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
class ImageLuminanceDetector:
@classmethod
def INPUT_TYPES(s):
#https://github.com/Mikubill/sd-webui-controlnet/blob/416c345072c9c2066101e225964e3986abe6945e/scripts/processor.py#L1229
return define_preprocessor_inputs(
gamma_correction=INPUT.FLOAT(default=1.0, min=0.1, max=2.0),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Recolor"
def execute(self, image, gamma_correction=1.0, resolution=512, **kwargs):
from custom_controlnet_aux.recolor import Recolorizer
return (common_annotator_call(Recolorizer(), image, mode="luminance", gamma_correction=gamma_correction , resolution=resolution), )
class ImageIntensityDetector:
@classmethod
def INPUT_TYPES(s):
#https://github.com/Mikubill/sd-webui-controlnet/blob/416c345072c9c2066101e225964e3986abe6945e/scripts/processor.py#L1229
return define_preprocessor_inputs(
gamma_correction=INPUT.FLOAT(default=1.0, min=0.1, max=2.0),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Recolor"
def execute(self, image, gamma_correction=1.0, resolution=512, **kwargs):
from custom_controlnet_aux.recolor import Recolorizer
return (common_annotator_call(Recolorizer(), image, mode="intensity", gamma_correction=gamma_correction , resolution=resolution), )
NODE_CLASS_MAPPINGS = {
"ImageLuminanceDetector": ImageLuminanceDetector,
"ImageIntensityDetector": ImageIntensityDetector
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageLuminanceDetector": "Image Luminance",
"ImageIntensityDetector": "Image Intensity"
}

View File

@@ -0,0 +1,74 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, nms
import comfy.model_management as model_management
import cv2
class Scribble_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Line Extractors"
def execute(self, image, resolution=512, **kwargs):
from custom_controlnet_aux.scribble import ScribbleDetector
model = ScribbleDetector()
return (common_annotator_call(model, image, resolution=resolution), )
class Scribble_XDoG_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
threshold=INPUT.INT(default=32, min=1, max=64),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Line Extractors"
def execute(self, image, threshold=32, resolution=512, **kwargs):
from custom_controlnet_aux.scribble import ScribbleXDog_Detector
model = ScribbleXDog_Detector()
return (common_annotator_call(model, image, resolution=resolution, thr_a=threshold), )
class Scribble_PiDiNet_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
safe=(["enable", "disable"],),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Line Extractors"
def execute(self, image, safe="enable", resolution=512):
def model(img, **kwargs):
from custom_controlnet_aux.pidi import PidiNetDetector
pidinet = PidiNetDetector.from_pretrained().to(model_management.get_torch_device())
result = pidinet(img, scribble=True, **kwargs)
result = nms(result, 127, 3.0)
result = cv2.GaussianBlur(result, (0, 0), 3.0)
result[result > 4] = 255
result[result < 255] = 0
return result
return (common_annotator_call(model, image, resolution=resolution, safe=safe=="enable"),)
NODE_CLASS_MAPPINGS = {
"ScribblePreprocessor": Scribble_Preprocessor,
"Scribble_XDoG_Preprocessor": Scribble_XDoG_Preprocessor,
"Scribble_PiDiNet_Preprocessor": Scribble_PiDiNet_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ScribblePreprocessor": "Scribble Lines",
"Scribble_XDoG_Preprocessor": "Scribble XDoG Lines",
"Scribble_PiDiNet_Preprocessor": "Scribble PiDiNet Lines"
}

View File

@@ -0,0 +1,27 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
class SAM_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/others"
def execute(self, image, resolution=512, **kwargs):
from custom_controlnet_aux.sam import SamDetector
mobile_sam = SamDetector.from_pretrained().to(model_management.get_torch_device())
out = common_annotator_call(mobile_sam, image, resolution=resolution)
del mobile_sam
return (out, )
NODE_CLASS_MAPPINGS = {
"SAMPreprocessor": SAM_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SAMPreprocessor": "SAM Segmentor"
}

View File

@@ -0,0 +1,27 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, MAX_RESOLUTION
import comfy.model_management as model_management
class Shuffle_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
resolution=INPUT.RESOLUTION(),
seed=INPUT.SEED()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "preprocess"
CATEGORY = "ControlNet Preprocessors/T2IAdapter-only"
def preprocess(self, image, resolution=512, seed=0):
from custom_controlnet_aux.shuffle import ContentShuffleDetector
return (common_annotator_call(ContentShuffleDetector(), image, resolution=resolution, seed=seed), )
NODE_CLASS_MAPPINGS = {
"ShufflePreprocessor": Shuffle_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ShufflePreprocessor": "Content Shuffle"
}

View File

@@ -0,0 +1,30 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
class TEED_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
safe_steps=INPUT.INT(default=2, max=10),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Line Extractors"
def execute(self, image, safe_steps=2, resolution=512, **kwargs):
from custom_controlnet_aux.teed import TEDDetector
model = TEDDetector.from_pretrained().to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution, safe_steps=safe_steps)
del model
return (out, )
NODE_CLASS_MAPPINGS = {
"TEEDPreprocessor": TEED_Preprocessor,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"TEED_Preprocessor": "TEED Soft-Edge Lines",
}

View File

@@ -0,0 +1,73 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
class Tile_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
pyrUp_iters=INPUT.INT(default=3, min=1, max=10),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/tile"
def execute(self, image, pyrUp_iters, resolution=512, **kwargs):
from custom_controlnet_aux.tile import TileDetector
return (common_annotator_call(TileDetector(), image, pyrUp_iters=pyrUp_iters, resolution=resolution),)
class TTPlanet_TileGF_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
scale_factor=INPUT.FLOAT(default=1.00, min=1.000, max=8.00),
blur_strength=INPUT.FLOAT(default=2.0, min=1.0, max=10.0),
radius=INPUT.INT(default=7, min=1, max=20),
eps=INPUT.FLOAT(default=0.01, min=0.001, max=0.1, step=0.001),
resolution=INPUT.RESOLUTION()
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/tile"
def execute(self, image, scale_factor, blur_strength, radius, eps, **kwargs):
from custom_controlnet_aux.tile import TTPlanet_Tile_Detector_GF
return (common_annotator_call(TTPlanet_Tile_Detector_GF(), image, scale_factor=scale_factor, blur_strength=blur_strength, radius=radius, eps=eps),)
class TTPlanet_TileSimple_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
scale_factor=INPUT.FLOAT(default=1.00, min=1.000, max=8.00),
blur_strength=INPUT.FLOAT(default=2.0, min=1.0, max=10.0),
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/tile"
def execute(self, image, scale_factor, blur_strength):
from custom_controlnet_aux.tile import TTPLanet_Tile_Detector_Simple
return (common_annotator_call(TTPLanet_Tile_Detector_Simple(), image, scale_factor=scale_factor, blur_strength=blur_strength),)
NODE_CLASS_MAPPINGS = {
"TilePreprocessor": Tile_Preprocessor,
"TTPlanet_TileGF_Preprocessor": TTPlanet_TileGF_Preprocessor,
"TTPlanet_TileSimple_Preprocessor": TTPlanet_TileSimple_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"TilePreprocessor": "Tile",
"TTPlanet_TileGF_Preprocessor": "TTPlanet Tile GuidedFilter",
"TTPlanet_TileSimple_Preprocessor": "TTPlanet Tile Simple"
}

View File

@@ -0,0 +1,34 @@
import os
# Disable NPU device initialization and problematic MMCV ops to prevent RuntimeError
os.environ['NPU_DEVICE_COUNT'] = '0'
os.environ['MMCV_WITH_OPS'] = '0'
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
class Uniformer_SemSegPreprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
RETURN_TYPES = ("IMAGE",)
FUNCTION = "semantic_segmentate"
CATEGORY = "ControlNet Preprocessors/Semantic Segmentation"
def semantic_segmentate(self, image, resolution=512):
from custom_controlnet_aux.uniformer import UniformerSegmentor
model = UniformerSegmentor.from_pretrained().to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution)
del model
return (out, )
NODE_CLASS_MAPPINGS = {
"UniFormer-SemSegPreprocessor": Uniformer_SemSegPreprocessor,
"SemSegPreprocessor": Uniformer_SemSegPreprocessor,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"UniFormer-SemSegPreprocessor": "UniFormer Segmentor",
"SemSegPreprocessor": "Semantic Segmentor (legacy, alias for UniFormer)",
}

View File

@@ -0,0 +1,75 @@
from ..utils import common_annotator_call
import comfy.model_management as model_management
import torch
import numpy as np
from einops import rearrange
import torch.nn.functional as F
class Unimatch_OptFlowPreprocessor:
@classmethod
def INPUT_TYPES(s):
return {
"required": dict(
image=("IMAGE",),
ckpt_name=(
["gmflow-scale1-mixdata.pth", "gmflow-scale2-mixdata.pth", "gmflow-scale2-regrefine6-mixdata.pth"],
{"default": "gmflow-scale2-regrefine6-mixdata.pth"}
),
backward_flow=("BOOLEAN", {"default": False}),
bidirectional_flow=("BOOLEAN", {"default": False})
)
}
RETURN_TYPES = ("OPTICAL_FLOW", "IMAGE")
RETURN_NAMES = ("OPTICAL_FLOW", "PREVIEW_IMAGE")
FUNCTION = "estimate"
CATEGORY = "ControlNet Preprocessors/Optical Flow"
def estimate(self, image, ckpt_name, backward_flow=False, bidirectional_flow=False):
assert len(image) > 1, "[Unimatch] Requiring as least two frames as an optical flow estimator. Only use this node on video input."
from custom_controlnet_aux.unimatch import UnimatchDetector
tensor_images = image
model = UnimatchDetector.from_pretrained(filename=ckpt_name).to(model_management.get_torch_device())
flows, vis_flows = [], []
for i in range(len(tensor_images) - 1):
image0, image1 = np.asarray(image[i:i+2].cpu() * 255., dtype=np.uint8)
flow, vis_flow = model(image0, image1, output_type="np", pred_bwd_flow=backward_flow, pred_bidir_flow=bidirectional_flow)
flows.append(torch.from_numpy(flow).float())
vis_flows.append(torch.from_numpy(vis_flow).float() / 255.)
del model
return (torch.stack(flows, dim=0), torch.stack(vis_flows, dim=0))
class MaskOptFlow:
@classmethod
def INPUT_TYPES(s):
return {
"required": dict(optical_flow=("OPTICAL_FLOW",), mask=("MASK",))
}
RETURN_TYPES = ("OPTICAL_FLOW", "IMAGE")
RETURN_NAMES = ("OPTICAL_FLOW", "PREVIEW_IMAGE")
FUNCTION = "mask_opt_flow"
CATEGORY = "ControlNet Preprocessors/Optical Flow"
def mask_opt_flow(self, optical_flow, mask):
from custom_controlnet_aux.unimatch import flow_to_image
assert len(mask) >= len(optical_flow), f"Not enough masks to mask optical flow: {len(mask)} vs {len(optical_flow)}"
mask = mask[:optical_flow.shape[0]]
mask = F.interpolate(mask, optical_flow.shape[1:3])
mask = rearrange(mask, "n 1 h w -> n h w 1")
vis_flows = torch.stack([torch.from_numpy(flow_to_image(flow)).float() / 255. for flow in optical_flow.numpy()], dim=0)
vis_flows *= mask
optical_flow *= mask
return (optical_flow, vis_flows)
NODE_CLASS_MAPPINGS = {
"Unimatch_OptFlowPreprocessor": Unimatch_OptFlowPreprocessor,
"MaskOptFlow": MaskOptFlow
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Unimatch_OptFlowPreprocessor": "Unimatch Optical Flow",
"MaskOptFlow": "Mask Optical Flow (DragNUWA)"
}

View File

@@ -0,0 +1,27 @@
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
import comfy.model_management as model_management
class Zoe_Depth_Map_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
def execute(self, image, resolution=512, **kwargs):
from custom_controlnet_aux.zoe import ZoeDetector
model = ZoeDetector.from_pretrained().to(model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution)
del model
return (out, )
NODE_CLASS_MAPPINGS = {
"Zoe-DepthMapPreprocessor": Zoe_Depth_Map_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Zoe-DepthMapPreprocessor": "Zoe Depth Map"
}

View File

@@ -0,0 +1,14 @@
[project]
name = "comfyui_controlnet_aux"
description = "Plug-and-play ComfyUI node sets for making ControlNet hint images"
version = "1.1.3"
dependencies = ["torch", "importlib_metadata", "huggingface_hub", "scipy", "opencv-python>=4.7.0.72", "filelock", "numpy", "Pillow", "einops", "torchvision", "pyyaml", "scikit-image", "python-dateutil", "mediapipe", "fvcore", "yapf", "omegaconf", "ftfy", "addict", "yacs", "trimesh[easy]", "albumentations", "scikit-learn", "matplotlib"]
[project.urls]
Repository = "https://github.com/Fannovel16/comfyui_controlnet_aux"
[tool.comfy]
PublisherId = "fannovel16"
DisplayName = "comfyui_controlnet_aux"
Icon = ""

View File

@@ -0,0 +1,25 @@
torch
importlib_metadata
huggingface_hub
scipy
opencv-python>=4.7.0.72
filelock
numpy
Pillow
einops
torchvision
pyyaml
scikit-image
python-dateutil
mediapipe
fvcore
yapf
omegaconf
ftfy
addict
yacs
yapf
trimesh[easy]
albumentations
scikit-learn
matplotlib

View File

@@ -0,0 +1,56 @@
from pathlib import Path
import os
import re
#Thanks ChatGPT
pattern = r'\bfrom_pretrained\(.*?pretrained_model_or_path\s*=\s*(.*?)(?:,|\))|filename\s*=\s*(.*?)(?:,|\))|(\w+_filename)\s*=\s*(.*?)(?:,|\))'
aux_dir = Path(__file__).parent / 'src' / 'custom_controlnet_aux'
VAR_DICT = dict(
HF_MODEL_NAME = "lllyasviel/Annotators",
DWPOSE_MODEL_NAME = "yzd-v/DWPose",
BDS_MODEL_NAME = "bdsqlsz/qinglong_controlnet-lllite",
DENSEPOSE_MODEL_NAME = "LayerNorm/DensePose-TorchScript-with-hint-image",
MESH_GRAPHORMER_MODEL_NAME = "hr16/ControlNet-HandRefiner-pruned",
SAM_MODEL_NAME = "dhkim2810/MobileSAM",
UNIMATCH_MODEL_NAME = "hr16/Unimatch",
DEPTH_ANYTHING_MODEL_NAME = "LiheYoung/Depth-Anything", #HF Space
DIFFUSION_EDGE_MODEL_NAME = "hr16/Diffusion-Edge"
)
re_result_dict = {}
for preprocc in os.listdir(aux_dir):
if preprocc in ["__pycache__", 'tests']: continue
if '.py' in preprocc: continue
f = open(aux_dir / preprocc / '__init__.py', 'r')
code = f.read()
matches = re.findall(pattern, code)
result = [match[0] or match[1] or match[3] for match in matches]
if not len(result):
print(preprocc)
continue
result = [el.replace("'", '').replace('"', '') for el in result]
result = [VAR_DICT.get(el, el) for el in result]
re_result_dict[preprocc] = result
f.close()
for preprocc, re_result in re_result_dict.items():
model_name, filenames = re_result[0], re_result[1:]
print(f"* {preprocc}: ", end=' ')
assests_md = ', '.join([f"[{model_name}/{filename}](https://huggingface.co/{model_name}/blob/main/{filename})" for filename in filenames])
print(assests_md)
preprocc = "dwpose"
model_name, filenames = VAR_DICT['DWPOSE_MODEL_NAME'], ["yolox_l.onnx", "dw-ll_ucoco_384.onnx"]
print(f"* {preprocc}: ", end=' ')
assests_md = ', '.join([f"[{model_name}/{filename}](https://huggingface.co/{model_name}/blob/main/{filename})" for filename in filenames])
print(assests_md)
preprocc = "yolo-nas"
model_name, filenames = "hr16/yolo-nas-fp16", ["yolo_nas_l_fp16.onnx", "yolo_nas_m_fp16.onnx", "yolo_nas_s_fp16.onnx"]
print(f"* {preprocc}: ", end=' ')
assests_md = ', '.join([f"[{model_name}/{filename}](https://huggingface.co/{model_name}/blob/main/{filename})" for filename in filenames])
print(assests_md)
preprocc = "dwpose-torchscript"
model_name, filenames = "hr16/DWPose-TorchScript-BatchSize5", ["dw-ll_ucoco_384_bs5.torchscript.pt", "rtmpose-m_ap10k_256_bs5.torchscript.pt"]
print(f"* {preprocc}: ", end=' ')
assests_md = ', '.join([f"[{model_name}/{filename}](https://huggingface.co/{model_name}/blob/main/{filename})" for filename in filenames])
print(assests_md)

View File

@@ -0,0 +1 @@
#Dummy file ensuring this package will be recognized

View File

@@ -0,0 +1 @@
#Dummy file ensuring this package will be recognized

View File

@@ -0,0 +1,66 @@
from .network import UNet
from .util import seg2img
import torch
import os
import cv2
from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download, BDS_MODEL_NAME
from huggingface_hub import hf_hub_download
from PIL import Image
from einops import rearrange
from .anime_segmentation import AnimeSegmentation
import numpy as np
class AnimeFaceSegmentor:
def __init__(self, model, seg_model):
self.model = model
self.seg_model = seg_model
self.device = "cpu"
@classmethod
def from_pretrained(cls, pretrained_model_or_path=BDS_MODEL_NAME, filename="UNet.pth", seg_filename="isnetis.ckpt"):
model_path = custom_hf_download(pretrained_model_or_path, filename, subfolder="Annotators")
seg_model_path = custom_hf_download("skytnt/anime-seg", seg_filename)
model = UNet()
ckpt = torch.load(model_path, map_location="cpu")
model.load_state_dict(ckpt)
model.eval()
seg_model = AnimeSegmentation(seg_model_path)
seg_model.net.eval()
return cls(model, seg_model)
def to(self, device):
self.model.to(device)
self.seg_model.net.to(device)
self.device = device
return self
def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", remove_background=True, **kwargs):
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
with torch.no_grad():
if remove_background:
print(input_image.shape)
mask, input_image = self.seg_model(input_image, 0) #Don't resize image as it is resized
image_feed = torch.from_numpy(input_image).float().to(self.device)
image_feed = rearrange(image_feed, 'h w c -> 1 c h w')
image_feed = image_feed / 255
seg = self.model(image_feed).squeeze(dim=0)
result = seg2img(seg.cpu().detach().numpy())
detected_map = HWC3(result)
detected_map = remove_pad(detected_map)
if remove_background:
mask = remove_pad(mask)
H, W, C = detected_map.shape
tmp = np.zeros([H, W, C + 1])
tmp[:,:,:C] = detected_map
tmp[:,:,3:] = mask
detected_map = tmp
if output_type == "pil":
detected_map = Image.fromarray(detected_map[..., :3])
return detected_map

View File

@@ -0,0 +1,58 @@
#https://github.com/SkyTNT/anime-segmentation/tree/main
#Only adapt isnet_is (https://huggingface.co/skytnt/anime-seg/blob/main/isnetis.ckpt)
import torch.nn as nn
import torch
from .isnet import ISNetDIS
import numpy as np
import cv2
from comfy.model_management import get_torch_device
DEVICE = get_torch_device()
class AnimeSegmentation:
def __init__(self, ckpt_path):
super(AnimeSegmentation).__init__()
sd = torch.load(ckpt_path, map_location="cpu")
self.net = ISNetDIS()
#gt_encoder isn't used during inference
self.net.load_state_dict({k.replace("net.", ''):v for k, v in sd.items() if k.startswith("net.")})
self.net = self.net.to(DEVICE)
self.net.eval()
def get_mask(self, input_img, s=640):
input_img = (input_img / 255).astype(np.float32)
if s == 0:
img_input = np.transpose(input_img, (2, 0, 1))
img_input = img_input[np.newaxis, :]
tmpImg = torch.from_numpy(img_input).float().to(DEVICE)
with torch.no_grad():
pred = self.net(tmpImg)[0][0].sigmoid() #https://github.com/SkyTNT/anime-segmentation/blob/main/train.py#L92C20-L92C47
pred = pred.cpu().numpy()[0]
pred = np.transpose(pred, (1, 2, 0))
#pred = pred[:, :, np.newaxis]
return pred
h, w = h0, w0 = input_img.shape[:-1]
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
ph, pw = s - h, s - w
img_input = np.zeros([s, s, 3], dtype=np.float32)
img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h))
img_input = np.transpose(img_input, (2, 0, 1))
img_input = img_input[np.newaxis, :]
tmpImg = torch.from_numpy(img_input).float().to(DEVICE)
with torch.no_grad():
pred = self.net(tmpImg)[0][0].sigmoid() #https://github.com/SkyTNT/anime-segmentation/blob/main/train.py#L92C20-L92C47
pred = pred.cpu().numpy()[0]
pred = np.transpose(pred, (1, 2, 0))
pred = pred[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
#pred = cv2.resize(pred, (w0, h0))[:, :, np.newaxis]
pred = cv2.resize(pred, (w0, h0))
return pred
def __call__(self, np_img, img_size):
mask = self.get_mask(np_img, int(img_size))
np_img = (mask * np_img + 255 * (1 - mask)).astype(np.uint8)
mask = (mask * 255).astype(np.uint8)
#np_img = np.concatenate([np_img, mask], axis=2, dtype=np.uint8)
#mask = mask.repeat(3, axis=2)
return mask, np_img

View File

@@ -0,0 +1,619 @@
# Codes are borrowed from
# https://github.com/xuebinqin/DIS/blob/main/IS-Net/models/isnet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
def muti_loss_fusion(preds, target):
loss0 = 0.0
loss = 0.0
for i in range(0, len(preds)):
if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
tmp_target = F.interpolate(
target, size=preds[i].size()[2:], mode="bilinear", align_corners=True
)
loss = loss + bce_loss(preds[i], tmp_target)
else:
loss = loss + bce_loss(preds[i], target)
if i == 0:
loss0 = loss
return loss0, loss
fea_loss = nn.MSELoss(reduction="mean")
kl_loss = nn.KLDivLoss(reduction="mean")
l1_loss = nn.L1Loss(reduction="mean")
smooth_l1_loss = nn.SmoothL1Loss(reduction="mean")
def muti_loss_fusion_kl(preds, target, dfs, fs, mode="MSE"):
loss0 = 0.0
loss = 0.0
for i in range(0, len(preds)):
if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
tmp_target = F.interpolate(
target, size=preds[i].size()[2:], mode="bilinear", align_corners=True
)
loss = loss + bce_loss(preds[i], tmp_target)
else:
loss = loss + bce_loss(preds[i], target)
if i == 0:
loss0 = loss
for i in range(0, len(dfs)):
df = dfs[i]
fs_i = fs[i]
if mode == "MSE":
loss = loss + fea_loss(
df, fs_i
) ### add the mse loss of features as additional constraints
elif mode == "KL":
loss = loss + kl_loss(F.log_softmax(df, dim=1), F.softmax(fs_i, dim=1))
elif mode == "MAE":
loss = loss + l1_loss(df, fs_i)
elif mode == "SmoothL1":
loss = loss + smooth_l1_loss(df, fs_i)
return loss0, loss
class REBNCONV(nn.Module):
def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
super(REBNCONV, self).__init__()
self.conv_s1 = nn.Conv2d(
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)
def forward(self, x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
return xout
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src, tar):
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
return src
### RSU-7 ###
class RSU7(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
super(RSU7, self).__init__()
self.in_ch = in_ch
self.mid_ch = mid_ch
self.out_ch = out_ch
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
b, c, h, w = x.shape
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
hx6dup = _upsample_like(hx6d, hx5)
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-6 ###
class RSU6(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU6, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx6 = self.rebnconv6(hx5)
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-5 ###
class RSU5(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU5, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-4 ###
class RSU4(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-4F ###
class RSU4F(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4F, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
return hx1d + hxin
class myrebnconv(nn.Module):
def __init__(
self,
in_ch=3,
out_ch=1,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1,
):
super(myrebnconv, self).__init__()
self.conv = nn.Conv2d(
in_ch,
out_ch,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
self.bn = nn.BatchNorm2d(out_ch)
self.rl = nn.ReLU(inplace=True)
def forward(self, x):
return self.rl(self.bn(self.conv(x)))
class ISNetGTEncoder(nn.Module):
def __init__(self, in_ch=1, out_ch=1):
super(ISNetGTEncoder, self).__init__()
self.conv_in = myrebnconv(
in_ch, 16, 3, stride=2, padding=1
) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
self.stage1 = RSU7(16, 16, 64)
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage2 = RSU6(64, 16, 64)
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage3 = RSU5(64, 32, 128)
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage4 = RSU4(128, 32, 256)
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage5 = RSU4F(256, 64, 512)
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage6 = RSU4F(512, 64, 512)
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
@staticmethod
def compute_loss(args):
preds, targets = args
return muti_loss_fusion(preds, targets)
def forward(self, x):
hx = x
hxin = self.conv_in(hx)
# hx = self.pool_in(hxin)
# stage 1
hx1 = self.stage1(hxin)
hx = self.pool12(hx1)
# stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
# stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
# stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
# stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
# stage 6
hx6 = self.stage6(hx)
# side output
d1 = self.side1(hx1)
d1 = _upsample_like(d1, x)
d2 = self.side2(hx2)
d2 = _upsample_like(d2, x)
d3 = self.side3(hx3)
d3 = _upsample_like(d3, x)
d4 = self.side4(hx4)
d4 = _upsample_like(d4, x)
d5 = self.side5(hx5)
d5 = _upsample_like(d5, x)
d6 = self.side6(hx6)
d6 = _upsample_like(d6, x)
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
# return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1, hx2, hx3, hx4, hx5, hx6]
return [d1, d2, d3, d4, d5, d6], [hx1, hx2, hx3, hx4, hx5, hx6]
class ISNetDIS(nn.Module):
def __init__(self, in_ch=3, out_ch=1):
super(ISNetDIS, self).__init__()
self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage1 = RSU7(64, 32, 64)
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage2 = RSU6(64, 32, 128)
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage3 = RSU5(128, 64, 256)
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage4 = RSU4(256, 128, 512)
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage5 = RSU4F(512, 256, 512)
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage6 = RSU4F(512, 256, 512)
# decoder
self.stage5d = RSU4F(1024, 256, 512)
self.stage4d = RSU4(1024, 128, 256)
self.stage3d = RSU5(512, 64, 128)
self.stage2d = RSU6(256, 32, 64)
self.stage1d = RSU7(128, 16, 64)
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
@staticmethod
def compute_loss_kl(preds, targets, dfs, fs, mode="MSE"):
return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
@staticmethod
def compute_loss(args):
if len(args) == 3:
ds, dfs, labels = args
return muti_loss_fusion(ds, labels)
else:
ds, dfs, labels, fs = args
return muti_loss_fusion_kl(ds, labels, dfs, fs, mode="MSE")
def forward(self, x):
hx = x
hxin = self.conv_in(hx)
hx = self.pool_in(hxin)
# stage 1
hx1 = self.stage1(hxin)
hx = self.pool12(hx1)
# stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
# stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
# stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
# stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
# stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6, hx5)
# -------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
# side output
d1 = self.side1(hx1d)
d1 = _upsample_like(d1, x)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2, x)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3, x)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4, x)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5, x)
d6 = self.side6(hx6)
d6 = _upsample_like(d6, x)
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
# return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
return [d1, d2, d3, d4, d5, d6], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]

View File

@@ -0,0 +1,100 @@
#https://github.com/siyeong0/Anime-Face-Segmentation/blob/main/network.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from custom_controlnet_aux.util import custom_torch_download
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.NUM_SEG_CLASSES = 7 # Background, hair, face, eye, mouth, skin, clothes
mobilenet_v2 = torchvision.models.mobilenet_v2(pretrained=False)
mobilenet_v2.load_state_dict(torch.load(custom_torch_download(filename="mobilenet_v2-b0353104.pth")), strict=True)
mob_blocks = mobilenet_v2.features
# Encoder
self.en_block0 = nn.Sequential( # in_ch=3 out_ch=16
mob_blocks[0],
mob_blocks[1]
)
self.en_block1 = nn.Sequential( # in_ch=16 out_ch=24
mob_blocks[2],
mob_blocks[3],
)
self.en_block2 = nn.Sequential( # in_ch=24 out_ch=32
mob_blocks[4],
mob_blocks[5],
mob_blocks[6],
)
self.en_block3 = nn.Sequential( # in_ch=32 out_ch=96
mob_blocks[7],
mob_blocks[8],
mob_blocks[9],
mob_blocks[10],
mob_blocks[11],
mob_blocks[12],
mob_blocks[13],
)
self.en_block4 = nn.Sequential( # in_ch=96 out_ch=160
mob_blocks[14],
mob_blocks[15],
mob_blocks[16],
)
# Decoder
self.de_block4 = nn.Sequential( # in_ch=160 out_ch=96
nn.UpsamplingNearest2d(scale_factor=2),
nn.Conv2d(160, 96, kernel_size=3, padding=1),
nn.InstanceNorm2d(96),
nn.LeakyReLU(0.1),
nn.Dropout(p=0.2)
)
self.de_block3 = nn.Sequential( # in_ch=96x2 out_ch=32
nn.UpsamplingNearest2d(scale_factor=2),
nn.Conv2d(96*2, 32, kernel_size=3, padding=1),
nn.InstanceNorm2d(32),
nn.LeakyReLU(0.1),
nn.Dropout(p=0.2)
)
self.de_block2 = nn.Sequential( # in_ch=32x2 out_ch=24
nn.UpsamplingNearest2d(scale_factor=2),
nn.Conv2d(32*2, 24, kernel_size=3, padding=1),
nn.InstanceNorm2d(24),
nn.LeakyReLU(0.1),
nn.Dropout(p=0.2)
)
self.de_block1 = nn.Sequential( # in_ch=24x2 out_ch=16
nn.UpsamplingNearest2d(scale_factor=2),
nn.Conv2d(24*2, 16, kernel_size=3, padding=1),
nn.InstanceNorm2d(16),
nn.LeakyReLU(0.1),
nn.Dropout(p=0.2)
)
self.de_block0 = nn.Sequential( # in_ch=16x2 out_ch=7
nn.UpsamplingNearest2d(scale_factor=2),
nn.Conv2d(16*2, self.NUM_SEG_CLASSES, kernel_size=3, padding=1),
nn.Softmax2d()
)
def forward(self, x):
e0 = self.en_block0(x)
e1 = self.en_block1(e0)
e2 = self.en_block2(e1)
e3 = self.en_block3(e2)
e4 = self.en_block4(e3)
d4 = self.de_block4(e4)
c4 = torch.cat((d4,e3),1)
d3 = self.de_block3(c4)
c3 = torch.cat((d3,e2),1)
d2 = self.de_block2(c3)
c2 =torch.cat((d2,e1),1)
d1 = self.de_block1(c2)
c1 = torch.cat((d1,e0),1)
y = self.de_block0(c1)
return y

View File

@@ -0,0 +1,40 @@
#https://github.com/siyeong0/Anime-Face-Segmentation/blob/main/util.py
#The color palette is changed according to https://github.com/Mikubill/sd-webui-controlnet/blob/91f67ddcc7bc47537a6285864abfc12590f46c3f/annotator/anime_face_segment/__init__.py
import cv2 as cv
import glob
import numpy as np
import os
"""
COLOR_BACKGROUND = (0,255,255)
COLOR_HAIR = (255,0,0)
COLOR_EYE = (0,0,255)
COLOR_MOUTH = (255,255,255)
COLOR_FACE = (0,255,0)
COLOR_SKIN = (255,255,0)
COLOR_CLOTHES = (255,0,255)
"""
COLOR_BACKGROUND = (255,255,0)
COLOR_HAIR = (0,0,255)
COLOR_EYE = (255,0,0)
COLOR_MOUTH = (255,255,255)
COLOR_FACE = (0,255,0)
COLOR_SKIN = (0,255,255)
COLOR_CLOTHES = (255,0,255)
PALETTE = [COLOR_BACKGROUND,COLOR_HAIR,COLOR_EYE,COLOR_MOUTH,COLOR_FACE,COLOR_SKIN,COLOR_CLOTHES]
def img2seg(path):
src = cv.imread(path)
src = src.reshape(-1, 3)
seg_list = []
for color in PALETTE:
seg_list.append(np.where(np.all(src==color, axis=1), 1.0, 0.0))
dst = np.stack(seg_list,axis=1).reshape(512,512,7)
return dst.astype(np.float32)
def seg2img(src):
src = np.moveaxis(src,0,2)
dst = [[PALETTE[np.argmax(val)] for val in buf]for buf in src]
return np.array(dst).astype(np.uint8)

View File

@@ -0,0 +1,38 @@
import warnings
import cv2
import numpy as np
from PIL import Image
from custom_controlnet_aux.util import HWC3, resize_image_with_pad
class BinaryDetector:
def __call__(self, input_image=None, bin_threshold=0, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
if "img" in kwargs:
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
input_image = kwargs.pop("img")
if input_image is None:
raise ValueError("input_image must be defined.")
if not isinstance(input_image, np.ndarray):
input_image = np.array(input_image, dtype=np.uint8)
output_type = output_type or "pil"
else:
output_type = output_type or "np"
detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
img_gray = cv2.cvtColor(detected_map, cv2.COLOR_RGB2GRAY)
if bin_threshold == 0 or bin_threshold == 255:
# Otsu's threshold
otsu_threshold, img_bin = cv2.threshold(img_gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
print("Otsu threshold:", otsu_threshold)
else:
_, img_bin = cv2.threshold(img_gray, bin_threshold, 255, cv2.THRESH_BINARY_INV)
detected_map = cv2.cvtColor(img_bin, cv2.COLOR_GRAY2RGB)
detected_map = HWC3(remove_pad(255 - detected_map))
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map

View File

@@ -0,0 +1,17 @@
import warnings
import cv2
import numpy as np
from PIL import Image
from custom_controlnet_aux.util import resize_image_with_pad, common_input_validate, HWC3
class CannyDetector:
def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
detected_map = cv2.Canny(detected_map, low_threshold, high_threshold)
detected_map = HWC3(remove_pad(detected_map))
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map

View File

@@ -0,0 +1,37 @@
import cv2
import warnings
import cv2
import numpy as np
from PIL import Image
from custom_controlnet_aux.util import HWC3, safer_memory, common_input_validate
def cv2_resize_shortest_edge(image, size):
h, w = image.shape[:2]
if h < w:
new_h = size
new_w = int(round(w / h * size))
else:
new_w = size
new_h = int(round(h / w * size))
resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
return resized_image
def apply_color(img, res=512):
img = cv2_resize_shortest_edge(img, res)
h, w = img.shape[:2]
input_img_color = cv2.resize(img, (w//64, h//64), interpolation=cv2.INTER_CUBIC)
input_img_color = cv2.resize(input_img_color, (w, h), interpolation=cv2.INTER_NEAREST)
return input_img_color
#Color T2I like multiples-of-64, upscale methods are fixed.
class ColorDetector:
def __call__(self, input_image=None, detect_resolution=512, output_type=None, **kwargs):
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
input_image = HWC3(input_image)
detected_map = HWC3(apply_color(input_image, detect_resolution))
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map

View File

@@ -0,0 +1,66 @@
import torchvision # Fix issue Unknown builtin op: torchvision::nms
import cv2
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from PIL import Image
from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download, DENSEPOSE_MODEL_NAME
from .densepose import DensePoseMaskedColormapResultsVisualizer, _extract_i_from_iuvarr, densepose_chart_predictor_output_to_result_with_confidences
N_PART_LABELS = 24
class DenseposeDetector:
def __init__(self, model):
self.dense_pose_estimation = model
self.device = "cpu"
self.result_visualizer = DensePoseMaskedColormapResultsVisualizer(
alpha=1,
data_extractor=_extract_i_from_iuvarr,
segm_extractor=_extract_i_from_iuvarr,
val_scale = 255.0 / N_PART_LABELS
)
@classmethod
def from_pretrained(cls, pretrained_model_or_path=DENSEPOSE_MODEL_NAME, filename="densepose_r50_fpn_dl.torchscript"):
torchscript_model_path = custom_hf_download(pretrained_model_or_path, filename)
densepose = torch.jit.load(torchscript_model_path, map_location="cpu")
return cls(densepose)
def to(self, device):
self.dense_pose_estimation.to(device)
self.device = device
return self
def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", cmap="viridis", **kwargs):
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
H, W = input_image.shape[:2]
hint_image_canvas = np.zeros([H, W], dtype=np.uint8)
hint_image_canvas = np.tile(hint_image_canvas[:, :, np.newaxis], [1, 1, 3])
input_image = rearrange(torch.from_numpy(input_image).to(self.device), 'h w c -> c h w')
pred_boxes, corase_segm, fine_segm, u, v = self.dense_pose_estimation(input_image)
extractor = densepose_chart_predictor_output_to_result_with_confidences
densepose_results = [extractor(pred_boxes[i:i+1], corase_segm[i:i+1], fine_segm[i:i+1], u[i:i+1], v[i:i+1]) for i in range(len(pred_boxes))]
if cmap=="viridis":
self.result_visualizer.mask_visualizer.cmap = cv2.COLORMAP_VIRIDIS
hint_image = self.result_visualizer.visualize(hint_image_canvas, densepose_results)
hint_image = cv2.cvtColor(hint_image, cv2.COLOR_BGR2RGB)
hint_image[:, :, 0][hint_image[:, :, 0] == 0] = 68
hint_image[:, :, 1][hint_image[:, :, 1] == 0] = 1
hint_image[:, :, 2][hint_image[:, :, 2] == 0] = 84
else:
self.result_visualizer.mask_visualizer.cmap = cv2.COLORMAP_PARULA
hint_image = self.result_visualizer.visualize(hint_image_canvas, densepose_results)
hint_image = cv2.cvtColor(hint_image, cv2.COLOR_BGR2RGB)
detected_map = remove_pad(HWC3(hint_image))
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map

View File

@@ -0,0 +1,347 @@
from typing import Tuple
import math
import numpy as np
from enum import IntEnum
from typing import List, Tuple, Union
import torch
from torch.nn import functional as F
import logging
import cv2
Image = np.ndarray
Boxes = torch.Tensor
ImageSizeType = Tuple[int, int]
_RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray]
IntTupleBox = Tuple[int, int, int, int]
class BoxMode(IntEnum):
"""
Enum of different ways to represent a box.
"""
XYXY_ABS = 0
"""
(x0, y0, x1, y1) in absolute floating points coordinates.
The coordinates in range [0, width or height].
"""
XYWH_ABS = 1
"""
(x0, y0, w, h) in absolute floating points coordinates.
"""
XYXY_REL = 2
"""
Not yet supported!
(x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image.
"""
XYWH_REL = 3
"""
Not yet supported!
(x0, y0, w, h) in range [0, 1]. They are relative to the size of the image.
"""
XYWHA_ABS = 4
"""
(xc, yc, w, h, a) in absolute floating points coordinates.
(xc, yc) is the center of the rotated box, and the angle a is in degrees ccw.
"""
@staticmethod
def convert(box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode") -> _RawBoxType:
"""
Args:
box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5
from_mode, to_mode (BoxMode)
Returns:
The converted box of the same type.
"""
if from_mode == to_mode:
return box
original_type = type(box)
is_numpy = isinstance(box, np.ndarray)
single_box = isinstance(box, (list, tuple))
if single_box:
assert len(box) == 4 or len(box) == 5, (
"BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor,"
" where k == 4 or 5"
)
arr = torch.tensor(box)[None, :]
else:
# avoid modifying the input box
if is_numpy:
arr = torch.from_numpy(np.asarray(box)).clone()
else:
arr = box.clone()
assert to_mode not in [BoxMode.XYXY_REL, BoxMode.XYWH_REL] and from_mode not in [
BoxMode.XYXY_REL,
BoxMode.XYWH_REL,
], "Relative mode not yet supported!"
if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
assert (
arr.shape[-1] == 5
), "The last dimension of input shape must be 5 for XYWHA format"
original_dtype = arr.dtype
arr = arr.double()
w = arr[:, 2]
h = arr[:, 3]
a = arr[:, 4]
c = torch.abs(torch.cos(a * math.pi / 180.0))
s = torch.abs(torch.sin(a * math.pi / 180.0))
# This basically computes the horizontal bounding rectangle of the rotated box
new_w = c * w + s * h
new_h = c * h + s * w
# convert center to top-left corner
arr[:, 0] -= new_w / 2.0
arr[:, 1] -= new_h / 2.0
# bottom-right corner
arr[:, 2] = arr[:, 0] + new_w
arr[:, 3] = arr[:, 1] + new_h
arr = arr[:, :4].to(dtype=original_dtype)
elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS:
original_dtype = arr.dtype
arr = arr.double()
arr[:, 0] += arr[:, 2] / 2.0
arr[:, 1] += arr[:, 3] / 2.0
angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype)
arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype)
else:
if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS:
arr[:, 2] += arr[:, 0]
arr[:, 3] += arr[:, 1]
elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS:
arr[:, 2] -= arr[:, 0]
arr[:, 3] -= arr[:, 1]
else:
raise NotImplementedError(
"Conversion from BoxMode {} to {} is not supported yet".format(
from_mode, to_mode
)
)
if single_box:
return original_type(arr.flatten().tolist())
if is_numpy:
return arr.numpy()
else:
return arr
class MatrixVisualizer:
"""
Base visualizer for matrix data
"""
def __init__(
self,
inplace=True,
cmap=cv2.COLORMAP_PARULA,
val_scale=1.0,
alpha=0.7,
interp_method_matrix=cv2.INTER_LINEAR,
interp_method_mask=cv2.INTER_NEAREST,
):
self.inplace = inplace
self.cmap = cmap
self.val_scale = val_scale
self.alpha = alpha
self.interp_method_matrix = interp_method_matrix
self.interp_method_mask = interp_method_mask
def visualize(self, image_bgr, mask, matrix, bbox_xywh):
self._check_image(image_bgr)
self._check_mask_matrix(mask, matrix)
if self.inplace:
image_target_bgr = image_bgr
else:
image_target_bgr = image_bgr * 0
x, y, w, h = [int(v) for v in bbox_xywh]
if w <= 0 or h <= 0:
return image_bgr
mask, matrix = self._resize(mask, matrix, w, h)
mask_bg = np.tile((mask == 0)[:, :, np.newaxis], [1, 1, 3])
matrix_scaled = matrix.astype(np.float32) * self.val_scale
_EPSILON = 1e-6
if np.any(matrix_scaled > 255 + _EPSILON):
logger = logging.getLogger(__name__)
logger.warning(
f"Matrix has values > {255 + _EPSILON} after " f"scaling, clipping to [0..255]"
)
matrix_scaled_8u = matrix_scaled.clip(0, 255).astype(np.uint8)
matrix_vis = cv2.applyColorMap(matrix_scaled_8u, self.cmap)
matrix_vis[mask_bg] = image_target_bgr[y : y + h, x : x + w, :][mask_bg]
image_target_bgr[y : y + h, x : x + w, :] = (
image_target_bgr[y : y + h, x : x + w, :] * (1.0 - self.alpha) + matrix_vis * self.alpha
)
return image_target_bgr.astype(np.uint8)
def _resize(self, mask, matrix, w, h):
if (w != mask.shape[1]) or (h != mask.shape[0]):
mask = cv2.resize(mask, (w, h), self.interp_method_mask)
if (w != matrix.shape[1]) or (h != matrix.shape[0]):
matrix = cv2.resize(matrix, (w, h), self.interp_method_matrix)
return mask, matrix
def _check_image(self, image_rgb):
assert len(image_rgb.shape) == 3
assert image_rgb.shape[2] == 3
assert image_rgb.dtype == np.uint8
def _check_mask_matrix(self, mask, matrix):
assert len(matrix.shape) == 2
assert len(mask.shape) == 2
assert mask.dtype == np.uint8
class DensePoseResultsVisualizer:
def visualize(
self,
image_bgr: Image,
results,
) -> Image:
context = self.create_visualization_context(image_bgr)
for i, result in enumerate(results):
boxes_xywh, labels, uv = result
iuv_array = torch.cat(
(labels[None].type(torch.float32), uv * 255.0)
).type(torch.uint8)
self.visualize_iuv_arr(context, iuv_array.cpu().numpy(), boxes_xywh)
image_bgr = self.context_to_image_bgr(context)
return image_bgr
def create_visualization_context(self, image_bgr: Image):
return image_bgr
def visualize_iuv_arr(self, context, iuv_arr: np.ndarray, bbox_xywh) -> None:
pass
def context_to_image_bgr(self, context):
return context
def get_image_bgr_from_context(self, context):
return context
class DensePoseMaskedColormapResultsVisualizer(DensePoseResultsVisualizer):
def __init__(
self,
data_extractor,
segm_extractor,
inplace=True,
cmap=cv2.COLORMAP_PARULA,
alpha=0.7,
val_scale=1.0,
**kwargs,
):
self.mask_visualizer = MatrixVisualizer(
inplace=inplace, cmap=cmap, val_scale=val_scale, alpha=alpha
)
self.data_extractor = data_extractor
self.segm_extractor = segm_extractor
def context_to_image_bgr(self, context):
return context
def visualize_iuv_arr(self, context, iuv_arr: np.ndarray, bbox_xywh) -> None:
image_bgr = self.get_image_bgr_from_context(context)
matrix = self.data_extractor(iuv_arr)
segm = self.segm_extractor(iuv_arr)
mask = np.zeros(matrix.shape, dtype=np.uint8)
mask[segm > 0] = 1
image_bgr = self.mask_visualizer.visualize(image_bgr, mask, matrix, bbox_xywh)
def _extract_i_from_iuvarr(iuv_arr):
return iuv_arr[0, :, :]
def _extract_u_from_iuvarr(iuv_arr):
return iuv_arr[1, :, :]
def _extract_v_from_iuvarr(iuv_arr):
return iuv_arr[2, :, :]
def make_int_box(box: torch.Tensor) -> IntTupleBox:
int_box = [0, 0, 0, 0]
int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist())
return int_box[0], int_box[1], int_box[2], int_box[3]
def densepose_chart_predictor_output_to_result_with_confidences(
boxes: Boxes,
coarse_segm,
fine_segm,
u, v
):
boxes_xyxy_abs = boxes.clone()
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
box_xywh = make_int_box(boxes_xywh_abs[0])
labels = resample_fine_and_coarse_segm_tensors_to_bbox(fine_segm, coarse_segm, box_xywh).squeeze(0)
uv = resample_uv_tensors_to_bbox(u, v, labels, box_xywh)
confidences = []
return box_xywh, labels, uv
def resample_fine_and_coarse_segm_tensors_to_bbox(
fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox
):
"""
Resample fine and coarse segmentation tensors to the given
bounding box and derive labels for each pixel of the bounding box
Args:
fine_segm: float tensor of shape [1, C, Hout, Wout]
coarse_segm: float tensor of shape [1, K, Hout, Wout]
box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
corner coordinates, width (W) and height (H)
Return:
Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
"""
x, y, w, h = box_xywh_abs
w = max(int(w), 1)
h = max(int(h), 1)
# coarse segmentation
coarse_segm_bbox = F.interpolate(
coarse_segm,
(h, w),
mode="bilinear",
align_corners=False,
).argmax(dim=1)
# combined coarse and fine segmentation
labels = (
F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
* (coarse_segm_bbox > 0).long()
)
return labels
def resample_uv_tensors_to_bbox(
u: torch.Tensor,
v: torch.Tensor,
labels: torch.Tensor,
box_xywh_abs: IntTupleBox,
) -> torch.Tensor:
"""
Resamples U and V coordinate estimates for the given bounding box
Args:
u (tensor [1, C, H, W] of float): U coordinates
v (tensor [1, C, H, W] of float): V coordinates
labels (tensor [H, W] of long): labels obtained by resampling segmentation
outputs for the given bounding box
box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
Return:
Resampled U and V coordinates - a tensor [2, H, W] of float
"""
x, y, w, h = box_xywh_abs
w = max(int(w), 1)
h = max(int(h), 1)
u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False)
v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False)
uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device)
for part_id in range(1, u_bbox.size(1)):
uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id]
uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id]
return uv

View File

@@ -0,0 +1 @@
from .transformers import DepthAnythingDetector

View File

@@ -0,0 +1,75 @@
"""
Modern DepthAnything implementation using HuggingFace transformers.
Replaces legacy torch.hub.load DINOv2 backbone with transformers pipeline.
"""
import numpy as np
import torch
from PIL import Image
from transformers import pipeline
from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad
class DepthAnythingDetector:
"""DepthAnything depth estimation using HuggingFace transformers."""
def __init__(self, model_name="LiheYoung/depth-anything-large-hf"):
"""Initialize DepthAnything with specified model."""
self.pipe = pipeline(task="depth-estimation", model=model_name)
self.device = "cpu"
@classmethod
def from_pretrained(cls, pretrained_model_or_path=None, filename="depth_anything_vitl14.pth"):
"""Create DepthAnything from pretrained model, mapping legacy names to HuggingFace models."""
# Map legacy checkpoint names to modern HuggingFace models
model_mapping = {
"depth_anything_vitl14.pth": "LiheYoung/depth-anything-large-hf",
"depth_anything_vitb14.pth": "LiheYoung/depth-anything-base-hf",
"depth_anything_vits14.pth": "LiheYoung/depth-anything-small-hf"
}
model_name = model_mapping.get(filename, "LiheYoung/depth-anything-large-hf")
return cls(model_name=model_name)
def to(self, device):
"""Move model to specified device."""
self.pipe.model = self.pipe.model.to(device)
self.device = device
return self
def __call__(self, input_image, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
"""Perform depth estimation on input image."""
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
if isinstance(input_image, np.ndarray):
pil_image = Image.fromarray(input_image)
else:
pil_image = input_image
with torch.no_grad():
result = self.pipe(pil_image)
depth = result["depth"]
if isinstance(depth, Image.Image):
depth_array = np.array(depth, dtype=np.float32)
else:
depth_array = np.array(depth)
# Normalize depth values to 0-255 range
depth_min = depth_array.min()
depth_max = depth_array.max()
if depth_max > depth_min:
depth_array = (depth_array - depth_min) / (depth_max - depth_min) * 255.0
else:
depth_array = np.zeros_like(depth_array)
depth_image = depth_array.astype(np.uint8)
detected_map = remove_pad(HWC3(depth_image))
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map

View File

@@ -0,0 +1,56 @@
import numpy as np
import torch
from einops import repeat
from PIL import Image
from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, DEPTH_ANYTHING_V2_MODEL_NAME_DICT
from custom_controlnet_aux.depth_anything_v2.dpt import DepthAnythingV2
import cv2
import torch.nn.functional as F
# https://github.com/DepthAnything/Depth-Anything-V2/blob/main/app.py
model_configs = {
'depth_anything_v2_vits.pth': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'depth_anything_v2_vitb.pth': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'depth_anything_v2_vitl.pth': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'depth_anything_v2_vitg.pth': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]},
'depth_anything_v2_metric_vkitti_vitl.pth': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'depth_anything_v2_metric_hypersim_vitl.pth': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
}
class DepthAnythingV2Detector:
def __init__(self, model, filename):
self.model = model
self.device = "cpu"
self.filename = filename
@classmethod
def from_pretrained(cls, pretrained_model_or_path=None, filename="depth_anything_v2_vits.pth"):
if pretrained_model_or_path is None:
pretrained_model_or_path = DEPTH_ANYTHING_V2_MODEL_NAME_DICT[filename]
model_path = custom_hf_download(pretrained_model_or_path, filename)
model = DepthAnythingV2(**model_configs[filename])
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model = model.eval()
return cls(model, filename)
def to(self, device):
self.model.to(device)
self.device = device
return self
def __call__(self, input_image, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", max_depth=20.0, **kwargs):
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
depth = self.model.infer_image(cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR), input_size=518, max_depth=max_depth)
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.astype(np.uint8)
if 'metric' in self.filename:
depth = 255 - depth
detected_map = repeat(depth, "h w -> h w 3")
detected_map, remove_pad = resize_image_with_pad(detected_map, detect_resolution, upscale_method)
detected_map = remove_pad(detected_map)
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map

View File

@@ -0,0 +1,415 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
from functools import partial
import math
import logging
from typing import Sequence, Tuple, Union, Callable
import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch.nn.init import trunc_normal_
from custom_controlnet_aux.depth_anything_v2.dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
logger = logging.getLogger("dinov2")
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = ".".join((name, child_name)) if name else child_name
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if depth_first and include_root:
fn(module=module, name=name)
return module
class BlockChunk(nn.ModuleList):
def forward(self, x):
for b in self:
x = b(x)
return x
class DinoVisionTransformer(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
ffn_bias=True,
proj_bias=True,
drop_path_rate=0.0,
drop_path_uniform=False,
init_values=None, # for layerscale: None or 0 => no layerscale
embed_layer=PatchEmbed,
act_layer=nn.GELU,
block_fn=Block,
ffn_layer="mlp",
block_chunks=1,
num_register_tokens=0,
interpolate_antialias=False,
interpolate_offset=0.1,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
proj_bias (bool): enable bias for proj in attn if True
ffn_bias (bool): enable bias for ffn if True
drop_path_rate (float): stochastic depth rate
drop_path_uniform (bool): apply uniform drop rate across blocks
weight_init (str): weight init scheme
init_values (float): layer-scale init values
embed_layer (nn.Module): patch embedding layer
act_layer (nn.Module): MLP activation layer
block_fn (nn.Module): transformer block class
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
"""
super().__init__()
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1
self.n_blocks = depth
self.num_heads = num_heads
self.patch_size = patch_size
self.num_register_tokens = num_register_tokens
self.interpolate_antialias = interpolate_antialias
self.interpolate_offset = interpolate_offset
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
assert num_register_tokens >= 0
self.register_tokens = (
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
)
if drop_path_uniform is True:
dpr = [drop_path_rate] * depth
else:
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
if ffn_layer == "mlp":
logger.info("using MLP layer as FFN")
ffn_layer = Mlp
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
logger.info("using SwiGLU layer as FFN")
ffn_layer = SwiGLUFFNFused
elif ffn_layer == "identity":
logger.info("using Identity layer as FFN")
def f(*args, **kwargs):
return nn.Identity()
ffn_layer = f
else:
raise NotImplementedError
blocks_list = [
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
ffn_layer=ffn_layer,
init_values=init_values,
)
for i in range(depth)
]
if block_chunks > 0:
self.chunked_blocks = True
chunked_blocks = []
chunksize = depth // block_chunks
for i in range(0, depth, chunksize):
# this is to keep the block index consistent if we chunk the block list
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
else:
self.chunked_blocks = False
self.blocks = nn.ModuleList(blocks_list)
self.norm = norm_layer(embed_dim)
self.head = nn.Identity()
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
self.init_weights()
def init_weights(self):
trunc_normal_(self.pos_embed, std=0.02)
nn.init.normal_(self.cls_token, std=1e-6)
if self.register_tokens is not None:
nn.init.normal_(self.register_tokens, std=1e-6)
named_apply(init_weights_vit_timm, self)
def interpolate_pos_encoding(self, x, w, h):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
# DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
# w0, h0 = w0 + 0.1, h0 + 0.1
sqrt_N = math.sqrt(N)
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
scale_factor=(sx, sy),
# (int(w0), int(h0)), # to solve the upsampling shape issue
mode="bicubic",
antialias=self.interpolate_antialias
)
assert int(w0) == patch_pos_embed.shape[-2]
assert int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
def prepare_tokens_with_masks(self, x, masks=None):
B, nc, w, h = x.shape
x = self.patch_embed(x)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.interpolate_pos_encoding(x, w, h)
if self.register_tokens is not None:
x = torch.cat(
(
x[:, :1],
self.register_tokens.expand(x.shape[0], -1, -1),
x[:, 1:],
),
dim=1,
)
return x
def forward_features_list(self, x_list, masks_list):
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
for blk in self.blocks:
x = blk(x)
all_x = x
output = []
for x, masks in zip(all_x, masks_list):
x_norm = self.norm(x)
output.append(
{
"x_norm_clstoken": x_norm[:, 0],
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
"x_prenorm": x,
"masks": masks,
}
)
return output
def forward_features(self, x, masks=None):
if isinstance(x, list):
return self.forward_features_list(x, masks)
x = self.prepare_tokens_with_masks(x, masks)
for blk in self.blocks:
x = blk(x)
x_norm = self.norm(x)
return {
"x_norm_clstoken": x_norm[:, 0],
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
"x_prenorm": x,
"masks": masks,
}
def _get_intermediate_layers_not_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
# If n is an int, take the n last blocks. If it's a list, take them
output, total_block_len = [], len(self.blocks)
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in blocks_to_take:
output.append(x)
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
return output
def _get_intermediate_layers_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
output, i, total_block_len = [], 0, len(self.blocks[-1])
# If n is an int, take the n last blocks. If it's a list, take them
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for block_chunk in self.blocks:
for blk in block_chunk[i:]: # Passing the nn.Identity()
x = blk(x)
if i in blocks_to_take:
output.append(x)
i += 1
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
return output
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1, # Layers or n last layers to take
reshape: bool = False,
return_class_token: bool = False,
norm=True
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
if self.chunked_blocks:
outputs = self._get_intermediate_layers_chunked(x, n)
else:
outputs = self._get_intermediate_layers_not_chunked(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, 0] for out in outputs]
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
if reshape:
B, _, w, h = x.shape
outputs = [
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_class_token:
return tuple(zip(outputs, class_tokens))
return tuple(outputs)
def forward(self, *args, is_training=False, **kwargs):
ret = self.forward_features(*args, **kwargs)
if is_training:
return ret
else:
return self.head(ret["x_norm_clstoken"])
def init_weights_vit_timm(module: nn.Module, name: str = ""):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
"""
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
"""
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1536,
depth=40,
num_heads=24,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def DINOv2(model_name):
model_zoo = {
"vits": vit_small,
"vitb": vit_base,
"vitl": vit_large,
"vitg": vit_giant2
}
return model_zoo[model_name](
img_size=518,
patch_size=14,
init_values=1.0,
ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
block_chunks=0,
num_register_tokens=0,
interpolate_antialias=False,
interpolate_offset=0.1
)

View File

@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from .mlp import Mlp
from .patch_embed import PatchEmbed
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
from .block import NestedTensorBlock
from .attention import MemEffAttention

View File

@@ -0,0 +1,83 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import logging
from torch import Tensor
from torch import nn
logger = logging.getLogger("dinov2")
try:
from xformers.ops import memory_efficient_attention, unbind, fmha
XFORMERS_AVAILABLE = True
except ImportError:
logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: Tensor) -> Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MemEffAttention(Attention):
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
if not XFORMERS_AVAILABLE:
assert attn_bias is None, "xFormers is required for nested tensors usage"
return super().forward(x)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = unbind(qkv, 2)
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
x = x.reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x

View File

@@ -0,0 +1,252 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
import logging
from typing import Callable, List, Any, Tuple, Dict
import torch
from torch import nn, Tensor
from .attention import Attention, MemEffAttention
from .drop_path import DropPath
from .layer_scale import LayerScale
from .mlp import Mlp
logger = logging.getLogger("dinov2")
try:
from xformers.ops import fmha
from xformers.ops import scaled_index_add, index_select_cat
XFORMERS_AVAILABLE = True
except ImportError:
logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
) -> None:
super().__init__()
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, x: Tensor) -> Tensor:
def attn_residual_func(x: Tensor) -> Tensor:
return self.ls1(self.attn(self.norm1(x)))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.1:
# the overhead is compensated only for a drop path rate larger than 0.1
x = drop_add_residual_stochastic_depth(
x,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
x = drop_add_residual_stochastic_depth(
x,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
elif self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x))
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
else:
x = x + attn_residual_func(x)
x = x + ffn_residual_func(x)
return x
def drop_add_residual_stochastic_depth(
x: Tensor,
residual_func: Callable[[Tensor], Tensor],
sample_drop_ratio: float = 0.0,
) -> Tensor:
# 1) extract subset using permutation
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
x_subset = x[brange]
# 2) apply residual_func to get residual
residual = residual_func(x_subset)
x_flat = x.flatten(1)
residual = residual.flatten(1)
residual_scale_factor = b / sample_subset_size
# 3) add the residual
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
return x_plus_residual.view_as(x)
def get_branges_scales(x, sample_drop_ratio=0.0):
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
residual_scale_factor = b / sample_subset_size
return brange, residual_scale_factor
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
if scaling_vector is None:
x_flat = x.flatten(1)
residual = residual.flatten(1)
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
else:
x_plus_residual = scaled_index_add(
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
)
return x_plus_residual
attn_bias_cache: Dict[Tuple, Any] = {}
def get_attn_bias_and_cat(x_list, branges=None):
"""
this will perform the index select, cat the tensors, and provide the attn_bias from cache
"""
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
if all_shapes not in attn_bias_cache.keys():
seqlens = []
for b, x in zip(batch_sizes, x_list):
for _ in range(b):
seqlens.append(x.shape[1])
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
attn_bias._batch_sizes = batch_sizes
attn_bias_cache[all_shapes] = attn_bias
if branges is not None:
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
else:
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
cat_tensors = torch.cat(tensors_bs1, dim=1)
return attn_bias_cache[all_shapes], cat_tensors
def drop_add_residual_stochastic_depth_list(
x_list: List[Tensor],
residual_func: Callable[[Tensor, Any], Tensor],
sample_drop_ratio: float = 0.0,
scaling_vector=None,
) -> Tensor:
# 1) generate random set of indices for dropping samples in the batch
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
branges = [s[0] for s in branges_scales]
residual_scale_factors = [s[1] for s in branges_scales]
# 2) get attention bias and index+concat the tensors
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
# 3) apply residual_func to get residual, and split the result
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
outputs = []
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
return outputs
class NestedTensorBlock(Block):
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
"""
x_list contains a list of tensors to nest together and run
"""
assert isinstance(self.attn, MemEffAttention)
if self.training and self.sample_drop_ratio > 0.0:
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.attn(self.norm1(x), attn_bias=attn_bias)
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.mlp(self.norm2(x))
x_list = drop_add_residual_stochastic_depth_list(
x_list,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
)
x_list = drop_add_residual_stochastic_depth_list(
x_list,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
)
return x_list
else:
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
attn_bias, x = get_attn_bias_and_cat(x_list)
x = x + attn_residual_func(x, attn_bias=attn_bias)
x = x + ffn_residual_func(x)
return attn_bias.split(x)
def forward(self, x_or_x_list):
if isinstance(x_or_x_list, Tensor):
return super().forward(x_or_x_list)
elif isinstance(x_or_x_list, list):
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
return self.forward_nested(x_or_x_list)
else:
raise AssertionError

View File

@@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
from torch import nn
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0:
random_tensor.div_(keep_prob)
output = x * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)

View File

@@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
from typing import Union
import torch
from torch import Tensor
from torch import nn
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: Union[float, Tensor] = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma

View File

@@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
from typing import Callable, Optional
from torch import Tensor, nn
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop = nn.Dropout(drop)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x

View File

@@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
from typing import Callable, Optional, Tuple, Union
from torch import Tensor
import torch.nn as nn
def make_2tuple(x):
if isinstance(x, tuple):
assert len(x) == 2
return x
assert isinstance(x, int)
return (x, x)
class PatchEmbed(nn.Module):
"""
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
Args:
img_size: Image size.
patch_size: Patch token size.
in_chans: Number of input image channels.
embed_dim: Number of linear projection output channels.
norm_layer: Normalization layer.
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten_embedding: bool = True,
) -> None:
super().__init__()
image_HW = make_2tuple(img_size)
patch_HW = make_2tuple(patch_size)
patch_grid_size = (
image_HW[0] // patch_HW[0],
image_HW[1] // patch_HW[1],
)
self.img_size = image_HW
self.patch_size = patch_HW
self.patches_resolution = patch_grid_size
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.flatten_embedding = flatten_embedding
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
_, _, H, W = x.shape
patch_H, patch_W = self.patch_size
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
x = self.proj(x) # B C H W
H, W = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2) # B HW C
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
return x
def flops(self) -> float:
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops

View File

@@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Optional
from torch import Tensor, nn
import torch.nn.functional as F
class SwiGLUFFN(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
def forward(self, x: Tensor) -> Tensor:
x12 = self.w12(x)
x1, x2 = x12.chunk(2, dim=-1)
hidden = F.silu(x1) * x2
return self.w3(hidden)
try:
from xformers.ops import SwiGLU
XFORMERS_AVAILABLE = True
except ImportError:
SwiGLU = SwiGLUFFN
XFORMERS_AVAILABLE = False
class SwiGLUFFNFused(SwiGLU):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
out_features = out_features or in_features
hidden_features = hidden_features or in_features
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
super().__init__(
in_features=in_features,
hidden_features=hidden_features,
out_features=out_features,
bias=bias,
)

View File

@@ -0,0 +1,220 @@
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose
from custom_controlnet_aux.depth_anything_v2.dinov2 import DINOv2
from custom_controlnet_aux.depth_anything_v2.util.blocks import FeatureFusionBlock, _make_scratch
from custom_controlnet_aux.depth_anything_v2.util.transform import Resize, NormalizeImage, PrepareForNet
def _make_fusion_block(features, use_bn, size=None):
return FeatureFusionBlock(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
size=size,
)
class ConvBlock(nn.Module):
def __init__(self, in_feature, out_feature):
super().__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_feature),
nn.ReLU(True)
)
def forward(self, x):
return self.conv_block(x)
class DPTHead(nn.Module):
def __init__(
self,
in_channels,
features=256,
use_bn=False,
out_channels=[256, 512, 1024, 1024],
use_clstoken=False
):
super(DPTHead, self).__init__()
self.use_clstoken = use_clstoken
self.projects = nn.ModuleList([
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channel,
kernel_size=1,
stride=1,
padding=0,
) for out_channel in out_channels
])
self.resize_layers = nn.ModuleList([
nn.ConvTranspose2d(
in_channels=out_channels[0],
out_channels=out_channels[0],
kernel_size=4,
stride=4,
padding=0),
nn.ConvTranspose2d(
in_channels=out_channels[1],
out_channels=out_channels[1],
kernel_size=2,
stride=2,
padding=0),
nn.Identity(),
nn.Conv2d(
in_channels=out_channels[3],
out_channels=out_channels[3],
kernel_size=3,
stride=2,
padding=1)
])
if use_clstoken:
self.readout_projects = nn.ModuleList()
for _ in range(len(self.projects)):
self.readout_projects.append(
nn.Sequential(
nn.Linear(2 * in_channels, in_channels),
nn.GELU()))
self.scratch = _make_scratch(
out_channels,
features,
groups=1,
expand=False,
)
self.scratch.stem_transpose = None
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
head_features_1 = features
head_features_2 = 32
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
self.scratch.output_conv2 = nn.Sequential(
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True),
nn.Identity(),
)
def forward(self, out_features, patch_h, patch_w):
out = []
for i, x in enumerate(out_features):
if self.use_clstoken:
x, cls_token = x[0], x[1]
readout = cls_token.unsqueeze(1).expand_as(x)
x = self.readout_projects[i](torch.cat((x, readout), -1))
else:
x = x[0]
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
x = self.projects[i](x)
x = self.resize_layers[i](x)
out.append(x)
layer_1, layer_2, layer_3, layer_4 = out
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv1(path_1)
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
out = self.scratch.output_conv2(out)
return out
class DepthAnythingV2(nn.Module):
def __init__(
self,
encoder='vitl',
features=256,
out_channels=[256, 512, 1024, 1024],
use_bn=False,
use_clstoken=False
):
super(DepthAnythingV2, self).__init__()
self.intermediate_layer_idx = {
'vits': [2, 5, 8, 11],
'vitb': [2, 5, 8, 11],
'vitl': [4, 11, 17, 23],
'vitg': [9, 19, 29, 39]
}
self.encoder = encoder
self.pretrained = DINOv2(model_name=encoder)
self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
def forward(self, x, max_depth):
patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
depth = self.depth_head(features, patch_h, patch_w) * max_depth
return depth.squeeze(1)
@torch.no_grad()
def infer_image(self, raw_image, input_size=518, max_depth=20.0):
image, (h, w) = self.image2tensor(raw_image, input_size)
depth = self.forward(image, max_depth)
depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
return depth.cpu().numpy()
def image2tensor(self, raw_image, input_size=518):
transform = Compose([
Resize(
width=input_size,
height=input_size,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method='lower_bound',
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
])
h, w = raw_image.shape[:2]
image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
image = transform({'image': image})['image']
image = torch.from_numpy(image).unsqueeze(0)
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
image = image.to(DEVICE)
return image, (h, w)

Some files were not shown because too many files have changed in this diff Show More