Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
183
custom_nodes/comfyui_controlnet_aux/.gitignore
vendored
Normal file
@@ -0,0 +1,183 @@
|
||||
# Initially taken from Github's Python gitignore file
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# tests and logs
|
||||
tests/fixtures/cached_*_text.txt
|
||||
logs/
|
||||
lightning_logs/
|
||||
lang_code_data/
|
||||
tests/outputs
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# vscode
|
||||
.vs
|
||||
.vscode
|
||||
|
||||
# Pycharm
|
||||
.idea
|
||||
|
||||
# TF code
|
||||
tensorflow_code
|
||||
|
||||
# Models
|
||||
proc_data
|
||||
|
||||
# examples
|
||||
runs
|
||||
/runs_old
|
||||
/wandb
|
||||
/examples/runs
|
||||
/examples/**/*.args
|
||||
/examples/rag/sweep
|
||||
|
||||
# data
|
||||
/data
|
||||
serialization_dir
|
||||
|
||||
# emacs
|
||||
*.*~
|
||||
debug.env
|
||||
|
||||
# vim
|
||||
.*.swp
|
||||
|
||||
#ctags
|
||||
tags
|
||||
|
||||
# pre-commit
|
||||
.pre-commit*
|
||||
|
||||
# .lock
|
||||
*.lock
|
||||
|
||||
# DS_Store (MacOS)
|
||||
.DS_Store
|
||||
# RL pipelines may produce mp4 outputs
|
||||
*.mp4
|
||||
|
||||
# dependencies
|
||||
/transformers
|
||||
|
||||
# ruff
|
||||
.ruff_cache
|
||||
|
||||
wandb
|
||||
|
||||
ckpts/
|
||||
|
||||
test.ipynb
|
||||
config.yaml
|
||||
test.ipynb
|
||||
201
custom_nodes/comfyui_controlnet_aux/LICENSE.txt
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
BIN
custom_nodes/comfyui_controlnet_aux/NotoSans-Regular.ttf
Normal file
252
custom_nodes/comfyui_controlnet_aux/README.md
Normal file
@@ -0,0 +1,252 @@
|
||||
# ComfyUI's ControlNet Auxiliary Preprocessors
|
||||
Plug-and-play [ComfyUI](https://github.com/comfyanonymous/ComfyUI) node sets for making [ControlNet](https://github.com/lllyasviel/ControlNet/) hint images
|
||||
|
||||
"anime style, a protest in the street, cyberpunk city, a woman with pink hair and golden eyes (looking at the viewer) is holding a sign with the text "ComfyUI ControlNet Aux" in bold, neon pink" on Flux.1 Dev
|
||||
|
||||

|
||||
|
||||
The code is copy-pasted from the respective folders in https://github.com/lllyasviel/ControlNet/tree/main/annotator and connected to [the 🤗 Hub](https://huggingface.co/lllyasviel/Annotators).
|
||||
|
||||
All credit & copyright goes to https://github.com/lllyasviel.
|
||||
|
||||
# Updates
|
||||
Go to [Update page](./UPDATES.md) to follow updates
|
||||
|
||||
# Installation:
|
||||
## Using ComfyUI Manager (recommended):
|
||||
Install [ComfyUI Manager](https://github.com/ltdrdata/ComfyUI-Manager) and do steps introduced there to install this repo.
|
||||
|
||||
## Alternative:
|
||||
If you're running on Linux, or non-admin account on windows you'll want to ensure `/ComfyUI/custom_nodes` and `comfyui_controlnet_aux` has write permissions.
|
||||
|
||||
There is now a **install.bat** you can run to install to portable if detected. Otherwise it will default to system and assume you followed ConfyUI's manual installation steps.
|
||||
|
||||
If you can't run **install.bat** (e.g. you are a Linux user). Open the CMD/Shell and do the following:
|
||||
- Navigate to your `/ComfyUI/custom_nodes/` folder
|
||||
- Run `git clone https://github.com/Fannovel16/comfyui_controlnet_aux/`
|
||||
- Navigate to your `comfyui_controlnet_aux` folder
|
||||
- Portable/venv:
|
||||
- Run `path/to/ComfUI/python_embeded/python.exe -s -m pip install -r requirements.txt`
|
||||
- With system python
|
||||
- Run `pip install -r requirements.txt`
|
||||
- Start ComfyUI
|
||||
|
||||
# Nodes
|
||||
Please note that this repo only supports preprocessors making hint images (e.g. stickman, canny edge, etc).
|
||||
All preprocessors except Inpaint are intergrated into `AIO Aux Preprocessor` node.
|
||||
This node allow you to quickly get the preprocessor but a preprocessor's own threshold parameters won't be able to set.
|
||||
You need to use its node directly to set thresholds.
|
||||
|
||||
# Nodes (sections are categories in Comfy menu)
|
||||
## Line Extractors
|
||||
| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
|
||||
|-----------------------------|---------------------------|-------------------------------------------|
|
||||
| Binary Lines | binary | control_scribble |
|
||||
| Canny Edge | canny | control_v11p_sd15_canny <br> control_canny <br> t2iadapter_canny |
|
||||
| HED Soft-Edge Lines | hed | control_v11p_sd15_softedge <br> control_hed |
|
||||
| Standard Lineart | standard_lineart | control_v11p_sd15_lineart |
|
||||
| Realistic Lineart | lineart (or `lineart_coarse` if `coarse` is enabled) | control_v11p_sd15_lineart |
|
||||
| Anime Lineart | lineart_anime | control_v11p_sd15s2_lineart_anime |
|
||||
| Manga Lineart | lineart_anime_denoise | control_v11p_sd15s2_lineart_anime |
|
||||
| M-LSD Lines | mlsd | control_v11p_sd15_mlsd <br> control_mlsd |
|
||||
| PiDiNet Soft-Edge Lines | pidinet | control_v11p_sd15_softedge <br> control_scribble |
|
||||
| Scribble Lines | scribble | control_v11p_sd15_scribble <br> control_scribble |
|
||||
| Scribble XDoG Lines | scribble_xdog | control_v11p_sd15_scribble <br> control_scribble |
|
||||
| Fake Scribble Lines | scribble_hed | control_v11p_sd15_scribble <br> control_scribble |
|
||||
| TEED Soft-Edge Lines | teed | [controlnet-sd-xl-1.0-softedge-dexined](https://huggingface.co/SargeZT/controlnet-sd-xl-1.0-softedge-dexined/blob/main/controlnet-sd-xl-1.0-softedge-dexined.safetensors) <br> control_v11p_sd15_softedge (Theoretically)
|
||||
| Scribble PiDiNet Lines | scribble_pidinet | control_v11p_sd15_scribble <br> control_scribble |
|
||||
| AnyLine Lineart | | mistoLine_fp16.safetensors <br> mistoLine_rank256 <br> control_v11p_sd15s2_lineart_anime <br> control_v11p_sd15_lineart |
|
||||
|
||||
## Normal and Depth Estimators
|
||||
| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
|
||||
|-----------------------------|---------------------------|-------------------------------------------|
|
||||
| MiDaS Depth Map | (normal) depth | control_v11f1p_sd15_depth <br> control_depth <br> t2iadapter_depth |
|
||||
| LeReS Depth Map | depth_leres | control_v11f1p_sd15_depth <br> control_depth <br> t2iadapter_depth |
|
||||
| Zoe Depth Map | depth_zoe | control_v11f1p_sd15_depth <br> control_depth <br> t2iadapter_depth |
|
||||
| MiDaS Normal Map | normal_map | control_normal |
|
||||
| BAE Normal Map | normal_bae | control_v11p_sd15_normalbae |
|
||||
| MeshGraphormer Hand Refiner ([HandRefinder](https://github.com/wenquanlu/HandRefiner)) | depth_hand_refiner | [control_sd15_inpaint_depth_hand_fp16](https://huggingface.co/hr16/ControlNet-HandRefiner-pruned/blob/main/control_sd15_inpaint_depth_hand_fp16.safetensors) |
|
||||
| Depth Anything | depth_anything | [Depth-Anything](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints_controlnet/diffusion_pytorch_model.safetensors) |
|
||||
| Zoe Depth Anything <br> (Basically Zoe but the encoder is replaced with DepthAnything) | depth_anything | [Depth-Anything](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints_controlnet/diffusion_pytorch_model.safetensors) |
|
||||
| Normal DSINE | | control_normal/control_v11p_sd15_normalbae |
|
||||
| Metric3D Depth | | control_v11f1p_sd15_depth <br> control_depth <br> t2iadapter_depth |
|
||||
| Metric3D Normal | | control_v11p_sd15_normalbae |
|
||||
| Depth Anything V2 | | [Depth-Anything](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints_controlnet/diffusion_pytorch_model.safetensors) |
|
||||
|
||||
## Faces and Poses Estimators
|
||||
| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
|
||||
|-----------------------------|---------------------------|-------------------------------------------|
|
||||
| DWPose Estimator | dw_openpose_full | control_v11p_sd15_openpose <br> control_openpose <br> t2iadapter_openpose |
|
||||
| OpenPose Estimator | openpose (detect_body) <br> openpose_hand (detect_body + detect_hand) <br> openpose_faceonly (detect_face) <br> openpose_full (detect_hand + detect_body + detect_face) | control_v11p_sd15_openpose <br> control_openpose <br> t2iadapter_openpose |
|
||||
| MediaPipe Face Mesh | mediapipe_face | controlnet_sd21_laion_face_v2 |
|
||||
| Animal Estimator | animal_openpose | [control_sd15_animal_openpose_fp16](https://huggingface.co/huchenlei/animal_openpose/blob/main/control_sd15_animal_openpose_fp16.pth) |
|
||||
|
||||
## Optical Flow Estimators
|
||||
| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
|
||||
|-----------------------------|---------------------------|-------------------------------------------|
|
||||
| Unimatch Optical Flow | | [DragNUWA](https://github.com/ProjectNUWA/DragNUWA) |
|
||||
|
||||
### How to get OpenPose-format JSON?
|
||||
#### User-side
|
||||
This workflow will save images to ComfyUI's output folder (the same location as output images). If you haven't found `Save Pose Keypoints` node, update this extension
|
||||

|
||||
|
||||
#### Dev-side
|
||||
An array of [OpenPose-format JSON](https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/02_output.md#json-output-format) corresponsding to each frame in an IMAGE batch can be gotten from DWPose and OpenPose using `app.nodeOutputs` on the UI or `/history` API endpoint. JSON output from AnimalPose uses a kinda similar format to OpenPose JSON:
|
||||
```
|
||||
[
|
||||
{
|
||||
"version": "ap10k",
|
||||
"animals": [
|
||||
[[x1, y1, 1], [x2, y2, 1],..., [x17, y17, 1]],
|
||||
[[x1, y1, 1], [x2, y2, 1],..., [x17, y17, 1]],
|
||||
...
|
||||
],
|
||||
"canvas_height": 512,
|
||||
"canvas_width": 768
|
||||
},
|
||||
...
|
||||
]
|
||||
```
|
||||
|
||||
For extension developers (e.g. Openpose editor):
|
||||
```js
|
||||
const poseNodes = app.graph._nodes.filter(node => ["OpenposePreprocessor", "DWPreprocessor", "AnimalPosePreprocessor"].includes(node.type))
|
||||
for (const poseNode of poseNodes) {
|
||||
const openposeResults = JSON.parse(app.nodeOutputs[poseNode.id].openpose_json[0])
|
||||
console.log(openposeResults) //An array containing Openpose JSON for each frame
|
||||
}
|
||||
```
|
||||
|
||||
For API users:
|
||||
Javascript
|
||||
```js
|
||||
import fetch from "node-fetch" //Remember to add "type": "module" to "package.json"
|
||||
async function main() {
|
||||
const promptId = '792c1905-ecfe-41f4-8114-83e6a4a09a9f' //Too lazy to POST /queue
|
||||
let history = await fetch(`http://127.0.0.1:8188/history/${promptId}`).then(re => re.json())
|
||||
history = history[promptId]
|
||||
const nodeOutputs = Object.values(history.outputs).filter(output => output.openpose_json)
|
||||
for (const nodeOutput of nodeOutputs) {
|
||||
const openposeResults = JSON.parse(nodeOutput.openpose_json[0])
|
||||
console.log(openposeResults) //An array containing Openpose JSON for each frame
|
||||
}
|
||||
}
|
||||
main()
|
||||
```
|
||||
|
||||
Python
|
||||
```py
|
||||
import json, urllib.request
|
||||
|
||||
server_address = "127.0.0.1:8188"
|
||||
prompt_id = '' #Too lazy to POST /queue
|
||||
|
||||
def get_history(prompt_id):
|
||||
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
|
||||
return json.loads(response.read())
|
||||
|
||||
history = get_history(prompt_id)[prompt_id]
|
||||
for o in history['outputs']:
|
||||
for node_id in history['outputs']:
|
||||
node_output = history['outputs'][node_id]
|
||||
if 'openpose_json' in node_output:
|
||||
print(json.loads(node_output['openpose_json'][0])) #An list containing Openpose JSON for each frame
|
||||
```
|
||||
## Semantic Segmentation
|
||||
| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
|
||||
|-----------------------------|---------------------------|-------------------------------------------|
|
||||
| OneFormer ADE20K Segmentor | oneformer_ade20k | control_v11p_sd15_seg |
|
||||
| OneFormer COCO Segmentor | oneformer_coco | control_v11p_sd15_seg |
|
||||
| UniFormer Segmentor | segmentation |control_sd15_seg <br> control_v11p_sd15_seg|
|
||||
|
||||
## T2IAdapter-only
|
||||
| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
|
||||
|-----------------------------|---------------------------|-------------------------------------------|
|
||||
| Color Pallete | color | t2iadapter_color |
|
||||
| Content Shuffle | shuffle | t2iadapter_style |
|
||||
|
||||
## Recolor
|
||||
| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
|
||||
|-----------------------------|---------------------------|-------------------------------------------|
|
||||
| Image Luminance | recolor_luminance | [ioclab_sd15_recolor](https://huggingface.co/lllyasviel/sd_control_collection/resolve/main/ioclab_sd15_recolor.safetensors) <br> [sai_xl_recolor_256lora](https://huggingface.co/lllyasviel/sd_control_collection/resolve/main/sai_xl_recolor_256lora.safetensors) <br> [bdsqlsz_controlllite_xl_recolor_luminance](https://huggingface.co/bdsqlsz/qinglong_controlnet-lllite/resolve/main/bdsqlsz_controlllite_xl_recolor_luminance.safetensors) |
|
||||
| Image Intensity | recolor_intensity | Idk. Maybe same as above? |
|
||||
|
||||
# Examples
|
||||
> A picture is worth a thousand words
|
||||
|
||||

|
||||

|
||||
|
||||
# Testing workflow
|
||||
https://github.com/Fannovel16/comfyui_controlnet_aux/blob/main/examples/ExecuteAll.png
|
||||
Input image: https://github.com/Fannovel16/comfyui_controlnet_aux/blob/main/examples/comfyui-controlnet-aux-logo.png
|
||||
|
||||
# Q&A:
|
||||
## Why some nodes doesn't appear after I installed this repo?
|
||||
|
||||
This repo has a new mechanism which will skip any custom node can't be imported. If you meet this case, please create a issue on [Issues tab](https://github.com/Fannovel16/comfyui_controlnet_aux/issues) with the log from the command line.
|
||||
|
||||
## DWPose/AnimalPose only uses CPU so it's so slow. How can I make it use GPU?
|
||||
There are two ways to speed-up DWPose: using TorchScript checkpoints (.torchscript.pt) checkpoints or ONNXRuntime (.onnx). TorchScript way is little bit slower than ONNXRuntime but doesn't require any additional library and still way way faster than CPU.
|
||||
|
||||
A torchscript bbox detector is compatiable with an onnx pose estimator and vice versa.
|
||||
### TorchScript
|
||||
Set `bbox_detector` and `pose_estimator` according to this picture. You can try other bbox detector endings with `.torchscript.pt` to reduce bbox detection time if input images are ideal.
|
||||

|
||||
### ONNXRuntime
|
||||
If onnxruntime is installed successfully and the checkpoint used endings with `.onnx`, it will replace default cv2 backend to take advantage of GPU. Note that if you are using NVidia card, this method currently can only works on CUDA 11.8 (ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z) unless you compile onnxruntime yourself.
|
||||
|
||||
1. Know your onnxruntime build:
|
||||
* * NVidia CUDA 11.x or bellow/AMD GPU: `onnxruntime-gpu`
|
||||
* * NVidia CUDA 12.x: `onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/`
|
||||
* * DirectML: `onnxruntime-directml`
|
||||
* * OpenVINO: `onnxruntime-openvino`
|
||||
|
||||
Note that if this is your first time using ComfyUI, please test if it can run on your device before doing next steps.
|
||||
|
||||
2. Add it into `requirements.txt`
|
||||
|
||||
3. Run `install.bat` or pip command mentioned in Installation
|
||||
|
||||

|
||||
|
||||
# Assets files of preprocessors
|
||||
* anime_face_segment: [bdsqlsz/qinglong_controlnet-lllite/Annotators/UNet.pth](https://huggingface.co/bdsqlsz/qinglong_controlnet-lllite/blob/main/Annotators/UNet.pth), [anime-seg/isnetis.ckpt](https://huggingface.co/skytnt/anime-seg/blob/main/isnetis.ckpt)
|
||||
* densepose: [LayerNorm/DensePose-TorchScript-with-hint-image/densepose_r50_fpn_dl.torchscript](https://huggingface.co/LayerNorm/DensePose-TorchScript-with-hint-image/blob/main/densepose_r50_fpn_dl.torchscript)
|
||||
* dwpose:
|
||||
* * bbox_detector: Either [yzd-v/DWPose/yolox_l.onnx](https://huggingface.co/yzd-v/DWPose/blob/main/yolox_l.onnx), [hr16/yolox-onnx/yolox_l.torchscript.pt](https://huggingface.co/hr16/yolox-onnx/blob/main/yolox_l.torchscript.pt), [hr16/yolo-nas-fp16/yolo_nas_l_fp16.onnx](https://huggingface.co/hr16/yolo-nas-fp16/blob/main/yolo_nas_l_fp16.onnx), [hr16/yolo-nas-fp16/yolo_nas_m_fp16.onnx](https://huggingface.co/hr16/yolo-nas-fp16/blob/main/yolo_nas_m_fp16.onnx), [hr16/yolo-nas-fp16/yolo_nas_s_fp16.onnx](https://huggingface.co/hr16/yolo-nas-fp16/blob/main/yolo_nas_s_fp16.onnx)
|
||||
* * pose_estimator: Either [hr16/DWPose-TorchScript-BatchSize5/dw-ll_ucoco_384_bs5.torchscript.pt](https://huggingface.co/hr16/DWPose-TorchScript-BatchSize5/blob/main/dw-ll_ucoco_384_bs5.torchscript.pt), [yzd-v/DWPose/dw-ll_ucoco_384.onnx](https://huggingface.co/yzd-v/DWPose/blob/main/dw-ll_ucoco_384.onnx)
|
||||
* animal_pose (ap10k):
|
||||
* * bbox_detector: Either [yzd-v/DWPose/yolox_l.onnx](https://huggingface.co/yzd-v/DWPose/blob/main/yolox_l.onnx), [hr16/yolox-onnx/yolox_l.torchscript.pt](https://huggingface.co/hr16/yolox-onnx/blob/main/yolox_l.torchscript.pt), [hr16/yolo-nas-fp16/yolo_nas_l_fp16.onnx](https://huggingface.co/hr16/yolo-nas-fp16/blob/main/yolo_nas_l_fp16.onnx), [hr16/yolo-nas-fp16/yolo_nas_m_fp16.onnx](https://huggingface.co/hr16/yolo-nas-fp16/blob/main/yolo_nas_m_fp16.onnx), [hr16/yolo-nas-fp16/yolo_nas_s_fp16.onnx](https://huggingface.co/hr16/yolo-nas-fp16/blob/main/yolo_nas_s_fp16.onnx)
|
||||
* * pose_estimator: Either [hr16/DWPose-TorchScript-BatchSize5/rtmpose-m_ap10k_256_bs5.torchscript.pt](https://huggingface.co/hr16/DWPose-TorchScript-BatchSize5/blob/main/rtmpose-m_ap10k_256_bs5.torchscript.pt), [hr16/UnJIT-DWPose/rtmpose-m_ap10k_256.onnx](https://huggingface.co/hr16/UnJIT-DWPose/blob/main/rtmpose-m_ap10k_256.onnx)
|
||||
* hed: [lllyasviel/Annotators/ControlNetHED.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/ControlNetHED.pth)
|
||||
* leres: [lllyasviel/Annotators/res101.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/res101.pth), [lllyasviel/Annotators/latest_net_G.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/latest_net_G.pth)
|
||||
* lineart: [lllyasviel/Annotators/sk_model.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/sk_model.pth), [lllyasviel/Annotators/sk_model2.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/sk_model2.pth)
|
||||
* lineart_anime: [lllyasviel/Annotators/netG.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/netG.pth)
|
||||
* manga_line: [lllyasviel/Annotators/erika.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/erika.pth)
|
||||
* mesh_graphormer: [hr16/ControlNet-HandRefiner-pruned/graphormer_hand_state_dict.bin](https://huggingface.co/hr16/ControlNet-HandRefiner-pruned/blob/main/graphormer_hand_state_dict.bin), [hr16/ControlNet-HandRefiner-pruned/hrnetv2_w64_imagenet_pretrained.pth](https://huggingface.co/hr16/ControlNet-HandRefiner-pruned/blob/main/hrnetv2_w64_imagenet_pretrained.pth)
|
||||
* midas: [lllyasviel/Annotators/dpt_hybrid-midas-501f0c75.pt](https://huggingface.co/lllyasviel/Annotators/blob/main/dpt_hybrid-midas-501f0c75.pt)
|
||||
* mlsd: [lllyasviel/Annotators/mlsd_large_512_fp32.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/mlsd_large_512_fp32.pth)
|
||||
* normalbae: [lllyasviel/Annotators/scannet.pt](https://huggingface.co/lllyasviel/Annotators/blob/main/scannet.pt)
|
||||
* oneformer: [lllyasviel/Annotators/250_16_swin_l_oneformer_ade20k_160k.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/250_16_swin_l_oneformer_ade20k_160k.pth)
|
||||
* open_pose: [lllyasviel/Annotators/body_pose_model.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/body_pose_model.pth), [lllyasviel/Annotators/hand_pose_model.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/hand_pose_model.pth), [lllyasviel/Annotators/facenet.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/facenet.pth)
|
||||
* pidi: [lllyasviel/Annotators/table5_pidinet.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/table5_pidinet.pth)
|
||||
* sam: [dhkim2810/MobileSAM/mobile_sam.pt](https://huggingface.co/dhkim2810/MobileSAM/blob/main/mobile_sam.pt)
|
||||
* uniformer: [lllyasviel/Annotators/upernet_global_small.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/upernet_global_small.pth)
|
||||
* zoe: [lllyasviel/Annotators/ZoeD_M12_N.pt](https://huggingface.co/lllyasviel/Annotators/blob/main/ZoeD_M12_N.pt)
|
||||
* teed: [bdsqlsz/qinglong_controlnet-lllite/7_model.pth](https://huggingface.co/bdsqlsz/qinglong_controlnet-lllite/blob/main/Annotators/7_model.pth)
|
||||
* depth_anything: Either [LiheYoung/Depth-Anything/checkpoints/depth_anything_vitl14.pth](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints/depth_anything_vitl14.pth), [LiheYoung/Depth-Anything/checkpoints/depth_anything_vitb14.pth](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints/depth_anything_vitb14.pth) or [LiheYoung/Depth-Anything/checkpoints/depth_anything_vits14.pth](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints/depth_anything_vits14.pth)
|
||||
* diffusion_edge: Either [hr16/Diffusion-Edge/diffusion_edge_indoor.pt](https://huggingface.co/hr16/Diffusion-Edge/blob/main/diffusion_edge_indoor.pt), [hr16/Diffusion-Edge/diffusion_edge_urban.pt](https://huggingface.co/hr16/Diffusion-Edge/blob/main/diffusion_edge_urban.pt) or [hr16/Diffusion-Edge/diffusion_edge_natrual.pt](https://huggingface.co/hr16/Diffusion-Edge/blob/main/diffusion_edge_natrual.pt)
|
||||
* unimatch: Either [hr16/Unimatch/gmflow-scale2-regrefine6-mixdata.pth](https://huggingface.co/hr16/Unimatch/blob/main/gmflow-scale2-regrefine6-mixdata.pth), [hr16/Unimatch/gmflow-scale2-mixdata.pth](https://huggingface.co/hr16/Unimatch/blob/main/gmflow-scale2-mixdata.pth) or [hr16/Unimatch/gmflow-scale1-mixdata.pth](https://huggingface.co/hr16/Unimatch/blob/main/gmflow-scale1-mixdata.pth)
|
||||
* zoe_depth_anything: Either [LiheYoung/Depth-Anything/checkpoints_metric_depth/depth_anything_metric_depth_indoor.pt](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints_metric_depth/depth_anything_metric_depth_indoor.pt) or [LiheYoung/Depth-Anything/checkpoints_metric_depth/depth_anything_metric_depth_outdoor.pt](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints_metric_depth/depth_anything_metric_depth_outdoor.pt)
|
||||
# 2000 Stars 😄
|
||||
<a href="https://star-history.com/#Fannovel16/comfyui_controlnet_aux&Date">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Fannovel16/comfyui_controlnet_aux&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Fannovel16/comfyui_controlnet_aux&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Fannovel16/comfyui_controlnet_aux&type=Date" />
|
||||
</picture>
|
||||
</a>
|
||||
|
||||
Thanks for yalls supports. I never thought the graph for stars would be linear lol.
|
||||
45
custom_nodes/comfyui_controlnet_aux/UPDATES.md
Normal file
@@ -0,0 +1,45 @@
|
||||
* `AIO Aux Preprocessor` intergrating all loadable aux preprocessors as dropdown options. Easy to copy, paste and get the preprocessor faster.
|
||||
* Added OpenPose-format JSON output from OpenPose Preprocessor and DWPose Preprocessor. Checks [here](#faces-and-poses).
|
||||
* Fixed wrong model path when downloading DWPose.
|
||||
* Make hint images less blurry.
|
||||
* Added `resolution` option, `PixelPerfectResolution` and `HintImageEnchance` nodes (TODO: Documentation).
|
||||
* Added `RAFT Optical Flow Embedder` for TemporalNet2 (TODO: Workflow example).
|
||||
* Fixed opencv's conflicts between this extension, [ReActor](https://github.com/Gourieff/comfyui-reactor-node) and Roop. Thanks `Gourieff` for [the solution](https://github.com/Fannovel16/comfyui_controlnet_aux/issues/7#issuecomment-1734319075)!
|
||||
* RAFT is removed as the code behind it doesn't match what what the original code does
|
||||
* Changed `lineart`'s display name from `Normal Lineart` to `Realistic Lineart`. This change won't affect old workflows
|
||||
* Added support for `onnxruntime` to speed-up DWPose (see the Q&A)
|
||||
* Fixed TypeError: expected size to be one of int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], but got size with types [<class 'numpy.int64'>, <class 'numpy.int64'>]: [Issue](https://github.com/Fannovel16/comfyui_controlnet_aux/issues/2), [PR](https://github.com/Fannovel16/comfyui_controlnet_aux/pull/71))
|
||||
* Fixed ImageGenResolutionFromImage mishape (https://github.com/Fannovel16/comfyui_controlnet_aux/pull/74)
|
||||
* Fixed LeRes and MiDaS's incomatipility with MPS device
|
||||
* Fixed checking DWPose onnxruntime session multiple times: https://github.com/Fannovel16/comfyui_controlnet_aux/issues/89)
|
||||
* Added `Anime Face Segmentor` (in `ControlNet Preprocessors/Semantic Segmentation`) for [ControlNet AnimeFaceSegmentV2](https://huggingface.co/bdsqlsz/qinglong_controlnet-lllite#animefacesegmentv2). Checks [here](#anime-face-segmentor)
|
||||
* Change download functions and fix [download error](https://github.com/Fannovel16/comfyui_controlnet_aux/issues/39): [PR](https://github.com/Fannovel16/comfyui_controlnet_aux/pull/96)
|
||||
* Caching DWPose Onnxruntime during the first use of DWPose node instead of ComfyUI startup
|
||||
* Added alternative YOLOX models for faster speed when using DWPose
|
||||
* Added alternative DWPose models
|
||||
* Implemented the preprocessor for [AnimalPose ControlNet](https://github.com/abehonest/ControlNet_AnimalPose/tree/main). Check [Animal Pose AP-10K](#animal-pose-ap-10k)
|
||||
* Added YOLO-NAS models which are drop-in replacements of YOLOX
|
||||
* Fixed Openpose Face/Hands no longer detecting: https://github.com/Fannovel16/comfyui_controlnet_aux/issues/54
|
||||
* Added TorchScript implementation of DWPose and AnimalPose
|
||||
* Added TorchScript implementation of DensePose from [Colab notebook](https://colab.research.google.com/drive/16hcaaKs210ivpxjoyGNuvEXZD4eqOOSQ) which doesn't require detectron2. [Example](#densepose). Thanks [@LayerNome](https://github.com/Layer-norm) for fixing bugs related.
|
||||
* Added Standard Lineart Preprocessor
|
||||
* Fixed OpenPose misplacements in some cases
|
||||
* Added Mesh Graphormer - Hand Depth Map & Mask
|
||||
* Misaligned hands bug from MeshGraphormer was fixed
|
||||
* Added more mask options for MeshGraphormer
|
||||
* Added Save Pose Keypoint node for editing
|
||||
* Added Unimatch Optical Flow
|
||||
* Added Depth Anything & Zoe Depth Anything
|
||||
* Removed resolution field from Unimatch Optical Flow as that interpolating optical flow seems unstable
|
||||
* Added TEED Soft-Edge Preprocessor
|
||||
* Added DiffusionEdge
|
||||
* Added Image Luminance and Image Intensity
|
||||
* Added Normal DSINE
|
||||
* Added TTPlanet Tile (09/05/2024, DD/MM/YYYY)
|
||||
* Added AnyLine, Metric3D (18/05/2024)
|
||||
* Added Depth Anything V2 (16/06/2024)
|
||||
* Added Union model of ControlNet and preprocessors
|
||||

|
||||
* Refactor INPUT_TYPES and add Execute All node during the process of learning [Execution Model Inversion](https://github.com/comfyanonymous/ComfyUI/pull/2666)
|
||||
* Added scale_stick_for_xinsr_cn (https://github.com/Fannovel16/comfyui_controlnet_aux/issues/447) (09/04/2024)
|
||||
* PyTorch 2.7 compatibility fixes - eliminated custom_timm, custom_detectron2, and custom_midas_repo dependencies causing hanging issues. Refactored 7 major preprocessors including OneFormer (now using HuggingFace transformers), ZOE, DSINE, MiDaS, BAE, Metric3D, and Uniformer. Resolved ~59 GitHub issues related to import failures, hanging, and extension conflicts. Full modernization to actively maintained packages.
|
||||
224
custom_nodes/comfyui_controlnet_aux/__init__.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import sys, os
|
||||
|
||||
# Disable NPU device initialization and problematic MMCV ops to prevent RuntimeError
|
||||
# Must be set BEFORE any MMCV imports happen anywhere in ComfyUI
|
||||
os.environ['NPU_DEVICE_COUNT'] = '0'
|
||||
os.environ['MMCV_WITH_OPS'] = '0'
|
||||
from .utils import here, define_preprocessor_inputs, INPUT
|
||||
from pathlib import Path
|
||||
import traceback
|
||||
import importlib
|
||||
from .log import log, blue_text, cyan_text, get_summary, get_label
|
||||
from .hint_image_enchance import NODE_CLASS_MAPPINGS as HIE_NODE_CLASS_MAPPINGS
|
||||
from .hint_image_enchance import NODE_DISPLAY_NAME_MAPPINGS as HIE_NODE_DISPLAY_NAME_MAPPINGS
|
||||
#Ref: https://github.com/comfyanonymous/ComfyUI/blob/76d53c4622fc06372975ed2a43ad345935b8a551/nodes.py#L17
|
||||
sys.path.insert(0, str(Path(here, "src").resolve()))
|
||||
for pkg_name in ["custom_controlnet_aux", "custom_mmpkg"]:
|
||||
sys.path.append(str(Path(here, "src", pkg_name).resolve()))
|
||||
|
||||
#Enable CPU fallback for ops not being supported by MPS like upsample_bicubic2d.out
|
||||
#https://github.com/pytorch/pytorch/issues/77764
|
||||
#https://github.com/Fannovel16/comfyui_controlnet_aux/issues/2#issuecomment-1763579485
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = os.getenv("PYTORCH_ENABLE_MPS_FALLBACK", '1')
|
||||
|
||||
|
||||
def load_nodes():
|
||||
shorted_errors = []
|
||||
full_error_messages = []
|
||||
node_class_mappings = {}
|
||||
node_display_name_mappings = {}
|
||||
|
||||
for filename in (here / "node_wrappers").iterdir():
|
||||
module_name = filename.stem
|
||||
if module_name.startswith('.'): continue #Skip hidden files created by the OS (e.g. [.DS_Store](https://en.wikipedia.org/wiki/.DS_Store))
|
||||
try:
|
||||
module = importlib.import_module(
|
||||
f".node_wrappers.{module_name}", package=__package__
|
||||
)
|
||||
node_class_mappings.update(getattr(module, "NODE_CLASS_MAPPINGS"))
|
||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS"):
|
||||
node_display_name_mappings.update(getattr(module, "NODE_DISPLAY_NAME_MAPPINGS"))
|
||||
|
||||
log.debug(f"Imported {module_name} nodes")
|
||||
|
||||
except AttributeError:
|
||||
pass # wip nodes
|
||||
except Exception:
|
||||
error_message = traceback.format_exc()
|
||||
full_error_messages.append(error_message)
|
||||
error_message = error_message.splitlines()[-1]
|
||||
shorted_errors.append(
|
||||
f"Failed to import module {module_name} because {error_message}"
|
||||
)
|
||||
|
||||
if len(shorted_errors) > 0:
|
||||
full_err_log = '\n\n'.join(full_error_messages)
|
||||
print(f"\n\nFull error log from comfyui_controlnet_aux: \n{full_err_log}\n\n")
|
||||
log.info(
|
||||
f"Some nodes failed to load:\n\t"
|
||||
+ "\n\t".join(shorted_errors)
|
||||
+ "\n\n"
|
||||
+ "Check that you properly installed the dependencies.\n"
|
||||
+ "If you think this is a bug, please report it on the github page (https://github.com/Fannovel16/comfyui_controlnet_aux/issues)"
|
||||
)
|
||||
return node_class_mappings, node_display_name_mappings
|
||||
|
||||
AUX_NODE_MAPPINGS, AUX_DISPLAY_NAME_MAPPINGS = load_nodes()
|
||||
|
||||
#For nodes not mapping image to image or has special requirements
|
||||
AIO_NOT_SUPPORTED = ["InpaintPreprocessor", "MeshGraphormer+ImpactDetector-DepthMapPreprocessor", "DiffusionEdge_Preprocessor"]
|
||||
AIO_NOT_SUPPORTED += ["SavePoseKpsAsJsonFile", "FacialPartColoringFromPoseKps", "UpperBodyTrackingFromPoseKps", "RenderPeopleKps", "RenderAnimalKps"]
|
||||
AIO_NOT_SUPPORTED += ["Unimatch_OptFlowPreprocessor", "MaskOptFlow"]
|
||||
|
||||
def preprocessor_options():
|
||||
auxs = list(AUX_NODE_MAPPINGS.keys())
|
||||
auxs.insert(0, "none")
|
||||
for name in AIO_NOT_SUPPORTED:
|
||||
if name in auxs:
|
||||
auxs.remove(name)
|
||||
return auxs
|
||||
|
||||
|
||||
PREPROCESSOR_OPTIONS = preprocessor_options()
|
||||
|
||||
class AIO_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
preprocessor=INPUT.COMBO(PREPROCESSOR_OPTIONS, default="none"),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors"
|
||||
|
||||
def execute(self, preprocessor, image, resolution=512):
|
||||
if preprocessor == "none":
|
||||
return (image, )
|
||||
else:
|
||||
aux_class = AUX_NODE_MAPPINGS[preprocessor]
|
||||
input_types = aux_class.INPUT_TYPES()
|
||||
input_types = {
|
||||
**input_types["required"],
|
||||
**(input_types["optional"] if "optional" in input_types else {})
|
||||
}
|
||||
params = {}
|
||||
for name, input_type in input_types.items():
|
||||
if name == "image":
|
||||
params[name] = image
|
||||
continue
|
||||
|
||||
if name == "resolution":
|
||||
params[name] = resolution
|
||||
continue
|
||||
|
||||
if len(input_type) == 2 and ("default" in input_type[1]):
|
||||
params[name] = input_type[1]["default"]
|
||||
continue
|
||||
|
||||
default_values = { "INT": 0, "FLOAT": 0.0 }
|
||||
if type(input_type[0]) is list:
|
||||
for input_type_value in input_type[0]:
|
||||
if input_type_value in default_values:
|
||||
params[name] = default_values[input_type[0]]
|
||||
else:
|
||||
if input_type[0] in default_values:
|
||||
params[name] = default_values[input_type[0]]
|
||||
|
||||
return getattr(aux_class(), aux_class.FUNCTION)(**params)
|
||||
|
||||
class ControlNetAuxSimpleAddText:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return dict(
|
||||
required=dict(image=INPUT.IMAGE(), text=INPUT.STRING())
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "ControlNet Preprocessors"
|
||||
def execute(self, image, text):
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
font = ImageFont.truetype(str((here / "NotoSans-Regular.ttf").resolve()), 40)
|
||||
img = Image.fromarray(image[0].cpu().numpy().__mul__(255.).astype(np.uint8))
|
||||
ImageDraw.Draw(img).text((0,0), text, fill=(0,255,0), font=font)
|
||||
return (torch.from_numpy(np.array(img)).unsqueeze(0) / 255.,)
|
||||
|
||||
class ExecuteAllControlNetPreprocessors:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors"
|
||||
|
||||
def execute(self, image, resolution=512):
|
||||
try:
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
except:
|
||||
raise RuntimeError("ExecuteAllControlNetPreprocessor requries [Execution Model Inversion](https://github.com/comfyanonymous/ComfyUI/commit/5cfe38). Update ComfyUI/SwarmUI to get this feature")
|
||||
|
||||
graph = GraphBuilder()
|
||||
curr_outputs = []
|
||||
for preprocc in PREPROCESSOR_OPTIONS:
|
||||
preprocc_node = graph.node("AIO_Preprocessor", preprocessor=preprocc, image=image, resolution=resolution)
|
||||
hint_img = preprocc_node.out(0)
|
||||
add_text_node = graph.node("ControlNetAuxSimpleAddText", image=hint_img, text=preprocc)
|
||||
curr_outputs.append(add_text_node.out(0))
|
||||
|
||||
while len(curr_outputs) > 1:
|
||||
_outputs = []
|
||||
for i in range(0, len(curr_outputs), 2):
|
||||
if i+1 < len(curr_outputs):
|
||||
image_batch = graph.node("ImageBatch", image1=curr_outputs[i], image2=curr_outputs[i+1])
|
||||
_outputs.append(image_batch.out(0))
|
||||
else:
|
||||
_outputs.append(curr_outputs[i])
|
||||
curr_outputs = _outputs
|
||||
|
||||
return {
|
||||
"result": (curr_outputs[0],),
|
||||
"expand": graph.finalize(),
|
||||
}
|
||||
|
||||
class ControlNetPreprocessorSelector:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"preprocessor": (PREPROCESSOR_OPTIONS,),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (PREPROCESSOR_OPTIONS,)
|
||||
RETURN_NAMES = ("preprocessor",)
|
||||
FUNCTION = "get_preprocessor"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors"
|
||||
|
||||
def get_preprocessor(self, preprocessor: str):
|
||||
return (preprocessor,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
**AUX_NODE_MAPPINGS,
|
||||
"AIO_Preprocessor": AIO_Preprocessor,
|
||||
"ControlNetPreprocessorSelector": ControlNetPreprocessorSelector,
|
||||
**HIE_NODE_CLASS_MAPPINGS,
|
||||
"ExecuteAllControlNetPreprocessors": ExecuteAllControlNetPreprocessors,
|
||||
"ControlNetAuxSimpleAddText": ControlNetAuxSimpleAddText
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
**AUX_DISPLAY_NAME_MAPPINGS,
|
||||
"AIO_Preprocessor": "AIO Aux Preprocessor",
|
||||
"ControlNetPreprocessorSelector": "Preprocessor Selector",
|
||||
**HIE_NODE_DISPLAY_NAME_MAPPINGS,
|
||||
"ExecuteAllControlNetPreprocessors": "Execute All ControlNet Preprocessors"
|
||||
}
|
||||
20
custom_nodes/comfyui_controlnet_aux/config.example.yaml
Normal file
@@ -0,0 +1,20 @@
|
||||
# this is an example for config.yaml file, you can rename it to config.yaml if you want to use it
|
||||
# ###############################################################################################
|
||||
# This path is for custom pressesor models base folder. default is "./ckpts"
|
||||
# you can also use absolute paths like: "/root/ComfyUI/custom_nodes/comfyui_controlnet_aux/ckpts" or "D:\\ComfyUI\\custom_nodes\\comfyui_controlnet_aux\\ckpts"
|
||||
annotator_ckpts_path: "./ckpts"
|
||||
# ###############################################################################################
|
||||
# This path is for downloading temporary files.
|
||||
# You SHOULD use absolute path for this like"D:\\temp", DO NOT use relative paths. Empty for default.
|
||||
custom_temp_path:
|
||||
# ###############################################################################################
|
||||
# if you already have downloaded ckpts via huggingface hub into default cache path like: ~/.cache/huggingface/hub, you can set this True to use symlinks to save space
|
||||
USE_SYMLINKS: False
|
||||
# ###############################################################################################
|
||||
# EP_list is a list of execution providers for onnxruntime, if one of them is not available or not working well, you can delete that provider from here(config.yaml)
|
||||
# you can find all available providers here: https://onnxruntime.ai/docs/execution-providers
|
||||
# for example, if you have CUDA installed, you can set it to: ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
# empty list or only keep ["CPUExecutionProvider"] means you use cv2.dnn.readNetFromONNX to load onnx models
|
||||
# if your onnx models can only run on the CPU or have other issues, we recommend using pt model instead.
|
||||
# default value is ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CPUExecutionProvider"]
|
||||
EP_list: ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CPUExecutionProvider"]
|
||||
6
custom_nodes/comfyui_controlnet_aux/dev_interface.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from pathlib import Path
|
||||
from utils import here
|
||||
import sys
|
||||
sys.path.append(str(Path(here, "src")))
|
||||
|
||||
from custom_controlnet_aux import *
|
||||
BIN
custom_nodes/comfyui_controlnet_aux/examples/CNAuxBanner.jpg
Normal file
|
After Width: | Height: | Size: 576 KiB |
BIN
custom_nodes/comfyui_controlnet_aux/examples/ExecuteAll.png
Normal file
|
After Width: | Height: | Size: 9.5 MiB |
BIN
custom_nodes/comfyui_controlnet_aux/examples/ExecuteAll1.jpg
Normal file
|
After Width: | Height: | Size: 1.1 MiB |
BIN
custom_nodes/comfyui_controlnet_aux/examples/ExecuteAll2.jpg
Normal file
|
After Width: | Height: | Size: 998 KiB |
|
After Width: | Height: | Size: 694 KiB |
|
After Width: | Height: | Size: 706 KiB |
|
After Width: | Height: | Size: 472 KiB |
BIN
custom_nodes/comfyui_controlnet_aux/examples/example_anyline.png
Normal file
|
After Width: | Height: | Size: 371 KiB |
|
After Width: | Height: | Size: 271 KiB |
|
After Width: | Height: | Size: 438 KiB |
|
After Width: | Height: | Size: 180 KiB |
BIN
custom_nodes/comfyui_controlnet_aux/examples/example_dsine.png
Normal file
|
After Width: | Height: | Size: 636 KiB |
|
After Width: | Height: | Size: 646 KiB |
|
After Width: | Height: | Size: 294 KiB |
|
After Width: | Height: | Size: 5.2 MiB |
|
After Width: | Height: | Size: 244 KiB |
BIN
custom_nodes/comfyui_controlnet_aux/examples/example_onnx.png
Normal file
|
After Width: | Height: | Size: 93 KiB |
BIN
custom_nodes/comfyui_controlnet_aux/examples/example_recolor.png
Normal file
|
After Width: | Height: | Size: 713 KiB |
|
After Width: | Height: | Size: 400 KiB |
BIN
custom_nodes/comfyui_controlnet_aux/examples/example_teed.png
Normal file
|
After Width: | Height: | Size: 519 KiB |
|
After Width: | Height: | Size: 107 KiB |
|
After Width: | Height: | Size: 244 KiB |
233
custom_nodes/comfyui_controlnet_aux/hint_image_enchance.py
Normal file
@@ -0,0 +1,233 @@
|
||||
from .log import log
|
||||
from .utils import ResizeMode, safe_numpy
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
from .utils import get_unique_axis0
|
||||
from .lvminthin import nake_nms, lvmin_thin
|
||||
|
||||
MAX_IMAGEGEN_RESOLUTION = 8192 #https://github.com/comfyanonymous/ComfyUI/blob/c910b4a01ca58b04e5d4ab4c747680b996ada02b/nodes.py#L42
|
||||
RESIZE_MODES = [ResizeMode.RESIZE.value, ResizeMode.INNER_FIT.value, ResizeMode.OUTER_FIT.value]
|
||||
|
||||
#Port from https://github.com/Mikubill/sd-webui-controlnet/blob/e67e017731aad05796b9615dc6eadce911298ea1/internal_controlnet/external_code.py#L89
|
||||
class PixelPerfectResolution:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"original_image": ("IMAGE", ),
|
||||
"image_gen_width": ("INT", {"default": 512, "min": 64, "max": MAX_IMAGEGEN_RESOLUTION, "step": 8}),
|
||||
"image_gen_height": ("INT", {"default": 512, "min": 64, "max": MAX_IMAGEGEN_RESOLUTION, "step": 8}),
|
||||
#https://github.com/comfyanonymous/ComfyUI/blob/c910b4a01ca58b04e5d4ab4c747680b996ada02b/nodes.py#L854
|
||||
"resize_mode": (RESIZE_MODES, {"default": ResizeMode.RESIZE.value})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("INT",)
|
||||
RETURN_NAMES = ("RESOLUTION (INT)", )
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors"
|
||||
|
||||
def execute(self, original_image, image_gen_width, image_gen_height, resize_mode):
|
||||
_, raw_H, raw_W, _ = original_image.shape
|
||||
|
||||
k0 = float(image_gen_height) / float(raw_H)
|
||||
k1 = float(image_gen_width) / float(raw_W)
|
||||
|
||||
if resize_mode == ResizeMode.OUTER_FIT.value:
|
||||
estimation = min(k0, k1) * float(min(raw_H, raw_W))
|
||||
else:
|
||||
estimation = max(k0, k1) * float(min(raw_H, raw_W))
|
||||
|
||||
log.debug(f"Pixel Perfect Computation:")
|
||||
log.debug(f"resize_mode = {resize_mode}")
|
||||
log.debug(f"raw_H = {raw_H}")
|
||||
log.debug(f"raw_W = {raw_W}")
|
||||
log.debug(f"target_H = {image_gen_height}")
|
||||
log.debug(f"target_W = {image_gen_width}")
|
||||
log.debug(f"estimation = {estimation}")
|
||||
|
||||
return (int(np.round(estimation)), )
|
||||
|
||||
class HintImageEnchance:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"hint_image": ("IMAGE", ),
|
||||
"image_gen_width": ("INT", {"default": 512, "min": 64, "max": MAX_IMAGEGEN_RESOLUTION, "step": 8}),
|
||||
"image_gen_height": ("INT", {"default": 512, "min": 64, "max": MAX_IMAGEGEN_RESOLUTION, "step": 8}),
|
||||
#https://github.com/comfyanonymous/ComfyUI/blob/c910b4a01ca58b04e5d4ab4c747680b996ada02b/nodes.py#L854
|
||||
"resize_mode": (RESIZE_MODES, {"default": ResizeMode.RESIZE.value})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors"
|
||||
def execute(self, hint_image, image_gen_width, image_gen_height, resize_mode):
|
||||
outs = []
|
||||
for single_hint_image in hint_image:
|
||||
np_hint_image = np.asarray(single_hint_image * 255., dtype=np.uint8)
|
||||
|
||||
if resize_mode == ResizeMode.RESIZE.value:
|
||||
np_hint_image = self.execute_resize(np_hint_image, image_gen_width, image_gen_height)
|
||||
elif resize_mode == ResizeMode.OUTER_FIT.value:
|
||||
np_hint_image = self.execute_outer_fit(np_hint_image, image_gen_width, image_gen_height)
|
||||
else:
|
||||
np_hint_image = self.execute_inner_fit(np_hint_image, image_gen_width, image_gen_height)
|
||||
|
||||
outs.append(torch.from_numpy(np_hint_image.astype(np.float32) / 255.0))
|
||||
|
||||
return (torch.stack(outs, dim=0),)
|
||||
|
||||
def execute_resize(self, detected_map, w, h):
|
||||
detected_map = self.high_quality_resize(detected_map, (w, h))
|
||||
detected_map = safe_numpy(detected_map)
|
||||
return detected_map
|
||||
|
||||
def execute_outer_fit(self, detected_map, w, h):
|
||||
old_h, old_w, _ = detected_map.shape
|
||||
old_w = float(old_w)
|
||||
old_h = float(old_h)
|
||||
k0 = float(h) / old_h
|
||||
k1 = float(w) / old_w
|
||||
safeint = lambda x: int(np.round(x))
|
||||
k = min(k0, k1)
|
||||
|
||||
borders = np.concatenate([detected_map[0, :, :], detected_map[-1, :, :], detected_map[:, 0, :], detected_map[:, -1, :]], axis=0)
|
||||
high_quality_border_color = np.median(borders, axis=0).astype(detected_map.dtype)
|
||||
if len(high_quality_border_color) == 4:
|
||||
# Inpaint hijack
|
||||
high_quality_border_color[3] = 255
|
||||
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
|
||||
detected_map = self.high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k)))
|
||||
new_h, new_w, _ = detected_map.shape
|
||||
pad_h = max(0, (h - new_h) // 2)
|
||||
pad_w = max(0, (w - new_w) // 2)
|
||||
high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = detected_map
|
||||
detected_map = high_quality_background
|
||||
detected_map = safe_numpy(detected_map)
|
||||
return detected_map
|
||||
|
||||
def execute_inner_fit(self, detected_map, w, h):
|
||||
old_h, old_w, _ = detected_map.shape
|
||||
old_w = float(old_w)
|
||||
old_h = float(old_h)
|
||||
k0 = float(h) / old_h
|
||||
k1 = float(w) / old_w
|
||||
safeint = lambda x: int(np.round(x))
|
||||
k = max(k0, k1)
|
||||
|
||||
detected_map = self.high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k)))
|
||||
new_h, new_w, _ = detected_map.shape
|
||||
pad_h = max(0, (new_h - h) // 2)
|
||||
pad_w = max(0, (new_w - w) // 2)
|
||||
detected_map = detected_map[pad_h:pad_h+h, pad_w:pad_w+w]
|
||||
detected_map = safe_numpy(detected_map)
|
||||
return detected_map
|
||||
|
||||
def high_quality_resize(self, x, size):
|
||||
# Written by lvmin
|
||||
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
|
||||
|
||||
inpaint_mask = None
|
||||
if x.ndim == 3 and x.shape[2] == 4:
|
||||
inpaint_mask = x[:, :, 3]
|
||||
x = x[:, :, 0:3]
|
||||
|
||||
if x.shape[0] != size[1] or x.shape[1] != size[0]:
|
||||
new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
|
||||
new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
|
||||
unique_color_count = len(get_unique_axis0(x.reshape(-1, x.shape[2])))
|
||||
is_one_pixel_edge = False
|
||||
is_binary = False
|
||||
if unique_color_count == 2:
|
||||
is_binary = np.min(x) < 16 and np.max(x) > 240
|
||||
if is_binary:
|
||||
xc = x
|
||||
xc = cv2.erode(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
|
||||
xc = cv2.dilate(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
|
||||
one_pixel_edge_count = np.where(xc < x)[0].shape[0]
|
||||
all_edge_count = np.where(x > 127)[0].shape[0]
|
||||
is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count
|
||||
|
||||
if 2 < unique_color_count < 200:
|
||||
interpolation = cv2.INTER_NEAREST
|
||||
elif new_size_is_smaller:
|
||||
interpolation = cv2.INTER_AREA
|
||||
else:
|
||||
interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
|
||||
|
||||
y = cv2.resize(x, size, interpolation=interpolation)
|
||||
if inpaint_mask is not None:
|
||||
inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
|
||||
|
||||
if is_binary:
|
||||
y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
|
||||
if is_one_pixel_edge:
|
||||
y = nake_nms(y)
|
||||
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
y = lvmin_thin(y, prunings=new_size_is_bigger)
|
||||
else:
|
||||
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
y = np.stack([y] * 3, axis=2)
|
||||
else:
|
||||
y = x
|
||||
|
||||
if inpaint_mask is not None:
|
||||
inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
|
||||
inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
|
||||
y = np.concatenate([y, inpaint_mask], axis=2)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class ImageGenResolutionFromLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": { "latent": ("LATENT", ) }
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("INT", "INT")
|
||||
RETURN_NAMES = ("IMAGE_GEN_WIDTH (INT)", "IMAGE_GEN_HEIGHT (INT)")
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors"
|
||||
|
||||
def execute(self, latent):
|
||||
_, _, H, W = latent["samples"].shape
|
||||
return (W * 8, H * 8)
|
||||
|
||||
class ImageGenResolutionFromImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": { "image": ("IMAGE", ) }
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("INT", "INT")
|
||||
RETURN_NAMES = ("IMAGE_GEN_WIDTH (INT)", "IMAGE_GEN_HEIGHT (INT)")
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors"
|
||||
|
||||
def execute(self, image):
|
||||
_, H, W, _ = image.shape
|
||||
return (W, H)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PixelPerfectResolution": PixelPerfectResolution,
|
||||
"ImageGenResolutionFromImage": ImageGenResolutionFromImage,
|
||||
"ImageGenResolutionFromLatent": ImageGenResolutionFromLatent,
|
||||
"HintImageEnchance": HintImageEnchance
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PixelPerfectResolution": "Pixel Perfect Resolution",
|
||||
"ImageGenResolutionFromImage": "Generation Resolution From Image",
|
||||
"ImageGenResolutionFromLatent": "Generation Resolution From Latent",
|
||||
"HintImageEnchance": "Enchance And Resize Hint Images"
|
||||
}
|
||||
20
custom_nodes/comfyui_controlnet_aux/install.bat
Normal file
@@ -0,0 +1,20 @@
|
||||
@echo off
|
||||
|
||||
set "requirements_txt=%~dp0\requirements.txt"
|
||||
set "python_exec=..\..\..\python_embeded\python.exe"
|
||||
|
||||
echo Installing ComfyUI's ControlNet Auxiliary Preprocessors..
|
||||
|
||||
if exist "%python_exec%" (
|
||||
echo Installing with ComfyUI Portable
|
||||
for /f "delims=" %%i in (%requirements_txt%) do (
|
||||
%python_exec% -s -m pip install "%%i"
|
||||
)
|
||||
) else (
|
||||
echo Installing with system Python
|
||||
for /f "delims=" %%i in (%requirements_txt%) do (
|
||||
pip install "%%i"
|
||||
)
|
||||
)
|
||||
|
||||
pause
|
||||
80
custom_nodes/comfyui_controlnet_aux/log.py
Normal file
@@ -0,0 +1,80 @@
|
||||
#Cre: https://github.com/melMass/comfy_mtb/blob/main/log.py
|
||||
import logging
|
||||
import re
|
||||
import os
|
||||
|
||||
base_log_level = logging.INFO
|
||||
|
||||
|
||||
# Custom object that discards the output
|
||||
class NullWriter:
|
||||
def write(self, text):
|
||||
pass
|
||||
|
||||
|
||||
class Formatter(logging.Formatter):
|
||||
grey = "\x1b[38;20m"
|
||||
cyan = "\x1b[36;20m"
|
||||
purple = "\x1b[35;20m"
|
||||
yellow = "\x1b[33;20m"
|
||||
red = "\x1b[31;20m"
|
||||
bold_red = "\x1b[31;1m"
|
||||
reset = "\x1b[0m"
|
||||
# format = "%(asctime)s - [%(name)s] - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
|
||||
format = "[%(name)s] | %(levelname)s -> %(message)s"
|
||||
|
||||
FORMATS = {
|
||||
logging.DEBUG: purple + format + reset,
|
||||
logging.INFO: cyan + format + reset,
|
||||
logging.WARNING: yellow + format + reset,
|
||||
logging.ERROR: red + format + reset,
|
||||
logging.CRITICAL: bold_red + format + reset,
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
log_fmt = self.FORMATS.get(record.levelno)
|
||||
formatter = logging.Formatter(log_fmt)
|
||||
return formatter.format(record)
|
||||
|
||||
|
||||
def mklog(name, level=base_log_level):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
|
||||
for handler in logger.handlers:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(level)
|
||||
ch.setFormatter(Formatter())
|
||||
logger.addHandler(ch)
|
||||
|
||||
# Disable log propagation
|
||||
logger.propagate = False
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
# - The main app logger
|
||||
log = mklog(__package__, base_log_level)
|
||||
|
||||
|
||||
def log_user(arg):
|
||||
print("\033[34mComfyUI ControlNet AUX:\033[0m {arg}")
|
||||
|
||||
|
||||
def get_summary(docstring):
|
||||
return docstring.strip().split("\n\n", 1)[0]
|
||||
|
||||
|
||||
def blue_text(text):
|
||||
return f"\033[94m{text}\033[0m"
|
||||
|
||||
|
||||
def cyan_text(text):
|
||||
return f"\033[96m{text}\033[0m"
|
||||
|
||||
|
||||
def get_label(label):
|
||||
words = re.findall(r"(?:^|[A-Z])[a-z]*", label)
|
||||
return " ".join(words).strip()
|
||||
87
custom_nodes/comfyui_controlnet_aux/lvminthin.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# High Quality Edge Thinning using Pure Python
|
||||
# Written by Lvmin Zhang
|
||||
# 2023 April
|
||||
# Stanford University
|
||||
# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
|
||||
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
lvmin_kernels_raw = [
|
||||
np.array([
|
||||
[-1, -1, -1],
|
||||
[0, 1, 0],
|
||||
[1, 1, 1]
|
||||
], dtype=np.int32),
|
||||
np.array([
|
||||
[0, -1, -1],
|
||||
[1, 1, -1],
|
||||
[0, 1, 0]
|
||||
], dtype=np.int32)
|
||||
]
|
||||
|
||||
lvmin_kernels = []
|
||||
lvmin_kernels += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
lvmin_kernels += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
|
||||
lvmin_prunings_raw = [
|
||||
np.array([
|
||||
[-1, -1, -1],
|
||||
[-1, 1, -1],
|
||||
[0, 0, -1]
|
||||
], dtype=np.int32),
|
||||
np.array([
|
||||
[-1, -1, -1],
|
||||
[-1, 1, -1],
|
||||
[-1, 0, 0]
|
||||
], dtype=np.int32)
|
||||
]
|
||||
|
||||
lvmin_prunings = []
|
||||
lvmin_prunings += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||
lvmin_prunings += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||
lvmin_prunings += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||
lvmin_prunings += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||
|
||||
|
||||
def remove_pattern(x, kernel):
|
||||
objects = cv2.morphologyEx(x, cv2.MORPH_HITMISS, kernel)
|
||||
objects = np.where(objects > 127)
|
||||
x[objects] = 0
|
||||
return x, objects[0].shape[0] > 0
|
||||
|
||||
|
||||
def thin_one_time(x, kernels):
|
||||
y = x
|
||||
is_done = True
|
||||
for k in kernels:
|
||||
y, has_update = remove_pattern(y, k)
|
||||
if has_update:
|
||||
is_done = False
|
||||
return y, is_done
|
||||
|
||||
|
||||
def lvmin_thin(x, prunings=True):
|
||||
y = x
|
||||
for i in range(32):
|
||||
y, is_done = thin_one_time(y, lvmin_kernels)
|
||||
if is_done:
|
||||
break
|
||||
if prunings:
|
||||
y, _ = thin_one_time(y, lvmin_prunings)
|
||||
return y
|
||||
|
||||
|
||||
def nake_nms(x):
|
||||
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
||||
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
||||
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
||||
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
||||
y = np.zeros_like(x)
|
||||
for f in [f1, f2, f3, f4]:
|
||||
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
||||
return y
|
||||
@@ -0,0 +1,43 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
class AnimeFace_SemSegPreprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
#This preprocessor is only trained on 512x resolution
|
||||
#https://github.com/siyeong0/Anime-Face-Segmentation/blob/main/predict.py#L25
|
||||
return define_preprocessor_inputs(
|
||||
remove_background_using_abg=INPUT.BOOLEAN(True),
|
||||
resolution=INPUT.RESOLUTION(default=512, min=512, max=512)
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
RETURN_NAMES = ("IMAGE", "ABG_CHARACTER_MASK (MASK)")
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Semantic Segmentation"
|
||||
|
||||
def execute(self, image, remove_background_using_abg=True, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.anime_face_segment import AnimeFaceSegmentor
|
||||
|
||||
model = AnimeFaceSegmentor.from_pretrained().to(model_management.get_torch_device())
|
||||
if remove_background_using_abg:
|
||||
out_image_with_mask = common_annotator_call(model, image, resolution=resolution, remove_background=True)
|
||||
out_image = out_image_with_mask[..., :3]
|
||||
mask = out_image_with_mask[..., 3:]
|
||||
mask = rearrange(mask, "n h w c -> n c h w")
|
||||
else:
|
||||
out_image = common_annotator_call(model, image, resolution=resolution, remove_background=False)
|
||||
N, H, W, C = out_image.shape
|
||||
mask = torch.ones(N, C, H, W)
|
||||
del model
|
||||
return (out_image, mask)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"AnimeFace_SemSegPreprocessor": AnimeFace_SemSegPreprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"AnimeFace_SemSegPreprocessor": "Anime Face Segmentor"
|
||||
}
|
||||
87
custom_nodes/comfyui_controlnet_aux/node_wrappers/anyline.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import comfy.model_management as model_management
|
||||
import comfy.utils
|
||||
|
||||
# Requires comfyui_controlnet_aux funcsions and classes
|
||||
from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
|
||||
|
||||
def get_intensity_mask(image_array, lower_bound, upper_bound):
|
||||
mask = image_array[:, :, 0]
|
||||
mask = np.where((mask >= lower_bound) & (mask <= upper_bound), mask, 0)
|
||||
mask = np.expand_dims(mask, 2).repeat(3, axis=2)
|
||||
return mask
|
||||
|
||||
def combine_layers(base_layer, top_layer):
|
||||
mask = top_layer.astype(bool)
|
||||
temp = 1 - (1 - top_layer) * (1 - base_layer)
|
||||
result = base_layer * (~mask) + temp * mask
|
||||
return result
|
||||
|
||||
class AnyLinePreprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
merge_with_lineart=INPUT.COMBO(["lineart_standard", "lineart_realisitic", "lineart_anime", "manga_line"], default="lineart_standard"),
|
||||
resolution=INPUT.RESOLUTION(default=1280, step=8),
|
||||
lineart_lower_bound=INPUT.FLOAT(default=0),
|
||||
lineart_upper_bound=INPUT.FLOAT(default=1),
|
||||
object_min_size=INPUT.INT(default=36, min=1),
|
||||
object_connectivity=INPUT.INT(default=1, min=1)
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("image",)
|
||||
|
||||
FUNCTION = "get_anyline"
|
||||
CATEGORY = "ControlNet Preprocessors/Line Extractors"
|
||||
|
||||
def __init__(self):
|
||||
self.device = model_management.get_torch_device()
|
||||
|
||||
def get_anyline(self, image, merge_with_lineart="lineart_standard", resolution=512, lineart_lower_bound=0, lineart_upper_bound=1, object_min_size=36, object_connectivity=1):
|
||||
from custom_controlnet_aux.teed import TEDDetector
|
||||
from skimage import morphology
|
||||
pbar = comfy.utils.ProgressBar(3)
|
||||
|
||||
# Process the image with MTEED model
|
||||
mteed_model = TEDDetector.from_pretrained("TheMistoAI/MistoLine", "MTEED.pth", subfolder="Anyline").to(self.device)
|
||||
mteed_result = common_annotator_call(mteed_model, image, resolution=resolution, show_pbar=False)
|
||||
mteed_result = mteed_result.numpy()
|
||||
del mteed_model
|
||||
pbar.update(1)
|
||||
|
||||
# Process the image with the lineart standard preprocessor
|
||||
if merge_with_lineart == "lineart_standard":
|
||||
from custom_controlnet_aux.lineart_standard import LineartStandardDetector
|
||||
lineart_standard_detector = LineartStandardDetector()
|
||||
lineart_result = common_annotator_call(lineart_standard_detector, image, guassian_sigma=2, intensity_threshold=3, resolution=resolution, show_pbar=False).numpy()
|
||||
del lineart_standard_detector
|
||||
else:
|
||||
from custom_controlnet_aux.lineart import LineartDetector
|
||||
from custom_controlnet_aux.lineart_anime import LineartAnimeDetector
|
||||
from custom_controlnet_aux.manga_line import LineartMangaDetector
|
||||
lineart_detector = dict(lineart_realisitic=LineartDetector, lineart_anime=LineartAnimeDetector, manga_line=LineartMangaDetector)[merge_with_lineart]
|
||||
lineart_detector = lineart_detector.from_pretrained().to(self.device)
|
||||
lineart_result = common_annotator_call(lineart_detector, image, resolution=resolution, show_pbar=False).numpy()
|
||||
del lineart_detector
|
||||
pbar.update(1)
|
||||
|
||||
final_result = []
|
||||
for i in range(len(image)):
|
||||
_lineart_result = get_intensity_mask(lineart_result[i], lower_bound=lineart_lower_bound, upper_bound=lineart_upper_bound)
|
||||
_cleaned = morphology.remove_small_objects(_lineart_result.astype(bool), min_size=object_min_size, connectivity=object_connectivity)
|
||||
_lineart_result = _lineart_result * _cleaned
|
||||
_mteed_result = mteed_result[i]
|
||||
|
||||
# Combine the results
|
||||
final_result.append(torch.from_numpy(combine_layers(_mteed_result, _lineart_result)))
|
||||
pbar.update(1)
|
||||
return (torch.stack(final_result),)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"AnyLineArtPreprocessor_aux": AnyLinePreprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"AnyLineArtPreprocessor_aux": "AnyLine Lineart"
|
||||
}
|
||||
29
custom_nodes/comfyui_controlnet_aux/node_wrappers/binary.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class Binary_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
bin_threshold=INPUT.INT(default=100, max=255),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Line Extractors"
|
||||
|
||||
def execute(self, image, bin_threshold=100, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.binary import BinaryDetector
|
||||
|
||||
return (common_annotator_call(BinaryDetector(), image, bin_threshold=bin_threshold, resolution=resolution), )
|
||||
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"BinaryPreprocessor": Binary_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"BinaryPreprocessor": "Binary Lines"
|
||||
}
|
||||
30
custom_nodes/comfyui_controlnet_aux/node_wrappers/canny.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class Canny_Edge_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
low_threshold=INPUT.INT(default=100, max=255),
|
||||
high_threshold=INPUT.INT(default=200, max=255),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Line Extractors"
|
||||
|
||||
def execute(self, image, low_threshold=100, high_threshold=200, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.canny import CannyDetector
|
||||
|
||||
return (common_annotator_call(CannyDetector(), image, low_threshold=low_threshold, high_threshold=high_threshold, resolution=resolution), )
|
||||
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CannyEdgePreprocessor": Canny_Edge_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CannyEdgePreprocessor": "Canny Edge"
|
||||
}
|
||||
26
custom_nodes/comfyui_controlnet_aux/node_wrappers/color.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class Color_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/T2IAdapter-only"
|
||||
|
||||
def execute(self, image, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.color import ColorDetector
|
||||
|
||||
return (common_annotator_call(ColorDetector(), image, resolution=resolution), )
|
||||
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ColorPreprocessor": Color_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ColorPreprocessor": "Color Pallete"
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class DensePose_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
model=INPUT.COMBO(["densepose_r50_fpn_dl.torchscript", "densepose_r101_fpn_dl.torchscript"]),
|
||||
cmap=INPUT.COMBO(["Viridis (MagicAnimate)", "Parula (CivitAI)"]),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Faces and Poses Estimators"
|
||||
|
||||
def execute(self, image, model="densepose_r50_fpn_dl.torchscript", cmap="Viridis (MagicAnimate)", resolution=512):
|
||||
from custom_controlnet_aux.densepose import DenseposeDetector
|
||||
model = DenseposeDetector \
|
||||
.from_pretrained(filename=model) \
|
||||
.to(model_management.get_torch_device())
|
||||
return (common_annotator_call(model, image, cmap="viridis" if "Viridis" in cmap else "parula", resolution=resolution), )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"DensePosePreprocessor": DensePose_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DensePosePreprocessor": "DensePose Estimator"
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class Depth_Anything_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
ckpt_name=INPUT.COMBO(
|
||||
["depth_anything_vitl14.pth", "depth_anything_vitb14.pth", "depth_anything_vits14.pth"]
|
||||
),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
|
||||
|
||||
def execute(self, image, ckpt_name="depth_anything_vitl14.pth", resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.depth_anything import DepthAnythingDetector
|
||||
|
||||
model = DepthAnythingDetector.from_pretrained(filename=ckpt_name).to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution)
|
||||
del model
|
||||
return (out, )
|
||||
|
||||
class Zoe_Depth_Anything_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
environment=INPUT.COMBO(["indoor", "outdoor"]),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
|
||||
|
||||
def execute(self, image, environment="indoor", resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.zoe import ZoeDepthAnythingDetector
|
||||
ckpt_name = "depth_anything_metric_depth_indoor.pt" if environment == "indoor" else "depth_anything_metric_depth_outdoor.pt"
|
||||
model = ZoeDepthAnythingDetector.from_pretrained(filename=ckpt_name).to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution)
|
||||
del model
|
||||
return (out, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"DepthAnythingPreprocessor": Depth_Anything_Preprocessor,
|
||||
"Zoe_DepthAnythingPreprocessor": Zoe_Depth_Anything_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DepthAnythingPreprocessor": "Depth Anything",
|
||||
"Zoe_DepthAnythingPreprocessor": "Zoe Depth Anything"
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class Depth_Anything_V2_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
ckpt_name=INPUT.COMBO(
|
||||
["depth_anything_v2_vitg.pth", "depth_anything_v2_vitl.pth", "depth_anything_v2_vitb.pth", "depth_anything_v2_vits.pth"],
|
||||
default="depth_anything_v2_vitl.pth"
|
||||
),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
|
||||
|
||||
def execute(self, image, ckpt_name="depth_anything_v2_vitl.pth", resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.depth_anything_v2 import DepthAnythingV2Detector
|
||||
|
||||
model = DepthAnythingV2Detector.from_pretrained(filename=ckpt_name).to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution, max_depth=1)
|
||||
del model
|
||||
return (out, )
|
||||
|
||||
""" class Depth_Anything_Metric_V2_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return create_node_input_types(
|
||||
environment=(["indoor", "outdoor"], {"default": "indoor"}),
|
||||
max_depth=("FLOAT", {"min": 0, "max": 100, "default": 20.0, "step": 0.01})
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
|
||||
|
||||
def execute(self, image, environment, resolution=512, max_depth=20.0, **kwargs):
|
||||
from custom_controlnet_aux.depth_anything_v2 import DepthAnythingV2Detector
|
||||
filename = dict(indoor="depth_anything_v2_metric_hypersim_vitl.pth", outdoor="depth_anything_v2_metric_vkitti_vitl.pth")[environment]
|
||||
model = DepthAnythingV2Detector.from_pretrained(filename=filename).to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution, max_depth=max_depth)
|
||||
del model
|
||||
return (out, ) """
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"DepthAnythingV2Preprocessor": Depth_Anything_V2_Preprocessor,
|
||||
#"Metric_DepthAnythingV2Preprocessor": Depth_Anything_Metric_V2_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DepthAnythingV2Preprocessor": "Depth Anything V2 - Relative",
|
||||
#"Metric_DepthAnythingV2Preprocessor": "Depth Anything V2 - Metric"
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, run_script
|
||||
import comfy.model_management as model_management
|
||||
import sys
|
||||
|
||||
def install_deps():
|
||||
try:
|
||||
import sklearn
|
||||
except:
|
||||
run_script([sys.executable, '-s', '-m', 'pip', 'install', 'scikit-learn'])
|
||||
|
||||
class DiffusionEdge_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
environment=INPUT.COMBO(["indoor", "urban", "natrual"]),
|
||||
patch_batch_size=INPUT.INT(default=4, min=1, max=16),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Line Extractors"
|
||||
|
||||
def execute(self, image, environment="indoor", patch_batch_size=4, resolution=512, **kwargs):
|
||||
install_deps()
|
||||
from custom_controlnet_aux.diffusion_edge import DiffusionEdgeDetector
|
||||
|
||||
model = DiffusionEdgeDetector \
|
||||
.from_pretrained(filename = f"diffusion_edge_{environment}.pt") \
|
||||
.to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution, patch_batch_size=patch_batch_size)
|
||||
del model
|
||||
return (out, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"DiffusionEdge_Preprocessor": DiffusionEdge_Preprocessor,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DiffusionEdge_Preprocessor": "Diffusion Edge (batch size ↑ => speed ↑, VRAM ↑)",
|
||||
}
|
||||
31
custom_nodes/comfyui_controlnet_aux/node_wrappers/dsine.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class DSINE_Normal_Map_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
fov=INPUT.FLOAT(max=365.0, default=60.0),
|
||||
iterations=INPUT.INT(min=1, max=20, default=5),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
|
||||
|
||||
def execute(self, image, fov=60.0, iterations=5, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.dsine import DsineDetector
|
||||
|
||||
model = DsineDetector.from_pretrained().to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, fov=fov, iterations=iterations, resolution=resolution)
|
||||
del model
|
||||
return (out,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"DSINE-NormalMapPreprocessor": DSINE_Normal_Map_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DSINE-NormalMapPreprocessor": "DSINE Normal Map"
|
||||
}
|
||||
166
custom_nodes/comfyui_controlnet_aux/node_wrappers/dwpose.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
import numpy as np
|
||||
import warnings
|
||||
from ..src.custom_controlnet_aux.dwpose import DwposeDetector, AnimalposeDetector
|
||||
import os
|
||||
import json
|
||||
|
||||
DWPOSE_MODEL_NAME = "yzd-v/DWPose"
|
||||
#Trigger startup caching for onnxruntime
|
||||
GPU_PROVIDERS = ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CoreMLExecutionProvider"]
|
||||
def check_ort_gpu():
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
for provider in GPU_PROVIDERS:
|
||||
if provider in ort.get_available_providers():
|
||||
return True
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
|
||||
if not os.environ.get("DWPOSE_ONNXRT_CHECKED"):
|
||||
if check_ort_gpu():
|
||||
print("DWPose: Onnxruntime with acceleration providers detected")
|
||||
else:
|
||||
warnings.warn("DWPose: Onnxruntime not found or doesn't come with acceleration providers, switch to OpenCV with CPU device. DWPose might run very slowly")
|
||||
os.environ['AUX_ORT_PROVIDERS'] = ''
|
||||
os.environ["DWPOSE_ONNXRT_CHECKED"] = '1'
|
||||
|
||||
class DWPose_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
detect_hand=INPUT.COMBO(["enable", "disable"]),
|
||||
detect_body=INPUT.COMBO(["enable", "disable"]),
|
||||
detect_face=INPUT.COMBO(["enable", "disable"]),
|
||||
resolution=INPUT.RESOLUTION(),
|
||||
bbox_detector=INPUT.COMBO(
|
||||
["None"] + ["yolox_l.torchscript.pt", "yolox_l.onnx", "yolo_nas_l_fp16.onnx", "yolo_nas_m_fp16.onnx", "yolo_nas_s_fp16.onnx"],
|
||||
default="yolox_l.onnx"
|
||||
),
|
||||
pose_estimator=INPUT.COMBO(
|
||||
["dw-ll_ucoco_384_bs5.torchscript.pt", "dw-ll_ucoco_384.onnx", "dw-ll_ucoco.onnx"],
|
||||
default="dw-ll_ucoco_384_bs5.torchscript.pt"
|
||||
),
|
||||
scale_stick_for_xinsr_cn=INPUT.COMBO(["disable", "enable"])
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "POSE_KEYPOINT")
|
||||
FUNCTION = "estimate_pose"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Faces and Poses Estimators"
|
||||
|
||||
def estimate_pose(self, image, detect_hand="enable", detect_body="enable", detect_face="enable", resolution=512, bbox_detector="yolox_l.onnx", pose_estimator="dw-ll_ucoco_384.onnx", scale_stick_for_xinsr_cn="disable", **kwargs):
|
||||
if bbox_detector == "None":
|
||||
yolo_repo = DWPOSE_MODEL_NAME
|
||||
elif bbox_detector == "yolox_l.onnx":
|
||||
yolo_repo = DWPOSE_MODEL_NAME
|
||||
elif "yolox" in bbox_detector:
|
||||
yolo_repo = "hr16/yolox-onnx"
|
||||
elif "yolo_nas" in bbox_detector:
|
||||
yolo_repo = "hr16/yolo-nas-fp16"
|
||||
else:
|
||||
raise NotImplementedError(f"Download mechanism for {bbox_detector}")
|
||||
|
||||
if pose_estimator == "dw-ll_ucoco_384.onnx":
|
||||
pose_repo = DWPOSE_MODEL_NAME
|
||||
elif pose_estimator.endswith(".onnx"):
|
||||
pose_repo = "hr16/UnJIT-DWPose"
|
||||
elif pose_estimator.endswith(".torchscript.pt"):
|
||||
pose_repo = "hr16/DWPose-TorchScript-BatchSize5"
|
||||
else:
|
||||
raise NotImplementedError(f"Download mechanism for {pose_estimator}")
|
||||
|
||||
model = DwposeDetector.from_pretrained(
|
||||
pose_repo,
|
||||
yolo_repo,
|
||||
det_filename=(None if bbox_detector == "None" else bbox_detector), pose_filename=pose_estimator,
|
||||
torchscript_device=model_management.get_torch_device()
|
||||
)
|
||||
detect_hand = detect_hand == "enable"
|
||||
detect_body = detect_body == "enable"
|
||||
detect_face = detect_face == "enable"
|
||||
scale_stick_for_xinsr_cn = scale_stick_for_xinsr_cn == "enable"
|
||||
self.openpose_dicts = []
|
||||
def func(image, **kwargs):
|
||||
pose_img, openpose_dict = model(image, **kwargs)
|
||||
self.openpose_dicts.append(openpose_dict)
|
||||
return pose_img
|
||||
|
||||
out = common_annotator_call(func, image, include_hand=detect_hand, include_face=detect_face, include_body=detect_body, image_and_json=True, resolution=resolution, xinsr_stick_scaling=scale_stick_for_xinsr_cn)
|
||||
del model
|
||||
return {
|
||||
'ui': { "openpose_json": [json.dumps(self.openpose_dicts, indent=4)] },
|
||||
"result": (out, self.openpose_dicts)
|
||||
}
|
||||
|
||||
class AnimalPose_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
bbox_detector = INPUT.COMBO(
|
||||
["None"] + ["yolox_l.torchscript.pt", "yolox_l.onnx", "yolo_nas_l_fp16.onnx", "yolo_nas_m_fp16.onnx", "yolo_nas_s_fp16.onnx"],
|
||||
default="yolox_l.torchscript.pt"
|
||||
),
|
||||
pose_estimator = INPUT.COMBO(
|
||||
["rtmpose-m_ap10k_256_bs5.torchscript.pt", "rtmpose-m_ap10k_256.onnx"],
|
||||
default="rtmpose-m_ap10k_256_bs5.torchscript.pt"
|
||||
),
|
||||
resolution = INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "POSE_KEYPOINT")
|
||||
FUNCTION = "estimate_pose"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Faces and Poses Estimators"
|
||||
|
||||
def estimate_pose(self, image, resolution=512, bbox_detector="yolox_l.onnx", pose_estimator="rtmpose-m_ap10k_256.onnx", **kwargs):
|
||||
if bbox_detector == "None":
|
||||
yolo_repo = DWPOSE_MODEL_NAME
|
||||
elif bbox_detector == "yolox_l.onnx":
|
||||
yolo_repo = DWPOSE_MODEL_NAME
|
||||
elif "yolox" in bbox_detector:
|
||||
yolo_repo = "hr16/yolox-onnx"
|
||||
elif "yolo_nas" in bbox_detector:
|
||||
yolo_repo = "hr16/yolo-nas-fp16"
|
||||
else:
|
||||
raise NotImplementedError(f"Download mechanism for {bbox_detector}")
|
||||
|
||||
if pose_estimator == "dw-ll_ucoco_384.onnx":
|
||||
pose_repo = DWPOSE_MODEL_NAME
|
||||
elif pose_estimator.endswith(".onnx"):
|
||||
pose_repo = "hr16/UnJIT-DWPose"
|
||||
elif pose_estimator.endswith(".torchscript.pt"):
|
||||
pose_repo = "hr16/DWPose-TorchScript-BatchSize5"
|
||||
else:
|
||||
raise NotImplementedError(f"Download mechanism for {pose_estimator}")
|
||||
|
||||
model = AnimalposeDetector.from_pretrained(
|
||||
pose_repo,
|
||||
yolo_repo,
|
||||
det_filename=(None if bbox_detector == "None" else bbox_detector), pose_filename=pose_estimator,
|
||||
torchscript_device=model_management.get_torch_device()
|
||||
)
|
||||
|
||||
self.openpose_dicts = []
|
||||
def func(image, **kwargs):
|
||||
pose_img, openpose_dict = model(image, **kwargs)
|
||||
self.openpose_dicts.append(openpose_dict)
|
||||
return pose_img
|
||||
|
||||
out = common_annotator_call(func, image, image_and_json=True, resolution=resolution)
|
||||
del model
|
||||
return {
|
||||
'ui': { "openpose_json": [json.dumps(self.openpose_dicts, indent=4)] },
|
||||
"result": (out, self.openpose_dicts)
|
||||
}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"DWPreprocessor": DWPose_Preprocessor,
|
||||
"AnimalPosePreprocessor": AnimalPose_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DWPreprocessor": "DWPose Estimator",
|
||||
"AnimalPosePreprocessor": "AnimalPose Estimator (AP10K)"
|
||||
}
|
||||
53
custom_nodes/comfyui_controlnet_aux/node_wrappers/hed.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class HED_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
safe=INPUT.COMBO(["enable", "disable"]),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Line Extractors"
|
||||
|
||||
def execute(self, image, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.hed import HEDdetector
|
||||
|
||||
model = HEDdetector.from_pretrained().to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution, safe = kwargs["safe"] == "enable")
|
||||
del model
|
||||
return (out, )
|
||||
|
||||
class Fake_Scribble_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
safe=INPUT.COMBO(["enable", "disable"]),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Line Extractors"
|
||||
|
||||
def execute(self, image, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.hed import HEDdetector
|
||||
|
||||
model = HEDdetector.from_pretrained().to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution, scribble=True, safe=kwargs["safe"]=="enable")
|
||||
del model
|
||||
return (out, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"HEDPreprocessor": HED_Preprocessor,
|
||||
"FakeScribblePreprocessor": Fake_Scribble_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"HEDPreprocessor": "HED Soft-Edge Lines",
|
||||
"FakeScribblePreprocessor": "Fake Scribble Lines (aka scribble_hed)"
|
||||
}
|
||||
32
custom_nodes/comfyui_controlnet_aux/node_wrappers/inpaint.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from ..utils import INPUT
|
||||
|
||||
class InpaintPreprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return dict(
|
||||
required=dict(image=INPUT.IMAGE(), mask=INPUT.MASK()),
|
||||
optional=dict(black_pixel_for_xinsir_cn=INPUT.BOOLEAN(False))
|
||||
)
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "preprocess"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/others"
|
||||
|
||||
def preprocess(self, image, mask, black_pixel_for_xinsir_cn=False):
|
||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(image.shape[1], image.shape[2]), mode="bilinear")
|
||||
mask = mask.movedim(1,-1).expand((-1,-1,-1,3))
|
||||
image = image.clone()
|
||||
if black_pixel_for_xinsir_cn:
|
||||
masked_pixel = 0.0
|
||||
else:
|
||||
masked_pixel = -1.0
|
||||
image[mask > 0.5] = masked_pixel
|
||||
return (image,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"InpaintPreprocessor": InpaintPreprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"InpaintPreprocessor": "Inpaint Preprocessor"
|
||||
}
|
||||
32
custom_nodes/comfyui_controlnet_aux/node_wrappers/leres.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class LERES_Depth_Map_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
rm_nearest=INPUT.FLOAT(max=100.0),
|
||||
rm_background=INPUT.FLOAT(max=100.0),
|
||||
boost=INPUT.COMBO(["disable", "enable"]),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
|
||||
|
||||
def execute(self, image, rm_nearest=0, rm_background=0, resolution=512, boost="disable", **kwargs):
|
||||
from custom_controlnet_aux.leres import LeresDetector
|
||||
|
||||
model = LeresDetector.from_pretrained().to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution, thr_a=rm_nearest, thr_b=rm_background, boost=boost == "enable")
|
||||
del model
|
||||
return (out, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LeReS-DepthMapPreprocessor": LERES_Depth_Map_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LeReS-DepthMapPreprocessor": "LeReS Depth Map (enable boost for leres++)"
|
||||
}
|
||||
30
custom_nodes/comfyui_controlnet_aux/node_wrappers/lineart.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class LineArt_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
coarse=INPUT.COMBO((["disable", "enable"])),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Line Extractors"
|
||||
|
||||
def execute(self, image, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.lineart import LineartDetector
|
||||
|
||||
model = LineartDetector.from_pretrained().to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution, coarse = kwargs["coarse"] == "enable")
|
||||
del model
|
||||
return (out, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LineArtPreprocessor": LineArt_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LineArtPreprocessor": "Realistic Lineart"
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class AnimeLineArt_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Line Extractors"
|
||||
|
||||
def execute(self, image, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.lineart_anime import LineartAnimeDetector
|
||||
|
||||
model = LineartAnimeDetector.from_pretrained().to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution)
|
||||
del model
|
||||
return (out, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"AnimeLineArtPreprocessor": AnimeLineArt_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"AnimeLineArtPreprocessor": "Anime Lineart"
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class Lineart_Standard_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
guassian_sigma=INPUT.FLOAT(default=6.0, max=100.0),
|
||||
intensity_threshold=INPUT.INT(default=8, max=16),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Line Extractors"
|
||||
|
||||
def execute(self, image, guassian_sigma=6, intensity_threshold=8, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.lineart_standard import LineartStandardDetector
|
||||
return (common_annotator_call(LineartStandardDetector(), image, guassian_sigma=guassian_sigma, intensity_threshold=intensity_threshold, resolution=resolution), )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LineartStandardPreprocessor": Lineart_Standard_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LineartStandardPreprocessor": "Standard Lineart"
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class Manga2Anime_LineArt_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Line Extractors"
|
||||
|
||||
def execute(self, image, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.manga_line import LineartMangaDetector
|
||||
|
||||
model = LineartMangaDetector.from_pretrained().to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution)
|
||||
del model
|
||||
return (out, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Manga2Anime_LineArt_Preprocessor": Manga2Anime_LineArt_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Manga2Anime_LineArt_Preprocessor": "Manga Lineart (aka lineart_anime_denoise)"
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, run_script
|
||||
import comfy.model_management as model_management
|
||||
import os, sys
|
||||
import subprocess, threading
|
||||
|
||||
def install_deps():
|
||||
try:
|
||||
import mediapipe
|
||||
except ImportError:
|
||||
run_script([sys.executable, '-s', '-m', 'pip', 'install', 'mediapipe'])
|
||||
run_script([sys.executable, '-s', '-m', 'pip', 'install', '--upgrade', 'protobuf'])
|
||||
|
||||
class Media_Pipe_Face_Mesh_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
max_faces=INPUT.INT(default=10, min=1, max=50), #Which image has more than 50 detectable faces?
|
||||
min_confidence=INPUT.FLOAT(default=0.5, min=0.1),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "detect"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Faces and Poses Estimators"
|
||||
|
||||
def detect(self, image, max_faces=10, min_confidence=0.5, resolution=512):
|
||||
#Ref: https://github.com/Fannovel16/comfy_controlnet_preprocessors/issues/70#issuecomment-1677967369
|
||||
install_deps()
|
||||
from custom_controlnet_aux.mediapipe_face import MediapipeFaceDetector
|
||||
return (common_annotator_call(MediapipeFaceDetector(), image, max_faces=max_faces, min_confidence=min_confidence, resolution=resolution), )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"MediaPipe-FaceMeshPreprocessor": Media_Pipe_Face_Mesh_Preprocessor
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"MediaPipe-FaceMeshPreprocessor": "MediaPipe Face Mesh"
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, MAX_RESOLUTION, run_script
|
||||
import comfy.model_management as model_management
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange
|
||||
import os, sys
|
||||
import subprocess, threading
|
||||
import scipy.ndimage
|
||||
import cv2
|
||||
import torch.nn.functional as F
|
||||
|
||||
def install_deps():
|
||||
try:
|
||||
import mediapipe
|
||||
except ImportError:
|
||||
run_script([sys.executable, '-s', '-m', 'pip', 'install', 'mediapipe'])
|
||||
run_script([sys.executable, '-s', '-m', 'pip', 'install', '--upgrade', 'protobuf'])
|
||||
|
||||
try:
|
||||
import trimesh
|
||||
except ImportError:
|
||||
run_script([sys.executable, '-s', '-m', 'pip', 'install', 'trimesh[easy]'])
|
||||
|
||||
#Sauce: https://github.com/comfyanonymous/ComfyUI/blob/8c6493578b3dda233e9b9a953feeaf1e6ca434ad/comfy_extras/nodes_mask.py#L309
|
||||
def expand_mask(mask, expand, tapered_corners):
|
||||
c = 0 if tapered_corners else 1
|
||||
kernel = np.array([[c, 1, c],
|
||||
[1, 1, 1],
|
||||
[c, 1, c]])
|
||||
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
|
||||
out = []
|
||||
for m in mask:
|
||||
output = m.numpy()
|
||||
for _ in range(abs(expand)):
|
||||
if expand < 0:
|
||||
output = scipy.ndimage.grey_erosion(output, footprint=kernel)
|
||||
else:
|
||||
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
|
||||
output = torch.from_numpy(output)
|
||||
out.append(output)
|
||||
return torch.stack(out, dim=0)
|
||||
|
||||
class Mesh_Graphormer_Depth_Map_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
mask_bbox_padding=("INT", {"default": 30, "min": 0, "max": 100}),
|
||||
resolution=INPUT.RESOLUTION(),
|
||||
mask_type=INPUT.COMBO(["based_on_depth", "tight_bboxes", "original"]),
|
||||
mask_expand=INPUT.INT(default=5, min=-MAX_RESOLUTION, max=MAX_RESOLUTION),
|
||||
rand_seed=INPUT.INT(default=88, min=0, max=0xffffffffffffffff),
|
||||
detect_thr=INPUT.FLOAT(default=0.6, min=0.1),
|
||||
presence_thr=INPUT.FLOAT(default=0.6, min=0.1)
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
RETURN_NAMES = ("IMAGE", "INPAINTING_MASK")
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
|
||||
|
||||
def execute(self, image, mask_bbox_padding=30, mask_type="based_on_depth", mask_expand=5, resolution=512, rand_seed=88, detect_thr=0.6, presence_thr=0.6, **kwargs):
|
||||
install_deps()
|
||||
from custom_controlnet_aux.mesh_graphormer import MeshGraphormerDetector
|
||||
model = kwargs["model"] if "model" in kwargs \
|
||||
else MeshGraphormerDetector.from_pretrained(detect_thr=detect_thr, presence_thr=presence_thr).to(model_management.get_torch_device())
|
||||
|
||||
depth_map_list = []
|
||||
mask_list = []
|
||||
for single_image in image:
|
||||
np_image = np.asarray(single_image.cpu() * 255., dtype=np.uint8)
|
||||
depth_map, mask, info = model(np_image, output_type="np", detect_resolution=resolution, mask_bbox_padding=mask_bbox_padding, seed=rand_seed)
|
||||
if mask_type == "based_on_depth":
|
||||
H, W = mask.shape[:2]
|
||||
mask = cv2.resize(depth_map.copy(), (W, H))
|
||||
mask[mask > 0] = 255
|
||||
|
||||
elif mask_type == "tight_bboxes":
|
||||
mask = np.zeros_like(mask)
|
||||
hand_bboxes = (info or {}).get("abs_boxes") or []
|
||||
for hand_bbox in hand_bboxes:
|
||||
x_min, x_max, y_min, y_max = hand_bbox
|
||||
mask[y_min:y_max+1, x_min:x_max+1, :] = 255 #HWC
|
||||
|
||||
mask = mask[:, :, :1]
|
||||
depth_map_list.append(torch.from_numpy(depth_map.astype(np.float32) / 255.0))
|
||||
mask_list.append(torch.from_numpy(mask.astype(np.float32) / 255.0))
|
||||
depth_maps, masks = torch.stack(depth_map_list, dim=0), rearrange(torch.stack(mask_list, dim=0), "n h w 1 -> n 1 h w")
|
||||
return depth_maps, expand_mask(masks, mask_expand, tapered_corners=True)
|
||||
|
||||
def normalize_size_base_64(w, h):
|
||||
short_side = min(w, h)
|
||||
remainder = short_side % 64
|
||||
return short_side - remainder + (64 if remainder > 0 else 0)
|
||||
|
||||
class Mesh_Graphormer_With_ImpactDetector_Depth_Map_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
types = define_preprocessor_inputs(
|
||||
# Impact pack
|
||||
bbox_threshold=INPUT.FLOAT(default=0.5, min=0.1),
|
||||
bbox_dilation=INPUT.INT(default=10, min=-512, max=512),
|
||||
bbox_crop_factor=INPUT.FLOAT(default=3.0, min=1.0, max=10.0),
|
||||
drop_size=INPUT.INT(default=10, min=1, max=MAX_RESOLUTION),
|
||||
# Mesh Graphormer
|
||||
mask_bbox_padding=INPUT.INT(default=30, min=0, max=100),
|
||||
mask_type=INPUT.COMBO(["based_on_depth", "tight_bboxes", "original"]),
|
||||
mask_expand=INPUT.INT(default=5, min=-MAX_RESOLUTION, max=MAX_RESOLUTION),
|
||||
rand_seed=INPUT.INT(default=88, min=0, max=0xffffffffffffffff),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
types["required"]["bbox_detector"] = ("BBOX_DETECTOR", )
|
||||
return types
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
RETURN_NAMES = ("IMAGE", "INPAINTING_MASK")
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
|
||||
|
||||
def execute(self, image, bbox_detector, bbox_threshold=0.5, bbox_dilation=10, bbox_crop_factor=3.0, drop_size=10, resolution=512, **mesh_graphormer_kwargs):
|
||||
install_deps()
|
||||
from custom_controlnet_aux.mesh_graphormer import MeshGraphormerDetector
|
||||
mesh_graphormer_node = Mesh_Graphormer_Depth_Map_Preprocessor()
|
||||
model = MeshGraphormerDetector.from_pretrained(detect_thr=0.6, presence_thr=0.6).to(model_management.get_torch_device())
|
||||
mesh_graphormer_kwargs["model"] = model
|
||||
|
||||
frames = image
|
||||
depth_maps, masks = [], []
|
||||
for idx in range(len(frames)):
|
||||
frame = frames[idx:idx+1,...] #Impact Pack's BBOX_DETECTOR only supports single batch image
|
||||
bbox_detector.setAux('face') # make default prompt as 'face' if empty prompt for CLIPSeg
|
||||
_, segs = bbox_detector.detect(frame, bbox_threshold, bbox_dilation, bbox_crop_factor, drop_size)
|
||||
bbox_detector.setAux(None)
|
||||
|
||||
n, h, w, _ = frame.shape
|
||||
depth_map, mask = torch.zeros_like(frame), torch.zeros(n, 1, h, w)
|
||||
for i, seg in enumerate(segs):
|
||||
x1, y1, x2, y2 = seg.crop_region
|
||||
cropped_image = frame[:, y1:y2, x1:x2, :] # Never use seg.cropped_image to handle overlapping area
|
||||
mesh_graphormer_kwargs["resolution"] = 0 #Disable resizing
|
||||
sub_depth_map, sub_mask = mesh_graphormer_node.execute(cropped_image, **mesh_graphormer_kwargs)
|
||||
depth_map[:, y1:y2, x1:x2, :] = sub_depth_map
|
||||
mask[:, :, y1:y2, x1:x2] = sub_mask
|
||||
|
||||
depth_maps.append(depth_map)
|
||||
masks.append(mask)
|
||||
|
||||
return (torch.cat(depth_maps), torch.cat(masks))
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"MeshGraphormer-DepthMapPreprocessor": Mesh_Graphormer_Depth_Map_Preprocessor,
|
||||
"MeshGraphormer+ImpactDetector-DepthMapPreprocessor": Mesh_Graphormer_With_ImpactDetector_Depth_Map_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"MeshGraphormer-DepthMapPreprocessor": "MeshGraphormer Hand Refiner",
|
||||
"MeshGraphormer+ImpactDetector-DepthMapPreprocessor": "MeshGraphormer Hand Refiner With External Detector"
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
import os
|
||||
# Disable NPU device initialization and problematic MMCV ops to prevent RuntimeError
|
||||
os.environ['NPU_DEVICE_COUNT'] = '0'
|
||||
os.environ['MMCV_WITH_OPS'] = '0'
|
||||
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, MAX_RESOLUTION
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class Metric3D_Depth_Map_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
backbone=INPUT.COMBO(["vit-small", "vit-large", "vit-giant2"]),
|
||||
fx=INPUT.INT(default=1000, min=1, max=MAX_RESOLUTION),
|
||||
fy=INPUT.INT(default=1000, min=1, max=MAX_RESOLUTION),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
|
||||
|
||||
def execute(self, image, backbone="vit-small", fx=1000, fy=1000, resolution=512):
|
||||
from custom_controlnet_aux.metric3d import Metric3DDetector
|
||||
model = Metric3DDetector.from_pretrained(filename=f"metric_depth_{backbone.replace('-', '_')}_800k.pth").to(model_management.get_torch_device())
|
||||
cb = lambda image, **kwargs: model(image, **kwargs)[0]
|
||||
out = common_annotator_call(cb, image, resolution=resolution, fx=fx, fy=fy, depth_and_normal=True)
|
||||
del model
|
||||
return (out, )
|
||||
|
||||
class Metric3D_Normal_Map_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
backbone=INPUT.COMBO(["vit-small", "vit-large", "vit-giant2"]),
|
||||
fx=INPUT.INT(default=1000, min=1, max=MAX_RESOLUTION),
|
||||
fy=INPUT.INT(default=1000, min=1, max=MAX_RESOLUTION),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
|
||||
|
||||
def execute(self, image, backbone="vit-small", fx=1000, fy=1000, resolution=512):
|
||||
from custom_controlnet_aux.metric3d import Metric3DDetector
|
||||
model = Metric3DDetector.from_pretrained(filename=f"metric_depth_{backbone.replace('-', '_')}_800k.pth").to(model_management.get_torch_device())
|
||||
cb = lambda image, **kwargs: model(image, **kwargs)[1]
|
||||
out = common_annotator_call(cb, image, resolution=resolution, fx=fx, fy=fy, depth_and_normal=True)
|
||||
del model
|
||||
return (out, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Metric3D-DepthMapPreprocessor": Metric3D_Depth_Map_Preprocessor,
|
||||
"Metric3D-NormalMapPreprocessor": Metric3D_Normal_Map_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Metric3D-DepthMapPreprocessor": "Metric3D Depth Map",
|
||||
"Metric3D-NormalMapPreprocessor": "Metric3D Normal Map"
|
||||
}
|
||||
59
custom_nodes/comfyui_controlnet_aux/node_wrappers/midas.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
import numpy as np
|
||||
|
||||
class MIDAS_Normal_Map_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
a=INPUT.FLOAT(default=np.pi * 2.0, min=0.0, max=np.pi * 5.0),
|
||||
bg_threshold=INPUT.FLOAT(default=0.1),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
|
||||
|
||||
def execute(self, image, a=np.pi * 2.0, bg_threshold=0.1, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.midas import MidasDetector
|
||||
|
||||
model = MidasDetector.from_pretrained().to(model_management.get_torch_device())
|
||||
#Dirty hack :))
|
||||
cb = lambda image, **kargs: model(image, **kargs)[1]
|
||||
out = common_annotator_call(cb, image, resolution=resolution, a=a, bg_th=bg_threshold, depth_and_normal=True)
|
||||
del model
|
||||
return (out, )
|
||||
|
||||
class MIDAS_Depth_Map_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
a=INPUT.FLOAT(default=np.pi * 2.0, min=0.0, max=np.pi * 5.0),
|
||||
bg_threshold=INPUT.FLOAT(default=0.1),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
|
||||
|
||||
def execute(self, image, a=np.pi * 2.0, bg_threshold=0.1, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.midas import MidasDetector
|
||||
|
||||
# Ref: https://github.com/lllyasviel/ControlNet/blob/main/gradio_depth2image.py
|
||||
model = MidasDetector.from_pretrained().to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution, a=a, bg_th=bg_threshold)
|
||||
del model
|
||||
return (out, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"MiDaS-NormalMapPreprocessor": MIDAS_Normal_Map_Preprocessor,
|
||||
"MiDaS-DepthMapPreprocessor": MIDAS_Depth_Map_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"MiDaS-NormalMapPreprocessor": "MiDaS Normal Map",
|
||||
"MiDaS-DepthMapPreprocessor": "MiDaS Depth Map"
|
||||
}
|
||||
31
custom_nodes/comfyui_controlnet_aux/node_wrappers/mlsd.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
import numpy as np
|
||||
|
||||
class MLSD_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
score_threshold=INPUT.FLOAT(default=0.1, min=0.01, max=2.0),
|
||||
dist_threshold=INPUT.FLOAT(default=0.1, min=0.01, max=20.0),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Line Extractors"
|
||||
|
||||
def execute(self, image, score_threshold, dist_threshold, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.mlsd import MLSDdetector
|
||||
|
||||
model = MLSDdetector.from_pretrained().to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution, thr_v=score_threshold, thr_d=dist_threshold)
|
||||
return (out, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"M-LSDPreprocessor": MLSD_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"M-LSDPreprocessor": "M-LSD Lines"
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class BAE_Normal_Map_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
|
||||
|
||||
def execute(self, image, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.normalbae import NormalBaeDetector
|
||||
|
||||
model = NormalBaeDetector.from_pretrained().to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution)
|
||||
del model
|
||||
return (out,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"BAE-NormalMapPreprocessor": BAE_Normal_Map_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"BAE-NormalMapPreprocessor": "BAE Normal Map"
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class OneFormer_COCO_SemSegPreprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "semantic_segmentate"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Semantic Segmentation"
|
||||
|
||||
def semantic_segmentate(self, image, resolution=512):
|
||||
from custom_controlnet_aux.oneformer import OneformerSegmentor
|
||||
|
||||
model = OneformerSegmentor.from_pretrained(filename="150_16_swin_l_oneformer_coco_100ep.pth")
|
||||
model = model.to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution)
|
||||
del model
|
||||
return (out,)
|
||||
|
||||
class OneFormer_ADE20K_SemSegPreprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "semantic_segmentate"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Semantic Segmentation"
|
||||
|
||||
def semantic_segmentate(self, image, resolution=512):
|
||||
from custom_controlnet_aux.oneformer import OneformerSegmentor
|
||||
|
||||
model = OneformerSegmentor.from_pretrained(filename="250_16_swin_l_oneformer_ade20k_160k.pth")
|
||||
model = model.to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution)
|
||||
del model
|
||||
return (out,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"OneFormer-COCO-SemSegPreprocessor": OneFormer_COCO_SemSegPreprocessor,
|
||||
"OneFormer-ADE20K-SemSegPreprocessor": OneFormer_ADE20K_SemSegPreprocessor
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"OneFormer-COCO-SemSegPreprocessor": "OneFormer COCO Segmentor",
|
||||
"OneFormer-ADE20K-SemSegPreprocessor": "OneFormer ADE20K Segmentor"
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
import json
|
||||
|
||||
class OpenPose_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
detect_hand=INPUT.COMBO(["enable", "disable"]),
|
||||
detect_body=INPUT.COMBO(["enable", "disable"]),
|
||||
detect_face=INPUT.COMBO(["enable", "disable"]),
|
||||
resolution=INPUT.RESOLUTION(),
|
||||
scale_stick_for_xinsr_cn=INPUT.COMBO(["disable", "enable"])
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "POSE_KEYPOINT")
|
||||
FUNCTION = "estimate_pose"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Faces and Poses Estimators"
|
||||
|
||||
def estimate_pose(self, image, detect_hand="enable", detect_body="enable", detect_face="enable", scale_stick_for_xinsr_cn="disable", resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.open_pose import OpenposeDetector
|
||||
|
||||
detect_hand = detect_hand == "enable"
|
||||
detect_body = detect_body == "enable"
|
||||
detect_face = detect_face == "enable"
|
||||
scale_stick_for_xinsr_cn = scale_stick_for_xinsr_cn == "enable"
|
||||
|
||||
model = OpenposeDetector.from_pretrained().to(model_management.get_torch_device())
|
||||
self.openpose_dicts = []
|
||||
def func(image, **kwargs):
|
||||
pose_img, openpose_dict = model(image, **kwargs)
|
||||
self.openpose_dicts.append(openpose_dict)
|
||||
return pose_img
|
||||
|
||||
out = common_annotator_call(func, image, include_hand=detect_hand, include_face=detect_face, include_body=detect_body, image_and_json=True, xinsr_stick_scaling=scale_stick_for_xinsr_cn, resolution=resolution)
|
||||
del model
|
||||
return {
|
||||
'ui': { "openpose_json": [json.dumps(self.openpose_dicts, indent=4)] },
|
||||
"result": (out, self.openpose_dicts)
|
||||
}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"OpenposePreprocessor": OpenPose_Preprocessor,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"OpenposePreprocessor": "OpenPose Pose",
|
||||
}
|
||||
30
custom_nodes/comfyui_controlnet_aux/node_wrappers/pidinet.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class PIDINET_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
safe=INPUT.COMBO(["enable", "disable"]),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Line Extractors"
|
||||
|
||||
def execute(self, image, safe, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.pidi import PidiNetDetector
|
||||
|
||||
model = PidiNetDetector.from_pretrained().to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution, safe = safe == "enable")
|
||||
del model
|
||||
return (out, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PiDiNetPreprocessor": PIDINET_Preprocessor,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PiDiNetPreprocessor": "PiDiNet Soft-Edge Lines"
|
||||
}
|
||||
@@ -0,0 +1,340 @@
|
||||
import folder_paths
|
||||
import json
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import ImageColor
|
||||
from einops import rearrange
|
||||
import torch
|
||||
import itertools
|
||||
|
||||
from ..src.custom_controlnet_aux.dwpose import draw_poses, draw_animalposes, decode_json_as_poses
|
||||
|
||||
|
||||
"""
|
||||
Format of POSE_KEYPOINT (AP10K keypoints):
|
||||
[{
|
||||
"version": "ap10k",
|
||||
"animals": [
|
||||
[[x1, y1, 1], [x2, y2, 1],..., [x17, y17, 1]],
|
||||
[[x1, y1, 1], [x2, y2, 1],..., [x17, y17, 1]],
|
||||
...
|
||||
],
|
||||
"canvas_height": 512,
|
||||
"canvas_width": 768
|
||||
},...]
|
||||
Format of POSE_KEYPOINT (OpenPose keypoints):
|
||||
[{
|
||||
"people": [
|
||||
{
|
||||
'pose_keypoints_2d': [[x1, y1, 1], [x2, y2, 1],..., [x17, y17, 1]]
|
||||
"face_keypoints_2d": [[x1, y1, 1], [x2, y2, 1],..., [x68, y68, 1]],
|
||||
"hand_left_keypoints_2d": [[x1, y1, 1], [x2, y2, 1],..., [x21, y21, 1]],
|
||||
"hand_right_keypoints_2d":[[x1, y1, 1], [x2, y2, 1],..., [x21, y21, 1]],
|
||||
}
|
||||
],
|
||||
"canvas_height": canvas_height,
|
||||
"canvas_width": canvas_width,
|
||||
},...]
|
||||
"""
|
||||
|
||||
class SavePoseKpsAsJsonFile:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"pose_kps": ("POSE_KEYPOINT",),
|
||||
"filename_prefix": ("STRING", {"default": "PoseKeypoint"})
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_pose_kps"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "ControlNet Preprocessors/Pose Keypoint Postprocess"
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
def save_pose_kps(self, pose_kps, filename_prefix):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = \
|
||||
folder_paths.get_save_image_path(filename_prefix, self.output_dir, pose_kps[0]["canvas_width"], pose_kps[0]["canvas_height"])
|
||||
file = f"{filename}_{counter:05}.json"
|
||||
with open(os.path.join(full_output_folder, file), 'w') as f:
|
||||
json.dump(pose_kps , f)
|
||||
return {}
|
||||
|
||||
#COCO-Wholebody doesn't have eyebrows as it inherits 68 keypoints format
|
||||
#Perhaps eyebrows can be estimated tho
|
||||
FACIAL_PARTS = ["skin", "left_eye", "right_eye", "nose", "upper_lip", "inner_mouth", "lower_lip"]
|
||||
LAPA_COLORS = dict(
|
||||
skin="rgb(0, 153, 255)",
|
||||
left_eye="rgb(0, 204, 153)",
|
||||
right_eye="rgb(255, 153, 0)",
|
||||
nose="rgb(255, 102, 255)",
|
||||
upper_lip="rgb(102, 0, 51)",
|
||||
inner_mouth="rgb(255, 204, 255)",
|
||||
lower_lip="rgb(255, 0, 102)"
|
||||
)
|
||||
|
||||
#One-based index
|
||||
def kps_idxs(start, end):
|
||||
step = -1 if start > end else 1
|
||||
return list(range(start-1, end+1-1, step))
|
||||
|
||||
#Source: https://www.researchgate.net/profile/Fabrizio-Falchi/publication/338048224/figure/fig1/AS:837860722741255@1576772971540/68-facial-landmarks.jpg
|
||||
FACIAL_PART_RANGES = dict(
|
||||
skin=kps_idxs(1, 17) + kps_idxs(27, 18),
|
||||
nose=kps_idxs(28, 36),
|
||||
left_eye=kps_idxs(37, 42),
|
||||
right_eye=kps_idxs(43, 48),
|
||||
upper_lip=kps_idxs(49, 55) + kps_idxs(65, 61),
|
||||
lower_lip=kps_idxs(61, 68),
|
||||
inner_mouth=kps_idxs(61, 65) + kps_idxs(55, 49)
|
||||
)
|
||||
|
||||
def is_normalized(keypoints) -> bool:
|
||||
point_normalized = [
|
||||
0 <= np.abs(k[0]) <= 1 and 0 <= np.abs(k[1]) <= 1
|
||||
for k in keypoints
|
||||
if k is not None
|
||||
]
|
||||
if not point_normalized:
|
||||
return False
|
||||
return np.all(point_normalized)
|
||||
|
||||
class FacialPartColoringFromPoseKps:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_types = {
|
||||
"required": {"pose_kps": ("POSE_KEYPOINT",), "mode": (["point", "polygon"], {"default": "polygon"})}
|
||||
}
|
||||
for facial_part in FACIAL_PARTS:
|
||||
input_types["required"][facial_part] = ("STRING", {"default": LAPA_COLORS[facial_part], "multiline": False})
|
||||
return input_types
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "colorize"
|
||||
CATEGORY = "ControlNet Preprocessors/Pose Keypoint Postprocess"
|
||||
def colorize(self, pose_kps, mode, **facial_part_colors):
|
||||
pose_frames = pose_kps
|
||||
np_frames = [self.draw_kps(pose_frame, mode, **facial_part_colors) for pose_frame in pose_frames]
|
||||
np_frames = np.stack(np_frames, axis=0)
|
||||
return (torch.from_numpy(np_frames).float() / 255.,)
|
||||
|
||||
def draw_kps(self, pose_frame, mode, **facial_part_colors):
|
||||
width, height = pose_frame["canvas_width"], pose_frame["canvas_height"]
|
||||
canvas = np.zeros((height, width, 3), dtype=np.uint8)
|
||||
for person, part_name in itertools.product(pose_frame["people"], FACIAL_PARTS):
|
||||
n = len(person["face_keypoints_2d"]) // 3
|
||||
facial_kps = rearrange(np.array(person["face_keypoints_2d"]), "(n c) -> n c", n=n, c=3)[:, :2]
|
||||
if is_normalized(facial_kps):
|
||||
facial_kps *= (width, height)
|
||||
facial_kps = facial_kps.astype(np.int32)
|
||||
part_color = ImageColor.getrgb(facial_part_colors[part_name])[:3]
|
||||
part_contours = facial_kps[FACIAL_PART_RANGES[part_name], :]
|
||||
if mode == "point":
|
||||
for pt in part_contours:
|
||||
cv2.circle(canvas, pt, radius=2, color=part_color, thickness=-1)
|
||||
else:
|
||||
cv2.fillPoly(canvas, pts=[part_contours], color=part_color)
|
||||
return canvas
|
||||
|
||||
# https://raw.githubusercontent.com/CMU-Perceptual-Computing-Lab/openpose/master/.github/media/keypoints_pose_18.png
|
||||
BODY_PART_INDEXES = {
|
||||
"Head": (16, 14, 0, 15, 17),
|
||||
"Neck": (0, 1),
|
||||
"Shoulder": (2, 5),
|
||||
"Torso": (2, 5, 8, 11),
|
||||
"RArm": (2, 3),
|
||||
"RForearm": (3, 4),
|
||||
"LArm": (5, 6),
|
||||
"LForearm": (6, 7),
|
||||
"RThigh": (8, 9),
|
||||
"RLeg": (9, 10),
|
||||
"LThigh": (11, 12),
|
||||
"LLeg": (12, 13)
|
||||
}
|
||||
BODY_PART_DEFAULT_W_H = {
|
||||
"Head": "256, 256",
|
||||
"Neck": "100, 100",
|
||||
"Shoulder": '',
|
||||
"Torso": "350, 450",
|
||||
"RArm": "128, 256",
|
||||
"RForearm": "128, 256",
|
||||
"LArm": "128, 256",
|
||||
"LForearm": "128, 256",
|
||||
"RThigh": "128, 256",
|
||||
"RLeg": "128, 256",
|
||||
"LThigh": "128, 256",
|
||||
"LLeg": "128, 256"
|
||||
}
|
||||
|
||||
class SinglePersonProcess:
|
||||
@classmethod
|
||||
def sort_and_get_max_people(s, pose_kps):
|
||||
for idx in range(len(pose_kps)):
|
||||
pose_kps[idx]["people"] = sorted(pose_kps[idx]["people"], key=lambda person:person["pose_keypoints_2d"][0])
|
||||
return pose_kps, max(len(frame["people"]) for frame in pose_kps)
|
||||
|
||||
def __init__(self, pose_kps, person_idx=0) -> None:
|
||||
self.width, self.height = pose_kps[0]["canvas_width"], pose_kps[0]["canvas_height"]
|
||||
self.poses = [
|
||||
self.normalize(pose_frame["people"][person_idx]["pose_keypoints_2d"])
|
||||
if person_idx < len(pose_frame["people"])
|
||||
else None
|
||||
for pose_frame in pose_kps
|
||||
]
|
||||
|
||||
def normalize(self, pose_kps_2d):
|
||||
n = len(pose_kps_2d) // 3
|
||||
pose_kps_2d = rearrange(np.array(pose_kps_2d), "(n c) -> n c", n=n, c=3)
|
||||
pose_kps_2d[np.argwhere(pose_kps_2d[:,2]==0), :] = np.iinfo(np.int32).max // 2 #Safe large value
|
||||
pose_kps_2d = pose_kps_2d[:, :2]
|
||||
if is_normalized(pose_kps_2d):
|
||||
pose_kps_2d *= (self.width, self.height)
|
||||
return pose_kps_2d
|
||||
|
||||
def get_xyxy_bboxes(self, part_name, bbox_size=(128, 256)):
|
||||
width, height = bbox_size
|
||||
xyxy_bboxes = {}
|
||||
for idx, pose in enumerate(self.poses):
|
||||
if pose is None:
|
||||
xyxy_bboxes[idx] = (np.iinfo(np.int32).max // 2,) * 4
|
||||
continue
|
||||
pts = pose[BODY_PART_INDEXES[part_name], :]
|
||||
|
||||
#top_left = np.min(pts[:,0]), np.min(pts[:,1])
|
||||
#bottom_right = np.max(pts[:,0]), np.max(pts[:,1])
|
||||
#pad_width = np.maximum(width - (bottom_right[0]-top_left[0]), 0) / 2
|
||||
#pad_height = np.maximum(height - (bottom_right[1]-top_left[1]), 0) / 2
|
||||
#xyxy_bboxes.append((
|
||||
# top_left[0] - pad_width, top_left[1] - pad_height,
|
||||
# bottom_right[0] + pad_width, bottom_right[1] + pad_height,
|
||||
#))
|
||||
|
||||
x_mid, y_mid = np.mean(pts[:, 0]), np.mean(pts[:, 1])
|
||||
xyxy_bboxes[idx] = (
|
||||
x_mid - width/2, y_mid - height/2,
|
||||
x_mid + width/2, y_mid + height/2
|
||||
)
|
||||
return xyxy_bboxes
|
||||
|
||||
class UpperBodyTrackingFromPoseKps:
|
||||
PART_NAMES = ["Head", "Neck", "Shoulder", "Torso", "RArm", "RForearm", "LArm", "LForearm"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"pose_kps": ("POSE_KEYPOINT",),
|
||||
"id_include": ("STRING", {"default": '', "multiline": False}),
|
||||
**{part_name + "_width_height": ("STRING", {"default": BODY_PART_DEFAULT_W_H[part_name], "multiline": False}) for part_name in s.PART_NAMES}
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("TRACKING", "STRING")
|
||||
RETURN_NAMES = ("tracking", "prompt")
|
||||
FUNCTION = "convert"
|
||||
CATEGORY = "ControlNet Preprocessors/Pose Keypoint Postprocess"
|
||||
|
||||
def convert(self, pose_kps, id_include, **parts_width_height):
|
||||
parts_width_height = {part_name.replace("_width_height", ''): value for part_name, value in parts_width_height.items()}
|
||||
enabled_part_names = [part_name for part_name in self.PART_NAMES if len(parts_width_height[part_name].strip())]
|
||||
tracked = {part_name: {} for part_name in enabled_part_names}
|
||||
id_include = id_include.strip()
|
||||
id_include = list(map(int, id_include.split(','))) if len(id_include) else []
|
||||
prompt_string = ''
|
||||
pose_kps, max_people = SinglePersonProcess.sort_and_get_max_people(pose_kps)
|
||||
|
||||
for person_idx in range(max_people):
|
||||
if len(id_include) and person_idx not in id_include:
|
||||
continue
|
||||
processor = SinglePersonProcess(pose_kps, person_idx)
|
||||
for part_name in enabled_part_names:
|
||||
bbox_size = tuple(map(int, parts_width_height[part_name].split(',')))
|
||||
part_bboxes = processor.get_xyxy_bboxes(part_name, bbox_size)
|
||||
id_coordinates = {idx: part_bbox+(processor.width, processor.height) for idx, part_bbox in part_bboxes.items()}
|
||||
tracked[part_name][person_idx] = id_coordinates
|
||||
|
||||
for class_name, class_data in tracked.items():
|
||||
for class_id in class_data.keys():
|
||||
class_id_str = str(class_id)
|
||||
# Use the incoming prompt for each class name and ID
|
||||
_class_name = class_name.replace('L', '').replace('R', '').lower()
|
||||
prompt_string += f'"{class_id_str}.{class_name}": "({_class_name})",\n'
|
||||
|
||||
return (tracked, prompt_string)
|
||||
|
||||
|
||||
def numpy2torch(np_image: np.ndarray) -> torch.Tensor:
|
||||
""" [H, W, C] => [B=1, H, W, C]"""
|
||||
return torch.from_numpy(np_image.astype(np.float32) / 255).unsqueeze(0)
|
||||
|
||||
|
||||
class RenderPeopleKps:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"kps": ("POSE_KEYPOINT",),
|
||||
"render_body": ("BOOLEAN", {"default": True}),
|
||||
"render_hand": ("BOOLEAN", {"default": True}),
|
||||
"render_face": ("BOOLEAN", {"default": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "render"
|
||||
CATEGORY = "ControlNet Preprocessors/Pose Keypoint Postprocess"
|
||||
|
||||
def render(self, kps, render_body, render_hand, render_face) -> tuple[np.ndarray]:
|
||||
if isinstance(kps, list):
|
||||
kps = kps[0]
|
||||
|
||||
poses, _, height, width = decode_json_as_poses(kps)
|
||||
np_image = draw_poses(
|
||||
poses,
|
||||
height,
|
||||
width,
|
||||
render_body,
|
||||
render_hand,
|
||||
render_face,
|
||||
)
|
||||
return (numpy2torch(np_image),)
|
||||
|
||||
class RenderAnimalKps:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"kps": ("POSE_KEYPOINT",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "render"
|
||||
CATEGORY = "ControlNet Preprocessors/Pose Keypoint Postprocess"
|
||||
|
||||
def render(self, kps) -> tuple[np.ndarray]:
|
||||
if isinstance(kps, list):
|
||||
kps = kps[0]
|
||||
|
||||
_, poses, height, width = decode_json_as_poses(kps)
|
||||
np_image = draw_animalposes(poses, height, width)
|
||||
return (numpy2torch(np_image),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SavePoseKpsAsJsonFile": SavePoseKpsAsJsonFile,
|
||||
"FacialPartColoringFromPoseKps": FacialPartColoringFromPoseKps,
|
||||
"UpperBodyTrackingFromPoseKps": UpperBodyTrackingFromPoseKps,
|
||||
"RenderPeopleKps": RenderPeopleKps,
|
||||
"RenderAnimalKps": RenderAnimalKps,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SavePoseKpsAsJsonFile": "Save Pose Keypoints",
|
||||
"FacialPartColoringFromPoseKps": "Colorize Facial Parts from PoseKPS",
|
||||
"UpperBodyTrackingFromPoseKps": "Upper Body Tracking From PoseKps (InstanceDiffusion)",
|
||||
"RenderPeopleKps": "Render Pose JSON (Human)",
|
||||
"RenderAnimalKps": "Render Pose JSON (Animal)",
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class PyraCanny_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
low_threshold=INPUT.INT(default=64, max=255),
|
||||
high_threshold=INPUT.INT(default=128, max=255),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Line Extractors"
|
||||
|
||||
def execute(self, image, low_threshold=64, high_threshold=128, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.pyracanny import PyraCannyDetector
|
||||
|
||||
return (common_annotator_call(PyraCannyDetector(), image, low_threshold=low_threshold, high_threshold=high_threshold, resolution=resolution), )
|
||||
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PyraCannyPreprocessor": PyraCanny_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PyraCannyPreprocessor": "PyraCanny"
|
||||
}
|
||||
46
custom_nodes/comfyui_controlnet_aux/node_wrappers/recolor.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
|
||||
class ImageLuminanceDetector:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
#https://github.com/Mikubill/sd-webui-controlnet/blob/416c345072c9c2066101e225964e3986abe6945e/scripts/processor.py#L1229
|
||||
return define_preprocessor_inputs(
|
||||
gamma_correction=INPUT.FLOAT(default=1.0, min=0.1, max=2.0),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Recolor"
|
||||
|
||||
def execute(self, image, gamma_correction=1.0, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.recolor import Recolorizer
|
||||
return (common_annotator_call(Recolorizer(), image, mode="luminance", gamma_correction=gamma_correction , resolution=resolution), )
|
||||
|
||||
class ImageIntensityDetector:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
#https://github.com/Mikubill/sd-webui-controlnet/blob/416c345072c9c2066101e225964e3986abe6945e/scripts/processor.py#L1229
|
||||
return define_preprocessor_inputs(
|
||||
gamma_correction=INPUT.FLOAT(default=1.0, min=0.1, max=2.0),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Recolor"
|
||||
|
||||
def execute(self, image, gamma_correction=1.0, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.recolor import Recolorizer
|
||||
return (common_annotator_call(Recolorizer(), image, mode="intensity", gamma_correction=gamma_correction , resolution=resolution), )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageLuminanceDetector": ImageLuminanceDetector,
|
||||
"ImageIntensityDetector": ImageIntensityDetector
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageLuminanceDetector": "Image Luminance",
|
||||
"ImageIntensityDetector": "Image Intensity"
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, nms
|
||||
import comfy.model_management as model_management
|
||||
import cv2
|
||||
|
||||
class Scribble_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Line Extractors"
|
||||
|
||||
def execute(self, image, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.scribble import ScribbleDetector
|
||||
|
||||
model = ScribbleDetector()
|
||||
return (common_annotator_call(model, image, resolution=resolution), )
|
||||
|
||||
class Scribble_XDoG_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
threshold=INPUT.INT(default=32, min=1, max=64),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Line Extractors"
|
||||
|
||||
def execute(self, image, threshold=32, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.scribble import ScribbleXDog_Detector
|
||||
|
||||
model = ScribbleXDog_Detector()
|
||||
return (common_annotator_call(model, image, resolution=resolution, thr_a=threshold), )
|
||||
|
||||
class Scribble_PiDiNet_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
safe=(["enable", "disable"],),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Line Extractors"
|
||||
|
||||
def execute(self, image, safe="enable", resolution=512):
|
||||
def model(img, **kwargs):
|
||||
from custom_controlnet_aux.pidi import PidiNetDetector
|
||||
pidinet = PidiNetDetector.from_pretrained().to(model_management.get_torch_device())
|
||||
result = pidinet(img, scribble=True, **kwargs)
|
||||
result = nms(result, 127, 3.0)
|
||||
result = cv2.GaussianBlur(result, (0, 0), 3.0)
|
||||
result[result > 4] = 255
|
||||
result[result < 255] = 0
|
||||
return result
|
||||
return (common_annotator_call(model, image, resolution=resolution, safe=safe=="enable"),)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ScribblePreprocessor": Scribble_Preprocessor,
|
||||
"Scribble_XDoG_Preprocessor": Scribble_XDoG_Preprocessor,
|
||||
"Scribble_PiDiNet_Preprocessor": Scribble_PiDiNet_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ScribblePreprocessor": "Scribble Lines",
|
||||
"Scribble_XDoG_Preprocessor": "Scribble XDoG Lines",
|
||||
"Scribble_PiDiNet_Preprocessor": "Scribble PiDiNet Lines"
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class SAM_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/others"
|
||||
|
||||
def execute(self, image, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.sam import SamDetector
|
||||
|
||||
mobile_sam = SamDetector.from_pretrained().to(model_management.get_torch_device())
|
||||
out = common_annotator_call(mobile_sam, image, resolution=resolution)
|
||||
del mobile_sam
|
||||
return (out, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SAMPreprocessor": SAM_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SAMPreprocessor": "SAM Segmentor"
|
||||
}
|
||||
27
custom_nodes/comfyui_controlnet_aux/node_wrappers/shuffle.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, MAX_RESOLUTION
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class Shuffle_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
resolution=INPUT.RESOLUTION(),
|
||||
seed=INPUT.SEED()
|
||||
)
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "preprocess"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/T2IAdapter-only"
|
||||
|
||||
def preprocess(self, image, resolution=512, seed=0):
|
||||
from custom_controlnet_aux.shuffle import ContentShuffleDetector
|
||||
|
||||
return (common_annotator_call(ContentShuffleDetector(), image, resolution=resolution, seed=seed), )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ShufflePreprocessor": Shuffle_Preprocessor
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ShufflePreprocessor": "Content Shuffle"
|
||||
}
|
||||
30
custom_nodes/comfyui_controlnet_aux/node_wrappers/teed.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class TEED_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
safe_steps=INPUT.INT(default=2, max=10),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Line Extractors"
|
||||
|
||||
def execute(self, image, safe_steps=2, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.teed import TEDDetector
|
||||
|
||||
model = TEDDetector.from_pretrained().to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution, safe_steps=safe_steps)
|
||||
del model
|
||||
return (out, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TEEDPreprocessor": TEED_Preprocessor,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TEED_Preprocessor": "TEED Soft-Edge Lines",
|
||||
}
|
||||
73
custom_nodes/comfyui_controlnet_aux/node_wrappers/tile.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
|
||||
|
||||
class Tile_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
pyrUp_iters=INPUT.INT(default=3, min=1, max=10),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/tile"
|
||||
|
||||
def execute(self, image, pyrUp_iters, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.tile import TileDetector
|
||||
|
||||
return (common_annotator_call(TileDetector(), image, pyrUp_iters=pyrUp_iters, resolution=resolution),)
|
||||
|
||||
class TTPlanet_TileGF_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
scale_factor=INPUT.FLOAT(default=1.00, min=1.000, max=8.00),
|
||||
blur_strength=INPUT.FLOAT(default=2.0, min=1.0, max=10.0),
|
||||
radius=INPUT.INT(default=7, min=1, max=20),
|
||||
eps=INPUT.FLOAT(default=0.01, min=0.001, max=0.1, step=0.001),
|
||||
resolution=INPUT.RESOLUTION()
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/tile"
|
||||
|
||||
def execute(self, image, scale_factor, blur_strength, radius, eps, **kwargs):
|
||||
from custom_controlnet_aux.tile import TTPlanet_Tile_Detector_GF
|
||||
|
||||
return (common_annotator_call(TTPlanet_Tile_Detector_GF(), image, scale_factor=scale_factor, blur_strength=blur_strength, radius=radius, eps=eps),)
|
||||
|
||||
class TTPlanet_TileSimple_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(
|
||||
scale_factor=INPUT.FLOAT(default=1.00, min=1.000, max=8.00),
|
||||
blur_strength=INPUT.FLOAT(default=2.0, min=1.0, max=10.0),
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/tile"
|
||||
|
||||
def execute(self, image, scale_factor, blur_strength):
|
||||
from custom_controlnet_aux.tile import TTPLanet_Tile_Detector_Simple
|
||||
|
||||
return (common_annotator_call(TTPLanet_Tile_Detector_Simple(), image, scale_factor=scale_factor, blur_strength=blur_strength),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TilePreprocessor": Tile_Preprocessor,
|
||||
"TTPlanet_TileGF_Preprocessor": TTPlanet_TileGF_Preprocessor,
|
||||
"TTPlanet_TileSimple_Preprocessor": TTPlanet_TileSimple_Preprocessor
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TilePreprocessor": "Tile",
|
||||
"TTPlanet_TileGF_Preprocessor": "TTPlanet Tile GuidedFilter",
|
||||
"TTPlanet_TileSimple_Preprocessor": "TTPlanet Tile Simple"
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
import os
|
||||
# Disable NPU device initialization and problematic MMCV ops to prevent RuntimeError
|
||||
os.environ['NPU_DEVICE_COUNT'] = '0'
|
||||
os.environ['MMCV_WITH_OPS'] = '0'
|
||||
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class Uniformer_SemSegPreprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "semantic_segmentate"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Semantic Segmentation"
|
||||
|
||||
def semantic_segmentate(self, image, resolution=512):
|
||||
from custom_controlnet_aux.uniformer import UniformerSegmentor
|
||||
|
||||
model = UniformerSegmentor.from_pretrained().to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution)
|
||||
del model
|
||||
return (out, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"UniFormer-SemSegPreprocessor": Uniformer_SemSegPreprocessor,
|
||||
"SemSegPreprocessor": Uniformer_SemSegPreprocessor,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"UniFormer-SemSegPreprocessor": "UniFormer Segmentor",
|
||||
"SemSegPreprocessor": "Semantic Segmentor (legacy, alias for UniFormer)",
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
from ..utils import common_annotator_call
|
||||
import comfy.model_management as model_management
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Unimatch_OptFlowPreprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": dict(
|
||||
image=("IMAGE",),
|
||||
ckpt_name=(
|
||||
["gmflow-scale1-mixdata.pth", "gmflow-scale2-mixdata.pth", "gmflow-scale2-regrefine6-mixdata.pth"],
|
||||
{"default": "gmflow-scale2-regrefine6-mixdata.pth"}
|
||||
),
|
||||
backward_flow=("BOOLEAN", {"default": False}),
|
||||
bidirectional_flow=("BOOLEAN", {"default": False})
|
||||
)
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("OPTICAL_FLOW", "IMAGE")
|
||||
RETURN_NAMES = ("OPTICAL_FLOW", "PREVIEW_IMAGE")
|
||||
FUNCTION = "estimate"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Optical Flow"
|
||||
|
||||
def estimate(self, image, ckpt_name, backward_flow=False, bidirectional_flow=False):
|
||||
assert len(image) > 1, "[Unimatch] Requiring as least two frames as an optical flow estimator. Only use this node on video input."
|
||||
from custom_controlnet_aux.unimatch import UnimatchDetector
|
||||
tensor_images = image
|
||||
model = UnimatchDetector.from_pretrained(filename=ckpt_name).to(model_management.get_torch_device())
|
||||
flows, vis_flows = [], []
|
||||
for i in range(len(tensor_images) - 1):
|
||||
image0, image1 = np.asarray(image[i:i+2].cpu() * 255., dtype=np.uint8)
|
||||
flow, vis_flow = model(image0, image1, output_type="np", pred_bwd_flow=backward_flow, pred_bidir_flow=bidirectional_flow)
|
||||
flows.append(torch.from_numpy(flow).float())
|
||||
vis_flows.append(torch.from_numpy(vis_flow).float() / 255.)
|
||||
del model
|
||||
return (torch.stack(flows, dim=0), torch.stack(vis_flows, dim=0))
|
||||
|
||||
class MaskOptFlow:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": dict(optical_flow=("OPTICAL_FLOW",), mask=("MASK",))
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("OPTICAL_FLOW", "IMAGE")
|
||||
RETURN_NAMES = ("OPTICAL_FLOW", "PREVIEW_IMAGE")
|
||||
FUNCTION = "mask_opt_flow"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Optical Flow"
|
||||
|
||||
def mask_opt_flow(self, optical_flow, mask):
|
||||
from custom_controlnet_aux.unimatch import flow_to_image
|
||||
assert len(mask) >= len(optical_flow), f"Not enough masks to mask optical flow: {len(mask)} vs {len(optical_flow)}"
|
||||
mask = mask[:optical_flow.shape[0]]
|
||||
mask = F.interpolate(mask, optical_flow.shape[1:3])
|
||||
mask = rearrange(mask, "n 1 h w -> n h w 1")
|
||||
vis_flows = torch.stack([torch.from_numpy(flow_to_image(flow)).float() / 255. for flow in optical_flow.numpy()], dim=0)
|
||||
vis_flows *= mask
|
||||
optical_flow *= mask
|
||||
return (optical_flow, vis_flows)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Unimatch_OptFlowPreprocessor": Unimatch_OptFlowPreprocessor,
|
||||
"MaskOptFlow": MaskOptFlow
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Unimatch_OptFlowPreprocessor": "Unimatch Optical Flow",
|
||||
"MaskOptFlow": "Mask Optical Flow (DragNUWA)"
|
||||
}
|
||||
27
custom_nodes/comfyui_controlnet_aux/node_wrappers/zoe.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
|
||||
import comfy.model_management as model_management
|
||||
|
||||
class Zoe_Depth_Map_Preprocessor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
|
||||
|
||||
def execute(self, image, resolution=512, **kwargs):
|
||||
from custom_controlnet_aux.zoe import ZoeDetector
|
||||
|
||||
model = ZoeDetector.from_pretrained().to(model_management.get_torch_device())
|
||||
out = common_annotator_call(model, image, resolution=resolution)
|
||||
del model
|
||||
return (out, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Zoe-DepthMapPreprocessor": Zoe_Depth_Map_Preprocessor
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Zoe-DepthMapPreprocessor": "Zoe Depth Map"
|
||||
}
|
||||
14
custom_nodes/comfyui_controlnet_aux/pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[project]
|
||||
name = "comfyui_controlnet_aux"
|
||||
description = "Plug-and-play ComfyUI node sets for making ControlNet hint images"
|
||||
|
||||
version = "1.1.3"
|
||||
dependencies = ["torch", "importlib_metadata", "huggingface_hub", "scipy", "opencv-python>=4.7.0.72", "filelock", "numpy", "Pillow", "einops", "torchvision", "pyyaml", "scikit-image", "python-dateutil", "mediapipe", "fvcore", "yapf", "omegaconf", "ftfy", "addict", "yacs", "trimesh[easy]", "albumentations", "scikit-learn", "matplotlib"]
|
||||
|
||||
[project.urls]
|
||||
Repository = "https://github.com/Fannovel16/comfyui_controlnet_aux"
|
||||
|
||||
[tool.comfy]
|
||||
PublisherId = "fannovel16"
|
||||
DisplayName = "comfyui_controlnet_aux"
|
||||
Icon = ""
|
||||
25
custom_nodes/comfyui_controlnet_aux/requirements.txt
Normal file
@@ -0,0 +1,25 @@
|
||||
torch
|
||||
importlib_metadata
|
||||
huggingface_hub
|
||||
scipy
|
||||
opencv-python>=4.7.0.72
|
||||
filelock
|
||||
numpy
|
||||
Pillow
|
||||
einops
|
||||
torchvision
|
||||
pyyaml
|
||||
scikit-image
|
||||
python-dateutil
|
||||
mediapipe
|
||||
fvcore
|
||||
yapf
|
||||
omegaconf
|
||||
ftfy
|
||||
addict
|
||||
yacs
|
||||
yapf
|
||||
trimesh[easy]
|
||||
albumentations
|
||||
scikit-learn
|
||||
matplotlib
|
||||
56
custom_nodes/comfyui_controlnet_aux/search_hf_assets.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import re
|
||||
#Thanks ChatGPT
|
||||
pattern = r'\bfrom_pretrained\(.*?pretrained_model_or_path\s*=\s*(.*?)(?:,|\))|filename\s*=\s*(.*?)(?:,|\))|(\w+_filename)\s*=\s*(.*?)(?:,|\))'
|
||||
aux_dir = Path(__file__).parent / 'src' / 'custom_controlnet_aux'
|
||||
VAR_DICT = dict(
|
||||
HF_MODEL_NAME = "lllyasviel/Annotators",
|
||||
DWPOSE_MODEL_NAME = "yzd-v/DWPose",
|
||||
BDS_MODEL_NAME = "bdsqlsz/qinglong_controlnet-lllite",
|
||||
DENSEPOSE_MODEL_NAME = "LayerNorm/DensePose-TorchScript-with-hint-image",
|
||||
MESH_GRAPHORMER_MODEL_NAME = "hr16/ControlNet-HandRefiner-pruned",
|
||||
SAM_MODEL_NAME = "dhkim2810/MobileSAM",
|
||||
UNIMATCH_MODEL_NAME = "hr16/Unimatch",
|
||||
DEPTH_ANYTHING_MODEL_NAME = "LiheYoung/Depth-Anything", #HF Space
|
||||
DIFFUSION_EDGE_MODEL_NAME = "hr16/Diffusion-Edge"
|
||||
)
|
||||
re_result_dict = {}
|
||||
for preprocc in os.listdir(aux_dir):
|
||||
if preprocc in ["__pycache__", 'tests']: continue
|
||||
if '.py' in preprocc: continue
|
||||
f = open(aux_dir / preprocc / '__init__.py', 'r')
|
||||
code = f.read()
|
||||
matches = re.findall(pattern, code)
|
||||
result = [match[0] or match[1] or match[3] for match in matches]
|
||||
if not len(result):
|
||||
print(preprocc)
|
||||
continue
|
||||
result = [el.replace("'", '').replace('"', '') for el in result]
|
||||
result = [VAR_DICT.get(el, el) for el in result]
|
||||
re_result_dict[preprocc] = result
|
||||
f.close()
|
||||
|
||||
for preprocc, re_result in re_result_dict.items():
|
||||
model_name, filenames = re_result[0], re_result[1:]
|
||||
print(f"* {preprocc}: ", end=' ')
|
||||
assests_md = ', '.join([f"[{model_name}/{filename}](https://huggingface.co/{model_name}/blob/main/{filename})" for filename in filenames])
|
||||
print(assests_md)
|
||||
|
||||
preprocc = "dwpose"
|
||||
model_name, filenames = VAR_DICT['DWPOSE_MODEL_NAME'], ["yolox_l.onnx", "dw-ll_ucoco_384.onnx"]
|
||||
print(f"* {preprocc}: ", end=' ')
|
||||
assests_md = ', '.join([f"[{model_name}/{filename}](https://huggingface.co/{model_name}/blob/main/{filename})" for filename in filenames])
|
||||
print(assests_md)
|
||||
|
||||
preprocc = "yolo-nas"
|
||||
model_name, filenames = "hr16/yolo-nas-fp16", ["yolo_nas_l_fp16.onnx", "yolo_nas_m_fp16.onnx", "yolo_nas_s_fp16.onnx"]
|
||||
print(f"* {preprocc}: ", end=' ')
|
||||
assests_md = ', '.join([f"[{model_name}/{filename}](https://huggingface.co/{model_name}/blob/main/{filename})" for filename in filenames])
|
||||
print(assests_md)
|
||||
|
||||
preprocc = "dwpose-torchscript"
|
||||
model_name, filenames = "hr16/DWPose-TorchScript-BatchSize5", ["dw-ll_ucoco_384_bs5.torchscript.pt", "rtmpose-m_ap10k_256_bs5.torchscript.pt"]
|
||||
print(f"* {preprocc}: ", end=' ')
|
||||
assests_md = ', '.join([f"[{model_name}/{filename}](https://huggingface.co/{model_name}/blob/main/{filename})" for filename in filenames])
|
||||
print(assests_md)
|
||||
1
custom_nodes/comfyui_controlnet_aux/src/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
#Dummy file ensuring this package will be recognized
|
||||
@@ -0,0 +1 @@
|
||||
#Dummy file ensuring this package will be recognized
|
||||
@@ -0,0 +1,66 @@
|
||||
from .network import UNet
|
||||
from .util import seg2img
|
||||
import torch
|
||||
import os
|
||||
import cv2
|
||||
from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download, BDS_MODEL_NAME
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from einops import rearrange
|
||||
from .anime_segmentation import AnimeSegmentation
|
||||
import numpy as np
|
||||
|
||||
class AnimeFaceSegmentor:
|
||||
def __init__(self, model, seg_model):
|
||||
self.model = model
|
||||
self.seg_model = seg_model
|
||||
self.device = "cpu"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_or_path=BDS_MODEL_NAME, filename="UNet.pth", seg_filename="isnetis.ckpt"):
|
||||
model_path = custom_hf_download(pretrained_model_or_path, filename, subfolder="Annotators")
|
||||
seg_model_path = custom_hf_download("skytnt/anime-seg", seg_filename)
|
||||
|
||||
model = UNet()
|
||||
ckpt = torch.load(model_path, map_location="cpu")
|
||||
model.load_state_dict(ckpt)
|
||||
model.eval()
|
||||
|
||||
seg_model = AnimeSegmentation(seg_model_path)
|
||||
seg_model.net.eval()
|
||||
return cls(model, seg_model)
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
self.seg_model.net.to(device)
|
||||
self.device = device
|
||||
return self
|
||||
|
||||
def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", remove_background=True, **kwargs):
|
||||
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
||||
input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
|
||||
|
||||
with torch.no_grad():
|
||||
if remove_background:
|
||||
print(input_image.shape)
|
||||
mask, input_image = self.seg_model(input_image, 0) #Don't resize image as it is resized
|
||||
image_feed = torch.from_numpy(input_image).float().to(self.device)
|
||||
image_feed = rearrange(image_feed, 'h w c -> 1 c h w')
|
||||
image_feed = image_feed / 255
|
||||
seg = self.model(image_feed).squeeze(dim=0)
|
||||
result = seg2img(seg.cpu().detach().numpy())
|
||||
|
||||
detected_map = HWC3(result)
|
||||
detected_map = remove_pad(detected_map)
|
||||
if remove_background:
|
||||
mask = remove_pad(mask)
|
||||
H, W, C = detected_map.shape
|
||||
tmp = np.zeros([H, W, C + 1])
|
||||
tmp[:,:,:C] = detected_map
|
||||
tmp[:,:,3:] = mask
|
||||
detected_map = tmp
|
||||
|
||||
if output_type == "pil":
|
||||
detected_map = Image.fromarray(detected_map[..., :3])
|
||||
|
||||
return detected_map
|
||||
@@ -0,0 +1,58 @@
|
||||
#https://github.com/SkyTNT/anime-segmentation/tree/main
|
||||
#Only adapt isnet_is (https://huggingface.co/skytnt/anime-seg/blob/main/isnetis.ckpt)
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from .isnet import ISNetDIS
|
||||
import numpy as np
|
||||
import cv2
|
||||
from comfy.model_management import get_torch_device
|
||||
DEVICE = get_torch_device()
|
||||
|
||||
class AnimeSegmentation:
|
||||
def __init__(self, ckpt_path):
|
||||
super(AnimeSegmentation).__init__()
|
||||
sd = torch.load(ckpt_path, map_location="cpu")
|
||||
self.net = ISNetDIS()
|
||||
#gt_encoder isn't used during inference
|
||||
self.net.load_state_dict({k.replace("net.", ''):v for k, v in sd.items() if k.startswith("net.")})
|
||||
self.net = self.net.to(DEVICE)
|
||||
self.net.eval()
|
||||
|
||||
def get_mask(self, input_img, s=640):
|
||||
input_img = (input_img / 255).astype(np.float32)
|
||||
if s == 0:
|
||||
img_input = np.transpose(input_img, (2, 0, 1))
|
||||
img_input = img_input[np.newaxis, :]
|
||||
tmpImg = torch.from_numpy(img_input).float().to(DEVICE)
|
||||
with torch.no_grad():
|
||||
pred = self.net(tmpImg)[0][0].sigmoid() #https://github.com/SkyTNT/anime-segmentation/blob/main/train.py#L92C20-L92C47
|
||||
pred = pred.cpu().numpy()[0]
|
||||
pred = np.transpose(pred, (1, 2, 0))
|
||||
#pred = pred[:, :, np.newaxis]
|
||||
return pred
|
||||
|
||||
h, w = h0, w0 = input_img.shape[:-1]
|
||||
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
|
||||
ph, pw = s - h, s - w
|
||||
img_input = np.zeros([s, s, 3], dtype=np.float32)
|
||||
img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h))
|
||||
img_input = np.transpose(img_input, (2, 0, 1))
|
||||
img_input = img_input[np.newaxis, :]
|
||||
tmpImg = torch.from_numpy(img_input).float().to(DEVICE)
|
||||
with torch.no_grad():
|
||||
pred = self.net(tmpImg)[0][0].sigmoid() #https://github.com/SkyTNT/anime-segmentation/blob/main/train.py#L92C20-L92C47
|
||||
pred = pred.cpu().numpy()[0]
|
||||
pred = np.transpose(pred, (1, 2, 0))
|
||||
pred = pred[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
|
||||
#pred = cv2.resize(pred, (w0, h0))[:, :, np.newaxis]
|
||||
pred = cv2.resize(pred, (w0, h0))
|
||||
return pred
|
||||
|
||||
def __call__(self, np_img, img_size):
|
||||
mask = self.get_mask(np_img, int(img_size))
|
||||
np_img = (mask * np_img + 255 * (1 - mask)).astype(np.uint8)
|
||||
mask = (mask * 255).astype(np.uint8)
|
||||
#np_img = np.concatenate([np_img, mask], axis=2, dtype=np.uint8)
|
||||
#mask = mask.repeat(3, axis=2)
|
||||
return mask, np_img
|
||||
|
||||
@@ -0,0 +1,619 @@
|
||||
# Codes are borrowed from
|
||||
# https://github.com/xuebinqin/DIS/blob/main/IS-Net/models/isnet.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision import models
|
||||
|
||||
bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
|
||||
|
||||
|
||||
def muti_loss_fusion(preds, target):
|
||||
loss0 = 0.0
|
||||
loss = 0.0
|
||||
|
||||
for i in range(0, len(preds)):
|
||||
if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
|
||||
tmp_target = F.interpolate(
|
||||
target, size=preds[i].size()[2:], mode="bilinear", align_corners=True
|
||||
)
|
||||
loss = loss + bce_loss(preds[i], tmp_target)
|
||||
else:
|
||||
loss = loss + bce_loss(preds[i], target)
|
||||
if i == 0:
|
||||
loss0 = loss
|
||||
return loss0, loss
|
||||
|
||||
|
||||
fea_loss = nn.MSELoss(reduction="mean")
|
||||
kl_loss = nn.KLDivLoss(reduction="mean")
|
||||
l1_loss = nn.L1Loss(reduction="mean")
|
||||
smooth_l1_loss = nn.SmoothL1Loss(reduction="mean")
|
||||
|
||||
|
||||
def muti_loss_fusion_kl(preds, target, dfs, fs, mode="MSE"):
|
||||
loss0 = 0.0
|
||||
loss = 0.0
|
||||
|
||||
for i in range(0, len(preds)):
|
||||
if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
|
||||
tmp_target = F.interpolate(
|
||||
target, size=preds[i].size()[2:], mode="bilinear", align_corners=True
|
||||
)
|
||||
loss = loss + bce_loss(preds[i], tmp_target)
|
||||
else:
|
||||
loss = loss + bce_loss(preds[i], target)
|
||||
if i == 0:
|
||||
loss0 = loss
|
||||
|
||||
for i in range(0, len(dfs)):
|
||||
df = dfs[i]
|
||||
fs_i = fs[i]
|
||||
if mode == "MSE":
|
||||
loss = loss + fea_loss(
|
||||
df, fs_i
|
||||
) ### add the mse loss of features as additional constraints
|
||||
elif mode == "KL":
|
||||
loss = loss + kl_loss(F.log_softmax(df, dim=1), F.softmax(fs_i, dim=1))
|
||||
elif mode == "MAE":
|
||||
loss = loss + l1_loss(df, fs_i)
|
||||
elif mode == "SmoothL1":
|
||||
loss = loss + smooth_l1_loss(df, fs_i)
|
||||
|
||||
return loss0, loss
|
||||
|
||||
|
||||
class REBNCONV(nn.Module):
|
||||
def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
|
||||
super(REBNCONV, self).__init__()
|
||||
|
||||
self.conv_s1 = nn.Conv2d(
|
||||
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
|
||||
)
|
||||
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
||||
self.relu_s1 = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
hx = x
|
||||
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
||||
|
||||
return xout
|
||||
|
||||
|
||||
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
||||
def _upsample_like(src, tar):
|
||||
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
|
||||
|
||||
return src
|
||||
|
||||
|
||||
### RSU-7 ###
|
||||
class RSU7(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
|
||||
super(RSU7, self).__init__()
|
||||
|
||||
self.in_ch = in_ch
|
||||
self.mid_ch = mid_ch
|
||||
self.out_ch = out_ch
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
|
||||
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
|
||||
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
hx = x
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
hx = self.pool4(hx4)
|
||||
|
||||
hx5 = self.rebnconv5(hx)
|
||||
hx = self.pool5(hx5)
|
||||
|
||||
hx6 = self.rebnconv6(hx)
|
||||
|
||||
hx7 = self.rebnconv7(hx6)
|
||||
|
||||
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
||||
hx6dup = _upsample_like(hx6d, hx5)
|
||||
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
||||
hx5dup = _upsample_like(hx5d, hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
### RSU-6 ###
|
||||
class RSU6(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU6, self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
|
||||
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
|
||||
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
hx = self.pool4(hx4)
|
||||
|
||||
hx5 = self.rebnconv5(hx)
|
||||
|
||||
hx6 = self.rebnconv6(hx5)
|
||||
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
||||
hx5dup = _upsample_like(hx5d, hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
### RSU-5 ###
|
||||
class RSU5(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU5, self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
|
||||
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
|
||||
hx5 = self.rebnconv5(hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
### RSU-4 ###
|
||||
class RSU4(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU4, self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
|
||||
hx4 = self.rebnconv4(hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
### RSU-4F ###
|
||||
class RSU4F(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU4F, self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
||||
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx2 = self.rebnconv2(hx1)
|
||||
hx3 = self.rebnconv3(hx2)
|
||||
|
||||
hx4 = self.rebnconv4(hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
class myrebnconv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_ch=3,
|
||||
out_ch=1,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
):
|
||||
super(myrebnconv, self).__init__()
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
in_ch,
|
||||
out_ch,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_ch)
|
||||
self.rl = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.rl(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class ISNetGTEncoder(nn.Module):
|
||||
def __init__(self, in_ch=1, out_ch=1):
|
||||
super(ISNetGTEncoder, self).__init__()
|
||||
|
||||
self.conv_in = myrebnconv(
|
||||
in_ch, 16, 3, stride=2, padding=1
|
||||
) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
|
||||
|
||||
self.stage1 = RSU7(16, 16, 64)
|
||||
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage2 = RSU6(64, 16, 64)
|
||||
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage3 = RSU5(64, 32, 128)
|
||||
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage4 = RSU4(128, 32, 256)
|
||||
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage5 = RSU4F(256, 64, 512)
|
||||
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage6 = RSU4F(512, 64, 512)
|
||||
|
||||
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
||||
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
||||
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
||||
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
||||
|
||||
@staticmethod
|
||||
def compute_loss(args):
|
||||
preds, targets = args
|
||||
return muti_loss_fusion(preds, targets)
|
||||
|
||||
def forward(self, x):
|
||||
hx = x
|
||||
|
||||
hxin = self.conv_in(hx)
|
||||
# hx = self.pool_in(hxin)
|
||||
|
||||
# stage 1
|
||||
hx1 = self.stage1(hxin)
|
||||
hx = self.pool12(hx1)
|
||||
|
||||
# stage 2
|
||||
hx2 = self.stage2(hx)
|
||||
hx = self.pool23(hx2)
|
||||
|
||||
# stage 3
|
||||
hx3 = self.stage3(hx)
|
||||
hx = self.pool34(hx3)
|
||||
|
||||
# stage 4
|
||||
hx4 = self.stage4(hx)
|
||||
hx = self.pool45(hx4)
|
||||
|
||||
# stage 5
|
||||
hx5 = self.stage5(hx)
|
||||
hx = self.pool56(hx5)
|
||||
|
||||
# stage 6
|
||||
hx6 = self.stage6(hx)
|
||||
|
||||
# side output
|
||||
d1 = self.side1(hx1)
|
||||
d1 = _upsample_like(d1, x)
|
||||
|
||||
d2 = self.side2(hx2)
|
||||
d2 = _upsample_like(d2, x)
|
||||
|
||||
d3 = self.side3(hx3)
|
||||
d3 = _upsample_like(d3, x)
|
||||
|
||||
d4 = self.side4(hx4)
|
||||
d4 = _upsample_like(d4, x)
|
||||
|
||||
d5 = self.side5(hx5)
|
||||
d5 = _upsample_like(d5, x)
|
||||
|
||||
d6 = self.side6(hx6)
|
||||
d6 = _upsample_like(d6, x)
|
||||
|
||||
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
||||
|
||||
# return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1, hx2, hx3, hx4, hx5, hx6]
|
||||
return [d1, d2, d3, d4, d5, d6], [hx1, hx2, hx3, hx4, hx5, hx6]
|
||||
|
||||
|
||||
class ISNetDIS(nn.Module):
|
||||
def __init__(self, in_ch=3, out_ch=1):
|
||||
super(ISNetDIS, self).__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
|
||||
self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage1 = RSU7(64, 32, 64)
|
||||
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage2 = RSU6(64, 32, 128)
|
||||
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage3 = RSU5(128, 64, 256)
|
||||
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage4 = RSU4(256, 128, 512)
|
||||
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage5 = RSU4F(512, 256, 512)
|
||||
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage6 = RSU4F(512, 256, 512)
|
||||
|
||||
# decoder
|
||||
self.stage5d = RSU4F(1024, 256, 512)
|
||||
self.stage4d = RSU4(1024, 128, 256)
|
||||
self.stage3d = RSU5(512, 64, 128)
|
||||
self.stage2d = RSU6(256, 32, 64)
|
||||
self.stage1d = RSU7(128, 16, 64)
|
||||
|
||||
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
||||
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
||||
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
||||
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
||||
|
||||
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
||||
|
||||
@staticmethod
|
||||
def compute_loss_kl(preds, targets, dfs, fs, mode="MSE"):
|
||||
return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
|
||||
|
||||
@staticmethod
|
||||
def compute_loss(args):
|
||||
if len(args) == 3:
|
||||
ds, dfs, labels = args
|
||||
return muti_loss_fusion(ds, labels)
|
||||
else:
|
||||
ds, dfs, labels, fs = args
|
||||
return muti_loss_fusion_kl(ds, labels, dfs, fs, mode="MSE")
|
||||
|
||||
def forward(self, x):
|
||||
hx = x
|
||||
|
||||
hxin = self.conv_in(hx)
|
||||
hx = self.pool_in(hxin)
|
||||
|
||||
# stage 1
|
||||
hx1 = self.stage1(hxin)
|
||||
hx = self.pool12(hx1)
|
||||
|
||||
# stage 2
|
||||
hx2 = self.stage2(hx)
|
||||
hx = self.pool23(hx2)
|
||||
|
||||
# stage 3
|
||||
hx3 = self.stage3(hx)
|
||||
hx = self.pool34(hx3)
|
||||
|
||||
# stage 4
|
||||
hx4 = self.stage4(hx)
|
||||
hx = self.pool45(hx4)
|
||||
|
||||
# stage 5
|
||||
hx5 = self.stage5(hx)
|
||||
hx = self.pool56(hx5)
|
||||
|
||||
# stage 6
|
||||
hx6 = self.stage6(hx)
|
||||
hx6up = _upsample_like(hx6, hx5)
|
||||
|
||||
# -------------------- decoder --------------------
|
||||
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
||||
hx5dup = _upsample_like(hx5d, hx4)
|
||||
|
||||
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
|
||||
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
# side output
|
||||
d1 = self.side1(hx1d)
|
||||
d1 = _upsample_like(d1, x)
|
||||
|
||||
d2 = self.side2(hx2d)
|
||||
d2 = _upsample_like(d2, x)
|
||||
|
||||
d3 = self.side3(hx3d)
|
||||
d3 = _upsample_like(d3, x)
|
||||
|
||||
d4 = self.side4(hx4d)
|
||||
d4 = _upsample_like(d4, x)
|
||||
|
||||
d5 = self.side5(hx5d)
|
||||
d5 = _upsample_like(d5, x)
|
||||
|
||||
d6 = self.side6(hx6)
|
||||
d6 = _upsample_like(d6, x)
|
||||
|
||||
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
||||
|
||||
# return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
|
||||
return [d1, d2, d3, d4, d5, d6], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
|
||||
@@ -0,0 +1,100 @@
|
||||
#https://github.com/siyeong0/Anime-Face-Segmentation/blob/main/network.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from custom_controlnet_aux.util import custom_torch_download
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(UNet, self).__init__()
|
||||
self.NUM_SEG_CLASSES = 7 # Background, hair, face, eye, mouth, skin, clothes
|
||||
|
||||
mobilenet_v2 = torchvision.models.mobilenet_v2(pretrained=False)
|
||||
mobilenet_v2.load_state_dict(torch.load(custom_torch_download(filename="mobilenet_v2-b0353104.pth")), strict=True)
|
||||
mob_blocks = mobilenet_v2.features
|
||||
|
||||
# Encoder
|
||||
self.en_block0 = nn.Sequential( # in_ch=3 out_ch=16
|
||||
mob_blocks[0],
|
||||
mob_blocks[1]
|
||||
)
|
||||
self.en_block1 = nn.Sequential( # in_ch=16 out_ch=24
|
||||
mob_blocks[2],
|
||||
mob_blocks[3],
|
||||
)
|
||||
self.en_block2 = nn.Sequential( # in_ch=24 out_ch=32
|
||||
mob_blocks[4],
|
||||
mob_blocks[5],
|
||||
mob_blocks[6],
|
||||
)
|
||||
self.en_block3 = nn.Sequential( # in_ch=32 out_ch=96
|
||||
mob_blocks[7],
|
||||
mob_blocks[8],
|
||||
mob_blocks[9],
|
||||
mob_blocks[10],
|
||||
mob_blocks[11],
|
||||
mob_blocks[12],
|
||||
mob_blocks[13],
|
||||
)
|
||||
self.en_block4 = nn.Sequential( # in_ch=96 out_ch=160
|
||||
mob_blocks[14],
|
||||
mob_blocks[15],
|
||||
mob_blocks[16],
|
||||
)
|
||||
|
||||
# Decoder
|
||||
self.de_block4 = nn.Sequential( # in_ch=160 out_ch=96
|
||||
nn.UpsamplingNearest2d(scale_factor=2),
|
||||
nn.Conv2d(160, 96, kernel_size=3, padding=1),
|
||||
nn.InstanceNorm2d(96),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Dropout(p=0.2)
|
||||
)
|
||||
self.de_block3 = nn.Sequential( # in_ch=96x2 out_ch=32
|
||||
nn.UpsamplingNearest2d(scale_factor=2),
|
||||
nn.Conv2d(96*2, 32, kernel_size=3, padding=1),
|
||||
nn.InstanceNorm2d(32),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Dropout(p=0.2)
|
||||
)
|
||||
self.de_block2 = nn.Sequential( # in_ch=32x2 out_ch=24
|
||||
nn.UpsamplingNearest2d(scale_factor=2),
|
||||
nn.Conv2d(32*2, 24, kernel_size=3, padding=1),
|
||||
nn.InstanceNorm2d(24),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Dropout(p=0.2)
|
||||
)
|
||||
self.de_block1 = nn.Sequential( # in_ch=24x2 out_ch=16
|
||||
nn.UpsamplingNearest2d(scale_factor=2),
|
||||
nn.Conv2d(24*2, 16, kernel_size=3, padding=1),
|
||||
nn.InstanceNorm2d(16),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Dropout(p=0.2)
|
||||
)
|
||||
|
||||
self.de_block0 = nn.Sequential( # in_ch=16x2 out_ch=7
|
||||
nn.UpsamplingNearest2d(scale_factor=2),
|
||||
nn.Conv2d(16*2, self.NUM_SEG_CLASSES, kernel_size=3, padding=1),
|
||||
nn.Softmax2d()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
e0 = self.en_block0(x)
|
||||
e1 = self.en_block1(e0)
|
||||
e2 = self.en_block2(e1)
|
||||
e3 = self.en_block3(e2)
|
||||
e4 = self.en_block4(e3)
|
||||
|
||||
d4 = self.de_block4(e4)
|
||||
c4 = torch.cat((d4,e3),1)
|
||||
d3 = self.de_block3(c4)
|
||||
c3 = torch.cat((d3,e2),1)
|
||||
d2 = self.de_block2(c3)
|
||||
c2 =torch.cat((d2,e1),1)
|
||||
d1 = self.de_block1(c2)
|
||||
c1 = torch.cat((d1,e0),1)
|
||||
y = self.de_block0(c1)
|
||||
|
||||
return y
|
||||
@@ -0,0 +1,40 @@
|
||||
#https://github.com/siyeong0/Anime-Face-Segmentation/blob/main/util.py
|
||||
#The color palette is changed according to https://github.com/Mikubill/sd-webui-controlnet/blob/91f67ddcc7bc47537a6285864abfc12590f46c3f/annotator/anime_face_segment/__init__.py
|
||||
import cv2 as cv
|
||||
import glob
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
"""
|
||||
COLOR_BACKGROUND = (0,255,255)
|
||||
COLOR_HAIR = (255,0,0)
|
||||
COLOR_EYE = (0,0,255)
|
||||
COLOR_MOUTH = (255,255,255)
|
||||
COLOR_FACE = (0,255,0)
|
||||
COLOR_SKIN = (255,255,0)
|
||||
COLOR_CLOTHES = (255,0,255)
|
||||
"""
|
||||
COLOR_BACKGROUND = (255,255,0)
|
||||
COLOR_HAIR = (0,0,255)
|
||||
COLOR_EYE = (255,0,0)
|
||||
COLOR_MOUTH = (255,255,255)
|
||||
COLOR_FACE = (0,255,0)
|
||||
COLOR_SKIN = (0,255,255)
|
||||
COLOR_CLOTHES = (255,0,255)
|
||||
PALETTE = [COLOR_BACKGROUND,COLOR_HAIR,COLOR_EYE,COLOR_MOUTH,COLOR_FACE,COLOR_SKIN,COLOR_CLOTHES]
|
||||
|
||||
def img2seg(path):
|
||||
src = cv.imread(path)
|
||||
src = src.reshape(-1, 3)
|
||||
seg_list = []
|
||||
for color in PALETTE:
|
||||
seg_list.append(np.where(np.all(src==color, axis=1), 1.0, 0.0))
|
||||
dst = np.stack(seg_list,axis=1).reshape(512,512,7)
|
||||
|
||||
return dst.astype(np.float32)
|
||||
|
||||
def seg2img(src):
|
||||
src = np.moveaxis(src,0,2)
|
||||
dst = [[PALETTE[np.argmax(val)] for val in buf]for buf in src]
|
||||
|
||||
return np.array(dst).astype(np.uint8)
|
||||
@@ -0,0 +1,38 @@
|
||||
import warnings
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from custom_controlnet_aux.util import HWC3, resize_image_with_pad
|
||||
|
||||
class BinaryDetector:
|
||||
def __call__(self, input_image=None, bin_threshold=0, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
|
||||
if "img" in kwargs:
|
||||
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
|
||||
input_image = kwargs.pop("img")
|
||||
|
||||
if input_image is None:
|
||||
raise ValueError("input_image must be defined.")
|
||||
|
||||
if not isinstance(input_image, np.ndarray):
|
||||
input_image = np.array(input_image, dtype=np.uint8)
|
||||
output_type = output_type or "pil"
|
||||
else:
|
||||
output_type = output_type or "np"
|
||||
|
||||
detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
|
||||
|
||||
img_gray = cv2.cvtColor(detected_map, cv2.COLOR_RGB2GRAY)
|
||||
if bin_threshold == 0 or bin_threshold == 255:
|
||||
# Otsu's threshold
|
||||
otsu_threshold, img_bin = cv2.threshold(img_gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
||||
print("Otsu threshold:", otsu_threshold)
|
||||
else:
|
||||
_, img_bin = cv2.threshold(img_gray, bin_threshold, 255, cv2.THRESH_BINARY_INV)
|
||||
|
||||
detected_map = cv2.cvtColor(img_bin, cv2.COLOR_GRAY2RGB)
|
||||
detected_map = HWC3(remove_pad(255 - detected_map))
|
||||
|
||||
if output_type == "pil":
|
||||
detected_map = Image.fromarray(detected_map)
|
||||
|
||||
return detected_map
|
||||
@@ -0,0 +1,17 @@
|
||||
import warnings
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from custom_controlnet_aux.util import resize_image_with_pad, common_input_validate, HWC3
|
||||
|
||||
class CannyDetector:
|
||||
def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
|
||||
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
||||
detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
|
||||
detected_map = cv2.Canny(detected_map, low_threshold, high_threshold)
|
||||
detected_map = HWC3(remove_pad(detected_map))
|
||||
|
||||
if output_type == "pil":
|
||||
detected_map = Image.fromarray(detected_map)
|
||||
|
||||
return detected_map
|
||||
@@ -0,0 +1,37 @@
|
||||
import cv2
|
||||
import warnings
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from custom_controlnet_aux.util import HWC3, safer_memory, common_input_validate
|
||||
|
||||
def cv2_resize_shortest_edge(image, size):
|
||||
h, w = image.shape[:2]
|
||||
if h < w:
|
||||
new_h = size
|
||||
new_w = int(round(w / h * size))
|
||||
else:
|
||||
new_w = size
|
||||
new_h = int(round(h / w * size))
|
||||
resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||
return resized_image
|
||||
|
||||
def apply_color(img, res=512):
|
||||
img = cv2_resize_shortest_edge(img, res)
|
||||
h, w = img.shape[:2]
|
||||
|
||||
input_img_color = cv2.resize(img, (w//64, h//64), interpolation=cv2.INTER_CUBIC)
|
||||
input_img_color = cv2.resize(input_img_color, (w, h), interpolation=cv2.INTER_NEAREST)
|
||||
return input_img_color
|
||||
|
||||
#Color T2I like multiples-of-64, upscale methods are fixed.
|
||||
class ColorDetector:
|
||||
def __call__(self, input_image=None, detect_resolution=512, output_type=None, **kwargs):
|
||||
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
||||
input_image = HWC3(input_image)
|
||||
detected_map = HWC3(apply_color(input_image, detect_resolution))
|
||||
|
||||
if output_type == "pil":
|
||||
detected_map = Image.fromarray(detected_map)
|
||||
|
||||
return detected_map
|
||||
@@ -0,0 +1,66 @@
|
||||
import torchvision # Fix issue Unknown builtin op: torchvision::nms
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download, DENSEPOSE_MODEL_NAME
|
||||
from .densepose import DensePoseMaskedColormapResultsVisualizer, _extract_i_from_iuvarr, densepose_chart_predictor_output_to_result_with_confidences
|
||||
|
||||
N_PART_LABELS = 24
|
||||
|
||||
class DenseposeDetector:
|
||||
def __init__(self, model):
|
||||
self.dense_pose_estimation = model
|
||||
self.device = "cpu"
|
||||
self.result_visualizer = DensePoseMaskedColormapResultsVisualizer(
|
||||
alpha=1,
|
||||
data_extractor=_extract_i_from_iuvarr,
|
||||
segm_extractor=_extract_i_from_iuvarr,
|
||||
val_scale = 255.0 / N_PART_LABELS
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_or_path=DENSEPOSE_MODEL_NAME, filename="densepose_r50_fpn_dl.torchscript"):
|
||||
torchscript_model_path = custom_hf_download(pretrained_model_or_path, filename)
|
||||
densepose = torch.jit.load(torchscript_model_path, map_location="cpu")
|
||||
return cls(densepose)
|
||||
|
||||
def to(self, device):
|
||||
self.dense_pose_estimation.to(device)
|
||||
self.device = device
|
||||
return self
|
||||
|
||||
def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", cmap="viridis", **kwargs):
|
||||
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
||||
input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
|
||||
H, W = input_image.shape[:2]
|
||||
|
||||
hint_image_canvas = np.zeros([H, W], dtype=np.uint8)
|
||||
hint_image_canvas = np.tile(hint_image_canvas[:, :, np.newaxis], [1, 1, 3])
|
||||
|
||||
input_image = rearrange(torch.from_numpy(input_image).to(self.device), 'h w c -> c h w')
|
||||
|
||||
pred_boxes, corase_segm, fine_segm, u, v = self.dense_pose_estimation(input_image)
|
||||
|
||||
extractor = densepose_chart_predictor_output_to_result_with_confidences
|
||||
densepose_results = [extractor(pred_boxes[i:i+1], corase_segm[i:i+1], fine_segm[i:i+1], u[i:i+1], v[i:i+1]) for i in range(len(pred_boxes))]
|
||||
|
||||
if cmap=="viridis":
|
||||
self.result_visualizer.mask_visualizer.cmap = cv2.COLORMAP_VIRIDIS
|
||||
hint_image = self.result_visualizer.visualize(hint_image_canvas, densepose_results)
|
||||
hint_image = cv2.cvtColor(hint_image, cv2.COLOR_BGR2RGB)
|
||||
hint_image[:, :, 0][hint_image[:, :, 0] == 0] = 68
|
||||
hint_image[:, :, 1][hint_image[:, :, 1] == 0] = 1
|
||||
hint_image[:, :, 2][hint_image[:, :, 2] == 0] = 84
|
||||
else:
|
||||
self.result_visualizer.mask_visualizer.cmap = cv2.COLORMAP_PARULA
|
||||
hint_image = self.result_visualizer.visualize(hint_image_canvas, densepose_results)
|
||||
hint_image = cv2.cvtColor(hint_image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
detected_map = remove_pad(HWC3(hint_image))
|
||||
if output_type == "pil":
|
||||
detected_map = Image.fromarray(detected_map)
|
||||
return detected_map
|
||||
@@ -0,0 +1,347 @@
|
||||
from typing import Tuple
|
||||
import math
|
||||
import numpy as np
|
||||
from enum import IntEnum
|
||||
from typing import List, Tuple, Union
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import logging
|
||||
import cv2
|
||||
|
||||
Image = np.ndarray
|
||||
Boxes = torch.Tensor
|
||||
ImageSizeType = Tuple[int, int]
|
||||
_RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray]
|
||||
IntTupleBox = Tuple[int, int, int, int]
|
||||
|
||||
class BoxMode(IntEnum):
|
||||
"""
|
||||
Enum of different ways to represent a box.
|
||||
"""
|
||||
|
||||
XYXY_ABS = 0
|
||||
"""
|
||||
(x0, y0, x1, y1) in absolute floating points coordinates.
|
||||
The coordinates in range [0, width or height].
|
||||
"""
|
||||
XYWH_ABS = 1
|
||||
"""
|
||||
(x0, y0, w, h) in absolute floating points coordinates.
|
||||
"""
|
||||
XYXY_REL = 2
|
||||
"""
|
||||
Not yet supported!
|
||||
(x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image.
|
||||
"""
|
||||
XYWH_REL = 3
|
||||
"""
|
||||
Not yet supported!
|
||||
(x0, y0, w, h) in range [0, 1]. They are relative to the size of the image.
|
||||
"""
|
||||
XYWHA_ABS = 4
|
||||
"""
|
||||
(xc, yc, w, h, a) in absolute floating points coordinates.
|
||||
(xc, yc) is the center of the rotated box, and the angle a is in degrees ccw.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def convert(box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode") -> _RawBoxType:
|
||||
"""
|
||||
Args:
|
||||
box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5
|
||||
from_mode, to_mode (BoxMode)
|
||||
|
||||
Returns:
|
||||
The converted box of the same type.
|
||||
"""
|
||||
if from_mode == to_mode:
|
||||
return box
|
||||
|
||||
original_type = type(box)
|
||||
is_numpy = isinstance(box, np.ndarray)
|
||||
single_box = isinstance(box, (list, tuple))
|
||||
if single_box:
|
||||
assert len(box) == 4 or len(box) == 5, (
|
||||
"BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor,"
|
||||
" where k == 4 or 5"
|
||||
)
|
||||
arr = torch.tensor(box)[None, :]
|
||||
else:
|
||||
# avoid modifying the input box
|
||||
if is_numpy:
|
||||
arr = torch.from_numpy(np.asarray(box)).clone()
|
||||
else:
|
||||
arr = box.clone()
|
||||
|
||||
assert to_mode not in [BoxMode.XYXY_REL, BoxMode.XYWH_REL] and from_mode not in [
|
||||
BoxMode.XYXY_REL,
|
||||
BoxMode.XYWH_REL,
|
||||
], "Relative mode not yet supported!"
|
||||
|
||||
if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
|
||||
assert (
|
||||
arr.shape[-1] == 5
|
||||
), "The last dimension of input shape must be 5 for XYWHA format"
|
||||
original_dtype = arr.dtype
|
||||
arr = arr.double()
|
||||
|
||||
w = arr[:, 2]
|
||||
h = arr[:, 3]
|
||||
a = arr[:, 4]
|
||||
c = torch.abs(torch.cos(a * math.pi / 180.0))
|
||||
s = torch.abs(torch.sin(a * math.pi / 180.0))
|
||||
# This basically computes the horizontal bounding rectangle of the rotated box
|
||||
new_w = c * w + s * h
|
||||
new_h = c * h + s * w
|
||||
|
||||
# convert center to top-left corner
|
||||
arr[:, 0] -= new_w / 2.0
|
||||
arr[:, 1] -= new_h / 2.0
|
||||
# bottom-right corner
|
||||
arr[:, 2] = arr[:, 0] + new_w
|
||||
arr[:, 3] = arr[:, 1] + new_h
|
||||
|
||||
arr = arr[:, :4].to(dtype=original_dtype)
|
||||
elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS:
|
||||
original_dtype = arr.dtype
|
||||
arr = arr.double()
|
||||
arr[:, 0] += arr[:, 2] / 2.0
|
||||
arr[:, 1] += arr[:, 3] / 2.0
|
||||
angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype)
|
||||
arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype)
|
||||
else:
|
||||
if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS:
|
||||
arr[:, 2] += arr[:, 0]
|
||||
arr[:, 3] += arr[:, 1]
|
||||
elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS:
|
||||
arr[:, 2] -= arr[:, 0]
|
||||
arr[:, 3] -= arr[:, 1]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Conversion from BoxMode {} to {} is not supported yet".format(
|
||||
from_mode, to_mode
|
||||
)
|
||||
)
|
||||
|
||||
if single_box:
|
||||
return original_type(arr.flatten().tolist())
|
||||
if is_numpy:
|
||||
return arr.numpy()
|
||||
else:
|
||||
return arr
|
||||
|
||||
class MatrixVisualizer:
|
||||
"""
|
||||
Base visualizer for matrix data
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplace=True,
|
||||
cmap=cv2.COLORMAP_PARULA,
|
||||
val_scale=1.0,
|
||||
alpha=0.7,
|
||||
interp_method_matrix=cv2.INTER_LINEAR,
|
||||
interp_method_mask=cv2.INTER_NEAREST,
|
||||
):
|
||||
self.inplace = inplace
|
||||
self.cmap = cmap
|
||||
self.val_scale = val_scale
|
||||
self.alpha = alpha
|
||||
self.interp_method_matrix = interp_method_matrix
|
||||
self.interp_method_mask = interp_method_mask
|
||||
|
||||
def visualize(self, image_bgr, mask, matrix, bbox_xywh):
|
||||
self._check_image(image_bgr)
|
||||
self._check_mask_matrix(mask, matrix)
|
||||
if self.inplace:
|
||||
image_target_bgr = image_bgr
|
||||
else:
|
||||
image_target_bgr = image_bgr * 0
|
||||
x, y, w, h = [int(v) for v in bbox_xywh]
|
||||
if w <= 0 or h <= 0:
|
||||
return image_bgr
|
||||
mask, matrix = self._resize(mask, matrix, w, h)
|
||||
mask_bg = np.tile((mask == 0)[:, :, np.newaxis], [1, 1, 3])
|
||||
matrix_scaled = matrix.astype(np.float32) * self.val_scale
|
||||
_EPSILON = 1e-6
|
||||
if np.any(matrix_scaled > 255 + _EPSILON):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
f"Matrix has values > {255 + _EPSILON} after " f"scaling, clipping to [0..255]"
|
||||
)
|
||||
matrix_scaled_8u = matrix_scaled.clip(0, 255).astype(np.uint8)
|
||||
matrix_vis = cv2.applyColorMap(matrix_scaled_8u, self.cmap)
|
||||
matrix_vis[mask_bg] = image_target_bgr[y : y + h, x : x + w, :][mask_bg]
|
||||
image_target_bgr[y : y + h, x : x + w, :] = (
|
||||
image_target_bgr[y : y + h, x : x + w, :] * (1.0 - self.alpha) + matrix_vis * self.alpha
|
||||
)
|
||||
return image_target_bgr.astype(np.uint8)
|
||||
|
||||
def _resize(self, mask, matrix, w, h):
|
||||
if (w != mask.shape[1]) or (h != mask.shape[0]):
|
||||
mask = cv2.resize(mask, (w, h), self.interp_method_mask)
|
||||
if (w != matrix.shape[1]) or (h != matrix.shape[0]):
|
||||
matrix = cv2.resize(matrix, (w, h), self.interp_method_matrix)
|
||||
return mask, matrix
|
||||
|
||||
def _check_image(self, image_rgb):
|
||||
assert len(image_rgb.shape) == 3
|
||||
assert image_rgb.shape[2] == 3
|
||||
assert image_rgb.dtype == np.uint8
|
||||
|
||||
def _check_mask_matrix(self, mask, matrix):
|
||||
assert len(matrix.shape) == 2
|
||||
assert len(mask.shape) == 2
|
||||
assert mask.dtype == np.uint8
|
||||
|
||||
class DensePoseResultsVisualizer:
|
||||
def visualize(
|
||||
self,
|
||||
image_bgr: Image,
|
||||
results,
|
||||
) -> Image:
|
||||
context = self.create_visualization_context(image_bgr)
|
||||
for i, result in enumerate(results):
|
||||
boxes_xywh, labels, uv = result
|
||||
iuv_array = torch.cat(
|
||||
(labels[None].type(torch.float32), uv * 255.0)
|
||||
).type(torch.uint8)
|
||||
self.visualize_iuv_arr(context, iuv_array.cpu().numpy(), boxes_xywh)
|
||||
image_bgr = self.context_to_image_bgr(context)
|
||||
return image_bgr
|
||||
|
||||
def create_visualization_context(self, image_bgr: Image):
|
||||
return image_bgr
|
||||
|
||||
def visualize_iuv_arr(self, context, iuv_arr: np.ndarray, bbox_xywh) -> None:
|
||||
pass
|
||||
|
||||
def context_to_image_bgr(self, context):
|
||||
return context
|
||||
|
||||
def get_image_bgr_from_context(self, context):
|
||||
return context
|
||||
|
||||
class DensePoseMaskedColormapResultsVisualizer(DensePoseResultsVisualizer):
|
||||
def __init__(
|
||||
self,
|
||||
data_extractor,
|
||||
segm_extractor,
|
||||
inplace=True,
|
||||
cmap=cv2.COLORMAP_PARULA,
|
||||
alpha=0.7,
|
||||
val_scale=1.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.mask_visualizer = MatrixVisualizer(
|
||||
inplace=inplace, cmap=cmap, val_scale=val_scale, alpha=alpha
|
||||
)
|
||||
self.data_extractor = data_extractor
|
||||
self.segm_extractor = segm_extractor
|
||||
|
||||
def context_to_image_bgr(self, context):
|
||||
return context
|
||||
|
||||
def visualize_iuv_arr(self, context, iuv_arr: np.ndarray, bbox_xywh) -> None:
|
||||
image_bgr = self.get_image_bgr_from_context(context)
|
||||
matrix = self.data_extractor(iuv_arr)
|
||||
segm = self.segm_extractor(iuv_arr)
|
||||
mask = np.zeros(matrix.shape, dtype=np.uint8)
|
||||
mask[segm > 0] = 1
|
||||
image_bgr = self.mask_visualizer.visualize(image_bgr, mask, matrix, bbox_xywh)
|
||||
|
||||
|
||||
def _extract_i_from_iuvarr(iuv_arr):
|
||||
return iuv_arr[0, :, :]
|
||||
|
||||
|
||||
def _extract_u_from_iuvarr(iuv_arr):
|
||||
return iuv_arr[1, :, :]
|
||||
|
||||
|
||||
def _extract_v_from_iuvarr(iuv_arr):
|
||||
return iuv_arr[2, :, :]
|
||||
|
||||
def make_int_box(box: torch.Tensor) -> IntTupleBox:
|
||||
int_box = [0, 0, 0, 0]
|
||||
int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist())
|
||||
return int_box[0], int_box[1], int_box[2], int_box[3]
|
||||
|
||||
def densepose_chart_predictor_output_to_result_with_confidences(
|
||||
boxes: Boxes,
|
||||
coarse_segm,
|
||||
fine_segm,
|
||||
u, v
|
||||
|
||||
):
|
||||
boxes_xyxy_abs = boxes.clone()
|
||||
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
||||
box_xywh = make_int_box(boxes_xywh_abs[0])
|
||||
|
||||
labels = resample_fine_and_coarse_segm_tensors_to_bbox(fine_segm, coarse_segm, box_xywh).squeeze(0)
|
||||
uv = resample_uv_tensors_to_bbox(u, v, labels, box_xywh)
|
||||
confidences = []
|
||||
return box_xywh, labels, uv
|
||||
|
||||
def resample_fine_and_coarse_segm_tensors_to_bbox(
|
||||
fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox
|
||||
):
|
||||
"""
|
||||
Resample fine and coarse segmentation tensors to the given
|
||||
bounding box and derive labels for each pixel of the bounding box
|
||||
|
||||
Args:
|
||||
fine_segm: float tensor of shape [1, C, Hout, Wout]
|
||||
coarse_segm: float tensor of shape [1, K, Hout, Wout]
|
||||
box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
|
||||
corner coordinates, width (W) and height (H)
|
||||
Return:
|
||||
Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
|
||||
"""
|
||||
x, y, w, h = box_xywh_abs
|
||||
w = max(int(w), 1)
|
||||
h = max(int(h), 1)
|
||||
# coarse segmentation
|
||||
coarse_segm_bbox = F.interpolate(
|
||||
coarse_segm,
|
||||
(h, w),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).argmax(dim=1)
|
||||
# combined coarse and fine segmentation
|
||||
labels = (
|
||||
F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
|
||||
* (coarse_segm_bbox > 0).long()
|
||||
)
|
||||
return labels
|
||||
|
||||
def resample_uv_tensors_to_bbox(
|
||||
u: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
box_xywh_abs: IntTupleBox,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Resamples U and V coordinate estimates for the given bounding box
|
||||
|
||||
Args:
|
||||
u (tensor [1, C, H, W] of float): U coordinates
|
||||
v (tensor [1, C, H, W] of float): V coordinates
|
||||
labels (tensor [H, W] of long): labels obtained by resampling segmentation
|
||||
outputs for the given bounding box
|
||||
box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
|
||||
Return:
|
||||
Resampled U and V coordinates - a tensor [2, H, W] of float
|
||||
"""
|
||||
x, y, w, h = box_xywh_abs
|
||||
w = max(int(w), 1)
|
||||
h = max(int(h), 1)
|
||||
u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False)
|
||||
v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False)
|
||||
uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device)
|
||||
for part_id in range(1, u_bbox.size(1)):
|
||||
uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id]
|
||||
uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id]
|
||||
return uv
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .transformers import DepthAnythingDetector
|
||||
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Modern DepthAnything implementation using HuggingFace transformers.
|
||||
Replaces legacy torch.hub.load DINOv2 backbone with transformers pipeline.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import pipeline
|
||||
|
||||
from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad
|
||||
|
||||
class DepthAnythingDetector:
|
||||
"""DepthAnything depth estimation using HuggingFace transformers."""
|
||||
|
||||
def __init__(self, model_name="LiheYoung/depth-anything-large-hf"):
|
||||
"""Initialize DepthAnything with specified model."""
|
||||
self.pipe = pipeline(task="depth-estimation", model=model_name)
|
||||
self.device = "cpu"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_or_path=None, filename="depth_anything_vitl14.pth"):
|
||||
"""Create DepthAnything from pretrained model, mapping legacy names to HuggingFace models."""
|
||||
|
||||
# Map legacy checkpoint names to modern HuggingFace models
|
||||
model_mapping = {
|
||||
"depth_anything_vitl14.pth": "LiheYoung/depth-anything-large-hf",
|
||||
"depth_anything_vitb14.pth": "LiheYoung/depth-anything-base-hf",
|
||||
"depth_anything_vits14.pth": "LiheYoung/depth-anything-small-hf"
|
||||
}
|
||||
|
||||
model_name = model_mapping.get(filename, "LiheYoung/depth-anything-large-hf")
|
||||
return cls(model_name=model_name)
|
||||
|
||||
def to(self, device):
|
||||
"""Move model to specified device."""
|
||||
self.pipe.model = self.pipe.model.to(device)
|
||||
self.device = device
|
||||
return self
|
||||
|
||||
def __call__(self, input_image, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
|
||||
"""Perform depth estimation on input image."""
|
||||
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
||||
input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
|
||||
|
||||
if isinstance(input_image, np.ndarray):
|
||||
pil_image = Image.fromarray(input_image)
|
||||
else:
|
||||
pil_image = input_image
|
||||
|
||||
with torch.no_grad():
|
||||
result = self.pipe(pil_image)
|
||||
depth = result["depth"]
|
||||
|
||||
if isinstance(depth, Image.Image):
|
||||
depth_array = np.array(depth, dtype=np.float32)
|
||||
else:
|
||||
depth_array = np.array(depth)
|
||||
|
||||
# Normalize depth values to 0-255 range
|
||||
depth_min = depth_array.min()
|
||||
depth_max = depth_array.max()
|
||||
if depth_max > depth_min:
|
||||
depth_array = (depth_array - depth_min) / (depth_max - depth_min) * 255.0
|
||||
else:
|
||||
depth_array = np.zeros_like(depth_array)
|
||||
|
||||
depth_image = depth_array.astype(np.uint8)
|
||||
|
||||
detected_map = remove_pad(HWC3(depth_image))
|
||||
|
||||
if output_type == "pil":
|
||||
detected_map = Image.fromarray(detected_map)
|
||||
|
||||
return detected_map
|
||||
@@ -0,0 +1,56 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import repeat
|
||||
from PIL import Image
|
||||
from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, DEPTH_ANYTHING_V2_MODEL_NAME_DICT
|
||||
from custom_controlnet_aux.depth_anything_v2.dpt import DepthAnythingV2
|
||||
import cv2
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# https://github.com/DepthAnything/Depth-Anything-V2/blob/main/app.py
|
||||
model_configs = {
|
||||
'depth_anything_v2_vits.pth': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
||||
'depth_anything_v2_vitb.pth': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
||||
'depth_anything_v2_vitl.pth': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
||||
'depth_anything_v2_vitg.pth': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]},
|
||||
'depth_anything_v2_metric_vkitti_vitl.pth': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
||||
'depth_anything_v2_metric_hypersim_vitl.pth': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
||||
}
|
||||
|
||||
class DepthAnythingV2Detector:
|
||||
def __init__(self, model, filename):
|
||||
self.model = model
|
||||
self.device = "cpu"
|
||||
self.filename = filename
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_or_path=None, filename="depth_anything_v2_vits.pth"):
|
||||
if pretrained_model_or_path is None:
|
||||
pretrained_model_or_path = DEPTH_ANYTHING_V2_MODEL_NAME_DICT[filename]
|
||||
model_path = custom_hf_download(pretrained_model_or_path, filename)
|
||||
model = DepthAnythingV2(**model_configs[filename])
|
||||
model.load_state_dict(torch.load(model_path, map_location="cpu"))
|
||||
model = model.eval()
|
||||
return cls(model, filename)
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
self.device = device
|
||||
return self
|
||||
|
||||
def __call__(self, input_image, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", max_depth=20.0, **kwargs):
|
||||
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
||||
|
||||
depth = self.model.infer_image(cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR), input_size=518, max_depth=max_depth)
|
||||
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
||||
depth = depth.astype(np.uint8)
|
||||
if 'metric' in self.filename:
|
||||
depth = 255 - depth
|
||||
|
||||
detected_map = repeat(depth, "h w -> h w 3")
|
||||
detected_map, remove_pad = resize_image_with_pad(detected_map, detect_resolution, upscale_method)
|
||||
detected_map = remove_pad(detected_map)
|
||||
if output_type == "pil":
|
||||
detected_map = Image.fromarray(detected_map)
|
||||
|
||||
return detected_map
|
||||
@@ -0,0 +1,415 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
import logging
|
||||
from typing import Sequence, Tuple, Union, Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn.init import trunc_normal_
|
||||
|
||||
from custom_controlnet_aux.depth_anything_v2.dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
||||
if not depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = ".".join((name, child_name)) if name else child_name
|
||||
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||
if depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
return module
|
||||
|
||||
|
||||
class BlockChunk(nn.ModuleList):
|
||||
def forward(self, x):
|
||||
for b in self:
|
||||
x = b(x)
|
||||
return x
|
||||
|
||||
|
||||
class DinoVisionTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
ffn_bias=True,
|
||||
proj_bias=True,
|
||||
drop_path_rate=0.0,
|
||||
drop_path_uniform=False,
|
||||
init_values=None, # for layerscale: None or 0 => no layerscale
|
||||
embed_layer=PatchEmbed,
|
||||
act_layer=nn.GELU,
|
||||
block_fn=Block,
|
||||
ffn_layer="mlp",
|
||||
block_chunks=1,
|
||||
num_register_tokens=0,
|
||||
interpolate_antialias=False,
|
||||
interpolate_offset=0.1,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
patch_size (int, tuple): patch size
|
||||
in_chans (int): number of input channels
|
||||
embed_dim (int): embedding dimension
|
||||
depth (int): depth of transformer
|
||||
num_heads (int): number of attention heads
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
proj_bias (bool): enable bias for proj in attn if True
|
||||
ffn_bias (bool): enable bias for ffn if True
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
drop_path_uniform (bool): apply uniform drop rate across blocks
|
||||
weight_init (str): weight init scheme
|
||||
init_values (float): layer-scale init values
|
||||
embed_layer (nn.Module): patch embedding layer
|
||||
act_layer (nn.Module): MLP activation layer
|
||||
block_fn (nn.Module): transformer block class
|
||||
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
||||
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
||||
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
||||
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
||||
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
||||
"""
|
||||
super().__init__()
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_tokens = 1
|
||||
self.n_blocks = depth
|
||||
self.num_heads = num_heads
|
||||
self.patch_size = patch_size
|
||||
self.num_register_tokens = num_register_tokens
|
||||
self.interpolate_antialias = interpolate_antialias
|
||||
self.interpolate_offset = interpolate_offset
|
||||
|
||||
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
||||
assert num_register_tokens >= 0
|
||||
self.register_tokens = (
|
||||
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
||||
)
|
||||
|
||||
if drop_path_uniform is True:
|
||||
dpr = [drop_path_rate] * depth
|
||||
else:
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
|
||||
if ffn_layer == "mlp":
|
||||
logger.info("using MLP layer as FFN")
|
||||
ffn_layer = Mlp
|
||||
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
||||
logger.info("using SwiGLU layer as FFN")
|
||||
ffn_layer = SwiGLUFFNFused
|
||||
elif ffn_layer == "identity":
|
||||
logger.info("using Identity layer as FFN")
|
||||
|
||||
def f(*args, **kwargs):
|
||||
return nn.Identity()
|
||||
|
||||
ffn_layer = f
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
blocks_list = [
|
||||
block_fn(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
ffn_bias=ffn_bias,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
ffn_layer=ffn_layer,
|
||||
init_values=init_values,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
if block_chunks > 0:
|
||||
self.chunked_blocks = True
|
||||
chunked_blocks = []
|
||||
chunksize = depth // block_chunks
|
||||
for i in range(0, depth, chunksize):
|
||||
# this is to keep the block index consistent if we chunk the block list
|
||||
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
||||
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
||||
else:
|
||||
self.chunked_blocks = False
|
||||
self.blocks = nn.ModuleList(blocks_list)
|
||||
|
||||
self.norm = norm_layer(embed_dim)
|
||||
self.head = nn.Identity()
|
||||
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
nn.init.normal_(self.cls_token, std=1e-6)
|
||||
if self.register_tokens is not None:
|
||||
nn.init.normal_(self.register_tokens, std=1e-6)
|
||||
named_apply(init_weights_vit_timm, self)
|
||||
|
||||
def interpolate_pos_encoding(self, x, w, h):
|
||||
previous_dtype = x.dtype
|
||||
npatch = x.shape[1] - 1
|
||||
N = self.pos_embed.shape[1] - 1
|
||||
if npatch == N and w == h:
|
||||
return self.pos_embed
|
||||
pos_embed = self.pos_embed.float()
|
||||
class_pos_embed = pos_embed[:, 0]
|
||||
patch_pos_embed = pos_embed[:, 1:]
|
||||
dim = x.shape[-1]
|
||||
w0 = w // self.patch_size
|
||||
h0 = h // self.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
# DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
|
||||
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
||||
# w0, h0 = w0 + 0.1, h0 + 0.1
|
||||
|
||||
sqrt_N = math.sqrt(N)
|
||||
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
|
||||
scale_factor=(sx, sy),
|
||||
# (int(w0), int(h0)), # to solve the upsampling shape issue
|
||||
mode="bicubic",
|
||||
antialias=self.interpolate_antialias
|
||||
)
|
||||
|
||||
assert int(w0) == patch_pos_embed.shape[-2]
|
||||
assert int(h0) == patch_pos_embed.shape[-1]
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
||||
|
||||
def prepare_tokens_with_masks(self, x, masks=None):
|
||||
B, nc, w, h = x.shape
|
||||
x = self.patch_embed(x)
|
||||
if masks is not None:
|
||||
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
||||
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + self.interpolate_pos_encoding(x, w, h)
|
||||
|
||||
if self.register_tokens is not None:
|
||||
x = torch.cat(
|
||||
(
|
||||
x[:, :1],
|
||||
self.register_tokens.expand(x.shape[0], -1, -1),
|
||||
x[:, 1:],
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
def forward_features_list(self, x_list, masks_list):
|
||||
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
all_x = x
|
||||
output = []
|
||||
for x, masks in zip(all_x, masks_list):
|
||||
x_norm = self.norm(x)
|
||||
output.append(
|
||||
{
|
||||
"x_norm_clstoken": x_norm[:, 0],
|
||||
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
||||
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
||||
"x_prenorm": x,
|
||||
"masks": masks,
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
def forward_features(self, x, masks=None):
|
||||
if isinstance(x, list):
|
||||
return self.forward_features_list(x, masks)
|
||||
|
||||
x = self.prepare_tokens_with_masks(x, masks)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x_norm = self.norm(x)
|
||||
return {
|
||||
"x_norm_clstoken": x_norm[:, 0],
|
||||
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
||||
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
||||
"x_prenorm": x,
|
||||
"masks": masks,
|
||||
}
|
||||
|
||||
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
||||
x = self.prepare_tokens_with_masks(x)
|
||||
# If n is an int, take the n last blocks. If it's a list, take them
|
||||
output, total_block_len = [], len(self.blocks)
|
||||
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if i in blocks_to_take:
|
||||
output.append(x)
|
||||
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
||||
return output
|
||||
|
||||
def _get_intermediate_layers_chunked(self, x, n=1):
|
||||
x = self.prepare_tokens_with_masks(x)
|
||||
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
||||
# If n is an int, take the n last blocks. If it's a list, take them
|
||||
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
||||
for block_chunk in self.blocks:
|
||||
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
||||
x = blk(x)
|
||||
if i in blocks_to_take:
|
||||
output.append(x)
|
||||
i += 1
|
||||
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
||||
return output
|
||||
|
||||
def get_intermediate_layers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
||||
reshape: bool = False,
|
||||
return_class_token: bool = False,
|
||||
norm=True
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
||||
if self.chunked_blocks:
|
||||
outputs = self._get_intermediate_layers_chunked(x, n)
|
||||
else:
|
||||
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
||||
if norm:
|
||||
outputs = [self.norm(out) for out in outputs]
|
||||
class_tokens = [out[:, 0] for out in outputs]
|
||||
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
|
||||
if reshape:
|
||||
B, _, w, h = x.shape
|
||||
outputs = [
|
||||
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
||||
for out in outputs
|
||||
]
|
||||
if return_class_token:
|
||||
return tuple(zip(outputs, class_tokens))
|
||||
return tuple(outputs)
|
||||
|
||||
def forward(self, *args, is_training=False, **kwargs):
|
||||
ret = self.forward_features(*args, **kwargs)
|
||||
if is_training:
|
||||
return ret
|
||||
else:
|
||||
return self.head(ret["x_norm_clstoken"])
|
||||
|
||||
|
||||
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
||||
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
||||
if isinstance(module, nn.Linear):
|
||||
trunc_normal_(module.weight, std=0.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
|
||||
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=384,
|
||||
depth=12,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
num_register_tokens=num_register_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
num_register_tokens=num_register_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
num_register_tokens=num_register_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
||||
"""
|
||||
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
||||
"""
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=1536,
|
||||
depth=40,
|
||||
num_heads=24,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
num_register_tokens=num_register_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def DINOv2(model_name):
|
||||
model_zoo = {
|
||||
"vits": vit_small,
|
||||
"vitb": vit_base,
|
||||
"vitl": vit_large,
|
||||
"vitg": vit_giant2
|
||||
}
|
||||
|
||||
return model_zoo[model_name](
|
||||
img_size=518,
|
||||
patch_size=14,
|
||||
init_values=1.0,
|
||||
ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
|
||||
block_chunks=0,
|
||||
num_register_tokens=0,
|
||||
interpolate_antialias=False,
|
||||
interpolate_offset=0.1
|
||||
)
|
||||
@@ -0,0 +1,11 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .mlp import Mlp
|
||||
from .patch_embed import PatchEmbed
|
||||
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
||||
from .block import NestedTensorBlock
|
||||
from .attention import MemEffAttention
|
||||
@@ -0,0 +1,83 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
||||
|
||||
import logging
|
||||
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention, unbind, fmha
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning("xFormers not available")
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
proj_bias: bool = True,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
|
||||
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class MemEffAttention(Attention):
|
||||
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
||||
if not XFORMERS_AVAILABLE:
|
||||
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
||||
return super().forward(x)
|
||||
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||
|
||||
q, k, v = unbind(qkv, 2)
|
||||
|
||||
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
||||
x = x.reshape([B, N, C])
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
@@ -0,0 +1,252 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
||||
|
||||
import logging
|
||||
from typing import Callable, List, Any, Tuple, Dict
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
from .attention import Attention, MemEffAttention
|
||||
from .drop_path import DropPath
|
||||
from .layer_scale import LayerScale
|
||||
from .mlp import Mlp
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
try:
|
||||
from xformers.ops import fmha
|
||||
from xformers.ops import scaled_index_add, index_select_cat
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning("xFormers not available")
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = False,
|
||||
proj_bias: bool = True,
|
||||
ffn_bias: bool = True,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
init_values=None,
|
||||
drop_path: float = 0.0,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||||
attn_class: Callable[..., nn.Module] = Attention,
|
||||
ffn_layer: Callable[..., nn.Module] = Mlp,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = attn_class(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = ffn_layer(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
bias=ffn_bias,
|
||||
)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.sample_drop_ratio = drop_path
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def attn_residual_func(x: Tensor) -> Tensor:
|
||||
return self.ls1(self.attn(self.norm1(x)))
|
||||
|
||||
def ffn_residual_func(x: Tensor) -> Tensor:
|
||||
return self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.1:
|
||||
# the overhead is compensated only for a drop path rate larger than 0.1
|
||||
x = drop_add_residual_stochastic_depth(
|
||||
x,
|
||||
residual_func=attn_residual_func,
|
||||
sample_drop_ratio=self.sample_drop_ratio,
|
||||
)
|
||||
x = drop_add_residual_stochastic_depth(
|
||||
x,
|
||||
residual_func=ffn_residual_func,
|
||||
sample_drop_ratio=self.sample_drop_ratio,
|
||||
)
|
||||
elif self.training and self.sample_drop_ratio > 0.0:
|
||||
x = x + self.drop_path1(attn_residual_func(x))
|
||||
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
||||
else:
|
||||
x = x + attn_residual_func(x)
|
||||
x = x + ffn_residual_func(x)
|
||||
return x
|
||||
|
||||
|
||||
def drop_add_residual_stochastic_depth(
|
||||
x: Tensor,
|
||||
residual_func: Callable[[Tensor], Tensor],
|
||||
sample_drop_ratio: float = 0.0,
|
||||
) -> Tensor:
|
||||
# 1) extract subset using permutation
|
||||
b, n, d = x.shape
|
||||
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
||||
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
||||
x_subset = x[brange]
|
||||
|
||||
# 2) apply residual_func to get residual
|
||||
residual = residual_func(x_subset)
|
||||
|
||||
x_flat = x.flatten(1)
|
||||
residual = residual.flatten(1)
|
||||
|
||||
residual_scale_factor = b / sample_subset_size
|
||||
|
||||
# 3) add the residual
|
||||
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
||||
return x_plus_residual.view_as(x)
|
||||
|
||||
|
||||
def get_branges_scales(x, sample_drop_ratio=0.0):
|
||||
b, n, d = x.shape
|
||||
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
||||
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
||||
residual_scale_factor = b / sample_subset_size
|
||||
return brange, residual_scale_factor
|
||||
|
||||
|
||||
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
||||
if scaling_vector is None:
|
||||
x_flat = x.flatten(1)
|
||||
residual = residual.flatten(1)
|
||||
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
||||
else:
|
||||
x_plus_residual = scaled_index_add(
|
||||
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
||||
)
|
||||
return x_plus_residual
|
||||
|
||||
|
||||
attn_bias_cache: Dict[Tuple, Any] = {}
|
||||
|
||||
|
||||
def get_attn_bias_and_cat(x_list, branges=None):
|
||||
"""
|
||||
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
||||
"""
|
||||
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
||||
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
||||
if all_shapes not in attn_bias_cache.keys():
|
||||
seqlens = []
|
||||
for b, x in zip(batch_sizes, x_list):
|
||||
for _ in range(b):
|
||||
seqlens.append(x.shape[1])
|
||||
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
||||
attn_bias._batch_sizes = batch_sizes
|
||||
attn_bias_cache[all_shapes] = attn_bias
|
||||
|
||||
if branges is not None:
|
||||
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
||||
else:
|
||||
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
||||
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
||||
|
||||
return attn_bias_cache[all_shapes], cat_tensors
|
||||
|
||||
|
||||
def drop_add_residual_stochastic_depth_list(
|
||||
x_list: List[Tensor],
|
||||
residual_func: Callable[[Tensor, Any], Tensor],
|
||||
sample_drop_ratio: float = 0.0,
|
||||
scaling_vector=None,
|
||||
) -> Tensor:
|
||||
# 1) generate random set of indices for dropping samples in the batch
|
||||
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
||||
branges = [s[0] for s in branges_scales]
|
||||
residual_scale_factors = [s[1] for s in branges_scales]
|
||||
|
||||
# 2) get attention bias and index+concat the tensors
|
||||
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
||||
|
||||
# 3) apply residual_func to get residual, and split the result
|
||||
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
||||
|
||||
outputs = []
|
||||
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
||||
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
||||
return outputs
|
||||
|
||||
|
||||
class NestedTensorBlock(Block):
|
||||
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
||||
"""
|
||||
x_list contains a list of tensors to nest together and run
|
||||
"""
|
||||
assert isinstance(self.attn, MemEffAttention)
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.0:
|
||||
|
||||
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
||||
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
||||
|
||||
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
||||
return self.mlp(self.norm2(x))
|
||||
|
||||
x_list = drop_add_residual_stochastic_depth_list(
|
||||
x_list,
|
||||
residual_func=attn_residual_func,
|
||||
sample_drop_ratio=self.sample_drop_ratio,
|
||||
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
||||
)
|
||||
x_list = drop_add_residual_stochastic_depth_list(
|
||||
x_list,
|
||||
residual_func=ffn_residual_func,
|
||||
sample_drop_ratio=self.sample_drop_ratio,
|
||||
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
||||
)
|
||||
return x_list
|
||||
else:
|
||||
|
||||
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
||||
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
||||
|
||||
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
||||
return self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
attn_bias, x = get_attn_bias_and_cat(x_list)
|
||||
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
||||
x = x + ffn_residual_func(x)
|
||||
return attn_bias.split(x)
|
||||
|
||||
def forward(self, x_or_x_list):
|
||||
if isinstance(x_or_x_list, Tensor):
|
||||
return super().forward(x_or_x_list)
|
||||
elif isinstance(x_or_x_list, list):
|
||||
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
|
||||
return self.forward_nested(x_or_x_list)
|
||||
else:
|
||||
raise AssertionError
|
||||
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
||||
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0:
|
||||
random_tensor.div_(keep_prob)
|
||||
output = x * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
init_values: Union[float, Tensor] = 1e-5,
|
||||
inplace: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
||||
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
||||
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
@@ -0,0 +1,89 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
||||
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def make_2tuple(x):
|
||||
if isinstance(x, tuple):
|
||||
assert len(x) == 2
|
||||
return x
|
||||
|
||||
assert isinstance(x, int)
|
||||
return (x, x)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
||||
|
||||
Args:
|
||||
img_size: Image size.
|
||||
patch_size: Patch token size.
|
||||
in_chans: Number of input image channels.
|
||||
embed_dim: Number of linear projection output channels.
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
flatten_embedding: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
image_HW = make_2tuple(img_size)
|
||||
patch_HW = make_2tuple(patch_size)
|
||||
patch_grid_size = (
|
||||
image_HW[0] // patch_HW[0],
|
||||
image_HW[1] // patch_HW[1],
|
||||
)
|
||||
|
||||
self.img_size = image_HW
|
||||
self.patch_size = patch_HW
|
||||
self.patches_resolution = patch_grid_size
|
||||
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.flatten_embedding = flatten_embedding
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
_, _, H, W = x.shape
|
||||
patch_H, patch_W = self.patch_size
|
||||
|
||||
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
||||
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
||||
|
||||
x = self.proj(x) # B C H W
|
||||
H, W = x.size(2), x.size(3)
|
||||
x = x.flatten(2).transpose(1, 2) # B HW C
|
||||
x = self.norm(x)
|
||||
if not self.flatten_embedding:
|
||||
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
||||
return x
|
||||
|
||||
def flops(self) -> float:
|
||||
Ho, Wo = self.patches_resolution
|
||||
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
||||
if self.norm is not None:
|
||||
flops += Ho * Wo * self.embed_dim
|
||||
return flops
|
||||
@@ -0,0 +1,63 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from torch import Tensor, nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = None,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
||||
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x12 = self.w12(x)
|
||||
x1, x2 = x12.chunk(2, dim=-1)
|
||||
hidden = F.silu(x1) * x2
|
||||
return self.w3(hidden)
|
||||
|
||||
|
||||
try:
|
||||
from xformers.ops import SwiGLU
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SwiGLU = SwiGLUFFN
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
class SwiGLUFFNFused(SwiGLU):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = None,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||
super().__init__(
|
||||
in_features=in_features,
|
||||
hidden_features=hidden_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
)
|
||||
@@ -0,0 +1,220 @@
|
||||
import cv2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from custom_controlnet_aux.depth_anything_v2.dinov2 import DINOv2
|
||||
from custom_controlnet_aux.depth_anything_v2.util.blocks import FeatureFusionBlock, _make_scratch
|
||||
from custom_controlnet_aux.depth_anything_v2.util.transform import Resize, NormalizeImage, PrepareForNet
|
||||
|
||||
|
||||
def _make_fusion_block(features, use_bn, size=None):
|
||||
return FeatureFusionBlock(
|
||||
features,
|
||||
nn.ReLU(False),
|
||||
deconv=False,
|
||||
bn=use_bn,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_feature, out_feature):
|
||||
super().__init__()
|
||||
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(out_feature),
|
||||
nn.ReLU(True)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv_block(x)
|
||||
|
||||
|
||||
class DPTHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
features=256,
|
||||
use_bn=False,
|
||||
out_channels=[256, 512, 1024, 1024],
|
||||
use_clstoken=False
|
||||
):
|
||||
super(DPTHead, self).__init__()
|
||||
|
||||
self.use_clstoken = use_clstoken
|
||||
|
||||
self.projects = nn.ModuleList([
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
) for out_channel in out_channels
|
||||
])
|
||||
|
||||
self.resize_layers = nn.ModuleList([
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[0],
|
||||
out_channels=out_channels[0],
|
||||
kernel_size=4,
|
||||
stride=4,
|
||||
padding=0),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[1],
|
||||
out_channels=out_channels[1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0),
|
||||
nn.Identity(),
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels[3],
|
||||
out_channels=out_channels[3],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1)
|
||||
])
|
||||
|
||||
if use_clstoken:
|
||||
self.readout_projects = nn.ModuleList()
|
||||
for _ in range(len(self.projects)):
|
||||
self.readout_projects.append(
|
||||
nn.Sequential(
|
||||
nn.Linear(2 * in_channels, in_channels),
|
||||
nn.GELU()))
|
||||
|
||||
self.scratch = _make_scratch(
|
||||
out_channels,
|
||||
features,
|
||||
groups=1,
|
||||
expand=False,
|
||||
)
|
||||
|
||||
self.scratch.stem_transpose = None
|
||||
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
|
||||
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True),
|
||||
nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, out_features, patch_h, patch_w):
|
||||
out = []
|
||||
for i, x in enumerate(out_features):
|
||||
if self.use_clstoken:
|
||||
x, cls_token = x[0], x[1]
|
||||
readout = cls_token.unsqueeze(1).expand_as(x)
|
||||
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
||||
else:
|
||||
x = x[0]
|
||||
|
||||
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
||||
|
||||
x = self.projects[i](x)
|
||||
x = self.resize_layers[i](x)
|
||||
|
||||
out.append(x)
|
||||
|
||||
layer_1, layer_2, layer_3, layer_4 = out
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
out = self.scratch.output_conv1(path_1)
|
||||
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
||||
out = self.scratch.output_conv2(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DepthAnythingV2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder='vitl',
|
||||
features=256,
|
||||
out_channels=[256, 512, 1024, 1024],
|
||||
use_bn=False,
|
||||
use_clstoken=False
|
||||
):
|
||||
super(DepthAnythingV2, self).__init__()
|
||||
|
||||
self.intermediate_layer_idx = {
|
||||
'vits': [2, 5, 8, 11],
|
||||
'vitb': [2, 5, 8, 11],
|
||||
'vitl': [4, 11, 17, 23],
|
||||
'vitg': [9, 19, 29, 39]
|
||||
}
|
||||
|
||||
self.encoder = encoder
|
||||
self.pretrained = DINOv2(model_name=encoder)
|
||||
|
||||
self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
|
||||
|
||||
def forward(self, x, max_depth):
|
||||
patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
|
||||
|
||||
features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
|
||||
|
||||
depth = self.depth_head(features, patch_h, patch_w) * max_depth
|
||||
|
||||
return depth.squeeze(1)
|
||||
|
||||
@torch.no_grad()
|
||||
def infer_image(self, raw_image, input_size=518, max_depth=20.0):
|
||||
image, (h, w) = self.image2tensor(raw_image, input_size)
|
||||
|
||||
depth = self.forward(image, max_depth)
|
||||
|
||||
depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
|
||||
|
||||
return depth.cpu().numpy()
|
||||
|
||||
def image2tensor(self, raw_image, input_size=518):
|
||||
transform = Compose([
|
||||
Resize(
|
||||
width=input_size,
|
||||
height=input_size,
|
||||
resize_target=False,
|
||||
keep_aspect_ratio=True,
|
||||
ensure_multiple_of=14,
|
||||
resize_method='lower_bound',
|
||||
image_interpolation_method=cv2.INTER_CUBIC,
|
||||
),
|
||||
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
PrepareForNet(),
|
||||
])
|
||||
|
||||
h, w = raw_image.shape[:2]
|
||||
|
||||
image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
|
||||
|
||||
image = transform({'image': image})['image']
|
||||
image = torch.from_numpy(image).unsqueeze(0)
|
||||
|
||||
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
||||
image = image.to(DEVICE)
|
||||
|
||||
return image, (h, w)
|
||||