Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled

Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-09 00:55:26 +00:00
parent 2b70ab9ad0
commit f09734b0ee
2274 changed files with 748556 additions and 3 deletions

View File

@@ -0,0 +1,20 @@
__pycache__
.DS_Store
*.cache
*.ini
*.bak
wildcards/**
styles/**
workflow/**
autocomplete/**
web_beta/**
web_version/dev/**
docs/**
.vscode/
.idea/
mmb-preset.custom.txt
config.yaml
node.tar.gz
.cursorrules
tools/ComfyUI-Easy-Use.json

View File

@@ -0,0 +1,674 @@
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
<program> Copyright (C) <year> <name of author>
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.

View File

@@ -0,0 +1,573 @@
![comfyui-easy-use](https://github.com/user-attachments/assets/9b7a5e44-f5e2-4c27-aed2-d0e6b50c46bb)
<div align="center">
<a href="https://space.bilibili.com/1840885116">视频介绍</a> |
<a href="https://docs.easyuse.yolain.com">文档</a> |
<a href="https://github.com/yolain/ComfyUI-Yolain-Workflows">工作流合集</a> |
<a href="#%EF%B8%8F-donation">捐助</a>
<br><br>
<a href="./README.md"><img src="https://img.shields.io/badge/🇬🇧English-e9e9e9"></a>
<a href="./README.ZH_CN.md"><img src="https://img.shields.io/badge/🇨🇳中文简体-0b8cf5"></a>
</div>
**ComfyUI-Easy-Use** 是一个化繁为简的节点整合包, 在 [tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes) 的基础上进行延展并针对了诸多主流的节点包做了整合与优化以达到更快更方便使用ComfyUI的目的在保证自由度的同时还原了本属于Stable Diffusion的极致畅快出图体验。
## 👨🏻‍🎨 特色介绍
- 沿用了 [tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes) 的思路,大大减少了折腾工作流的时间成本。
- UI界面美化首次安装的用户如需使用UI主题请在 Settings -> Color Palette 中自行切换主题并**刷新页面**即可
- 增加了预采样参数配置的节点,可与采样节点分离,更方便预览。
- 支持通配符与Lora的提示词节点如需使用Lora Block Weight用法需先保证自定义节点包中安装了 [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack)
- 可多选的风格化提示词选择器默认是Fooocus的样式json可自定义json放在styles底下samples文件夹里可放预览图(名称和name一致,图片文件名如有空格需转为下划线'_')
- 加载器可开启A1111提示词风格模式可重现与webui生成近乎相同的图像
- 可使用`easy latentNoisy``easy preSamplingNoiseIn`节点实现对潜空间的噪声注入
- 简化 SD1.x、SD2.x、SDXL、SVD、Zero123等流程
- 简化 Stable Cascade [示例参考](https://github.com/yolain/ComfyUI-Yolain-Workflows?tab=readme-ov-file#1-13-stable-cascade)
- 简化 Layer Diffuse [示例参考](https://github.com/yolain/ComfyUI-Yolain-Workflows?tab=readme-ov-file#2-3-layerdiffusion)
- 简化 InstantID [示例参考](https://github.com/yolain/ComfyUI-Yolain-Workflows?tab=readme-ov-file#2-2-instantid), 需先保证自定义节点包中安装了 [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID)
- 简化 IPAdapter, 需先保证自定义节点包中安装最新版v2的 [ComfyUI_IPAdapter_plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus)
- 扩展 XYplot 的可用性
- 整合了Fooocus Inpaint功能
- 整合了常用的逻辑计算、转换类型、展示所有类型等
- 支持节点上checkpoint、lora模型子目录分类及预览图 (请在设置中开启上下文菜单嵌套子目录)
- 支持BriaAI的RMBG-1.4模型的背景去除节点,[技术参考](https://huggingface.co/briaai/RMBG-1.4)
- 支持 强制清理comfyUI模型显存占用
- 支持Stable Diffusion 3 多账号API节点
- 支持IC-Light的应用 [示例参考](https://github.com/yolain/ComfyUI-Yolain-Workflows?tab=readme-ov-file#2-5-ic-light) | [代码整合来源](https://github.com/huchenlei/ComfyUI-IC-Light) | [技术参考](https://github.com/lllyasviel/IC-Light)
- 中文提示词自动识别,使用[opus-mt-zh-en模型](https://huggingface.co/Helsinki-NLP/opus-mt-zh-en)
- 支持 sd3 模型
- 支持 kolors 模型
- 支持 flux 模型
- 支持 惰性条件判断ifElse和 for循环
## 👨🏻‍🔧 安装
1. 将存储库克隆到 **custom_nodes** 目录并安装依赖
```shell
#1. git下载
git clone https://github.com/yolain/ComfyUI-Easy-Use
#2. 安装依赖
双击install.bat安装依赖
```
## 📜 更新日志
**v1.3.6**
- 恢复 `easy showAnything` 对于列表类型的支持但一些情况下展示庞大数据时仍会导致ComfyUI崩溃
- 修复自定义小部件以支持子图和 Nodes 2.0 #942
- 添加 `easy multiAngle` 节点
-`prompt.py` 转换为 V3 Schema
- 修复 `easy humanSegmentation` 错误
- 添加 `easy stringJoinLines``easy stringToIntList``easy simpleMath`
- 修复 `easy ifElse``easy anythingIndexSwitch` 在某些环境下失败的问题
**v1.3.5**
- 修复`isNone`
-`preview_rescale`添加到`easy imageChooser`
- 修复小部件隐藏#910
- 将 max 参数添加到 `wildcardsPromptMatrix` 偏移量 #909
- 修复子图节点上的标题框样式
-`easypromptLine` 上添加 `remove_empty_lines`
**v1.3.4**
- 修复 `easy seedList` 最大值 #879
- 为xyplot添加controlnet input #877
-`easy indexAnything` 支持 `反向索引`
**v1.3.3**
- 删除CSS类名称`gird-cols-1` #859
- 修复锁定种子在 `easy promptAwait` 中不起作用
- 重命名节点图
- 修复`easy ImageChooser`输出错误类型 #845
**v1.3.2**
- 改造 `easy imageChooser` 节点以兼容 frontend>=v1.24.2, 解决方案参考自 [Comfyui_LG_Tools](https://github.com/LAOGOU-666/Comfyui_LG_Tools)
- 改造 `easy stylesSelector` 节点, 你可在 [other styles files](https://github.com/yolain/EasyUse-Styles-Templates) 下载到 `styles` 文件夹下
- 改造 `easy humanSegmentation` 节点
- 修复 `easy makeImageForICLora` 节点.
- 添加 `easy joycaption3API` 节点
- 添加 `easy promptAwait` 节点
**v1.3.1**
- 重写 drawNodeWidget 修复组节点预览的问题.
- 更新了一些 XYPlot 的功能 by [mekinney](https://github.com/mekinney)
- 添加 `easy seedList` 节点 (它对循环节点有用)
**v1.3.0**
- 将循环节点设置为最大输入和输出数量为20
- 添加 `uniform width` 方式到 `easy makeImageForICLora`
- 增加 `wildcardsPromptMatrix` 通配符提示词矩阵,由 [Rosmeowtis](https://github.com/Rosmeowtis) 贡献
**v1.2.9**
- 修复 Imagechooser 会导致工作流处理取消
- 修复 brushnet tensor(640) 错误
- 修复v1.6.0前端之后无法隐藏小部件的bug
- 修复图像选择器无法选择图像
- 修复ContextMenu Monkey修补以影响自定义脚本PYSSSS节点
**v1.2.8**
- 修复了一些BUG (😹)
- 增加了多语言目录
**v1.2.7**
- 优化管理节点组显示
-`easy imageRemBg` 上添加 `ben2`
- 添加 joyCaption2 API版节点 https://github.com/siliconflow/BizyAir
- 使用一种新的方式在 loader 中显示模型缩略图(支持 diffusion_models、lors、checkpoints
**v1.2.6**
- 修复了在缺少自定义节点时缺少 “红色框框” 样式的问题。
- 在一些简单的加载器中,将 `clip_skip` 的默认值从 `-1` 调整为 `-2`
- 修复因设置节点中缺少相连接的自定义节点而导致弄乱画布的问题
- 修复 'easy imageChooser' 不能循环使用的问题。
**v1.2.5**
-`easy preSamplingCustom``easy preSamplingAdvanced` 上增加 `enable (GPU=A1111)` 噪波生成模式选择项
- 增加 `easy makeImageForICLora`
-`easy ipadapterApply` 添加 `REGULAR - FLUX and SD3.5 only (high strength)` 预置项以支持 InstantX Flux ipadapter
- 修复brushnet 无法在 `--fast` 模式下使用
- 支持briaai RMBG-2.0
- 支持mochi模型
- 实现在循环主体中重复使用终端节点输出(例如预览图像和显示任何内容等输出节点...
**v1.2.4**
- 增加 `easy imageSplitTiles` and `easy imageTilesFromBatch` - 图像分块
- 支持 `model_override`,`vae_override`,`clip_override` 可以在 `easy fullLoader` 中单独输入
- 增加 `easy saveImageLazy`
- 增加 `easy loadImageForLoop`
- 增加 `easy isFileExist`
- 增加 `easy saveText`
**v1.2.3**
- `easy showAnything``easy cleanGPUUsed` 增加输出插槽
- 添加新的人体分割在 `easy humanSegmentation` 节点上 - 代码从 [ComfyUI_Human_Parts](https://github.com/metal3d/ComfyUI_Human_Parts) 整合
- 当你在 `easy preSamplingCustom` 节点上选择basicGuiderCFG>0 且当前模型为Flux时将使用FluxGuidance
- 增加 `easy loraStackApply` and `easy controlnetStackApply`
**v1.2.2**
- 增加 `easy batchAny`
- 增加 `easy anythingIndexSwitch`
- 增加 `easy forLoopStart``easy forLoopEnd`
- 增加 `easy ifElse`
- 增加 v2 版本新前端代码
- 增加 `easy fluxLoader`
- 增加 `controlnetApply` 相关节点对sd3和hunyuanDiT的支持
- 修复 当使用fooocus inpaint后再使用Lora模型无法生效的问题
**v1.2.1**
- 增加 `easy ipadapterApplyFaceIDKolors`
- `easy ipadapterApply``easy ipadapterApplyADV` 增加 **PLUS (kolors genernal)****FACEID PLUS KOLORS** 预置项
- `easy imageRemBg` 增加 **inspyrenet** 选项
- 增加 `easy controlnetLoader++`
- 去除 `easy positive` `easy negative` 等prompt节点的自动将中文翻译功能自动翻译仅在 `easy a1111Loader` 等不支持中文TE的加载器中生效
- 增加 `easy kolorsLoader` - 可灵加载器,参考了 [MinusZoneAI](https://github.com/MinusZoneAI/ComfyUI-Kolors-MZ) 和 [kijai](https://github.com/kijai/ComfyUI-KwaiKolorsWrapper) 的代码。
**v1.2.0**
- 增加 `easy pulIDApply``easy pulIDApplyADV`
- 增加 `easy hunyuanDiTLoader``easy pixArtLoader`
- 当新菜单的位置在上或者下时增加上 crystools 的显示推荐开两个就好如果后续crystools有更新UI适配我可能会删除掉
- 增加 **easy sliderControl** - 滑块控制节点当前可用于控制ipadapterMS的参数 (双击滑块可重置为默认值)
- 增加 **layer_weights** 属性在 `easy ipadapterApplyADV` 节点
**v1.1.9**
- 增加 新的调度器 **gitsScheduler**
- 增加 `easy imageBatchToImageList``easy imageListToImageBatch` (修复Impact版的一点小问题)
- 递归模型子目录嵌套
- 支持 sd3 模型
- 增加 `easy applyInpaint` - 局部重绘全模式节点 (相比与之前的kSamplerInpating节点逻辑会更合理些)
**v1.1.8**
- 增加中文提示词自动翻译,使用[opus-mt-zh-en模型](https://huggingface.co/Helsinki-NLP/opus-mt-zh-en), 默认已对wildcard、lora正则处理, 其他需要保留的中文,可使用`@你的提示词@`包裹 (若依赖安装完成后报错, 请重启)测算大约会占0.3GB显存
- 增加 `easy controlnetStack` - controlnet堆
- 增加 `easy applyBrushNet` - [示例参考](https://github.com/yolain/ComfyUI-Yolain-Workflows/blob/main/workflows/2_advanced/2-4inpainting/2-4brushnet_1.1.8.json)
- 增加 `easy applyPowerPaint` - [示例参考](https://github.com/yolain/ComfyUI-Yolain-Workflows/blob/main/workflows/2_advanced/2-4inpainting/2-4powerpaint_outpaint_1.1.8.json)
**v1.1.7**
- 修复 一些模型(如controlnet模型等)未成功写入缓存,导致修改前置节点束参数(如提示词)需要二次载入模型的问题
- 增加 `easy prompt` - 主体和光影预置项,后期可能会调整
- 增加 `easy icLightApply` - 重绘光影, 从[ComfyUI-IC-Light](https://github.com/huchenlei/ComfyUI-IC-Light)优化
- 增加 `easy imageSplitGrid` - 图像网格拆分
- `easy kSamplerInpainting`**additional** 属性增加差异扩散和brushnet等相关选项
- 增加 brushnet模型加载的支持 - [ComfyUI-BrushNet](https://github.com/nullquant/ComfyUI-BrushNet)
- 增加 `easy applyFooocusInpaint` - Fooocus内补节点 替代原有的 FooocusInpaintLoader
- 移除 `easy fooocusInpaintLoader` - 容易bug不再使用
- 修改 easy kSampler等采样器中并联的model 不再替换输出中pipe里的model
**v1.1.6**
- 增加步调齐整适配 - 在所有的预采样和全采样器节点中的 调度器(schedulder) 增加了 **alignYourSteps** 选项
- `easy kSampler``easy fullkSampler`**image_output** 增加 **Preview&Choose**选项
- 增加 `easy styleAlignedBatchAlign` - 风格对齐 [style_aligned_comfy](https://github.com/brianfitzgerald/style_aligned_comfy)
- 增加 `easy ckptNames`
- 增加 `easy controlnetNames`
- 增加 `easy imagesSplitimage` - 批次图像拆分单张
- 增加 `easy imageCount` - 图像数量
- 增加 `easy textSwitch` - 文字切换
<details>
<summary><b>v1.1.5</b></summary>
- 重写 `easy cleanGPUUsed` - 可强制清理comfyUI的模型显存占用
- 增加 `easy humanSegmentation` - 多类分割、人像分割
- 增加 `easy imageColorMatch`
- 增加 `easy ipadapterApplyRegional`
- 增加 `easy ipadapterApplyFromParams`
- 增加 `easy imageInterrogator` - 图像反推
- 增加 `easy stableDiffusion3API` - 简易的Stable Diffusion 3 多账号API节点
</details>
<details>
<summary><b>v1.1.4</b></summary>
- 增加 `easy imageChooser` - 从[cg-image-picker](https://github.com/chrisgoringe/cg-image-picker)简化的图片选择器
- 增加 `easy preSamplingCustom` - 自定义预采样可支持cosXL-edit
- 增加 `easy ipadapterStyleComposition`
- 增加 在Loaders上右键菜单可查看 checkpoints、lora 信息
- 修复 `easy preSamplingNoiseIn``easy latentNoisy``east Unsampler` 以兼容ComfyUI Revision>=2098 [0542088e] 以上版本
- 修复 FooocusInpaint修改ModelPatcher计算权重引发的问题理应在生成model后重置ModelPatcher为默认值
</details>
<details>
<summary><b>v1.1.3</b></summary>
- `easy ipadapterApply` 增加 **COMPOSITION** 预置项
- 增加 对[ResAdapter](https://huggingface.co/jiaxiangc/res-adapter) lora模型 的加载支持
- 增加 `easy promptLine`
- 增加 `easy promptReplace`
- 增加 `easy promptConcat`
- `easy wildcards` 增加 **multiline_mode**属性
- 增加 当节点需要下载模型时若huggingface连接超时会切换至镜像地址下载模型
</details>
<details>
<summary><b>v1.1.2</b></summary>
- 改写 EasyUse 相关节点的部分插槽推荐节点
- 增加 **启用上下文菜单自动嵌套子目录** 设置项默认为启用状态可分类子目录及checkpoints、loras预览图
- 增加 `easy sv3dLoader`
- 增加 `easy dynamiCrafterLoader`
- 增加 `easy ipadapterApply`
- 增加 `easy ipadapterApplyADV`
- 增加 `easy ipadapterApplyEncoder`
- 增加 `easy ipadapterApplyEmbeds`
- 增加 `easy preMaskDetailerFix`
- `easy kSamplerInpainting` 增加 **additional** 属性,可设置成 Differential Diffusion 或 Only InpaintModelConditioning
- 修复 `easy stylesSelector` 当未选择样式时,原有提示词发生了变化
- 修复 `easy pipeEdit` 提示词输入lora时报错
- 修复 layerDiffuse xyplot相关bug
</details>
<details>
<summary><b>v1.1.1</b></summary>
- 修复首次添加含seed的节点且当前模式为control_before_generate时seed为0的问题
- `easy preSamplingAdvanced` 增加 **return_with_leftover_noise**
- 修复 `easy stylesSelector` 当选择自定义样式文件时运行队列报错
- `easy preSamplingLayerDiffusion` 增加 mask 可选传入参数
- 将所有 **seed_num** 调整回 **seed**
- 修补官方BUG: 当control_mode为before 在首次加载页面时未修改节点中widget名称为 control_before_generate
- 去除强制**control_before_generate**设定
- 增加 `easy imageRemBg` - 默认为BriaAI的RMBG-1.4模型, 移除背景效果更加,速度更快
</details>
<details>
<summary><b>v1.1.0</b></summary>
- 增加 `easy imageSplitList` - 拆分每 N 张图像
- 增加 `easy preSamplingDiffusionADDTL` - 可配置前景、背景、blended的additional_prompt等
- 增加 `easy preSamplingNoiseIn` 可替代需要前置的`easy latentNoisy`节点 实现效果更好的噪声注入
- `easy pipeEdit` 增加 条件拼接模式选择,可选择替换、合并、联结、平均、设置条件时间
- 增加 `easy pipeEdit` - 可编辑Pipe的节点包含可重新输入提示词
- 增加 `easy preSamplingLayerDiffusion``easy kSamplerLayerDiffusion` (连接 `easy kSampler` 也能通)
- 增加 在 加载器、预采样、采样器、Controlnet等节点上右键可快速替换同类型节点的便捷菜单
- 增加 `easy instantIDApplyADV` 可连入 positive 与 negative
- 修复 `easy wildcards` 读取lora未填写完整路径时未自动检索导致加载lora失败的问题
- 修复 `easy instantIDApply` mask 未传入正确值
- 修复 在 非a1111提示词风格下 BREAK 不生效的问题
</details>
<details>
<summary><b>v1.0.9</b></summary>
- 修复未安装 ComfyUI-Impack-Pack 和 ComfyUI_InstantID 时报错
- 修复 `easy pipeIn` - pipe设为可不必选
- 增加 `easy instantIDApply` - 需要先安装 [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID), 工作流参考[示例](https://github.com/yolain/ComfyUI-Yolain-Workflows?tab=readme-ov-file#2-2-instantid)
- 修复 `easy detailerFix` 未添加到保存图片格式化扩展名可用节点列表
- 修复 `easy XYInputs: PromptSR` 在替换负面提示词时报错
</details>
<details>
<summary><b>v1.0.8</b></summary>
- `easy cascadeLoader` stage_c 与 stage_b 支持checkpoint模型 (需要下载[checkpoints](https://huggingface.co/stabilityai/stable-cascade/tree/main/comfyui_checkpoints))
- `easy styleSelector` 搜索框修改为不区分大小写匹配
- `easy fullLoader` 增加 **positive**、**negative**、**latent** 输出项
- 修复 SDXLClipModel 在 ComfyUI 修订版本号 2016[c2cb8e88] 及以上的报错(判断了版本号可兼容老版本)
- 修复 `easy detailerFix` 批次大小大于1时生成出错
- 修复`easy preSampling`等 latent传入后无法根据批次索引生成的问题
- 修复 `easy svdLoader` 报错
- 优化代码,减少了诸多冗余,提升运行速度
- 去除中文翻译对照文本
(翻译对照已由 [AIGODLIKE-COMFYUI-TRANSLATION](https://github.com/AIGODLIKE/AIGODLIKE-ComfyUI-Translation) 统一维护啦!
首次下载或者版本较早的朋友请更新 AIGODLIKE-COMFYUI-TRANSLATION 和本节点包至最新版本。)
</details>
<details>
<summary><b>v1.0.7</b></summary>
- 增加 `easy cascadeLoader` - stable cascade 加载器
- 增加 `easy preSamplingCascade` - stabled cascade stage_c 预采样参数
- 增加 `easy fullCascadeKSampler` - stable cascade stage_c 完整版采样器
- 增加 `easy cascadeKSampler` - stable cascade stage-c ksampler simple
</details>
<details>
<summary><b>v1.0.6</b></summary>
- 增加 `easy XYInputs: Checkpoint`
- 增加 `easy XYInputs: Lora`
- `easy seed` 增加固定种子值时可手动切换随机种
- 修复 `easy fullLoader`等加载器切换lora时自动调整节点大小的问题
- 去除原有ttn的图片保存逻辑并适配ComfyUI默认的图片保存格式化扩展
</details>
<details>
<summary><b>v1.0.5</b></summary>
- 增加 `easy isSDXL`
- `easy svdLoader` 增加提示词控制, 可配合open_clip模型进行使用
- `easy wildcards` 增加 **populated_text** 可输出通配填充后文本
</details>
<details>
<summary><b>v1.0.4</b></summary>
- 增加 `easy showLoaderSettingsNames` 可显示与输出加载器部件中的 模型与VAE名称
- 增加 `easy promptList` - 提示词列表
- 增加 `easy fooocusInpaintLoader` - Fooocus内补节点仅支持XL模型的流程
- 增加 **Logic** 逻辑类节点 - 包含类型、计算、判断和转换类型等
- 增加 `easy imageSave` - 带日期转换和宽高格式化的图像保存节点
- 增加 `easy joinImageBatch` - 合并图像批次
- `easy showAnything` 增加支持转换其他类型tensor类型的条件、图像等
- `easy kSamplerInpainting` 增加 **patch** 传入值配合Fooocus内补节点使用
- `easy imageSave` 增加 **only_preivew**
- 修复 xyplot在pillow>9.5中报错
- 修复 `easy wildcards` 在使用PS扩展插件运行时报错
- 修复 `easy latentCompositeMaskedWithCond`
- 修复 `easy XYInputs: ControlNet` 报错
- 修复 `easy loraStack` **toggle** 为 disabled 时报错
- 修改首次安装节点包不再自动替换主题,需手动调整并刷新页面
</details>
<details>
<summary><b>v1.0.3</b></summary>
- 增加 `easy stylesSelector` 风格化提示词选择器
- 增加队列进度条设置项,默认为未启用状态
- `easy controlnetLoader``easy controlnetLoaderADV` 增加参数 **scale_soft_weights**
- 修复 `easy XYInputs: Sampler/Scheduler` 报错
- 修复 右侧菜单 点击按钮时老是跑位的问题
- 修复 styles 路径在其他环境报错
- 修复 `easy comfyLoader` 读取错误
- 修复 xyPlot 在连接 zero123 时报错
- 修复加载器中提示词为组件时报错
- 修复 `easy getNode``easy setNode` 加载时标题未更改
- 修复所有采样器中存储图片使用子目录前缀不生效的问题
- 调整UI主题
</details>
<details>
<summary><b>v1.0.2</b></summary>
- 增加 **autocomplete** 文件夹,如果您安装了 [ComfyUI-Custom-Scripts](https://github.com/pythongosssss/ComfyUI-Custom-Scripts), 将在启动时合并该文件夹下的所有txt文件并覆盖到pyssss包里的autocomplete.txt文件。
- 增加 `easy XYPlotAdvanced``easy XYInputs` 等相关节点
- 增加 **Alt+1到9** 快捷键,可快速粘贴 Node templates 的节点预设 (对应 1到9 顺序)
- 修复 `easy imageInsetCrop` 测量值为百分比时步进为1
- 修复 开启 `a1111_prompt_style` 时XY图表无法使用的问题
- 右键菜单中增加了一个 `📜Groups Map(EasyUse)`
- 修复在Comfy新版本中UI加载失败
- 修复 `easy pipeToBasicPipe` 报错
- 修改 `easy fullLoader``easy a1111Loader` 中的 **a1111_prompt_style** 默认值为 False
- `easy XYInputs ModelMergeBlocks` 支持csv文件导入数值
- 替换了XY图生成时的字体文件
- 移除 `easy imageRemBg`
- 移除包中的介绍图和工作流文件,减少包体积
</details>
<details>
<summary><b>v1.0.1</b></summary>
- 新增 `easy seed` - 简易随机种
- `easy preDetailerFix` 新增了 `optional_image` 传入图像可选如未传默认取值为pipe里的图像
- 新增 `easy kSamplerInpainting` 用于内补潜空间的采样器
- 新增 `easy pipeToBasicPipe` 用于转换到Impact的某些节点上
- 修复 `easy comfyLoader` 报错
- 修复所有包含输出图片尺寸的节点取值方式无法批处理的问题
- 修复 `width``height` 无法在 `easy svdLoader` 自定义的报错问题
- 修复所有采样器预览图片的地址链接 (解决在 MACOS 系统中图片无法在采样器中预览的问题)
- 修复 `vae_name``easy fullLoader``easy a1111Loader``easy comfyLoader` 中选择但未替换原始vae问题
- 修复 `easy fullkSampler` 除pipe外其他输出值的报错
- 修复 `easy hiresFix` 输入连接pipe和image、vae同时存在时报错
- 修复 `easy fullLoader``model_override` 连接后未执行
- 修复 因新增`easy seed` 导致action错误
- 修复 `easy xyplot` 的字体文件路径读取错误
- 修复 convert 到 `easy seed` 随机种无法固定的问题
- 修复 `easy pipeIn` 值传入的报错问题
- 修复 `easy zero123Loader``easy svdLoader` 读取模型时将模型加入到缓存中
- 修复 `easy kSampler` `easy kSamplerTiled` `easy detailerFix``image_output` 默认值为 Preview
- `easy fullLoader``easy a1111Loader` 新增了 `a1111_prompt_style` 参数可以重现和webui生成相同的图像当前您需要安装 [ComfyUI_smZNodes](https://github.com/shiimizu/ComfyUI_smZNodes) 才能使用此功能
</details>
<details>
<summary><b>v1.0.0</b></summary>
- 新增`easy positive` - 简易正面提示词文本
- 新增`easy negative` - 简易负面提示词文本
- 新增`easy wildcards` - 支持通配符和Lora选择的提示词文本
- 新增`easy portraitMaster` - 肖像大师v2.2
- 新增`easy loraStack` - Lora堆
- 新增`easy fullLoader` - 完整版的加载器
- 新增`easy zero123Loader` - 简易zero123加载器
- 新增`easy svdLoader` - 简易svd加载器
- 新增`easy fullkSampler` - 完整版的采样器(无分离)
- 新增`easy hiresFix` - 支持Pipe的高清修复
- 新增`easy predetailerFix` `easy DetailerFix` - 支持Pipe的细节修复
- 新增`easy ultralyticsDetectorPipe` `easy samLoaderPipe` - 检测加载器(细节修复的输入项)
- 新增`easy pipein` `easy pipeout` - Pipe的输入与输出
- 新增`easy xyPlot` - 简易的xyplot (后续会更新更多可控参数)
- 新增`easy imageRemoveBG` - 图像去除背景
- 新增`easy imagePixelPerfect` - 图像完美像素
- 新增`easy poseEditor` - 姿势编辑器
- 新增UI主题黑曜石- 默认自动加载UI, 也可在设置中自行更替
- 修复 `easy globalSeed` 不生效问题
- 修复所有的`seed_num` 因 [cg-use-everywhere](https://github.com/chrisgoringe/cg-use-everywhere) 实时更新图表导致值错乱的问题
- 修复`easy imageSize` `easy imageSizeBySide` `easy imageSizeByLongerSide` 可作为终节点
- 修复 `seed_num` (随机种子值) 在历史记录中读取无法一致的Bug
</details>
<details>
<summary><b>v0.5</b></summary>
- 新增 `easy controlnetLoaderADV` 节点
- 新增 `easy imageSizeBySide` 节点,可选输出为长边或短边
- 新增 `easy LLLiteLoader` 节点,如果您预先安装过 kohya-ss/ControlNet-LLLite-ComfyUI 包,请将 models 里的模型文件移动至 ComfyUI\models\controlnet\ (即comfy默认的controlnet路径里请勿修改模型的文件名不然会读取不到)。
- 新增 `easy imageSize``easy imageSizeByLongerSize` 输出的尺寸显示。
- 新增 `easy showSpentTime` 节点用于展示图片推理花费时间与VAE解码花费时间。
- `easy controlnetLoaderADV``easy controlnetLoader` 新增 `control_net` 可选传入参数
- `easy preSampling``easy preSamplingAdvanced` 新增 `image_to_latent` 可选传入参数
- `easy a1111Loader``easy comfyLoader` 新增 `batch_size` 传入参数
- 修改 `easy controlnetLoader` 到 loader 分类底下。
</details>
## 整合参考到的相关节点包
声明: 非常尊重这些原作者们的付出,开源不易,我仅仅只是做了一些整合与优化。
| 节点名 (搜索名) | 相关的库 | 库相关的节点 |
|:-------------------------------|:----------------------------------------------------------------------------|:------------------------|
| easy setNode | [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) | diffus3.SetNode |
| easy getNode | [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) | diffus3.GetNode |
| easy bookmark | [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) | Bookmark 🔖 |
| easy portraitMarker | [comfyui-portrait-master](https://github.com/florestefano1975/comfyui-portrait-master) | Portrait Master |
| easy LLLiteLoader | [ControlNet-LLLite-ComfyUI](https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI) | LLLiteLoader |
| easy globalSeed | [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack) | Global Seed (Inspire) |
| easy preSamplingDynamicCFG | [sd-dynamic-thresholding](https://github.com/mcmonkeyprojects/sd-dynamic-thresholding) | DynamicThresholdingFull |
| dynamicThresholdingFull | [sd-dynamic-thresholding](https://github.com/mcmonkeyprojects/sd-dynamic-thresholding) | DynamicThresholdingFull |
| easy imageInsetCrop | [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) | ImageInsetCrop |
| easy poseEditor | [ComfyUI_Custom_Nodes_AlekPet](https://github.com/AlekPet/ComfyUI_Custom_Nodes_AlekPet) | poseNode |
| easy if | [ComfyUI-Logic](https://github.com/theUpsider/ComfyUI-Logic) | IfExecute |
| easy preSamplingLayerDiffusion | [ComfyUI-layerdiffusion](https://github.com/huchenlei/ComfyUI-layerdiffusion) | LayeredDiffusionApply等 |
| easy dynamiCrafterLoader | [ComfyUI-layerdiffusion](https://github.com/ExponentialML/ComfyUI_Native_DynamiCrafter) | Apply Dynamicrafter |
| easy imageChooser | [cg-image-picker](https://github.com/chrisgoringe/cg-image-picker) | Preview Chooser |
| easy styleAlignedBatchAlign | [style_aligned_comfy](https://github.com/chrisgoringe/cg-image-picker) | styleAlignedBatchAlign |
| easy icLightApply | [ComfyUI-IC-Light](https://github.com/huchenlei/ComfyUI-IC-Light) | ICLightApply等 |
| easy kolorsLoader | [ComfyUI-Kolors-MZ](https://github.com/MinusZoneAI/ComfyUI-Kolors-MZ) | kolorsLoader |
## Credits
[ComfyUI](https://github.com/comfyanonymous/ComfyUI) - 功能强大且模块化的Stable Diffusion GUI
[ComfyUI-ComfyUI-Manager](https://github.com/ltdrdata/ComfyUI-Manager) - ComfyUI管理器
[tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes) - 管道节点(节点束)让用户减少了不必要的连接
[ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) - diffus3的获取与设置点让用户可以分离工作流构成
[ComfyUI-Impact-Pack](https://github.com/ltdrdata/ComfyUI-Impact-Pack) - 常规整合包1
[ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack) - 常规整合包2
[ComfyUI-Logic](https://github.com/theUpsider/ComfyUI-Logic) - ComfyUI逻辑运算
[ComfyUI-ResAdapter](https://github.com/jiaxiangc/ComfyUI-ResAdapter) - 让模型生成不受训练分辨率限制
[ComfyUI_IPAdapter_plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus) - 风格迁移
[ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID) - 人脸迁移
[ComfyUI_PuLID](https://github.com/cubiq/PuLID_ComfyUI) - 人脸迁移
[ComfyUI-Custom-Scripts](https://github.com/pythongosssss/ComfyUI-Custom-Scripts) - pyssss 小蛇🐍脚本
[cg-image-picker](https://github.com/chrisgoringe/cg-image-picker) - 图片选择器
[ComfyUI-BrushNet](https://github.com/nullquant/ComfyUI-BrushNet) - BrushNet 内补节点
[ComfyUI_ExtraModels](https://github.com/city96/ComfyUI_ExtraModels) - DiT架构相关节点Pixart、混元DiT等
## 免责声明
本开源项目及其内容按 “原样 ”提供,不作任何明示或暗示的保证,包括但不限于适销性、特定用途适用性和非侵权保证。在任何情况下,作者或其他版权所有者均不对因本软件或本软件的使用或其他交易而产生、引起或与之相关的任何索赔、损害或其他责任承担责任,无论是合同诉讼、侵权诉讼还是其他诉讼。
用户应自行负责确保在使用本软件或发布由本软件生成的内容时,遵守所在司法管辖区的所有适用法律和法规。作者和版权所有者不对用户在其各自所在地违反法律或法规的行为负责。
## ☕️ 投喂
**Comfyui-Easy-Use** 是一个 GPL 许可的开源项目。为了项目取得更好、可持续的发展,我希望能够获得更多的支持。 如果我的自定义节点为您的一天增添了价值,请考虑喝杯咖啡来进一步补充能量! 💖感谢您的支持,每一杯咖啡都是我创作的动力!
- [BiliBili充电](https://space.bilibili.com/1840885116)
- [Wechat/Alipay](https://github.com/user-attachments/assets/803469bd-ed6a-4fab-932d-50e5088a2d03)
感谢您的捐助,我将用这些费用来租用 GPU 或购买其他 GPT 服务,以便更好地调试和完善 ComfyUI-Easy-Use 功能
## 🌟大富大贵的人儿
我对那些慷慨的赐予一颗星的人表示感谢。非常感谢您的支持!
[![Stargazers repo roster for @yolain/ComfyUI-Easy-Use](https://reporoster.com/stars/yolain/ComfyUI-Easy-Use)](https://github.com/yolain/ComfyUI-Easy-Use/stargazers)

View File

@@ -0,0 +1,555 @@
![comfyui-easy-use](https://github.com/user-attachments/assets/9b7a5e44-f5e2-4c27-aed2-d0e6b50c46bb)
<div align="center">
<a href="https://space.bilibili.com/1840885116">Video Tutorial</a> |
<a href="https://docs.easyuse.yolain.com">Docs</a> |
<a href="https://github.com/yolain/ComfyUI-Yolain-Workflows">Workflow Collection</a> |
<a href="#%EF%B8%8F-donation">Donation</a>
<br><br>
<a href="./README.md"><img src="https://img.shields.io/badge/🇬🇧English-0b8cf5"></a>
<a href="./README.ZH_CN.md"><img src="https://img.shields.io/badge/🇨🇳中文简体-e9e9e9"></a>
</div>
**ComfyUI-Easy-Use** is an efficiency custom nodes integration package, which is extended on the basis of [TinyTerraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes). It has been integrated and optimized for many popular awesome custom nodes to achieve the purpose of faster and more convenient use of ComfyUI. While ensuring the degree of freedom, it restores the ultimate smooth image production experience that belongs to Stable Diffusion.
## 👨🏻‍🎨 Introduce
- Inspire by [tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes), which greatly reduces the time cost of tossing workflows。
- UI interface beautification, the first time you install the user, if you need to use the UI theme, please switch the theme in Settings -> Color Palette and refresh page.
- Added a node for pre-sampling parameter configuration, which can be separated from the sampling node for easier previewing
- Wildcards and lora's are supported, for Lora Block Weight usage, ensure that the custom node package has the [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack)
- Multi-selectable styled cue word selector, default is Fooocus style json, custom json can be placed under styles, samples folder can be placed in the preview image (name and name consistent, image file name such as spaces need to be converted to underscores '_')
- The loader enables the A1111 prompt mode, which reproduces nearly identical images to those generated by webui.
- Noise injection into the latent space can be achieved using the `easy latentNoisy` or `easy preSamplingNoiseIn` node
- Simplified processes for SD1.x, SD2.x, SDXL, SVD, Zero123, etc. [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#StableDiffusion)
- Simplified Stable Cascade [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#StableCascade)
- Simplified Layer Diffuse [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#LayerDiffusion)The first time you use it you may need to run `pip install -r requirements.txt` to install the required dependencies.
- Simplified InstantID [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#InstantID), You need to make sure that the custom node package has the [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID)
- Extending the usability of XYplot
- Fooocus Inpaint integration
- Integration of common logical calculations, conversion of types, display of all types, etc.
- Background removal nodes for the RMBG-1.4 model supporting BriaAI, [BriaAI Guide](https://huggingface.co/briaai/RMBG-1.4)
- Forcibly cleared the memory usage of the comfy UI model are supported
- Stable Diffusion 3 multi-account API nodes are supported
- Support SD3's model
- Support Kolorss model
- Support Flux's model
- Support lazy if else and for loops
## 👨🏻‍🔧 Installation
Clone the repo into the **custom_nodes** directory and install the requirements:
```shell
#1. Clone the repo
git clone https://github.com/yolain/ComfyUI-Easy-Use
#2. Install the requirements
Double-click install.bat to install the required dependencies
```
## 📜 Changelog
**v1.3.6**
- Restored `easy showAnything` support for list types (but displaying large data in some cases may still cause ComfyUI to crash)
- Fix custom widgets to support subgraph and Nodes 2.0 #942
- Add `easy multiAngle` node
- Convert `prompt.py` to V3 Schema
- Fix `easy humanSegmentation` error
- Add `easy stringJoinLines``easy stringToIntList`, `easy simpleMath`
- Fix `easy ifElse` and `easy anythingIndexSwitch` fails in certain environments
**v1.3.5**
- Fix `isNone`
- Add `preview_rescale` to `easy imageChooser`
- Fix widget hidden #910
- Add max parameter to `wildcardsPromptMatrix` offset #909
- Fix title box style on subgraph node
- Add `remove_empty_lines` on `easy promptLine`
**v1.3.4**
- Fix `easy seedList` max_num #879
- Add controlnet input to xyplot #877
- Support `nagative indexing` for `easy indexAnything`
**v1.3.3**
- Removed the definition of the CSS class name gird-cols-1 #859
- Fix lock seed not working in `easy promptAwait`
- Rename the nodes map
- Fix `easy imageChooser` output error type #845
**v1.3.2**
- Revamp `easy imageChooser` node to adapt frontend>=v1.24.2, solution referenced from [Comfyui_LG_Tools](https://github.com/LAOGOU-666/Comfyui_LG_Tools)
- Revamp `easy stylesSelector` node, and you can download [other styles files](https://github.com/yolain/EasyUse-Styles-Templates) to the `styles` folder
- Revamp `easy humanSegmentation` node
- Fix `easy makeImageForICLora` node issue, that occurred when the heights of two images were the same during image stitching on.
- Add `easy joycaption3API` node
- Add `easy promptAwait` node
**v1.3.1**
- Rewrite drawNodeWidget and fix the GroupNode preview issue.
- Updated some features of XYPlot by [mekinney](https://github.com/mekinney)
- Add `easy seedList` node (It's useful for in loops)
**v1.3.0**
- Set loop nodes maximum number of inputs and outputs to 20
- Add `uniform width` method to `easy makeImageForICLora`
- Add `wildcardsPromptMatrix` Node by [Rosmeowtis](https://github.com/Rosmeowtis)
**v1.2.9**
- Fix ImageChooser causes workflow processing to cancel
- Fix brushnet tensor(640) error
- Fix widgets not hidden after v1.6.0 frontend
- Fix image chooser can not select images
- Fix contextMenu monkey patching to affect custom scripts (pysssss) nodes
**v1.2.8**
- Added the multi-language catalog
- Fix CLIP vision model download URLs for IPAdapter and DynamiCrafter
- Improve error handling for model downloads with clearer error messages and better handling of download failures
**v1.2.7**
- Optimize display of the node maps
- Added `ben2` on `easy imageRemBg`
- Using a new way to display the models thumbnails in the loaders (supported diffusion_models、lors、checkpoints)
**v1.2.6**
- Fix missing the "Red Rect" styles when you are missing custom nodes.
- Adjust the default value of `clip_skip` from `-1` to `-2` in some easy loaders.
- Fix the issue due to set nodes missing custom nodes which their connected, causing canvas to be messed up.
- Fix the `easy imageChooser` can not using in a loop.
**v1.2.5**
- Added `enable (GPU=A1111)` noise mode on `easy preSamplingCustom` and `easy preSamplingAdvanced`
- Added `easy makeImageForICLora`
- Added `REGULAR - FLUX and SD3.5 only (high strength)` preset for InstantX Flux ipadapter on `easy ipadapterApply`
- Fix brushnet can not be used with startup arg `--fast` mode
- Support briaai RMBG-2.0
- Support mochi
- Implement reuse of end nodes output in the loop body (e.g: previewImage and showAnything and sth.)
**v1.2.4**
- Added `easy imageSplitTiles` and `easy imageTilesFromBatch`
- Support `model_override`,`vae_override`,`clip_override` can be input separately to `easy fullLoader`
- Added `easy saveImageLazy`
- Added `easy loadImageForLoop`
- Added `easy isFileExist`
- Added `easy saveText`
**v1.2.3**
- `easy showAnything` and `easy cleanGPUUsed` added slot of output
- Added human parts segmentation to `easy humanSegmentation` - Code based on [ComfyUI_Human_Parts](https://github.com/metal3d/ComfyUI_Human_Parts)
- Using FluxGuidance when you are using a flux model and choose basicGuider and set the cfg>0 on `easy preSamplingCustom`
- Added `easy loraStackApply` and `easy controlnetStackApply` - Apply loraStack and controlnetStack
**v1.2.2**
- Added `easy batchAny`
- Added `easy anythingIndexSwitch`
- Added `easy forLoopStart` and `easy forLoopEnd`
- Added `easy ifElse`
- Added v2 web frond-end code
- Added `easy fluxLoader`
- Added support for `controlnetApply` Related nodes with SD3 and hunyuanDiT
- Fixed after using `easy applyFooocusInpaint`, all lora models become unusable
**v1.2.1**
- Added `easy ipadapterApplyFaceIDKolors`
- Added **inspyrenet** to `easy imageRemBg`
- Added `easy controlnetLoader++`
- Added **PLUS (kolors genernal)** and **FACEID PLUS KOLORS** preset to `easy ipadapterApply` and `easy ipadapterApplyADV` (Supported kolors ipadapter
- Added `easy kolorsLoader` - Code based on [MinusZoneAI](https://github.com/MinusZoneAI/ComfyUI-Kolors-MZ)'s and [kijai](https://github.com/kijai/ComfyUI-KwaiKolorsWrapper)'s repo, thanks for their contribution.
**v1.2.0**
- Added `easy pulIDApply` and `easy pulIDApplyADV`
- Added `easy huanyuanDiTLoader` and `easy pixArtLoader`
- Added **easy sliderControl** - Slider control node, which can currently be used to control the parameters of ipadapterMS (double-click the slider to reset to default)
- Added **layer_weights** in `easy ipadapterApplyADV`
**v1.1.9**
- Added **gitsScheduler**
- Added `easy imageBatchToImageList` and `easy imageListToImageBatch`
- Recursive subcategories nested for models
- Support for Stable Diffusion 3 model
- Added `easy applyInpaint` - All inpainting mode in this node
**v1.1.8**
- Added `easy controlnetStack`
- Added `easy applyBrushNet` - [Workflow Example](https://github.com/yolain/ComfyUI-Yolain-Workflows/blob/main/workflows/2_advanced/2-4inpainting/2-4brushnet_1.1.8.json)
- Added `easy applyPowerPaint` - [Workflow Example](https://github.com/yolain/ComfyUI-Yolain-Workflows/blob/main/workflows/2_advanced/2-4inpainting/2-4powerpaint_outpaint_1.1.8.json)
**v1.1.7**
- Added `easy prompt` - Subject and light presets, maybe adjusted later
- Added `easy icLightApply` - Light and shadow migration, Code based on [ComfyUI-IC-Light](https://github.com/huchenlei/ComfyUI-IC-Light)
- Added `easy imageSplitGrid`
- `easy kSamplerInpainting` added options such as different diffusion and brushnet in **additional** widget
- Support for brushnet model loading - [ComfyUI-BrushNet](https://github.com/nullquant/ComfyUI-BrushNet)
- Added `easy applyFooocusInpaint` - Replace FooocusInpaintLoader
- Removed `easy fooocusInpaintLoader`
**v1.1.6**
- Added **alignYourSteps** to **schedulder** widget in all `easy preSampling` and `easy fullkSampler`
- Added **Preview&Choose** to **image_output** widget in `easy kSampler` & `easy fullkSampler`
- Added `easy styleAlignedBatchAlign` - Credit of [style_aligned_comfy](https://github.com/brianfitzgerald/style_aligned_comfy)
- Added `easy ckptNames`
- Added `easy controlnetNames`
- Added `easy imagesSplitimage` - Batch images split into single images
- Added `easy imageCount` - Get Image Count
- Added `easy textSwitch` - Text Switch
<details>
<summary><b>v1.1.5</b></summary>
- Rewrite `easy cleanGPUUsed` - the memory usage of the comfyUI can to be cleared
- Added `easy humanSegmentation` - Human Part Segmentation
- Added `easy imageColorMatch`
- Added `easy ipadapterApplyRegional`
- Added `easy ipadapterApplyFromParams`
- Added `easy imageInterrogator` - Image To Prompt
- Added `easy stableDiffusion3API` - Easy Stable Diffusion 3 Multiple accounts API Node
</details>
<details>
<summary><b>v1.1.4</b></summary>
- Added `easy preSamplingCustom` - Custom-PreSampling, can be supported cosXL-edit
- Added `easy ipadapterStyleComposition`
- Added the right-click menu to view checkpoints and lora information in all Loaders
- Fixed `easy preSamplingNoiseIn``easy latentNoisy``east Unsampler` compatible with ComfyUI Revision>=2098 [0542088e] or later
</details>
<details>
<summary><b>v1.1.3</b></summary>
- `easy ipadapterApply` Added **COMPOSITION** preset
- Supported [ResAdapter](https://huggingface.co/jiaxiangc/res-adapter) when load ResAdapter lora
- Added `easy promptLine`
- Added `easy promptReplace`
- Added `easy promptConcat`
- `easy wildcards` Added **multiline_mode**
</details>
<details>
<summary><b>v1.1.2</b></summary>
- Optimized some of the recommended nodes for slots related to EasyUse
- Added **Enable ContextMenu Auto Nest Subdirectories** The setting item is enabled by default, and it can be classified into subdirectories, checkpoints and loras previews
- Added `easy sv3dLoader`
- Added `easy dynamiCrafterLoader`
- Added `easy ipadapterApply`
- Added `easy ipadapterApplyADV`
- Added `easy ipadapterApplyEncoder`
- Added `easy ipadapterApplyEmbeds`
- Added `easy preMaskDetailerFix`
- Fixed `easy stylesSelector` is change the prompt when not select the style
- Fixed `easy pipeEdit` error when add lora to prompt
- Fixed layerDiffuse xyplot bug
- `easy kSamplerInpainting` add *additional* widgetyou can choose 'Differential Diffusion' or 'Only InpaintModelConditioning'
</details>
<details>
<summary><b>v1.1.1</b></summary>
- The issue that the seed is 0 when a node with a seed control is added and **control before generate** is fixed for the first time run queue prompt.
- `easy preSamplingAdvanced` Added **return_with_leftover_noise**
- Fixed `easy stylesSelector` error when choose the custom file
- `easy preSamplingLayerDiffusion` Added optional input parameter for mask
- Renamed all nodes widget name named seed_num to seed
- Remove forced **control_before_generate** settings。 If you want to use control_before_generate, change widget_value_control_mode to before in system settings
- Added `easy imageRemBg` - The default is BriaAI's RMBG-1.4 model, which removes the background effect more and faster
</details>
<details>
<summary><b>v1.1.0</b></summary>
- Added `easy imageSplitList` - to split every N images
- Added `easy preSamplingDiffusionADDTL` - It can modify foreground、background or blended additional prompt
- Added `easy preSamplingNoiseIn` It can replace the `easy latentNoisy` node that needs to be fronted to achieve better noise injection
- `easy pipeEdit` Added conditioning splicing mode selection, you can choose to replace, concat, combine, average, and set timestep range
- Added `easy pipeEdit` - nodes that can edit pipes (including re-enterable prompts)
- Added `easy preSamplingLayerDiffusion` and `easy kSamplerLayerDiffusion`
- Added a convenient menu to right-click on nodes such as Loader, Presampler, Sampler, Controlnet, etc. to quickly replace nodes of the same type
- Added `easy instantIDApplyADV` can link positive and negative
- Fixed layerDiffusion error when batch size greater than 1
- Fixed `easy wildcards` When LoRa is not filled in completely, LoRa is not automatically retrieved, resulting in failure to load LoRa
- Fixed the issue that 'BREAK' non-initiation when didn't use a1111 prompt style
- Fixed `easy instantIDApply` mask not input right
</details>
<details>
<summary><b>v1.0.9</b></summary>
- Fixed the error when ComfyUI-Impack-Pack and ComfyUI_InstantID were not installed
- Fixed `easy pipeIn`
- Added `easy instantIDApply` - you need installed [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID) fisrt, Workflow[Example](https://github.com/yolain/ComfyUI-Easy-Use/blob/main/README.en.md#InstantID)
- Fixed `easy detailerFix` not added to the list of nodes available for saving images formatting extensions
- Fixed `easy XYInputs: PromptSR` errors are reported when replacing negative prompts
</details>
<details>
<summary><b>v1.0.8</b></summary>
- `easy cascadeLoader` stage_c and stage_b support the checkpoint model (Download [checkpoints](https://huggingface.co/stabilityai/stable-cascade/tree/main/comfyui_checkpoints) models)
- `easy styleSelector` The search box is modified to be case-insensitive
- `easy fullLoader` **positive**、**negative**、**latent** added to the output items
- Fixed the issue that 'easy preSampling' and other similar node, latent could not be generated based on the batch index after passing in
- Fixed `easy svdLoader` error when the positive or negative is empty
- Fixed the error of SDXLClipModel in ComfyUI revision 2016[c2cb8e88] and above (the revision number was judged to be compatible with the old revision)
- Fixed `easy detailerFix` generation error when batch size is greater than 1
- Optimize the code, reduce a lot of redundant code and improve the running speed
</details>
<details>
<summary><b>v1.0.7</b></summary>
- Added `easy cascadeLoader` - stable cascade Loader
- Added `easy preSamplingCascade` - stable cascade preSampling Settings
- Added `easy fullCascadeKSampler` - stable cascade stage-c ksampler full
- Added `easy cascadeKSampler` - stable cascade stage-c ksampler simple
-
- Optimize the image to image[Example](https://github.com/yolain/ComfyUI-Easy-Use/blob/main/README.en.md#image-to-image)
</details>
<details>
<summary><b>v1.0.6</b></summary>
- Added `easy XYInputs: Checkpoint`
- Added `easy XYInputs: Lora`
- `easy seed` can manually switch the random seed when increasing the fixed seed value
- Fixed `easy fullLoader` and all loaders to automatically adjust the node size when switching LoRa
- Removed the original ttn image saving logic and adapted to the default image saving format extension of ComfyUI
</details>
<details>
<summary><b>v1.0.5</b></summary>
- Added `easy isSDXL`
- Added prompt word control on `easy svdLoader`, which can be used with open_clip model
- Added **populated_text** on `easy wildcards`, wildcard populated text can be output
</details>
<details>
<summary><b>v1.0.4</b></summary>
- `easy showAnything` added support for converting other types (e.g., tensor conditions, images, etc.)
- Added `easy showLoaderSettingsNames` can display the model and VAE name in the output loader assembly
- Added `easy promptList`
- Added `easy fooocusInpaintLoader` only the process of SDXLModel is supported
- Added **Logic** nodes
- Added `easy imageSave` - Image saving node with date conversion and aspect and height formatting
- Added `easy joinImageBatch`
- `easy kSamplerInpainting` Added the **patch** input value to be used with the FooocusInpaintLoader node
- Fixed xyplot error when with Pillow>9.5
- Fixed `easy wildcards` An error is reported when running with the PS extension
- Fixed `easy XYInputs: ControlNet` Error
- Fixed `easy loraStack` error when **toggle** is disabled
- Changing the first-time install node package no longer automatically replaces the theme, you need to manually adjust and refresh the page
- `easy imageSave` added **only_preivew**
- Adjust the `easy latentCompositeMaskedWithCond` node
</details>
<details>
<summary><b>v1.0.3</b></summary>
- Added `easy stylesSelector`
- Added **scale_soft_weights** in `easy controlnetLoader` and `easy controlnetLoaderADV`
- Added the queue progress bar setting item, which is not enabled by default
- Fixed `easy XYInputs: Sampler/Scheduler` Error
- Fixed the right menu has a problem when clicking the button
- Fixed `easy comfyLoader` error
- Fixed xyPlot error when connecting to zero123
- Fixed the error message in the loader when the prompt word was component
- Fixed `easy getNode` and `easy setNode` the title does not change when loading
- Fixed all samplers using subdirectories to store images
- Adjust the UI theme, divided into two sets of styles: the official default background and the dark black background, which can be switched in the color palette in the settings
- Modify the styles path to be compatible with other environments
</details>
<details>
<summary><b>v1.0.2</b></summary>
- Added `easy XYPlotAdvanced` and some nodes about `easy XYInputs`
- Added **Alt+1-Alt+9** Shortcut keys to quickly paste node presets for Node templates (corresponding to 1~9 sequences)
- Added a `📜Groups Map(EasyUse)` to the context menu.
- An `autocomplete` folder has been added, If you have [ComfyUI-Custom-Scripts](https://github.com/pythongosssss/ComfyUI-Custom-Scripts) installed, the txt files in that folder will be merged and overwritten to the autocomplete .txt file of the pyssss package at startup.
- Fixed XYPlot is not working when `a1111_prompt_style` is True
- Fixed UI loading failure in the new version of ComfyUI
- `easy XYInputs ModelMergeBlocks` Values can be imported from CSV files
- Fixed `easy pipeToBasicPipe` Bug
- Removed `easy imageRemBg`
- Remove the introductory diagram and workflow files from the package to reduce the package size
- Replaced the font file used in the generation of XY diagrams
</details>
<details>
<summary><b>v1.0.1</b></summary>
- Fixed `easy comfyLoader` error
- Fixed All nodes that contain the value of the image size
- Added `easy kSamplerInpainting`
- Added `easy pipeToBasicPipe`
- Fixed `width` and `height` can not customize in `easy svdLoader`
- Fixed all preview image path (Previously, it was not possible to preview the image on the Mac system)
- Fixed `vae_name` is not working in `easy fullLoader` and `easy a1111Loader` and `easy comfyLoader`
- Fixed `easy fullkSampler` outputs error
- Fixed `model_override` is not working in `easy fullLoader`
- Fixed `easy hiresFix` error
- Fixed `easy xyplot` font file path error
- Fixed seed that cannot be fixed when you convert `seed_num` to `easy seed`
- Fixed `easy pipeIn` inputs bug
- `easy preDetailerFix` have added a new parameter `optional_image`
- Fixed `easy zero123Loader` and `easy svdLoader` model into cache.
- Added `easy seed`
- Fixed `image_output` default value is "Preview"
- `easy fullLoader` and `easy a1111Loader` have added a new parameter `a1111_prompt_style`,that can reproduce the same image generated from stable-diffusion-webui on comfyui, but you need to install [ComfyUI_smZNodes](https://github.com/shiimizu/ComfyUI_smZNodes) to use this feature in the current version
</details>
<details>
<summary><b>v1.0.0</b></summary>
- Added `easy positive` - simple positive prompt text
- Added `easy negative` - simple negative prompt text
- Added `easy wildcards` - support for wildcards and hint text selected by Lora
- Added `easy portraitMaster` - PortraitMaster v2.2
- Added `easy loraStack` - Lora stack
- Added `easy fullLoader` - full version of the loader
- Added `easy zero123Loader` - simple zero123 loader
- Added `easy svdLoader` - easy svd loader
- Added `easy fullkSampler` - full version of the sampler (no separation)
- Added `easy hiresFix` - support for HD repair of Pipe
- Added `easy predetailerFix` and `easy DetailerFix` - support for Pipe detail fixing
- Added `easy ultralyticsDetectorPipe` and `easy samLoaderPipe` - Detect loader (detail fixed input)
- Added `easy pipein` `easy pipeout` - Pipe input and output
- Added `easy xyPlot` - simple xyplot (more controllable parameters will be updated in the future)
- Added `easy imageRemoveBG` - image to remove background
- Added `easy imagePixelPerfect` - image pixel perfect
- Added `easy poseEditor` - Pose editor
- New UI Theme (Obsidian) - Auto-load UI by default, which can also be changed in the settings
- Fixed `easy globalSeed` is not working
- Fixed an issue where all `seed_num` values were out of order due to [cg-use-everywhere](https://github.com/chrisgoringe/cg-use-everywhere) updating the chart in real time
- Fixed `easy imageSize`, `easy imageSizeBySide`, `easy imageSizeByLongerSide` as end nodes
- Fixed the bug that `seed_num` (random seed value) could not be read consistently in history
</details>
<details>
<summary><b>Updated at 12/14/2023</b></summary>
- `easy a1111Loader` and `easy comfyLoader` added `batch_size` of required input parameters
- Added the `easy controlnetLoaderADV` node
- `easy controlnetLoaderADV` and `easy controlnetLoader` added `control_net ` of optional input parameters
- `easy preSampling` and `easy preSamplingAdvanced` added `image_to_latent` optional input parameters
- Added the `easy imageSizeBySide` node, which can be output as a long side or a short side
</details>
<details>
<summary><b>Updated at 12/13/2023</b></summary>
- Added the `easy LLLiteLoader` node, if you have pre-installed the kohya-ss/ControlNet-LLLite-ComfyUI package, please move the model files in the models to `ComfyUI\models\controlnet\` (i.e. in the default controlnet path of comfy, please do not change the file name of the model, otherwise it will not be read).
- Modify `easy controlnetLoader` to the bottom of the loader category.
- Added size display for `easy imageSize` and `easy imageSizeByLongerSize` outputs.
</details>
<details>
<summary><b>Updated at 12/11/2023</b></summary>
- Added the `showSpentTime` node to display the time spent on image diffusion and the time spent on VAE decoding images
</details>
## The relevant node package involved
Disclaimer: Opened source was not easy. I have a lot of respect for the contributions of these original authors. I just did some integration and optimization.
| Nodes Name(Search Name) | Related libraries | Library-related node |
|:-------------------------------|:----------------------------------------------------------------------------|:-------------------------|
| easy setNode | [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) | diffus3.SetNode |
| easy getNode | [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) | diffus3.GetNode |
| easy bookmark | [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) | Bookmark 🔖 |
| easy portraitMarker | [comfyui-portrait-master](https://github.com/florestefano1975/comfyui-portrait-master) | Portrait Master |
| easy LLLiteLoader | [ControlNet-LLLite-ComfyUI](https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI) | LLLiteLoader |
| easy globalSeed | [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack) | Global Seed (Inspire) |
| easy preSamplingDynamicCFG | [sd-dynamic-thresholding](https://github.com/mcmonkeyprojects/sd-dynamic-thresholding) | DynamicThresholdingFull |
| dynamicThresholdingFull | [sd-dynamic-thresholding](https://github.com/mcmonkeyprojects/sd-dynamic-thresholding) | DynamicThresholdingFull |
| easy imageInsetCrop | [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) | ImageInsetCrop |
| easy poseEditor | [ComfyUI_Custom_Nodes_AlekPet](https://github.com/AlekPet/ComfyUI_Custom_Nodes_AlekPet) | poseNode |
| easy preSamplingLayerDiffusion | [ComfyUI-layerdiffusion](https://github.com/huchenlei/ComfyUI-layerdiffusion) | LayeredDiffusionApply... |
| easy dynamiCrafterLoader | [ComfyUI-layerdiffusion](https://github.com/ExponentialML/ComfyUI_Native_DynamiCrafter) | Apply Dynamicrafter |
| easy imageChooser | [cg-image-picker](https://github.com/chrisgoringe/cg-image-picker) | Preview Chooser |
| easy styleAlignedBatchAlign | [style_aligned_comfy](https://github.com/chrisgoringe/cg-image-picker) | styleAlignedBatchAlign |
| easy kolorsLoader | [ComfyUI-Kolors-MZ](https://github.com/MinusZoneAI/ComfyUI-Kolors-MZ) | kolorsLoader |
## Credits
[ComfyUI](https://github.com/comfyanonymous/ComfyUI) - Powerful and modular Stable Diffusion GUI
[ComfyUI-ComfyUI-Manager](https://github.com/ltdrdata/ComfyUI-Manager) - ComfyUI Manager
[tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes) - Pipe nodes (node bundles) allow users to reduce unnecessary connections
[ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) - Diffus3 gets and sets points that allow the user to detach the composition of the workflow
[ComfyUI-Impact-Pack](https://github.com/ltdrdata/ComfyUI-Impact-Pack) - General modpack 1
[ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack) - General Modpack 2
[ComfyUI-ResAdapter](https://github.com/jiaxiangc/ComfyUI-ResAdapter) - Make model generation independent of training resolution
[ComfyUI_IPAdapter_plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus) - Style migration
[ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID) - Face migration
[ComfyUI_PuLID](https://github.com/cubiq/PuLID_ComfyUI) - Face migration
[ComfyUI-Custom-Scripts](https://github.com/pythongosssss/ComfyUI-Custom-Scripts) - pyssss🐍
[cg-image-picker](https://github.com/chrisgoringe/cg-image-picker) - Image Preview Chooser
[ComfyUI_ExtraModels](https://github.com/city96/ComfyUI_ExtraModels) - DiT custom nodes
## Disclaimer
This software is provided “as is,” without warranty of any kind, express or implied, including but not limited to the warranties of merchantability, fitness for a particular purpose, and non-infringement. In no event shall the authors or copyright holders be liable for any claim, damages, or other liability, whether in an action of contract, tort, or otherwise, arising from, out of, or in connection with the software or the use or other dealings in the software.
Users are solely responsible for ensuring that their use of this software complies with all applicable laws and regulations in the jurisdiction where they use the software or publish content generated by it. The authors and copyright holders are not responsible for any violations of laws or regulations by users in their respective locations.
## ☕️ Donation
**Comfyui-Easy-Use** is an GPL-licensed open source project. In order to achieve better and sustainable development of the project, i expect to gain more backers. <br>
If my custom nodes has added value to your day, consider indulging in a coffee to fuel it further! <br>
💖You can support me in any of the following ways:
- [BiliBili](https://space.bilibili.com/1840885116)
- [Wechat / Alipay](https://github.com/user-attachments/assets/803469bd-ed6a-4fab-932d-50e5088a2d03)
## 🌟Stargazers
My gratitude extends to the generous souls who bestow a star. Your support is much appreciated!
[![Stargazers repo roster for @yolain/ComfyUI-Easy-Use](https://reporoster.com/stars/yolain/ComfyUI-Easy-Use)](https://github.com/yolain/ComfyUI-Easy-Use/stargazers)

View File

@@ -0,0 +1,95 @@
__version__ = "1.3.7"
import yaml
import json
import os
import folder_paths
import importlib
cwd_path = os.path.dirname(os.path.realpath(__file__))
comfy_path = folder_paths.base_path
NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
importlib.import_module('.py.routes', __name__)
importlib.import_module('.py.server', __name__)
nodes_list = ["util", "seed", "prompt", "loaders", "adapter", "inpaint", "preSampling", "samplers", "fix", "pipe", "xyplot", "image", "logic", "api", "deprecated"]
for module_name in nodes_list:
imported_module = importlib.import_module(".py.nodes.{}".format(module_name), __name__)
NODE_CLASS_MAPPINGS = {**NODE_CLASS_MAPPINGS, **imported_module.NODE_CLASS_MAPPINGS}
NODE_DISPLAY_NAME_MAPPINGS = {**NODE_DISPLAY_NAME_MAPPINGS, **imported_module.NODE_DISPLAY_NAME_MAPPINGS}
#Wildcards
from .py.libs.wildcards import read_wildcard_dict
wildcards_path = os.path.join(os.path.dirname(__file__), "wildcards")
if not os.path.exists(wildcards_path):
os.mkdir(wildcards_path)
# Add custom wildcards example
example_path = os.path.join(wildcards_path, "example.txt")
if not os.path.exists(example_path):
with open(example_path, 'w') as f:
text = "blue\nred\nyellow\ngreen\nbrown\npink\npurple\norange\nblack\nwhite"
f.write(text)
read_wildcard_dict(wildcards_path)
#Styles
styles_path = os.path.join(os.path.dirname(__file__), "styles")
samples_path = os.path.join(os.path.dirname(__file__), "styles", "samples")
if os.path.exists(styles_path):
if not os.path.exists(samples_path):
os.mkdir(samples_path)
else:
os.mkdir(styles_path)
os.mkdir(samples_path)
# Add custom styles example
example_path = os.path.join(styles_path, "your_styles.json.example")
if not os.path.exists(example_path):
import json
data = [
{
"name": "Example Style",
"name_cn": "示例样式",
"prompt": "(masterpiece), (best quality), (ultra-detailed), {prompt} ",
"negative_prompt": "text, watermark, logo"
},
]
# Write to file
with open(example_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=4, ensure_ascii=False)
web_default_version = 'v2'
# web directory
config_path = os.path.join(cwd_path, "config.yaml")
if os.path.isfile(config_path):
with open(config_path, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
if data and "WEB_VERSION" in data:
directory = f"web_version/{data['WEB_VERSION']}"
with open(config_path, 'w') as f:
yaml.dump(data, f)
elif web_default_version != 'v1':
if not data:
data = {'WEB_VERSION': web_default_version}
elif 'WEB_VERSION' not in data:
data = {**data, 'WEB_VERSION': web_default_version}
with open(config_path, 'w') as f:
yaml.dump(data, f)
directory = f"web_version/{web_default_version}"
else:
directory = f"web_version/v1"
if not os.path.exists(os.path.join(cwd_path, directory)):
print(f"web root {data['WEB_VERSION']} not found, using default")
directory = f"web_version/{web_default_version}"
WEB_DIRECTORY = directory
else:
directory = f"web_version/{web_default_version}"
WEB_DIRECTORY = directory
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', "WEB_DIRECTORY"]
print(f'\033[34m[ComfyUI-Easy-Use] server: \033[0mv{__version__} \033[92mLoaded\033[0m')
print(f'\033[34m[ComfyUI-Easy-Use] web root: \033[0m{os.path.join(cwd_path, directory)} \033[92mLoaded\033[0m')

View File

@@ -0,0 +1,26 @@
@echo off
set "requirements_txt=%~dp0\requirements.txt"
set "requirements_repair_txt=%~dp0\repair_dependency_list.txt"
set "python_exec=..\..\..\python_embeded\python.exe"
set "aki_python_exec=..\..\python\python.exe"
echo Installing EasyUse Requirements...
if exist "%python_exec%" (
echo Installing with ComfyUI Portable
"%python_exec%" -s -m pip install -r "%requirements_txt%"
)^
else if exist "%aki_python_exec%" (
echo Installing with ComfyUI Aki
"%aki_python_exec%" -s -m pip install -r "%requirements_txt%"
for /f "delims=" %%i in (%requirements_repair_txt%) do (
%aki_python_exec% -s -m pip install -i https://pypi.tuna.tsinghua.edu.cn/simple "%%i"
)
)^
else (
echo Installing with system Python
pip install -r "%requirements_txt%"
)
pause

View File

@@ -0,0 +1,24 @@
#!/bin/bash
requirements_txt="$(dirname "$0")/requirements.txt"
requirements_repair_txt="$(dirname "$0")/repair_dependency_list.txt"
python_exec="../../../python_embeded/python.exe"
aki_python_exec="../../python/python.exe"
echo "Installing EasyUse Requirements..."
if [ -f "$python_exec" ]; then
echo "Installing with ComfyUI Portable"
"$python_exec" -s -m pip install -r "$requirements_txt"
elif [ -f "$aki_python_exec" ]; then
echo "Installing with ComfyUI Aki"
"$aki_python_exec" -s -m pip install -r "$requirements_txt"
while IFS= read -r line; do
"$aki_python_exec" -s -m pip install -i https://pypi.tuna.tsinghua.edu.cn/simple "$line"
done < "$requirements_repair_txt"
else
echo "Installing with system Python"
pip install -r "$requirements_txt"
fi
read -p "Press any key to continue..."

View File

@@ -0,0 +1,31 @@
{
"settingsCategories": {
"Hotkeys": "Hotkeys",
"Nodes": "Nodes",
"NodesMap": "NodesMap",
"StylesSelector": "StylesSelector"
},
"nodeCategories": {
"Util": "Util",
"Seed": "Seed",
"Prompt": "Prompt",
"Loaders": "Loaders",
"Adapter": "Adapter",
"Inpaint": "Inpaint",
"PreSampling": "PreSampling",
"Sampler": "Sampler",
"Fix": "Fix",
"Pipe": "Pipe",
"XY Inputs": "XY Inputs",
"Image": "Image",
"Segmentation": "Segmentation",
"\uD83D\uDEAB Deprecated": "\uD83D\uDEAB Deprecated",
"Type": "Type",
"Math": "Math",
"Switch": "Switch",
"Index Switch": "Index Switch",
"While Loop": "While Loop",
"For Loop": "For Loop",
"LoadImage": "Load Image"
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,67 @@
{
"EasyUse_Hotkeys_AddGroup": {
"name": "Enable Shift+g to add the selected nodes to a group",
"tooltip": "From v1.2.39, you can use Ctrl+g instead"
},
"EasyUse_Hotkeys_cleanVRAMUsed": {
"name": "Enable Shift+r to unload model and node cache"
},
"EasyUse_Hotkeys_toggleNodesMap": {
"name": "Enable Shift+m to toggle nodes map"
},
"EasyUse_Hotkeys_AlignSelectedNodes": {
"name": "Enable Shift+Up/Down/Left/Right and Shift+Ctrl+Alt+Left/Right to align selected nodes",
"tooltip": "Shift+Up/Down/Left/Right can align selected nodes, Shift+Ctrl+Alt+Left/Right can distribute nodes horizontally/vertically"
},
"EasyUse_Hotkeys_NormalizeSelectedNodes": {
"name": "Enable Shift+Ctrl+Left/Right to normalize selected nodes",
"tooltip": "Enable Shift+Ctrl+Left to normalize width and Shift+Ctrl+Right to normalize height"
},
"EasyUse_Hotkeys_NodesTemplate": {
"name": "Enable Alt+1~9 to paste node templates into the workflow"
},
"EasyUse_Hotkeys_JumpNearestNodes": {
"name": "Enable Up/Down/Left/Right to jump to the nearest node"
},
"EasyUse_ContextMenu_SubDirectories": {
"name": "Enable automatic nesting of subdirectories in the context menu"
},
"EasyUse_ContextMenu_ModelsThumbnails": {
"name": "Enable model preview thumbnails"
},
"EasyUse_ContextMenu_NodesSort": {
"name": "Enable A~Z sorting of new nodes in the context menu"
},
"EasyUse_ContextMenu_QuickOptions": {
"name": "Use three quick buttons in the context menu",
"options": {
"At the forefront": "At the forefront",
"At the end": "At the end",
"Disable": "Disable"
}
},
"EasyUse_Nodes_Runtime": {
"name": "Enable node runtime display"
},
"EasyUse_Nodes_ChainGetSet": {
"name": "Enable chaining of get and set points with the parent node"
},
"EasyUse_NodesMap_Sorting": {
"name": "Manage nodes group sorting mode",
"tooltip": "Automatically sort by default. If set to manual, groups can be drag and dropped and the order will be saved.",
"options": {
"Auto sorting": "Auto sorting",
"Manual drag&drop sorting": "Manual drag&drop sorting"
}
},
"EasyUse_NodesMap_DisplayNodeID": {
"name": "Enable node ID display"
},
"EasyUse_NodesMap_DisplayGroupOnly": {
"name": "Show groups only"
},
"EasyUse_NodesMap_Enable": {
"name": "Enable Group Map",
"tooltip": "You need to refresh the page to update successfully"
}
}

View File

@@ -0,0 +1,30 @@
{
"settingsCategories": {
"Hotkeys": "Raccourcis",
"Nodes": "Nœuds",
"NodesMap": "Carte des nœuds"
},
"nodeCategories": {
"Util": "Utilitaire",
"Seed": "Graine",
"Prompt": "Prompt",
"Loaders": "Chargeurs",
"Adapter": "Adaptateur",
"Inpaint": "Retouche",
"PreSampling": "Pré-échantillonnage",
"Sampler": "Échantillonneur",
"Fix": "Correction",
"Pipe": "Pipeline",
"XY Inputs": "Entrées XY",
"Image": "Image",
"Segmentation": "Segmentation",
"\uD83D\uDEAB Deprecated": "\uD83D\uDEAB Obsolète",
"Type": "Type",
"Math": "Mathématiques",
"Switch": "Interrupteur",
"Index Switch": "Interrupteur d'index",
"While Loop": "Boucle While",
"For Loop": "Boucle For",
"LoadImage": "Charger l'image"
}
}

View File

@@ -0,0 +1,67 @@
{
"EasyUse_Hotkeys_AddGroup": {
"name": "Activer Shift+g pour ajouter les nœuds sélectionnés à un groupe",
"tooltip": "Depuis la v1.2.39, vous pouvez utiliser Ctrl+g à la place"
},
"EasyUse_Hotkeys_cleanVRAMUsed": {
"name": "Activer Shift+r pour décharger le cache du modèle et des nœuds"
},
"EasyUse_Hotkeys_toggleNodesMap": {
"name": "Activer Shift+m pour basculer la carte des nœuds"
},
"EasyUse_Hotkeys_AlignSelectedNodes": {
"name": "Activer Shift+Up/Down/Left/Right et Shift+Ctrl+Alt+Left/Right pour aligner les nœuds sélectionnés",
"tooltip": "Shift+Up/Down/Left/Right peut aligner les nœuds sélectionnés, Shift+Ctrl+Alt+Left/Right peut les répartir horizontalement/verticalement"
},
"EasyUse_Hotkeys_NormalizeSelectedNodes": {
"name": "Activer Shift+Ctrl+Left/Right pour normaliser les nœuds sélectionnés",
"tooltip": "Activer Shift+Ctrl+Left pour normaliser la largeur et Shift+Ctrl+Right pour normaliser la hauteur"
},
"EasyUse_Hotkeys_NodesTemplate": {
"name": "Activer Alt+1~9 pour coller les modèles de nœuds dans le workflow"
},
"EasyUse_Hotkeys_JumpNearestNodes": {
"name": "Activer Up/Down/Left/Right pour passer au nœud le plus proche"
},
"EasyUse_ContextMenu_SubDirectories": {
"name": "Activer l'imbrication automatique des sous-répertoires dans le menu contextuel"
},
"EasyUse_ContextMenu_ModelsThumbnails": {
"name": "Activer les vignettes d'aperçu du modèle"
},
"EasyUse_ContextMenu_NodesSort": {
"name": "Activer le tri A~Z des nouveaux nœuds dans le menu contextuel"
},
"EasyUse_ContextMenu_QuickOptions": {
"name": "Utiliser trois boutons rapides dans le menu contextuel",
"options": {
"At the forefront": "À l'avant-plan",
"At the end": "À la fin",
"Disable": "Désactiver"
}
},
"EasyUse_Nodes_Runtime": {
"name": "Activer l'affichage du temps d'exécution des nœuds"
},
"EasyUse_Nodes_ChainGetSet": {
"name": "Activer le chaînage des points get et set avec le nœud parent"
},
"EasyUse_NodesMap_Sorting": {
"name": "Gérer le mode de tri des groupes de nœuds",
"tooltip": "Tri automatique par défaut. Si défini sur manuel, les groupes peuvent être glissés-déposés et l'ordre sera sauvegardé.",
"options": {
"Auto sorting": "Tri automatique",
"Manual drag&drop sorting": "Tri manuel par glisser-déposer"
}
},
"EasyUse_NodesMap_DisplayNodeID": {
"name": "Activer l'affichage de l'ID du nœud"
},
"EasyUse_NodesMap_DisplayGroupOnly": {
"name": "Afficher uniquement les groupes"
},
"EasyUse_NodesMap_Enable": {
"name": "Activer la carte des groupes",
"tooltip": "Vous devez actualiser la page pour mettre à jour"
}
}

View File

@@ -0,0 +1,30 @@
{
"settingsCategories": {
"Hotkeys": "ショートカットキー",
"Nodes": "ノード",
"NodesMap": "ノードマップ"
},
"nodeCategories": {
"Util": "ユーティリティ",
"Seed": "シード",
"Prompt": "プロンプト",
"Loaders": "ローダー",
"Adapter": "アダプター",
"Inpaint": "インペイント",
"PreSampling": "プリサンプリング",
"Sampler": "サンプラー",
"Fix": "フィックス",
"Pipe": "パイプ",
"XY Inputs": "XY入力",
"Image": "画像",
"Segmentation": "セグメンテーション",
"\uD83D\uDEAB Deprecated": "🚫 非推奨",
"Type": "タイプ",
"Math": "数学",
"Switch": "スイッチ",
"Index Switch": "インデックススイッチ",
"While Loop": "Whileループ",
"For Loop": "Forループ",
"LoadImage": "画像読み込み"
}
}

View File

@@ -0,0 +1,67 @@
{
"EasyUse_Hotkeys_AddGroup": {
"name": "Shift+gを使用して選択したードをグループに追加する",
"tooltip": "v1.2.39以降、Ctrl+gが使用できます"
},
"EasyUse_Hotkeys_cleanVRAMUsed": {
"name": "Shift+rを使用してモデルおよびードキャッシュをアンロードする"
},
"EasyUse_Hotkeys_toggleNodesMap": {
"name": "Shift+mを使用してードマップを表示/非表示にします"
},
"EasyUse_Hotkeys_AlignSelectedNodes": {
"name": "Shift+上/下/左/右およびShift+Ctrl+Alt+左/右を使用して選択したノードを整列する",
"tooltip": "Shift+上/下/左/右で選択したードを整列し、Shift+Ctrl+Alt+左/右で水平方向/垂直方向に分布させる"
},
"EasyUse_Hotkeys_NormalizeSelectedNodes": {
"name": "Shift+Ctrl+左/右を使用して選択したノードのサイズを正規化する",
"tooltip": "Shift+Ctrl+左で幅を、Shift+Ctrl+右で高さを正規化する"
},
"EasyUse_Hotkeys_NodesTemplate": {
"name": "Alt+1~9を使用してワークフローにードテンプレートを貼り付ける"
},
"EasyUse_Hotkeys_JumpNearestNodes": {
"name": "上/下/左/右を使用して最も近いノードにジャンプする"
},
"EasyUse_ContextMenu_SubDirectories": {
"name": "コンテキストメニューでサブディレクトリを自動でネストする"
},
"EasyUse_ContextMenu_ModelsThumbnails": {
"name": "モデルプレビューサムネイルを有効にする"
},
"EasyUse_ContextMenu_NodesSort": {
"name": "コンテキストメニューで新規ードをA~Z順に並べ替える"
},
"EasyUse_ContextMenu_QuickOptions": {
"name": "コンテキストメニューで3つのクイックボタンを使用する",
"options": {
"At the forefront": "最前面に",
"At the end": "最後に",
"Disable": "無効"
}
},
"EasyUse_Nodes_Runtime": {
"name": "ノードの実行時間表示を有効にする"
},
"EasyUse_Nodes_ChainGetSet": {
"name": "親ノードと取得/設定ポイントを連結することを有効にする"
},
"EasyUse_NodesMap_Sorting": {
"name": "ノードグループの並べ替えモードを管理する",
"tooltip": "デフォルトで自動的に並べ替えます。マニュアルに設定した場合、グループをドラッグアンドドロップで並べ替え、順序が保存されます。",
"options": {
"Auto sorting": "自動並べ替え",
"Manual drag&drop sorting": "手動ドラッグアンドドロップによる並べ替え"
}
},
"EasyUse_NodesMap_DisplayNodeID": {
"name": "ードIDの表示を有効にする"
},
"EasyUse_NodesMap_DisplayGroupOnly": {
"name": "グループのみ表示する"
},
"EasyUse_NodesMap_Enable": {
"name": "グループマップを有効にする",
"tooltip": "ページを更新する必要があります"
}
}

View File

@@ -0,0 +1,30 @@
{
"settingsCategories": {
"Hotkeys": "단축키",
"Nodes": "노드",
"NodesMap": "노드 맵"
},
"nodeCategories": {
"Util": "유틸",
"Seed": "시드",
"Prompt": "프롬프트",
"Loaders": "로더",
"Adapter": "어댑터",
"Inpaint": "인페인트",
"PreSampling": "사전 샘플링",
"Sampler": "샘플러",
"Fix": "픽스",
"Pipe": "파이프",
"XY Inputs": "XY 입력",
"Image": "이미지",
"Segmentation": "분할",
"\uD83D\uDEAB Deprecated": "\uD83D\uDEAB 사용 중단",
"Type": "유형",
"Math": "수학",
"Switch": "스위치",
"Index Switch": "인덱스 스위치",
"While Loop": "while 루프",
"For Loop": "for 루프",
"LoadImage": "이미지 로드"
}
}

View File

@@ -0,0 +1,67 @@
{
"EasyUse_Hotkeys_AddGroup": {
"name": "Shift+g 를 사용하여 선택된 노드를 그룹에 추가합니다",
"tooltip": "v1.2.39부터는 Ctrl+g 를 사용할 수 있습니다"
},
"EasyUse_Hotkeys_cleanVRAMUsed": {
"name": "Shift+r 를 사용하여 모델 및 노드 캐시를 언로드합니다"
},
"EasyUse_Hotkeys_toggleNodesMap": {
"name": "Shift+m 를 사용하여 노드 맵을 전환합니다"
},
"EasyUse_Hotkeys_AlignSelectedNodes": {
"name": "Shift+Up/Down/Left/Right 와 Shift+Ctrl+Alt+Left/Right 를 사용하여 선택된 노드를 정렬합니다",
"tooltip": "Shift+Up/Down/Left/Right 는 선택된 노드를 정렬하며, Shift+Ctrl+Alt+Left/Right 는 노드를 수평/수직으로 분배합니다"
},
"EasyUse_Hotkeys_NormalizeSelectedNodes": {
"name": "Shift+Ctrl+Left/Right 를 사용하여 선택된 노드를 정규화합니다",
"tooltip": "Shift+Ctrl+Left 는 너비를, Shift+Ctrl+Right 는 높이를 정규화합니다"
},
"EasyUse_Hotkeys_NodesTemplate": {
"name": "Alt+1~9 를 사용하여 워크플로우에 노드 템플릿을 붙여넣습니다"
},
"EasyUse_Hotkeys_JumpNearestNodes": {
"name": "Up/Down/Left/Right 를 사용하여 가장 가까운 노드로 이동합니다"
},
"EasyUse_ContextMenu_SubDirectories": {
"name": "컨텍스트 메뉴에서 자동으로 하위 디렉토리를 중첩합니다"
},
"EasyUse_ContextMenu_ModelsThumbnails": {
"name": "모델 미리보기 썸네일을 활성화합니다"
},
"EasyUse_ContextMenu_NodesSort": {
"name": "컨텍스트 메뉴에서 새로운 노드를 A~Z 순으로 정렬합니다"
},
"EasyUse_ContextMenu_QuickOptions": {
"name": "컨텍스트 메뉴에 3개의 빠른 옵션 버튼을 사용합니다",
"options": {
"At the forefront": "앞쪽에",
"At the end": "뒤쪽에",
"Disable": "비활성화"
}
},
"EasyUse_Nodes_Runtime": {
"name": "노드 실행 시간 표시를 활성화합니다"
},
"EasyUse_Nodes_ChainGetSet": {
"name": "부모 노드와 연결된 get/ set 포인트 체이닝을 활성화합니다"
},
"EasyUse_NodesMap_Sorting": {
"name": "노드 그룹 정렬 모드를 관리합니다",
"tooltip": "기본값은 자동 정렬입니다. 수동으로 설정하면 그룹을 드래그 앤 드롭할 수 있으며 순서가 저장됩니다.",
"options": {
"Auto sorting": "자동 정렬",
"Manual drag&drop sorting": "수동 드래그 앤 드롭 정렬"
}
},
"EasyUse_NodesMap_DisplayNodeID": {
"name": "노드 ID 표시를 활성화합니다"
},
"EasyUse_NodesMap_DisplayGroupOnly": {
"name": "그룹만 표시합니다"
},
"EasyUse_NodesMap_Enable": {
"name": "그룹 맵을 활성화합니다",
"tooltip": "업데이트를 위해 페이지를 새로고침해야 합니다"
}
}

View File

@@ -0,0 +1,30 @@
{
"settingsCategories": {
"Hotkeys": "Горячие клавиши",
"Nodes": "Узлы",
"NodesMap": "Карта узлов"
},
"nodeCategories": {
"Util": "Утилиты",
"Seed": "Сид",
"Prompt": "Подсказка",
"Loaders": "Загрузчики",
"Adapter": "Адаптер",
"Inpaint": "Ретушь",
"PreSampling": "Предвыборка",
"Sampler": "Сэмплер",
"Fix": "Исправление",
"Pipe": "Конвейер",
"XY Inputs": "Ввод XY",
"Image": "Изображение",
"Segmentation": "Сегментация",
"\uD83D\uDEAB Deprecated": "\uD83D\uDEAB Устарело",
"Type": "Тип",
"Math": "Математика",
"Switch": "Переключатель",
"Index Switch": "Переключатель индексов",
"While Loop": "Цикл while",
"For Loop": "Цикл for",
"LoadImage": "Загрузка изображения"
}
}

View File

@@ -0,0 +1,67 @@
{
"EasyUse_Hotkeys_AddGroup": {
"name": "Включить Shift+g для добавления выделенных узлов в группу",
"tooltip": "Начиная с версии v1.2.39, можно использовать Ctrl+g"
},
"EasyUse_Hotkeys_cleanVRAMUsed": {
"name": "Включить Shift+r для выгрузки модели и кэша узлов"
},
"EasyUse_Hotkeys_toggleNodesMap": {
"name": "Включить Shift+m для переключения карты узлов"
},
"EasyUse_Hotkeys_AlignSelectedNodes": {
"name": "Включить Shift+Стрелки для выравнивания выделенных узлов и Shift+Ctrl+Alt+Стрелки для распределения узлов по горизонтали/вертикали",
"tooltip": "Shift+Стрелки выравнивают выделенные узлы, Shift+Ctrl+Alt+Стрелки распределяют узлы по горизонтали/вертикали"
},
"EasyUse_Hotkeys_NormalizeSelectedNodes": {
"name": "Включить Shift+Ctrl+Стрелки для нормализации выделенных узлов",
"tooltip": "Включить Shift+Ctrl+Лево для нормализации ширины и Shift+Ctrl+Право для нормализации высоты"
},
"EasyUse_Hotkeys_NodesTemplate": {
"name": "Включить Alt+1~9 для вставки шаблонов узлов в рабочий процесс"
},
"EasyUse_Hotkeys_JumpNearestNodes": {
"name": "Включить Стрелки для перехода к ближайшему узлу"
},
"EasyUse_ContextMenu_SubDirectories": {
"name": "Включить автоматическое вложение подкаталогов в контекстном меню"
},
"EasyUse_ContextMenu_ModelsThumbnails": {
"name": "Включить превью миниатюр моделей"
},
"EasyUse_ContextMenu_NodesSort": {
"name": "Включить A~Z сортировку новых узлов в контекстном меню"
},
"EasyUse_ContextMenu_QuickOptions": {
"name": "Использовать три быстрых кнопки в контекстном меню",
"options": {
"At the forefront": "В начале",
"At the end": "В конце",
"Disable": "Отключено"
}
},
"EasyUse_Nodes_Runtime": {
"name": "Включить отображение времени выполнения узлов"
},
"EasyUse_Nodes_ChainGetSet": {
"name": "Включить связывание точек получения и установки с родительским узлом"
},
"EasyUse_NodesMap_Sorting": {
"name": "Управление режимом сортировки групп узлов",
"tooltip": "По умолчанию автоматическая сортировка. При ручном режиме группы можно перемещать методом перетаскивания, и порядок будет сохранён.",
"options": {
"Auto sorting": "Автоматическая сортировка",
"Manual drag&drop sorting": "Ручная сортировка перетаскиванием"
}
},
"EasyUse_NodesMap_DisplayNodeID": {
"name": "Включить отображение ID узлов"
},
"EasyUse_NodesMap_DisplayGroupOnly": {
"name": "Показывать только группы"
},
"EasyUse_NodesMap_Enable": {
"name": "Включить карту групп",
"tooltip": "Необходимо обновить страницу для успешного обновления"
}
}

View File

@@ -0,0 +1,33 @@
{
"settingsCategories": {
"Hotkeys": "快捷键",
"Nodes": "节点相关",
"NodesMap": "管理节点组",
"StylesSelector": "样式选择器",
"MultiAngle": "摄影机多角度提示词"
},
"nodeCategories": {
"Util": "工具",
"Seed": "随机种",
"Prompt": "提示词",
"Loaders": "模型加载器",
"Adapter": "模型适配器",
"Inpaint": "内补重绘",
"PreSampling": "预采样参数",
"Sampler": "采样器",
"Fix": "修复相关",
"Pipe": "节点束",
"XY Inputs": "XY图表输入项",
"Image": "图像",
"Segmentation": "分割",
"Logic": "逻辑",
"\uD83D\uDEAB Deprecated": "\uD83D\uDEAB 已弃用",
"Type": "类型",
"Math": "数学计算",
"Switch": "开关",
"Index Switch": "索引开关",
"While Loop": "While循环",
"For Loop": "For循环",
"LoadImage": "加载图像"
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,86 @@
{
"EasyUse_Hotkeys_AddGroup": {
"name": "启用 Shift+g 键将选中的节点添加一个组",
"tooltip": "从v1.2.39开始可以使用Ctrl+g代替"
},
"EasyUse_Hotkeys_cleanVRAMUsed": {
"name": "启用 Shift+r 键卸载模型和节点缓存"
},
"EasyUse_Hotkeys_toggleNodesMap": {
"name": "启用 Shift+m 键显隐管理节点组"
},
"EasyUse_Hotkeys_AlignSelectedNodes": {
"name": "启用 Shift+上/下/左/右 和 Shift+Ctrl+Alt+左/右 键对齐选中的节点",
"tooltip": "Shift+上/下/左/右 可以对齐选中的节点, Shift+Ctrl+Alt+左/右 可以水平/垂直分布节点"
},
"EasyUse_Hotkeys_NormalizeSelectedNodes": {
"name": "启用 Shift+Ctrl+左/右 键规范化选中的节点",
"tooltip": "启用 Shift+Ctrl+左 键规范化宽度和 Shift+Ctrl+右 键规范化高度"
},
"EasyUse_Hotkeys_NodesTemplate": {
"name": "启用 Alt+1~9 从节点模板粘贴到工作流中"
},
"EasyUse_Hotkeys_JumpNearestNodes": {
"name": "启用 上/下/左/右 键跳转到最近的前后节点"
},
"EasyUse_ContextMenu_SubDirectories": {
"name": "启用上下文菜单自动嵌套子目录"
},
"EasyUse_ContextMenu_ModelsThumbnails": {
"name": "启动模型预览图显示"
},
"EasyUse_ContextMenu_NodesSort": {
"name": "启用右键菜单中新建节点A~Z排序"
},
"EasyUse_ContextMenu_QuickOptions": {
"name": "在右键菜单中使用三个快捷按钮",
"options": {
"At the forefront": "在最前面",
"At the end": "在最后面",
"Disable": "禁用"
}
},
"EasyUse_Nodes_Runtime": {
"name": "启动节点运行时间显示"
},
"EasyUse_Nodes_ChainGetSet": {
"name": "启用将获取点和设置点与父节点链在一起"
},
"EasyUse_NodesMap_Sorting": {
"name": "管理节点组排序模式",
"tooltip": "默认自动排序,如果设置为手动,组可以拖放并保存排序结果。",
"options": {
"Auto sorting": "自动排序",
"Manual drag&drop sorting": "手动拖拽排序"
}
},
"EasyUse_NodesMap_DisplayNodeID": {
"name": "启用节点ID显示"
},
"EasyUse_NodesMap_DisplayGroupOnly": {
"name": "仅显示组"
},
"EasyUse_NodesMap_Enable": {
"name": "启用管理节点组",
"tooltip": "您需要刷新页面以成功更新"
},
"EasyUse_StylesSelector_DisplayType": {
"name": "样式选择器显示类型",
"tooltip": "样式选择器显示类型,如果设置为“网格”,则显示为网格,如果设置为“列表”,则显示为列表",
"options": {
"Grid": "网格",
"List": "列表"
}
},
"EasyUse_MultiAngle_InvertRotate": {
"name": "启用反转旋转模式",
"tooltip": "在多角度节点中启用反转旋转模式使旋转方向与大多数3D软件一致"
},
"EasyUse_MultiAngle_HollowMode": {
"name": "启用多角度镂空展示模式",
"tooltip": "在多角度节点中启用镂空展示模式,可以更直观地查看相机角度"
},
"EasyUse_MultiAngle_AddAnglePrompt": {
"name": "启用添加多角度提示词"
}
}

View File

@@ -0,0 +1,34 @@
import folder_paths
import os
def add_folder_path_and_extensions(folder_name, full_folder_paths, extensions):
for full_folder_path in full_folder_paths:
folder_paths.add_model_folder_path(folder_name, full_folder_path)
if folder_name in folder_paths.folder_names_and_paths:
current_paths, current_extensions = folder_paths.folder_names_and_paths[folder_name]
updated_extensions = current_extensions | extensions
folder_paths.folder_names_and_paths[folder_name] = (current_paths, updated_extensions)
else:
folder_paths.folder_names_and_paths[folder_name] = (full_folder_paths, extensions)
image_suffixs = set([".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".svg", ".ico", ".apng", ".tif", ".hdr", ".exr"])
model_path = folder_paths.models_dir
add_folder_path_and_extensions("ultralytics_bbox", [os.path.join(model_path, "ultralytics", "bbox")], folder_paths.supported_pt_extensions)
add_folder_path_and_extensions("ultralytics_segm", [os.path.join(model_path, "ultralytics", "segm")], folder_paths.supported_pt_extensions)
add_folder_path_and_extensions("ultralytics", [os.path.join(model_path, "ultralytics")], folder_paths.supported_pt_extensions)
add_folder_path_and_extensions("mmdets_bbox", [os.path.join(model_path, "mmdets", "bbox")], folder_paths.supported_pt_extensions)
add_folder_path_and_extensions("mmdets_segm", [os.path.join(model_path, "mmdets", "segm")], folder_paths.supported_pt_extensions)
add_folder_path_and_extensions("mmdets", [os.path.join(model_path, "mmdets")], folder_paths.supported_pt_extensions)
add_folder_path_and_extensions("sams", [os.path.join(model_path, "sams")], folder_paths.supported_pt_extensions)
add_folder_path_and_extensions("onnx", [os.path.join(model_path, "onnx")], {'.onnx'})
add_folder_path_and_extensions("instantid", [os.path.join(model_path, "instantid")], folder_paths.supported_pt_extensions)
add_folder_path_and_extensions("pulid", [os.path.join(model_path, "pulid")], folder_paths.supported_pt_extensions)
add_folder_path_and_extensions("layer_model", [os.path.join(model_path, "layer_model")], folder_paths.supported_pt_extensions)
add_folder_path_and_extensions("rembg", [os.path.join(model_path, "rembg")], folder_paths.supported_pt_extensions)
add_folder_path_and_extensions("ipadapter", [os.path.join(model_path, "ipadapter")], folder_paths.supported_pt_extensions)
add_folder_path_and_extensions("dynamicrafter_models", [os.path.join(model_path, "dynamicrafter_models")], folder_paths.supported_pt_extensions)
add_folder_path_and_extensions("mediapipe", [os.path.join(model_path, "mediapipe")], set(['.tflite','.pth']))
add_folder_path_and_extensions("inpaint", [os.path.join(model_path, "inpaint")], folder_paths.supported_pt_extensions)
add_folder_path_and_extensions("prompt_generator", [os.path.join(model_path, "prompt_generator")], folder_paths.supported_pt_extensions)
add_folder_path_and_extensions("t5", [os.path.join(model_path, "t5")], folder_paths.supported_pt_extensions)
add_folder_path_and_extensions("llm", [os.path.join(model_path, "LLM")], folder_paths.supported_pt_extensions)

View File

@@ -0,0 +1,6 @@
from .libs.loader import easyLoader
from .libs.sampler import easySampler
sampler = easySampler()
easyCache = easyLoader()

View File

@@ -0,0 +1,407 @@
import os
import folder_paths
from pathlib import Path
BASE_RESOLUTIONS = [
("width", "height"),
(512, 512),
(512, 768),
(576, 1024),
(768, 512),
(768, 768),
(768, 1024),
(768, 1280),
(768, 1344),
(768, 1536),
(816, 1920),
(832, 1152),
(832, 1216),
(896, 1152),
(896, 1088),
(1024, 1024),
(1024, 576),
(1024, 768),
(1080, 1920),
(1440, 2560),
(1088, 896),
(1216, 832),
(1152, 832),
(1152, 896),
(1280, 768),
(1344, 768),
(1536, 640),
(1536, 768),
(1920, 816),
(1920, 1080),
(2560, 1440),
]
MAX_SEED_NUM = 1125899906842624
RESOURCES_DIR = os.path.join(Path(__file__).parent.parent, "resources")
# inpaint
INPAINT_DIR = os.path.join(folder_paths.models_dir, "inpaint")
FOOOCUS_STYLES_DIR = os.path.join(Path(__file__).parent.parent, "styles")
FOOOCUS_STYLES_SAMPLES = 'https://raw.githubusercontent.com/lllyasviel/Fooocus/main/sdxl_styles/samples/'
FOOOCUS_INPAINT_HEAD = {
"fooocus_inpaint_head": {
"model_url": "https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/fooocus_inpaint_head.pth"
}
}
FOOOCUS_INPAINT_PATCH = {
"inpaint_v26 (1.32GB)": {
"model_url": "https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint_v26.fooocus.patch"
},
"inpaint_v25 (2.58GB)": {
"model_url": "https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint_v25.fooocus.patch"
},
"inpaint (1.32GB)": {
"model_url": "https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint.fooocus.patch"
},
}
BRUSHNET_MODELS = {
"random_mask": {
"sd1": {
"model_url": "https://huggingface.co/Kijai/BrushNet-fp16/resolve/main/brushnet_random_mask_fp16.safetensors"
},
"sdxl": {
"model_url": "https://huggingface.co/yolain/brushnet/resolve/main/brushnet_random_mask_sdxl.safetensors"
}
},
"segmentation_mask": {
"sd1": {
"model_url": "https://huggingface.co/Kijai/BrushNet-fp16/resolve/main/brushnet_segmentation_mask_fp16.safetensors"
},
"sdxl": {
"model_url": "https://huggingface.co/yolain/brushnet/resolve/main/brushnet_segmentation_mask_sdxl.safetensors"
}
}
}
POWERPAINT_MODELS = {
"base_fp16": {
"model_url": "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/text_encoder/model.fp16.safetensors"
},
"v2.1": {
"model_url": "https://huggingface.co/JunhaoZhuang/PowerPaint-v2-1/resolve/main/PowerPaint_Brushnet/diffusion_pytorch_model.safetensors",
"clip_url": "https://huggingface.co/JunhaoZhuang/PowerPaint-v2-1/resolve/main/PowerPaint_Brushnet/pytorch_model.bin",
}
}
# layerDiffuse
LAYER_DIFFUSION_DIR = os.path.join(folder_paths.models_dir, "layer_model")
LAYER_DIFFUSION_VAE = {
"encode": {
"sdxl": {
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/vae_transparent_encoder.safetensors"
}
},
"decode": {
"sd1": {
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_vae_transparent_decoder.safetensors"
},
"sdxl": {
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/vae_transparent_decoder.safetensors"
}
}
}
LAYER_DIFFUSION = {
"Attention Injection": {
"sd1": {
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_transparent_attn.safetensors"
},
"sdxl": {
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_transparent_attn.safetensors"
},
},
"Conv Injection": {
"sdxl": {
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_transparent_conv.safetensors"
},
"sd1": {
"model_url": None
}
},
"Everything": {
"sd1": {
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_joint.safetensors"
},
"sdxl": {
"model_url": None
}
},
"Foreground": {
"sd1": {
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_fg2bg.safetensors"
},
"sdxl": {
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_fg2ble.safetensors"
}
},
"Foreground to Background": {
"sd1": {
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_fg2bg.safetensors"
},
"sdxl": {
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_fgble2bg.safetensors"
}
},
"Background": {
"sd1": {
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_bg2fg.safetensors"
},
"sdxl": {
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_bg2ble.safetensors"
}
},
"Background to Foreground": {
"sd1": {
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_bg2fg.safetensors"
},
"sdxl": {
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_bgble2fg.safetensors"
}
},
}
# IC Light
IC_LIGHT_MODELS = {
"Foreground": {
"sd1": {
"model_url": "https://huggingface.co/huchenlei/IC-Light-ldm/resolve/main/iclight_sd15_fc_unet_ldm.safetensors"
},
"sdxl": {
"model_url": None
}
},
"Foreground&Background": {
"sd1": {
"model_url": "https://huggingface.co/huchenlei/IC-Light-ldm/resolve/main/iclight_sd15_fbc_unet_ldm.safetensors"
},
"sdxl": {
"model_url": None
}
}
}
# REMBG
REMBG_DIR = os.path.join(folder_paths.models_dir, "rembg")
REMBG_MODELS = {
"RMBG-1.4": {
"model_url": "https://huggingface.co/briaai/RMBG-1.4/resolve/main/model.pth"
},
"RMBG-2.0": {
"model_url": "briaai/RMBG-2.0"
},
"BEN2": {
"model_url": "https://huggingface.co/PramaLLC/BEN2/resolve/main/BEN2_Base.pth"
}
}
#ipadapter
IPADAPTER_DIR = os.path.join(folder_paths.models_dir, "ipadapter")
IPADAPTER_MODELS = {
"LIGHT - SD1.5 only (low strength)": {
"sd1": {
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter_sd15_light_v11.bin"
},
"sdxl": {
"model_url": ""
}
},
"STANDARD (medium strength)": {
"sd1": {
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter_sd15.safetensors"
},
"sdxl": {
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter_sdxl_vit-h.safetensors"
}
},
"VIT-G (medium strength)": {
"sd1": {
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter_sd15_vit-G.safetensors"
},
"sdxl": {
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter_sdxl.safetensors"
}
},
"PLUS (high strength)": {
"sd1": {
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus_sd15.safetensors"
},
"sdxl": {
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter-plus_sdxl_vit-h.safetensors"
}
},
"PLUS (kolors genernal)": {
"sd1": {
"model_url": ""
},
"sdxl": {
"model_url":"https://huggingface.co/Kwai-Kolors/Kolors-IP-Adapter-Plus/resolve/main/ip_adapter_plus_general.bin"
}
},
"REGULAR - FLUX and SD3.5 only (high strength)": {
"flux": {
"model_url": "https://huggingface.co/InstantX/FLUX.1-dev-IP-Adapter/resolve/main/ip-adapter.bin",
"model_file_name": "ip-adapter_flux_1_dev.bin",
},
"sd3": {
"model_url": "https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter/resolve/main/ip-adapter.bin",
"model_file_name": "ip-adapter_sd35.bin",
},
},
"PLUS FACE (portraits)": {
"sd1": {
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus-face_sd15.safetensors"
},
"sdxl": {
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter-plus-face_sdxl_vit-h.safetensors"
}
},
"FULL FACE - SD1.5 only (portraits stronger)": {
"sd1": {
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-full-face_sd15.safetensors"
},
"sdxl": {
"model_url": ""
}
},
"FACEID": {
"sd1": {
"model_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid_sd15.bin",
"lora_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid_sd15_lora.safetensors"
},
"sdxl": {
"model_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid_sdxl.bin",
"lora_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid_sdxl_lora.safetensors"
}
},
"FACEID PLUS - SD1.5 only": {
"sd1": {
"model_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plus_sd15.bin",
"lora_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plus_sd15_lora.safetensors"
},
"sdxl": {
"model_url": "",
"lora_url": ""
}
},
"FACEID PLUS V2": {
"sd1": {
"model_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15.bin",
"lora_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors"
},
"sdxl": {
"model_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sdxl.bin",
"lora_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sdxl_lora.safetensors"
}
},
"FACEID PLUS KOLORS":{
"sd1":{
},
"sdxl":{
"model_url":"https://huggingface.co/Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus/resolve/main/ipa-faceid-plus.bin"
}
},
"FACEID PORTRAIT (style transfer)": {
"sd1": {
"model_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-portrait-v11_sd15.bin",
},
"sdxl": {
"model_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-portrait_sdxl.bin",
}
},
"FACEID PORTRAIT UNNORM - SDXL only (strong)": {
"sd1": {
"model_url":""
},
"sdxl": {
"model_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-portrait_sdxl_unnorm.bin",
}
},
"COMPOSITION": {
"sd1": {
"model_url": "https://huggingface.co/ostris/ip-composition-adapter/resolve/main/ip_plus_composition_sd15.safetensors"
},
"sdxl": {
"model_url": "https://huggingface.co/ostris/ip-composition-adapter/resolve/main/ip_plus_composition_sdxl.safetensors"
}
}
}
IPADAPTER_CLIPVISION_MODELS = {
"clip-vit-large-patch14-336":{
"model_url": "https://huggingface.co/openai/clip-vit-large-patch14-336/resolve/main/pytorch_model.bin"
},
"clip-vit-h-14-laion2B-s32B-b79K":{
"model_url": "https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_model.safetensors"
},
"sigclip_vision_patch14_384":{
"model_url": "https://huggingface.co/Comfy-Org/sigclip_vision_384/resolve/main/sigclip_vision_patch14_384.safetensors"
}
}
# dynamiCrafter
DYNAMICRAFTER_DIR = os.path.join(folder_paths.models_dir, "dynamicrafter_models")
DYNAMICRAFTER_MODELS = {
"dynamicrafter_unet_512 (2.98GB)": {
"model_url": "https://huggingface.co/ExponentialML/DynamiCrafterUNet/resolve/main/dynamicrafter_unet_512.safetensors",
"vae_url": "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.safetensors",
"clip_url": "https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/text_encoder/model.safetensors",
"clip_vision_url": "https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_model.safetensors",
},
"dynamicrafter_unet_512_interp (2.98GB)": {
"model_url": "https://huggingface.co/ExponentialML/DynamiCrafterUNet/resolve/main/dynamicrafter_unet_512_interp.safetensors"
},
"dynamicrafter_unet_1024 (2.98GB)": {
"model_url": "https://huggingface.co/ExponentialML/DynamiCrafterUNet/resolve/main/dynamicrafter_unet_1024.safetensors"
},
"dynamicrafter_unet_256 (2.98GB)": {
"model_url": "https://huggingface.co/ExponentialML/DynamiCrafterUNet/resolve/main/dynamicrafter_unet_256.safetensors"
},
}
#humanParsing
HUMANPARSING_MODELS = {
"parsing_lip": {
"model_url": "https://huggingface.co/levihsu/OOTDiffusion/resolve/main/checkpoints/humanparsing/parsing_lip.onnx",
},
"human-parts":{
"model_url":"https://huggingface.co/Metal3d/deeplabv3p-resnet50-human/resolve/main/deeplabv3p-resnet50-human.onnx",
},
"segformer_b3_clothes":{
"model_name": "sayeed99/segformer_b3_clothes",
},
"segformer_b3_fashion":{
"model_name": "sayeed99/segformer-b3-fashion",
},
"face_parsing":{
"model_name": "jonathandinu/face-parsing"
}
}
#mediapipe
MEDIAPIPE_DIR = os.path.join(folder_paths.models_dir, "mediapipe")
MEDIAPIPE_MODELS = {
"selfie_multiclass_256x256": {
"model_url": "https://huggingface.co/yolain/selfie_multiclass_256x256/resolve/main/selfie_multiclass_256x256.tflite"
}
}
#prompt template
PROMPT_TEMPLATE = {
"prefix": ["Detailed photo of", "Amateur photo of", "Flicker 2008 photo of", "Fantastic artwork of",
"Vintage photograph of", "Unreal 5 render of", "Surrealist painting of",
"Professional advertising design of"],
"subject": ["a man", "a woman", "a young man", "a young woman", "a handsome man", "a beautiful woman", "a monster", "a toy", "a product", "a buddha", "a dog", "a cat"],
"action": ["looking at viewer", "looking away", "looking up", "looking down", "looking back", "open mouth", "half-closed mouth", "closed mouth", "open eyes", "half-closed eyes", "closed eyes", "wink", "standing", "sitting", "lying", "walking", "running", "adjusting hair", "waving", "hand on hip", "crossed arms", "smile", "sad", "angry", "sleepy", "tired", "expressionless"],
"clothes": ["underwear", "clothed", "casual", "dress", "swimsuit", "uniform", "bikini", "one-piece swimsuit", "shirt", "blouse", "sweater", "hoodie", "jeans", "pants", "shorts", "skirt", "vest", "coat", "trenchoat", "jacket", "short dress", "long dress", "off-shoulder", "backless", "hairbow", "hair ribbon", "hair tie", "hairband", "cap", "beanie", "bucket hat", "sun hat", "straw hat", "rice hat", "witch hat", "crown", "chain necklace", "tooth necklace", "choker", "pendant", "bracelet", "watch", "ring", "earring", "anklet", "belt", "scarf", "gloves", "mittens", "socks", "stockings", "tights", "leggings", "boots", "sneakers", "heels", "sandals", "flip-flops", "slippers", "loafers", "mules", "oxfords", "brogues", "derbies", "monk shoes", "chelsea boots", "combat boots", "riding boots", "rain boots", "wedge heels", "platform heels", "stilettos", "block heels", "kitten heels", "moccasins", "espadrilles", "pumps", "flats", "ballet flats", "mary janes", "slingbacks", "peep-toe", "mule sandals", "gladiator sandals", "thong sandals", "slide sandals", "espadrille sandals", "wedge sandals", "platform sandals", "ankle boots", "knee-high boots", "over-the-knee boots", "thigh-high boots", "wellington boots", "chukka boots", "desert boots", "chelsea boots", "hiking boots", "work boots", "snow boots", "rain boots", "riding boots", "cowboy boots", "combat boots", "biker boots", "duck boots", "military boots", "western boots", "ankle strap heels", "block heels", "chunky heels", "cone heels", "kitten heels", "platform heels", "pumps", "slingback heels", "stiletto heels", "wedge heels", "mules", "slingbacks", "slides", "thong sandals", "gladiator sandals", "espadrilles", "wedge sandals", "platform sandals", "ankle boots", "knee-high boots", "over-the-knee boots", "thigh-high boots", "wellington boots", "chukka boots", "desert boots", "chelsea boots", "hiking boots", "work boots", "snow boots", "rain boots", "riding boots", "cowboy boots", "combat boots", "biker boots", "duck boots", "military boots", "western boots", "ankle strap heels", "block heels" ],
"environment": ["sunshine from window", "neon night, city", "sunset over sea", "golden time", "sci-fi RGB glowing, cyberpunk", "natural lighting", "warm atmosphere, at home, bedroom", "magic lit", "evil, gothic, in a cave", "light and shadow", "shadow from window", "soft studio lighting", "home atmosphere, cozy bedroom illumination", "neon, Wong Kar-wai, warm", "moonlight through curtains", "stormy sky lighting", "underwater glow, deep sea", "foggy forest at dawn", "golden hour in a meadow", "rainbow reflections, neon", "cozy candlelight", "apocalyptic, smoky atmosphere", "red glow, emergency lights", "mystical glow, enchanted forest", "campfire light", "harsh, industrial lighting", "sunrise in the mountains", "evening glow in the desert", "moonlight in a dark alley", "golden glow at a fairground", "midnight in the forest", "purple and pink hues at twilight", "foggy morning, muted light", "candle-lit room, rustic vibe", "fluorescent office lighting", "lightning flash in storm", "night, cozy warm light from fireplace", "ethereal glow, magical forest", "dusky evening on a beach", "afternoon light filtering through trees", "blue neon light, urban street", "red and blue police lights in rain", "aurora borealis glow, arctic landscape", "sunrise through foggy mountains", "golden hour on a city skyline", "mysterious twilight, heavy mist", "early morning rays, forest clearing", "colorful lantern light at festival", "soft glow through stained glass", "harsh spotlight in dark room", "mellow evening glow on a lake", "crystal reflections in a cave", "vibrant autumn lighting in a forest", "gentle snowfall at dusk", "hazy light of a winter morning", "soft, diffused foggy glow", "underwater luminescence", "rain-soaked reflections in city lights", "golden sunlight streaming through trees", "fireflies lighting up a summer night", "glowing embers from a forge", "dim candlelight in a gothic castle", "midnight sky with bright starlight", "warm sunset in a rural village", "flickering light in a haunted house", "desert sunset with mirage-like glow", "golden beams piercing through storm clouds"],
"background": ["cars and people", "a cozy bed and a lamp", "a forest clearing with mist", "a bustling marketplace", "a quiet beach at dusk", "an old, cobblestone street", "a futuristic cityscape", "a tranquil lake with mountains", "a mysterious cave entrance", "bookshelves and plants in the background", "an ancient temple in ruins", "tall skyscrapers and neon signs", "a starry sky over a desert", "a bustling café", "rolling hills and farmland", "a modern living room with a fireplace", "an abandoned warehouse", "a picturesque mountain range", "a starry night sky", "the interior of a futuristic spaceship", "the cluttered workshop of an inventor", "the glowing embers of a bonfire", "a misty lake surrounded by trees", "an ornate palace hall", "a busy street market", "a vast desert landscape", "a peaceful library corner", "bustling train station", "a mystical, enchanted forest", "an underwater reef with colorful fish", "a quiet rural village", "a sandy beach with palm trees", "a vibrant coral reef, teeming with life", "snow-capped mountains in distance", "a stormy ocean, waves crashing", "a rustic barn in open fields", "a futuristic lab with glowing screens", "a dark, abandoned castle", "the ruins of an ancient civilization", "a bustling urban street in rain", "an elegant grand ballroom", "a sprawling field of wildflowers", "a dense jungle with sunlight filtering through", "a dimly lit, vintage bar", "an ice cave with sparkling crystals", "a serene riverbank at sunset", "a narrow alley with graffiti walls", "a peaceful zen garden with koi pond", "a high-tech control room", "a quiet mountain village at dawn", "a lighthouse on a rocky coast", "a rainy street with flickering lights", "a frozen lake with ice formations", "an abandoned theme park", "a small fishing village on a pier", "rolling sand dunes in a desert", "a dense forest with towering redwoods", "a snowy cabin in the mountains", "a mystical cave with bioluminescent plants", "a castle courtyard under moonlight", "a bustling open-air night market", "an old train station with steam", "a tranquil waterfall surrounded by trees", "a vineyard in the countryside", "a quaint medieval village", "a bustling harbor with boats", "a high-tech futuristic mall", "a lush tropical rainforest"],
"nsfw": ["nude", "breast", "small breast", "middle breast", "large breast", "nipples", "clothes lift", "pussy juice trail", "pussy juice puddle", "small testicles", "medium testicles", "large testicles", "disembodied penis", "cum on body", "cum inside", "cum outside", "fingering", "handjob", "fellatio", "licking penis", "paizuri", "doggystyle", "cowgirl", "reversed cowgirl", "piledriver", "suspended congress", "full nelson",],
}
NEW_SCHEDULERS = ['align_your_steps', 'gits']

View File

@@ -0,0 +1,113 @@
import urllib.parse
from os import PathLike
from aiohttp import web
from aiohttp.web_urldispatcher import AbstractRoute, UrlDispatcher
from server import PromptServer
from pathlib import Path
# 文件限制大小MB
max_size = 50
def suffix_limiter(self: web.StaticResource, request: web.Request):
suffixes = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".svg", ".ico", ".apng", ".tif", ".hdr", ".exr"}
rel_url = request.match_info["filename"]
try:
filename = Path(rel_url)
if filename.anchor:
raise web.HTTPForbidden()
filepath = self._directory.joinpath(filename).resolve()
if filepath.exists() and filepath.suffix.lower() not in suffixes:
raise web.HTTPForbidden(reason="File type is not allowed")
finally:
pass
def filesize_limiter(self: web.StaticResource, request: web.Request):
rel_url = request.match_info["filename"]
try:
filename = Path(rel_url)
filepath = self._directory.joinpath(filename).resolve()
if filepath.exists() and filepath.stat().st_size > max_size * 1024 * 1024:
raise web.HTTPForbidden(reason="File size is too large")
finally:
pass
class LimitResource(web.StaticResource):
limiters = []
def push_limiter(self, limiter):
self.limiters.append(limiter)
async def _handle(self, request: web.Request) -> web.StreamResponse:
try:
for limiter in self.limiters:
limiter(self, request)
except (ValueError, FileNotFoundError) as error:
raise web.HTTPNotFound() from error
return await super()._handle(request)
def __repr__(self) -> str:
name = "'" + self.name + "'" if self.name is not None else ""
return f'<LimitResource {name} {self._prefix} -> {self._directory!r}>'
class LimitRouter(web.StaticDef):
def __repr__(self) -> str:
info = []
for name, value in sorted(self.kwargs.items()):
info.append(f", {name}={value!r}")
return f'<LimitRouter {self.prefix} -> {self.path}{"".join(info)}>'
def register(self, router: UrlDispatcher) -> list[AbstractRoute]:
# resource = router.add_static(self.prefix, self.path, **self.kwargs)
def add_static(
self: UrlDispatcher,
prefix: str,
path: PathLike,
*,
name=None,
expect_handler=None,
chunk_size: int = 256 * 1024,
show_index: bool = False,
follow_symlinks: bool = False,
append_version: bool = False,
) -> web.AbstractResource:
assert prefix.startswith("/")
if prefix.endswith("/"):
prefix = prefix[:-1]
resource = LimitResource(
prefix,
path,
name=name,
expect_handler=expect_handler,
chunk_size=chunk_size,
show_index=show_index,
follow_symlinks=follow_symlinks,
append_version=append_version,
)
resource.push_limiter(suffix_limiter)
resource.push_limiter(filesize_limiter)
self.register_resource(resource)
return resource
resource = add_static(router, self.prefix, self.path, **self.kwargs)
routes = resource.get_info().get("routes", {})
return list(routes.values())
def path_to_url(path):
if not path:
return path
path = path.replace("\\", "/")
if not path.startswith("/"):
path = "/" + path
while path.startswith("//"):
path = path[1:]
path = path.replace("//", "/")
return path
def add_static_resource(prefix, path,limit=False):
app = PromptServer.instance.app
prefix = path_to_url(prefix)
prefix = urllib.parse.quote(prefix)
prefix = path_to_url(prefix)
if limit:
route = LimitRouter(prefix, path, {"follow_symlinks": True})
else:
route = web.static(prefix, path, follow_symlinks=True)
app.add_routes([route])

View File

@@ -0,0 +1,427 @@
import torch
import numpy as np
import re
import itertools
from comfy import model_management
from comfy.sdxl_clip import SDXLClipModel, SDXLRefinerClipModel, SDXLClipG
try:
from comfy.text_encoders.sd3_clip import SD3ClipModel, T5XXLModel
except ImportError:
from comfy.sd3_clip import SD3ClipModel, T5XXLModel
from nodes import NODE_CLASS_MAPPINGS, ConditioningConcat, ConditioningZeroOut, ConditioningSetTimestepRange, ConditioningCombine
def _grouper(n, iterable):
it = iter(iterable)
while True:
chunk = list(itertools.islice(it, n))
if not chunk:
return
yield chunk
def _norm_mag(w, n):
d = w - 1
return 1 + np.sign(d) * np.sqrt(np.abs(d) ** 2 / n)
# return np.sign(w) * np.sqrt(np.abs(w)**2 / n)
def divide_length(word_ids, weights):
sums = dict(zip(*np.unique(word_ids, return_counts=True)))
sums[0] = 1
weights = [[_norm_mag(w, sums[id]) if id != 0 else 1.0
for w, id in zip(x, y)] for x, y in zip(weights, word_ids)]
return weights
def shift_mean_weight(word_ids, weights):
delta = 1 - np.mean([w for x, y in zip(weights, word_ids) for w, id in zip(x, y) if id != 0])
weights = [[w if id == 0 else w + delta
for w, id in zip(x, y)] for x, y in zip(weights, word_ids)]
return weights
def scale_to_norm(weights, word_ids, w_max):
top = np.max(weights)
w_max = min(top, w_max)
weights = [[w_max if id == 0 else (w / top) * w_max
for w, id in zip(x, y)] for x, y in zip(weights, word_ids)]
return weights
def from_zero(weights, base_emb):
weight_tensor = torch.tensor(weights, dtype=base_emb.dtype, device=base_emb.device)
weight_tensor = weight_tensor.reshape(1, -1, 1).expand(base_emb.shape)
return base_emb * weight_tensor
def mask_word_id(tokens, word_ids, target_id, mask_token):
new_tokens = [[mask_token if wid == target_id else t
for t, wid in zip(x, y)] for x, y in zip(tokens, word_ids)]
mask = np.array(word_ids) == target_id
return (new_tokens, mask)
def batched_clip_encode(tokens, length, encode_func, num_chunks):
embs = []
for e in _grouper(32, tokens):
enc, pooled = encode_func(e)
enc = enc.reshape((len(e), length, -1))
embs.append(enc)
embs = torch.cat(embs)
embs = embs.reshape((len(tokens) // num_chunks, length * num_chunks, -1))
return embs
def from_masked(tokens, weights, word_ids, base_emb, length, encode_func, m_token=266):
pooled_base = base_emb[0, length - 1:length, :]
wids, inds = np.unique(np.array(word_ids).reshape(-1), return_index=True)
weight_dict = dict((id, w)
for id, w in zip(wids, np.array(weights).reshape(-1)[inds])
if w != 1.0)
if len(weight_dict) == 0:
return torch.zeros_like(base_emb), base_emb[0, length - 1:length, :]
weight_tensor = torch.tensor(weights, dtype=base_emb.dtype, device=base_emb.device)
weight_tensor = weight_tensor.reshape(1, -1, 1).expand(base_emb.shape)
# m_token = (clip.tokenizer.end_token, 1.0) if clip.tokenizer.pad_with_end else (0,1.0)
# TODO: find most suitable masking token here
m_token = (m_token, 1.0)
ws = []
masked_tokens = []
masks = []
# create prompts
for id, w in weight_dict.items():
masked, m = mask_word_id(tokens, word_ids, id, m_token)
masked_tokens.extend(masked)
m = torch.tensor(m, dtype=base_emb.dtype, device=base_emb.device)
m = m.reshape(1, -1, 1).expand(base_emb.shape)
masks.append(m)
ws.append(w)
# batch process prompts
embs = batched_clip_encode(masked_tokens, length, encode_func, len(tokens))
masks = torch.cat(masks)
embs = (base_emb.expand(embs.shape) - embs)
pooled = embs[0, length - 1:length, :]
embs *= masks
embs = embs.sum(axis=0, keepdim=True)
pooled_start = pooled_base.expand(len(ws), -1)
ws = torch.tensor(ws).reshape(-1, 1).expand(pooled_start.shape)
pooled = (pooled - pooled_start) * (ws - 1)
pooled = pooled.mean(axis=0, keepdim=True)
return ((weight_tensor - 1) * embs), pooled_base + pooled
def mask_inds(tokens, inds, mask_token):
clip_len = len(tokens[0])
inds_set = set(inds)
new_tokens = [[mask_token if i * clip_len + j in inds_set else t
for j, t in enumerate(x)] for i, x in enumerate(tokens)]
return new_tokens
def down_weight(tokens, weights, word_ids, base_emb, length, encode_func, m_token=266):
w, w_inv = np.unique(weights, return_inverse=True)
if np.sum(w < 1) == 0:
return base_emb, tokens, base_emb[0, length - 1:length, :]
# m_token = (clip.tokenizer.end_token, 1.0) if clip.tokenizer.pad_with_end else (0,1.0)
# using the comma token as a masking token seems to work better than aos tokens for SD 1.x
m_token = (m_token, 1.0)
masked_tokens = []
masked_current = tokens
for i in range(len(w)):
if w[i] >= 1:
continue
masked_current = mask_inds(masked_current, np.where(w_inv == i)[0], m_token)
masked_tokens.extend(masked_current)
embs = batched_clip_encode(masked_tokens, length, encode_func, len(tokens))
embs = torch.cat([base_emb, embs])
w = w[w <= 1.0]
w_mix = np.diff([0] + w.tolist())
w_mix = torch.tensor(w_mix, dtype=embs.dtype, device=embs.device).reshape((-1, 1, 1))
weighted_emb = (w_mix * embs).sum(axis=0, keepdim=True)
return weighted_emb, masked_current, weighted_emb[0, length - 1:length, :]
def scale_emb_to_mag(base_emb, weighted_emb):
norm_base = torch.linalg.norm(base_emb)
norm_weighted = torch.linalg.norm(weighted_emb)
embeddings_final = (norm_base / norm_weighted) * weighted_emb
return embeddings_final
def recover_dist(base_emb, weighted_emb):
fixed_std = (base_emb.std() / weighted_emb.std()) * (weighted_emb - weighted_emb.mean())
embeddings_final = fixed_std + (base_emb.mean() - fixed_std.mean())
return embeddings_final
def A1111_renorm(base_emb, weighted_emb):
embeddings_final = (base_emb.mean() / weighted_emb.mean()) * weighted_emb
return embeddings_final
def advanced_encode_from_tokens(tokenized, token_normalization, weight_interpretation, encode_func, m_token=266,
length=77, w_max=1.0, return_pooled=False, apply_to_pooled=False):
tokens = [[t for t, _, _ in x] for x in tokenized]
weights = [[w for _, w, _ in x] for x in tokenized]
word_ids = [[wid for _, _, wid in x] for x in tokenized]
# weight normalization
# ====================
# distribute down/up weights over word lengths
if token_normalization.startswith("length"):
weights = divide_length(word_ids, weights)
# make mean of word tokens 1
if token_normalization.endswith("mean"):
weights = shift_mean_weight(word_ids, weights)
# weight interpretation
# =====================
pooled = None
if weight_interpretation == "comfy":
weighted_tokens = [[(t, w) for t, w in zip(x, y)] for x, y in zip(tokens, weights)]
weighted_emb, pooled_base = encode_func(weighted_tokens)
pooled = pooled_base
else:
unweighted_tokens = [[(t, 1.0) for t, _, _ in x] for x in tokenized]
base_emb, pooled_base = encode_func(unweighted_tokens)
if weight_interpretation == "A1111":
weighted_emb = from_zero(weights, base_emb)
weighted_emb = A1111_renorm(base_emb, weighted_emb)
pooled = pooled_base
if weight_interpretation == "compel":
pos_tokens = [[(t, w) if w >= 1.0 else (t, 1.0) for t, w in zip(x, y)] for x, y in zip(tokens, weights)]
weighted_emb, _ = encode_func(pos_tokens)
weighted_emb, _, pooled = down_weight(pos_tokens, weights, word_ids, weighted_emb, length, encode_func)
if weight_interpretation == "comfy++":
weighted_emb, tokens_down, _ = down_weight(unweighted_tokens, weights, word_ids, base_emb, length, encode_func)
weights = [[w if w > 1.0 else 1.0 for w in x] for x in weights]
# unweighted_tokens = [[(t,1.0) for t, _,_ in x] for x in tokens_down]
embs, pooled = from_masked(unweighted_tokens, weights, word_ids, base_emb, length, encode_func)
weighted_emb += embs
if weight_interpretation == "down_weight":
weights = scale_to_norm(weights, word_ids, w_max)
weighted_emb, _, pooled = down_weight(unweighted_tokens, weights, word_ids, base_emb, length, encode_func)
if return_pooled:
if apply_to_pooled:
return weighted_emb, pooled
else:
return weighted_emb, pooled_base
return weighted_emb, None
def encode_token_weights_g(model, token_weight_pairs):
return model.clip_g.encode_token_weights(token_weight_pairs)
def encode_token_weights_l(model, token_weight_pairs):
l_out, pooled = model.clip_l.encode_token_weights(token_weight_pairs)
return l_out, pooled
def encode_token_weights_t5(model, token_weight_pairs):
return model.t5xxl.encode_token_weights(token_weight_pairs)
def encode_token_weights(model, token_weight_pairs, encode_func):
if model.layer_idx is not None:
# 2016 [c2cb8e88] 及以上版本去除了sdxl clip的clip_layer方法
# if compare_revision(2016):
model.cond_stage_model.set_clip_options({'layer': model.layer_idx})
# else:
# model.cond_stage_model.clip_layer(model.layer_idx)
model_management.load_model_gpu(model.patcher)
return encode_func(model.cond_stage_model, token_weight_pairs)
def prepareXL(embs_l, embs_g, pooled, clip_balance):
l_w = 1 - max(0, clip_balance - .5) * 2
g_w = 1 - max(0, .5 - clip_balance) * 2
if embs_l is not None:
return torch.cat([embs_l * l_w, embs_g * g_w], dim=-1), pooled
else:
return embs_g, pooled
def prepareSD3(out, pooled, clip_balance):
lg_w = 1 - max(0, clip_balance - .5) * 2
t5_w = 1 - max(0, .5 - clip_balance) * 2
if out.shape[0] > 1:
return torch.cat([out[0] * lg_w, out[1] * t5_w], dim=-1), pooled
else:
return out, pooled
def advanced_encode(clip, text, token_normalization, weight_interpretation, w_max=1.0, clip_balance=.5,
apply_to_pooled=True, width=1024, height=1024, crop_w=0, crop_h=0, target_width=1024, target_height=1024, a1111_prompt_style=False, steps=1):
# Use clip text encode by smzNodes like same as a1111, when if you need installed the smzNodes
if a1111_prompt_style:
if "smZ CLIPTextEncode" in NODE_CLASS_MAPPINGS:
cls = NODE_CLASS_MAPPINGS['smZ CLIPTextEncode']
embeddings_final, = cls().encode(clip, text, weight_interpretation, True, True, False, False, 6, 1024, 1024, 0, 0, 1024, 1024, '', '', steps)
return embeddings_final
else:
raise Exception(f"[smzNodes Not Found] you need to install 'ComfyUI-smzNodes'")
time_start = 0
time_end = 1
match = re.search(r'TIMESTEP.*$', text)
if match:
timestep = match.group()
timestep = timestep.split(' ')
timestep = timestep[0]
text = text.replace(timestep, '')
value = timestep.split(':')
if len(value) >= 3:
time_start = float(value[1])
time_end = float(value[2])
elif len(value) == 2:
time_start = float(value[1])
time_end = 1
elif len(value) == 1:
time_start = 0.1
time_end = 1
pass3 = [x.strip() for x in text.split("BREAK")]
pass3 = [x for x in pass3 if x != '']
if len(pass3) == 0:
pass3 = ['']
# pass3_str = [f'[{x}]' for x in pass3]
# print(f"CLIP: {str.join(' + ', pass3_str)}")
conditioning = None
for text in pass3:
tokenized = clip.tokenize(text, return_word_ids=True)
if SD3ClipModel and isinstance(clip.cond_stage_model, SD3ClipModel):
lg_out = None
pooled = None
out = None
if len(tokenized['l']) > 0 or len(tokenized['g']) > 0:
if clip.cond_stage_model.clip_l is not None:
lg_out, l_pooled = advanced_encode_from_tokens(tokenized['l'],
token_normalization,
weight_interpretation,
lambda x: encode_token_weights(clip, x, encode_token_weights_l),
w_max=w_max, return_pooled=True,)
else:
l_pooled = torch.zeros((1, 768), device=model_management.intermediate_device())
if clip.cond_stage_model.clip_g is not None:
g_out, g_pooled = advanced_encode_from_tokens(tokenized['g'],
token_normalization,
weight_interpretation,
lambda x: encode_token_weights(clip, x, encode_token_weights_g),
w_max=w_max, return_pooled=True)
if lg_out is not None:
lg_out = torch.cat([lg_out, g_out], dim=-1)
else:
lg_out = torch.nn.functional.pad(g_out, (768, 0))
else:
g_out = None
g_pooled = torch.zeros((1, 1280), device=model_management.intermediate_device())
if lg_out is not None:
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
out = lg_out
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
# t5xxl
if 't5xxl' in tokenized:
t5_out, t5_pooled = advanced_encode_from_tokens(tokenized['t5xxl'],
token_normalization,
weight_interpretation,
lambda x: encode_token_weights(clip, x, encode_token_weights_t5),
w_max=w_max, return_pooled=True)
if lg_out is not None:
out = torch.cat([lg_out, t5_out], dim=-2)
else:
out = t5_out
if out is None:
out = torch.zeros((1, 77, 4096), device=model_management.intermediate_device())
if pooled is None:
pooled = torch.zeros((1, 768 + 1280), device=model_management.intermediate_device())
embeddings_final, pooled = prepareSD3(out, pooled, clip_balance)
cond = [[embeddings_final, {"pooled_output": pooled}]]
elif isinstance(clip.cond_stage_model, (SDXLClipModel, SDXLRefinerClipModel, SDXLClipG)):
embs_l = None
embs_g = None
pooled = None
if 'l' in tokenized and isinstance(clip.cond_stage_model, SDXLClipModel):
embs_l, _ = advanced_encode_from_tokens(tokenized['l'],
token_normalization,
weight_interpretation,
lambda x: encode_token_weights(clip, x, encode_token_weights_l),
w_max=w_max,
return_pooled=False)
if 'g' in tokenized:
embs_g, pooled = advanced_encode_from_tokens(tokenized['g'],
token_normalization,
weight_interpretation,
lambda x: encode_token_weights(clip, x,
encode_token_weights_g),
w_max=w_max,
return_pooled=True,
apply_to_pooled=apply_to_pooled)
embeddings_final, pooled = prepareXL(embs_l, embs_g, pooled, clip_balance)
cond = [[embeddings_final, {"pooled_output": pooled}]]
# cond = [[embeddings_final,
# {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w,
# "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]]
else:
embeddings_final, pooled = advanced_encode_from_tokens(tokenized['l'],
token_normalization,
weight_interpretation,
lambda x: encode_token_weights(clip, x, encode_token_weights_l),
w_max=w_max,return_pooled=True,)
cond = [[embeddings_final, {"pooled_output": pooled}]]
if conditioning is not None:
conditioning = ConditioningConcat().concat(conditioning, cond)[0]
else:
conditioning = cond
# setTimeStepRange
if time_start > 0 or time_end < 1:
conditioning_2, = ConditioningSetTimestepRange().set_range(conditioning, 0, time_start)
conditioning_1, = ConditioningZeroOut().zero_out(conditioning)
conditioning_1, = ConditioningSetTimestepRange().set_range(conditioning_1, time_start, time_end)
conditioning, = ConditioningCombine().combine(conditioning_1, conditioning_2)
return conditioning

View File

@@ -0,0 +1,372 @@
import yaml
import pathlib
import base64
import io
import json
import os
import pickle
import zlib
import urllib.parse
import urllib.request
import urllib.error
from enum import Enum
from functools import singledispatch
from typing import Any, List, Union
import numpy as np
import torch
from PIL import Image
root_path = pathlib.Path(__file__).parent.parent.parent.parent
config_path = os.path.join(root_path, 'config.yaml')
class BizyAIRAPI:
def __init__(self):
self.base_url = 'https://bizyair-api.siliconflow.cn/x/v1'
self.api_key = None
def getAPIKey(self):
if self.api_key is None:
if os.path.isfile(config_path):
with open(config_path, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
if 'BIZYAIR_API_KEY' not in data:
raise Exception("Please add BIZYAIR_API_KEY to config.yaml")
self.api_key = data['BIZYAIR_API_KEY']
else:
raise Exception("Please add config.yaml to root path")
return self.api_key
def send_post_request(self, url, payload, headers):
try:
data = json.dumps(payload).encode("utf-8")
req = urllib.request.Request(url, data=data, headers=headers, method="POST")
with urllib.request.urlopen(req) as response:
response_data = response.read().decode("utf-8")
return response_data
except urllib.error.URLError as e:
if "Unauthorized" in str(e):
raise Exception(
"Key is invalid, please refer to https://cloud.siliconflow.cn to get the API key.\n"
"If you have the key, please click the 'BizyAir Key' button at the bottom right to set the key."
)
else:
raise Exception(
f"Failed to connect to the server: {e}, if you have no key, "
)
# joycaption
def joyCaption(self, payload, image, apikey_override=None, API_URL='/supernode/joycaption2'):
if apikey_override is not None:
api_key = apikey_override
else:
api_key = self.getAPIKey()
url = f"{self.base_url}{API_URL}"
print('Sending request to:', url)
auth = f"Bearer {api_key}"
headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": auth,
}
input_image = encode_data(image, disable_image_marker=True)
payload["image"] = input_image
ret: str = self.send_post_request(url=url, payload=payload, headers=headers)
ret = json.loads(ret)
try:
if "result" in ret:
ret = json.loads(ret["result"])
except Exception as e:
raise Exception(f"Unexpected response: {ret} {e=}")
if ret["type"] == "error":
raise Exception(ret["message"])
msg = ret["data"]
if msg["type"] not in ("comfyair", "bizyair",):
raise Exception(f"Unexpected response type: {msg}")
caption = msg["data"]
return caption
bizyairAPI = BizyAIRAPI()
BIZYAIR_DEBUG = True
# Marker to identify base64-encoded tensors
TENSOR_MARKER = "TENSOR:"
IMAGE_MARKER = "IMAGE:"
class TaskStatus(Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
def convert_image_to_rgb(image: Image.Image) -> Image.Image:
if image.mode != "RGB":
return image.convert("RGB")
return image
def encode_image_to_base64(
image: Image.Image, format: str = "png", quality: int = 100, lossless=False
) -> str:
image = convert_image_to_rgb(image)
with io.BytesIO() as output:
image.save(output, format=format, quality=quality, lossless=lossless)
output.seek(0)
img_bytes = output.getvalue()
if BIZYAIR_DEBUG:
print(f"encode_image_to_base64: {format_bytes(len(img_bytes))}")
return base64.b64encode(img_bytes).decode("utf-8")
def decode_base64_to_np(img_data: str, format: str = "png") -> np.ndarray:
img_bytes = base64.b64decode(img_data)
if BIZYAIR_DEBUG:
print(f"decode_base64_to_np: {format_bytes(len(img_bytes))}")
with io.BytesIO(img_bytes) as input_buffer:
img = Image.open(input_buffer)
# https://github.com/comfyanonymous/ComfyUI/blob/a178e25912b01abf436eba1cfaab316ba02d272d/nodes.py#L1511
img = img.convert("RGB")
return np.array(img)
def decode_base64_to_image(img_data: str) -> Image.Image:
img_bytes = base64.b64decode(img_data)
with io.BytesIO(img_bytes) as input_buffer:
img = Image.open(input_buffer)
if BIZYAIR_DEBUG:
format_info = img.format.upper() if img.format else "Unknown"
print(f"decode image format: {format_info}")
return img
def format_bytes(num_bytes: int) -> str:
"""
Converts a number of bytes to a human-readable string with units (B, KB, or MB).
:param num_bytes: The number of bytes to convert.
:return: A string representing the number of bytes in a human-readable format.
"""
if num_bytes < 1024:
return f"{num_bytes} B"
elif num_bytes < 1024 * 1024:
return f"{num_bytes / 1024:.2f} KB"
else:
return f"{num_bytes / (1024 * 1024):.2f} MB"
def _legacy_encode_comfy_image(image: torch.Tensor, image_format="png") -> str:
input_image = image.cpu().detach().numpy()
i = 255.0 * input_image[0]
input_image = np.clip(i, 0, 255).astype(np.uint8)
base64ed_image = encode_image_to_base64(
Image.fromarray(input_image), format=image_format
)
return base64ed_image
def _legacy_decode_comfy_image(
img_data: Union[List, str], image_format="png"
) -> torch.tensor:
if isinstance(img_data, List):
decoded_imgs = [decode_comfy_image(x, old_version=True) for x in img_data]
combined_imgs = torch.cat(decoded_imgs, dim=0)
return combined_imgs
out = decode_base64_to_np(img_data, format=image_format)
out = np.array(out).astype(np.float32) / 255.0
output = torch.from_numpy(out)[None,]
return output
def _new_encode_comfy_image(images: torch.Tensor, image_format="WEBP", **kwargs) -> str:
"""https://docs.comfy.org/essentials/custom_node_snippets#save-an-image-batch
Encode a batch of images to base64 strings.
Args:
images (torch.Tensor): A batch of images.
image_format (str, optional): The format of the images. Defaults to "WEBP".
Returns:
str: A JSON string containing the base64-encoded images.
"""
results = {}
for batch_number, image in enumerate(images):
i = 255.0 * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
base64ed_image = encode_image_to_base64(img, format=image_format, **kwargs)
results[batch_number] = base64ed_image
return json.dumps(results)
def _new_decode_comfy_image(img_datas: str, image_format="WEBP") -> torch.tensor:
"""
Decode a batch of base64-encoded images.
Args:
img_datas (str): A JSON string containing the base64-encoded images.
image_format (str, optional): The format of the images. Defaults to "WEBP".
Returns:
torch.Tensor: A tensor containing the decoded images.
"""
img_datas = json.loads(img_datas)
decoded_imgs = []
for img_data in img_datas.values():
decoded_image = decode_base64_to_np(img_data, format=image_format)
decoded_image = np.array(decoded_image).astype(np.float32) / 255.0
decoded_imgs.append(torch.from_numpy(decoded_image)[None,])
return torch.cat(decoded_imgs, dim=0)
def encode_comfy_image(
image: torch.Tensor, image_format="WEBP", old_version=False, lossless=False
) -> str:
if old_version:
return _legacy_encode_comfy_image(image, image_format)
return _new_encode_comfy_image(image, image_format, lossless=lossless)
def decode_comfy_image(
img_data: Union[List, str], image_format="WEBP", old_version=False
) -> torch.tensor:
if old_version:
return _legacy_decode_comfy_image(img_data, image_format)
return _new_decode_comfy_image(img_data, image_format)
def tensor_to_base64(tensor: torch.Tensor, compress=True) -> str:
tensor_np = tensor.cpu().detach().numpy()
tensor_bytes = pickle.dumps(tensor_np)
if compress:
tensor_bytes = zlib.compress(tensor_bytes)
tensor_b64 = base64.b64encode(tensor_bytes).decode("utf-8")
return tensor_b64
def base64_to_tensor(tensor_b64: str, compress=True) -> torch.Tensor:
tensor_bytes = base64.b64decode(tensor_b64)
if compress:
tensor_bytes = zlib.decompress(tensor_bytes)
tensor_np = pickle.loads(tensor_bytes)
tensor = torch.from_numpy(tensor_np)
return tensor
@singledispatch
def decode_data(input, old_version=False):
raise NotImplementedError(f"Unsupported type: {type(input)}")
@decode_data.register(int)
@decode_data.register(float)
@decode_data.register(bool)
@decode_data.register(type(None))
def _(input, **kwargs):
return input
@decode_data.register(dict)
def _(input, **kwargs):
return {k: decode_data(v, **kwargs) for k, v in input.items()}
@decode_data.register(list)
def _(input, **kwargs):
return [decode_data(x, **kwargs) for x in input]
@decode_data.register(str)
def _(input: str, **kwargs):
if input.startswith(TENSOR_MARKER):
tensor_b64 = input[len(TENSOR_MARKER) :]
return base64_to_tensor(tensor_b64)
elif input.startswith(IMAGE_MARKER):
tensor_b64 = input[len(IMAGE_MARKER) :]
old_version = kwargs.get("old_version", False)
return decode_comfy_image(tensor_b64, old_version=old_version)
return input
@singledispatch
def encode_data(output, disable_image_marker=False, old_version=False):
raise NotImplementedError(f"Unsupported type: {type(output)}")
@encode_data.register(dict)
def _(output, **kwargs):
return {k: encode_data(v, **kwargs) for k, v in output.items()}
@encode_data.register(list)
def _(output, **kwargs):
return [encode_data(x, **kwargs) for x in output]
def is_image_tensor(tensor) -> bool:
"""https://docs.comfy.org/essentials/custom_node_datatypes#image
Check if the given tensor is in the format of an IMAGE (shape [B, H, W, C] where C=3).
`Args`:
tensor (torch.Tensor): The tensor to check.
`Returns`:
bool: True if the tensor is in the IMAGE format, False otherwise.
"""
try:
if not isinstance(tensor, torch.Tensor):
return False
if len(tensor.shape) != 4:
return False
B, H, W, C = tensor.shape
if C != 3:
return False
return True
except:
return False
@encode_data.register(torch.Tensor)
def _(output, **kwargs):
if is_image_tensor(output) and not kwargs.get("disable_image_marker", False):
old_version = kwargs.get("old_version", False)
lossless = kwargs.get("lossless", True)
return IMAGE_MARKER + encode_comfy_image(
output, image_format="WEBP", old_version=old_version, lossless=lossless
)
return TENSOR_MARKER + tensor_to_base64(output)
@encode_data.register(int)
@encode_data.register(float)
@encode_data.register(bool)
@encode_data.register(type(None))
def _(output, **kwargs):
return output
@encode_data.register(str)
def _(output, **kwargs):
return output

View File

@@ -0,0 +1,51 @@
import json
import os
import yaml
import requests
import pathlib
from aiohttp import web
root_path = pathlib.Path(__file__).parent.parent.parent.parent
config_path = os.path.join(root_path,'config.yaml')
class FluxAIAPI:
def __init__(self):
self.api_url = "https://fluxaiimagegenerator.com/api"
self.origin = "https://fluxaiimagegenerator.com"
self.user_agent = None
self.cookie = None
def promptGenerate(self, text, cookies=None):
cookie = self.cookie if cookies is None else cookies
if cookie is None:
if os.path.isfile(config_path):
with open(config_path, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
if 'FLUXAI_COOKIE' not in data:
raise Exception("Please add FLUXAI_COOKIE to config.yaml")
if "FLUXAI_USER_AGENT" in data:
self.user_agent = data["FLUXAI_USER_AGENT"]
self.cookie = cookie = data['FLUXAI_COOKIE']
headers = {
"Cookie": cookie,
"Referer": "https://fluxaiimagegenerator.com/flux-prompt-generator",
"Origin": self.origin,
"Content-Type": "application/json",
}
if self.user_agent is not None:
headers['User-Agent'] = self.user_agent
url = self.api_url + '/prompt'
json = {
"prompt": text
}
response = requests.post(url, json=json, headers=headers)
res = response.json()
if "error" in res:
return res['error']
elif "data" in res and "prompt" in res['data']:
return res['data']['prompt']
fluxaiAPI = FluxAIAPI()

View File

@@ -0,0 +1,200 @@
import json
import os
import yaml
import requests
import pathlib
from aiohttp import web
from server import PromptServer
from ..image import tensor2pil, pil2tensor, image2base64, pil2byte
from ..log import log_node_error
root_path = pathlib.Path(__file__).parent.parent.parent.parent
config_path = os.path.join(root_path,'config.yaml')
default_key = [{'name':'Default', 'key':''}]
class StabilityAPI:
def __init__(self):
self.api_url = "https://api.stability.ai"
self.api_keys = None
self.api_current = 0
self.user_info = {}
def getErrors(self, code):
errors = {
400: "Bad Request",
403: "ApiKey Forbidden",
413: "Your request was larger than 10MiB.",
429: "You have made more than 150 requests in 10 seconds.",
500: "Internal Server Error",
}
return errors.get(code, "Unknown Error")
def getAPIKeys(self):
if os.path.isfile(config_path):
with open(config_path, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
if not data:
data = {'STABILITY_API_KEY': default_key, 'STABILITY_API_DEFAULT':0}
with open(config_path, 'w') as f:
yaml.dump(data, f)
if 'STABILITY_API_KEY' not in data:
data['STABILITY_API_KEY'] = default_key
data['STABILITY_API_DEFAULT'] = 0
with open(config_path, 'w') as f:
yaml.dump(data, f)
api_keys = data['STABILITY_API_KEY']
self.api_current = data['STABILITY_API_DEFAULT']
self.api_keys = api_keys
return api_keys
else:
# create a yaml file
with open(config_path, 'w') as f:
data = {'STABILITY_API_KEY': default_key, 'STABILITY_API_DEFAULT':0}
yaml.dump(data, f)
return data['STABILITY_API_KEY']
pass
def setAPIKeys(self, api_keys):
if len(api_keys) > 0:
self.api_keys = api_keys
# load and save the yaml file
with open(config_path, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data['STABILITY_API_KEY'] = api_keys
with open(config_path, 'w') as f:
yaml.dump(data, f)
return True
def setAPIDefault(self, current):
if current is not None:
self.api_current = current
# load and save the yaml file
with open(config_path, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data['STABILITY_API_DEFAULT'] = current
with open(config_path, 'w') as f:
yaml.dump(data, f)
return True
def generate_sd3_image(self, prompt, negative_prompt, aspect_ratio, model, seed, mode='text-to-image', image=None, strength=1, output_format='png', node_name='easy stableDiffusion3API'):
url = f"{self.api_url}/v2beta/stable-image/generate/sd3"
api_key = self.api_keys[self.api_current]['key']
files = None
data = {
"prompt": prompt,
"mode": mode,
"model": model,
"seed": seed,
"output_format": output_format,
}
if model == 'sd3':
data['negative_prompt'] = negative_prompt
if mode == 'text-to-image':
files = {"none": ''}
data['aspect_ratio'] = aspect_ratio
elif mode == 'image-to-image':
pil_image = tensor2pil(image)
image_byte = pil2byte(pil_image)
files = {"image": ("output.png", image_byte, 'image/png')}
data['strength'] = strength
response = requests.post(url,
headers={"authorization": f"{api_key}", "accept": "application/json"},
files=files,
data=data,
)
if response.status_code == 200:
PromptServer.instance.send_sync('stable-diffusion-api-generate-succeed',{"model":model})
json_data = response.json()
image_base64 = json_data['image']
image_data = image2base64(image_base64)
output_t = pil2tensor(image_data)
return output_t
else:
if 'application/json' in response.headers['Content-Type']:
error_info = response.json()
log_node_error(node_name, error_info.get('name', 'No name provided'))
log_node_error(node_name, error_info.get('errors', ['No details provided']))
error_status_text = self.getErrors(response.status_code)
PromptServer.instance.send_sync('easyuse-toast',{"type": "error", "content": error_status_text})
raise Exception(f"Failed to generate image: {error_status_text}")
# get user account
async def getUserAccount(self, cache=True):
url = f"{self.api_url}/v1/user/account"
api_key = self.api_keys[self.api_current]['key']
name = self.api_keys[self.api_current]['name']
if cache and name in self.user_info:
return self.user_info[name]
else:
response = requests.get(url, headers={"Authorization": f"Bearer {api_key}"})
if response.status_code == 200:
user_info = response.json()
self.user_info[name] = user_info
return user_info
else:
PromptServer.instance.send_sync('easyuse-toast',{'type': 'error', 'content': self.getErrors(response.status_code)})
return None
# get user balance
async def getUserBalance(self):
url = f"{self.api_url}/v1/user/balance"
api_key = self.api_keys[self.api_current]['key']
response = requests.get(url, headers={
"Authorization": f"Bearer {api_key}"
})
if response.status_code == 200:
return response.json()
else:
PromptServer.instance.send_sync('easyuse-toast', {'type': 'error', 'content': self.getErrors(response.status_code)})
return None
stableAPI = StabilityAPI()
@PromptServer.instance.routes.get("/easyuse/stability/api_keys")
async def get_stability_api_keys(request):
stableAPI.getAPIKeys()
return web.json_response({"keys": stableAPI.api_keys, "current": stableAPI.api_current})
@PromptServer.instance.routes.post("/easyuse/stability/set_api_keys")
async def set_stability_api_keys(request):
post = await request.post()
api_keys = post.get("api_keys")
current = post.get('current')
if api_keys is not None:
api_keys = json.loads(api_keys)
stableAPI.setAPIKeys(api_keys)
if current is not None:
print(current)
stableAPI.setAPIDefault(int(current))
account = await stableAPI.getUserAccount()
balance = await stableAPI.getUserBalance()
return web.json_response({'account': account, 'balance': balance})
else:
return web.json_response({'status': 'ok'})
else:
return web.Response(status=400)
@PromptServer.instance.routes.post("/easyuse/stability/set_apikey_default")
async def set_stability_api_default(request):
post = await request.post()
current = post.get("current")
if current is not None and current < len(stableAPI.api_keys):
stableAPI.api_current = current
return web.json_response({'status': 'ok'})
else:
return web.Response(status=400)
@PromptServer.instance.routes.get("/easyuse/stability/user_info")
async def get_account_info(request):
account = await stableAPI.getUserAccount()
balance = await stableAPI.getUserBalance()
return web.json_response({'account': account, 'balance': balance})
@PromptServer.instance.routes.get("/easyuse/stability/balance")
async def get_balance_info(request):
balance = await stableAPI.getUserBalance()
return web.json_response({'balance': balance})

View File

@@ -0,0 +1,86 @@
import itertools
from typing import Optional
class TaggedCache:
def __init__(self, tag_settings: Optional[dict]=None):
self._tag_settings = tag_settings or {} # tag cache size
self._data = {}
def __getitem__(self, key):
for tag_data in self._data.values():
if key in tag_data:
return tag_data[key]
raise KeyError(f'Key `{key}` does not exist')
def __setitem__(self, key, value: tuple):
# value: (tag: str, (islist: bool, data: *))
# if key already exists, pop old value
for tag_data in self._data.values():
if key in tag_data:
tag_data.pop(key, None)
break
tag = value[0]
if tag not in self._data:
try:
from cachetools import LRUCache
default_size = 20
if 'ckpt' in tag:
default_size = 5
elif tag in ['latent', 'image']:
default_size = 100
self._data[tag] = LRUCache(maxsize=self._tag_settings.get(tag, default_size))
except (ImportError, ModuleNotFoundError):
# TODO: implement a simple lru dict
self._data[tag] = {}
self._data[tag][key] = value
def __delitem__(self, key):
for tag_data in self._data.values():
if key in tag_data:
del tag_data[key]
return
raise KeyError(f'Key `{key}` does not exist')
def __contains__(self, key):
return any(key in tag_data for tag_data in self._data.values())
def items(self):
yield from itertools.chain(*map(lambda x :x.items(), self._data.values()))
def get(self, key, default=None):
"""D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None."""
for tag_data in self._data.values():
if key in tag_data:
return tag_data[key]
return default
def clear(self):
# clear all cache
self._data = {}
cache_settings = {}
cache = TaggedCache(cache_settings)
cache_count = {}
def update_cache(k, tag, v):
cache[k] = (tag, v)
cnt = cache_count.get(k)
if cnt is None:
cnt = 0
cache_count[k] = cnt
else:
cache_count[k] += 1
def remove_cache(key):
global cache
if key == '*':
cache = TaggedCache(cache_settings)
elif key in cache:
del cache[key]
else:
print(f"invalid {key}")

View File

@@ -0,0 +1,153 @@
from threading import Event
import torch
from server import PromptServer
from aiohttp import web
from comfy import model_management as mm
from comfy_execution.graph import ExecutionBlocker
import time
class ChooserCancelled(Exception):
pass
def get_chooser_cache():
"""获取选择器缓存"""
if not hasattr(PromptServer.instance, '_easyuse_chooser_node'):
PromptServer.instance._easyuse_chooser_node = {}
return PromptServer.instance._easyuse_chooser_node
def cleanup_session_data(node_id):
"""清理会话数据"""
node_data = get_chooser_cache()
if node_id in node_data:
session_keys = ["event", "selected", "images", "total_count", "cancelled"]
for key in session_keys:
if key in node_data[node_id]:
del node_data[node_id][key]
def wait_for_chooser(id, images, mode, period=0.1):
try:
node_data = get_chooser_cache()
images = [images[i:i + 1, ...] for i in range(images.shape[0])]
if mode == "Keep Last Selection":
if id in node_data and "last_selection" in node_data[id]:
last_selection = node_data[id]["last_selection"]
if last_selection and len(last_selection) > 0:
valid_indices = [idx for idx in last_selection if 0 <= idx < len(images)]
if valid_indices:
try:
PromptServer.instance.send_sync("easyuse-image-keep-selection", {
"id": id,
"selected": valid_indices
})
except Exception as e:
pass
cleanup_session_data(id)
indices_str = ','.join(str(i) for i in valid_indices)
images = [images[idx] for idx in valid_indices]
images = torch.cat(images, dim=0)
return {"result": (images,)}
if id in node_data:
del node_data[id]
event = Event()
node_data[id] = {
"event": event,
"images": images,
"selected": None,
"total_count": len(images),
"cancelled": False,
}
while id in node_data:
node_info = node_data[id]
if node_info.get("cancelled", False):
cleanup_session_data(id)
raise ChooserCancelled("Manual selection cancelled")
if "selected" in node_info and node_info["selected"] is not None:
break
time.sleep(period)
if id in node_data:
node_info = node_data[id]
selected_indices = node_info.get("selected")
if selected_indices is not None and len(selected_indices) > 0:
valid_indices = [idx for idx in selected_indices if 0 <= idx < len(images)]
if valid_indices:
selected_images = [images[idx] for idx in valid_indices]
if id not in node_data:
node_data[id] = {}
node_data[id]["last_selection"] = valid_indices
cleanup_session_data(id)
selected_images = torch.cat(selected_images, dim=0)
return {"result": (selected_images,)}
else:
cleanup_session_data(id)
return {"result": (images[0] if len(images) > 0 else ExecutionBlocker(None),)}
else:
cleanup_session_data(id)
return {
"result": (images[0] if len(images) > 0 else ExecutionBlocker(None),)}
else:
return {"result": (images[0] if len(images) > 0 else ExecutionBlocker(None),)}
except ChooserCancelled:
raise mm.InterruptProcessingException()
except Exception as e:
node_data = get_chooser_cache()
if id in node_data:
cleanup_session_data(id)
if 'image_list' in locals() and len(images) > 0:
return {"result": (images[0])}
else:
return {"result": (ExecutionBlocker(None),)}
@PromptServer.instance.routes.post('/easyuse/image_chooser_message')
async def handle_image_selection(request):
try:
data = await request.json()
node_id = data.get("node_id")
selected = data.get("selected", [])
action = data.get("action")
node_data = get_chooser_cache()
if node_id not in node_data:
return web.json_response({"code": -1, "error": "Node data does not exist"})
try:
node_info = node_data[node_id]
if "total_count" not in node_info:
return web.json_response({"code": -1, "error": "The node has been processed"})
if action == "cancel":
node_info["cancelled"] = True
node_info["selected"] = []
elif action == "select" and isinstance(selected, list):
valid_indices = [idx for idx in selected if isinstance(idx, int) and 0 <= idx < node_info["total_count"]]
if valid_indices:
node_info["selected"] = valid_indices
node_info["cancelled"] = False
else:
return web.json_response({"code": -1, "error": "Invalid Selection Index"})
else:
return web.json_response({"code": -1, "error": "Invalid operation"})
node_info["event"].set()
return web.json_response({"code": 1})
except Exception as e:
if node_id in node_data and "event" in node_data[node_id]:
node_data[node_id]["event"].set()
return web.json_response({"code": -1, "message": "Processing Failed"})
except Exception as e:
return web.json_response({"code": -1, "message": "Request Failed"})

View File

@@ -0,0 +1,115 @@
import torch
from PIL import Image
from torch import Tensor
from torch.nn import functional as F
from torchvision.transforms import ToTensor, ToPILImage
def adain_color_fix(target: Image, source: Image):
# Convert images to tensors
to_tensor = ToTensor()
target_tensor = to_tensor(target).unsqueeze(0)
source_tensor = to_tensor(source).unsqueeze(0)
# Apply adaptive instance normalization
result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
return result_image
def wavelet_color_fix(target: Image, source: Image):
source = source.resize(target.size, resample=Image.Resampling.LANCZOS)
# Convert images to tensors
to_tensor = ToTensor()
target_tensor = to_tensor(target).unsqueeze(0)
source_tensor = to_tensor(source).unsqueeze(0)
# Apply wavelet reconstruction
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
return result_image
def calc_mean_std(feat: Tensor, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.'
b, c = size[:2]
feat_var = feat.view(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(b, c, 1, 1)
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
def wavelet_blur(image: Tensor, radius: int):
"""
Apply wavelet blur to the input tensor.
"""
# input shape: (1, 3, H, W)
# convolution kernel
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
# add channel dimensions to the kernel to make it a 4D tensor
kernel = kernel[None, None]
# repeat the kernel across all input channels
kernel = kernel.repeat(3, 1, 1, 1)
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
# apply convolution
output = F.conv2d(image, kernel, groups=3, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels=5):
"""
Apply wavelet decomposition to the input tensor.
This function only returns the low frequency & the high frequency.
"""
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq += (image - low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
"""
Apply wavelet decomposition, so that the content will have the same color as the style.
"""
# calculate the wavelet decomposition of the content feature
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq
# calculate the wavelet decomposition of the style feature
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq
# reconstruct the content feature with the style's high frequency
return content_high_freq + style_low_freq

View File

@@ -0,0 +1,57 @@
from .utils import find_wildcards_seed, find_nearest_steps, is_linked_styles_selector
from .log import log_node_warn
from .translate import zh_to_en, has_chinese
from .wildcards import process_with_loras
from .adv_encode import advanced_encode
from nodes import ConditioningConcat, ConditioningCombine, ConditioningAverage, ConditioningSetTimestepRange, CLIPTextEncode
def prompt_to_cond(type, model, clip, clip_skip, lora_stack, text, prompt_token_normalization, prompt_weight_interpretation, a1111_prompt_style ,my_unique_id, prompt, easyCache, can_load_lora=True, steps=None, model_type=None):
styles_selector = is_linked_styles_selector(prompt, my_unique_id, type)
title = "Positive encoding" if type == 'positive' else "Negative encoding"
# Translate cn to en
if model_type not in ['hydit'] and text is not None and has_chinese(text):
text = zh_to_en([text])[0]
if model_type in ['hydit', 'flux', 'mochi']:
log_node_warn(title + "...")
embeddings_final, = CLIPTextEncode().encode(clip, text) if text is not None else (None,)
return (embeddings_final, "", model, clip)
log_node_warn(title + "...")
positive_seed = find_wildcards_seed(my_unique_id, text, prompt)
model, clip, text, cond_decode, show_prompt, pipe_lora_stack = process_with_loras(
text, model, clip, type, positive_seed, can_load_lora, lora_stack, easyCache)
wildcard_prompt = cond_decode if show_prompt or styles_selector else ""
clipped = clip.clone()
# 当clip模型不存在t5xxl时可执行跳过层
if not hasattr(clip.cond_stage_model, 't5xxl'):
if clip_skip != 0:
clipped.clip_layer(clip_skip)
steps = steps if steps is not None else find_nearest_steps(my_unique_id, prompt)
return (advanced_encode(clipped, text, prompt_token_normalization,
prompt_weight_interpretation, w_max=1.0,
apply_to_pooled='enable',
a1111_prompt_style=a1111_prompt_style, steps=steps) if text is not None else None, wildcard_prompt, model, clipped)
def set_cond(old_cond, new_cond, mode, average_strength, old_cond_start, old_cond_end, new_cond_start, new_cond_end):
if not old_cond:
return new_cond
else:
if mode == "replace":
return new_cond
elif mode == "concat":
return ConditioningConcat().concat(new_cond, old_cond)[0]
elif mode == "combine":
return ConditioningCombine().combine(old_cond, new_cond)[0]
elif mode == 'average':
return ConditioningAverage().addWeighted(new_cond, old_cond, average_strength)[0]
elif mode == 'timestep':
cond_1 = ConditioningSetTimestepRange().set_range(old_cond, old_cond_start, old_cond_end)[0]
cond_2 = ConditioningSetTimestepRange().set_range(new_cond, new_cond_start, new_cond_end)[0]
return ConditioningCombine().combine(cond_1, cond_2)[0]

View File

@@ -0,0 +1,93 @@
import folder_paths
import comfy.controlnet
import comfy.model_management
from nodes import NODE_CLASS_MAPPINGS
union_controlnet_types = {"auto": -1, "openpose": 0, "depth": 1, "hed/pidi/scribble/ted": 2, "canny/lineart/anime_lineart/mlsd": 3, "normal": 4, "segment": 5, "tile": 6, "repaint": 7}
class easyControlnet:
def __init__(self):
pass
def apply(self, control_net_name, image, positive, negative, strength, start_percent=0, end_percent=1, control_net=None, scale_soft_weights=1, mask=None, union_type=None, easyCache=None, use_cache=True, model=None, vae=None):
if strength == 0:
return (positive, negative)
# kolors controlnet patch
from ..modules.kolors.loader import is_kolors_model, applyKolorsUnet
if is_kolors_model(model):
from ..modules.kolors.model_patch import patch_controlnet
if control_net is None:
with applyKolorsUnet():
control_net = easyCache.load_controlnet(control_net_name, scale_soft_weights, use_cache)
control_net = patch_controlnet(model, control_net)
else:
if control_net is None:
if easyCache is not None:
control_net = easyCache.load_controlnet(control_net_name, scale_soft_weights, use_cache)
else:
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
control_net = comfy.controlnet.load_controlnet(controlnet_path)
# union controlnet
if union_type is not None:
control_net = control_net.copy()
type_number = union_controlnet_types[union_type]
if type_number >= 0:
control_net.set_extra_arg("control_type", [type_number])
else:
control_net.set_extra_arg("control_type", [])
if mask is not None:
mask = mask.to(self.device)
if mask is not None and len(mask.shape) < 3:
mask = mask.unsqueeze(0)
control_hint = image.movedim(-1, 1)
is_cond = True
if negative is None:
p = []
for t in positive:
n = [t[0], t[1].copy()]
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
if 'control' in t[1]:
c_net.set_previous_controlnet(t[1]['control'])
n[1]['control'] = c_net
n[1]['control_apply_to_uncond'] = True
if mask is not None:
n[1]['mask'] = mask
n[1]['set_area_to_bounds'] = False
p.append(n)
positive = p
else:
cnets = {}
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
prev_cnet = d.get('control', None)
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net
d['control'] = c_net
d['control_apply_to_uncond'] = False
if mask is not None:
d['mask'] = mask
d['set_area_to_bounds'] = False
n = [t[0], d]
c.append(n)
out.append(c)
positive = out[0]
negative = out[1]
return (positive, negative)

View File

@@ -0,0 +1,167 @@
import torch, math
######################### DynThresh Core #########################
class DynThresh:
Modes = ["Constant", "Linear Down", "Cosine Down", "Half Cosine Down", "Linear Up", "Cosine Up", "Half Cosine Up", "Power Up", "Power Down", "Linear Repeating", "Cosine Repeating", "Sawtooth"]
Startpoints = ["MEAN", "ZERO"]
Variabilities = ["AD", "STD"]
def __init__(self, mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, experiment_mode, max_steps, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi):
self.mimic_scale = mimic_scale
self.threshold_percentile = threshold_percentile
self.mimic_mode = mimic_mode
self.cfg_mode = cfg_mode
self.max_steps = max_steps
self.cfg_scale_min = cfg_scale_min
self.mimic_scale_min = mimic_scale_min
self.experiment_mode = experiment_mode
self.sched_val = sched_val
self.sep_feat_channels = separate_feature_channels
self.scaling_startpoint = scaling_startpoint
self.variability_measure = variability_measure
self.interpolate_phi = interpolate_phi
def interpret_scale(self, scale, mode, min):
scale -= min
max = self.max_steps - 1
frac = self.step / max
if mode == "Constant":
pass
elif mode == "Linear Down":
scale *= 1.0 - frac
elif mode == "Half Cosine Down":
scale *= math.cos(frac)
elif mode == "Cosine Down":
scale *= math.cos(frac * 1.5707)
elif mode == "Linear Up":
scale *= frac
elif mode == "Half Cosine Up":
scale *= 1.0 - math.cos(frac)
elif mode == "Cosine Up":
scale *= 1.0 - math.cos(frac * 1.5707)
elif mode == "Power Up":
scale *= math.pow(frac, self.sched_val)
elif mode == "Power Down":
scale *= 1.0 - math.pow(frac, self.sched_val)
elif mode == "Linear Repeating":
portion = (frac * self.sched_val) % 1.0
scale *= (0.5 - portion) * 2 if portion < 0.5 else (portion - 0.5) * 2
elif mode == "Cosine Repeating":
scale *= math.cos(frac * 6.28318 * self.sched_val) * 0.5 + 0.5
elif mode == "Sawtooth":
scale *= (frac * self.sched_val) % 1.0
scale += min
return scale
def dynthresh(self, cond, uncond, cfg_scale, weights):
mimic_scale = self.interpret_scale(self.mimic_scale, self.mimic_mode, self.mimic_scale_min)
cfg_scale = self.interpret_scale(cfg_scale, self.cfg_mode, self.cfg_scale_min)
# uncond shape is (batch, 4, height, width)
conds_per_batch = cond.shape[0] / uncond.shape[0]
assert conds_per_batch == int(conds_per_batch), "Expected # of conds per batch to be constant across batches"
cond_stacked = cond.reshape((-1, int(conds_per_batch)) + uncond.shape[1:])
### Normal first part of the CFG Scale logic, basically
diff = cond_stacked - uncond.unsqueeze(1)
if weights is not None:
diff = diff * weights
relative = diff.sum(1)
### Get the normal result for both mimic and normal scale
mim_target = uncond + relative * mimic_scale
cfg_target = uncond + relative * cfg_scale
### If we weren't doing mimic scale, we'd just return cfg_target here
### Now recenter the values relative to their average rather than absolute, to allow scaling from average
mim_flattened = mim_target.flatten(2)
cfg_flattened = cfg_target.flatten(2)
mim_means = mim_flattened.mean(dim=2).unsqueeze(2)
cfg_means = cfg_flattened.mean(dim=2).unsqueeze(2)
mim_centered = mim_flattened - mim_means
cfg_centered = cfg_flattened - cfg_means
if self.sep_feat_channels:
if self.variability_measure == 'STD':
mim_scaleref = mim_centered.std(dim=2).unsqueeze(2)
cfg_scaleref = cfg_centered.std(dim=2).unsqueeze(2)
else: # 'AD'
mim_scaleref = mim_centered.abs().max(dim=2).values.unsqueeze(2)
cfg_scaleref = torch.quantile(cfg_centered.abs(), self.threshold_percentile, dim=2).unsqueeze(2)
else:
if self.variability_measure == 'STD':
mim_scaleref = mim_centered.std()
cfg_scaleref = cfg_centered.std()
else: # 'AD'
mim_scaleref = mim_centered.abs().max()
cfg_scaleref = torch.quantile(cfg_centered.abs(), self.threshold_percentile)
if self.scaling_startpoint == 'ZERO':
scaling_factor = mim_scaleref / cfg_scaleref
result = cfg_flattened * scaling_factor
else: # 'MEAN'
if self.variability_measure == 'STD':
cfg_renormalized = (cfg_centered / cfg_scaleref) * mim_scaleref
else: # 'AD'
### Get the maximum value of all datapoints (with an optional threshold percentile on the uncond)
max_scaleref = torch.maximum(mim_scaleref, cfg_scaleref)
### Clamp to the max
cfg_clamped = cfg_centered.clamp(-max_scaleref, max_scaleref)
### Now shrink from the max to normalize and grow to the mimic scale (instead of the CFG scale)
cfg_renormalized = (cfg_clamped / max_scaleref) * mim_scaleref
### Now add it back onto the averages to get into real scale again and return
result = cfg_renormalized + cfg_means
actual_res = result.unflatten(2, mim_target.shape[2:])
if self.interpolate_phi != 1.0:
actual_res = actual_res * self.interpolate_phi + cfg_target * (1.0 - self.interpolate_phi)
if self.experiment_mode == 1:
num = actual_res.cpu().numpy()
for y in range(0, 64):
for x in range (0, 64):
if num[0][0][y][x] > 1.0:
num[0][1][y][x] *= 0.5
if num[0][1][y][x] > 1.0:
num[0][1][y][x] *= 0.5
if num[0][2][y][x] > 1.5:
num[0][2][y][x] *= 0.5
actual_res = torch.from_numpy(num).to(device=uncond.device)
elif self.experiment_mode == 2:
num = actual_res.cpu().numpy()
for y in range(0, 64):
for x in range (0, 64):
over_scale = False
for z in range(0, 4):
if abs(num[0][z][y][x]) > 1.5:
over_scale = True
if over_scale:
for z in range(0, 4):
num[0][z][y][x] *= 0.7
actual_res = torch.from_numpy(num).to(device=uncond.device)
elif self.experiment_mode == 3:
coefs = torch.tensor([
# R G B W
[0.298, 0.207, 0.208, 0.0], # L1
[0.187, 0.286, 0.173, 0.0], # L2
[-0.158, 0.189, 0.264, 0.0], # L3
[-0.184, -0.271, -0.473, 1.0], # L4
], device=uncond.device)
res_rgb = torch.einsum("laxy,ab -> lbxy", actual_res, coefs)
max_r, max_g, max_b, max_w = res_rgb[0][0].max(), res_rgb[0][1].max(), res_rgb[0][2].max(), res_rgb[0][3].max()
max_rgb = max(max_r, max_g, max_b)
print(f"test max = r={max_r}, g={max_g}, b={max_b}, w={max_w}, rgb={max_rgb}")
if self.step / (self.max_steps - 1) > 0.2:
if max_rgb < 2.0 and max_w < 3.0:
res_rgb /= max_rgb / 2.4
else:
if max_rgb > 2.4 and max_w > 3.0:
res_rgb /= max_rgb / 2.4
actual_res = torch.einsum("laxy,ab -> lbxy", res_rgb, coefs.inverse())
return actual_res

View File

@@ -0,0 +1,27 @@
@staticmethod
def easyIn(t: float)-> float:
return t*t
@staticmethod
def easyOut(t: float)-> float:
return -(t * (t - 2))
@staticmethod
def easyInOut(t: float)-> float:
if t < 0.5:
return 2*t*t
else:
return (-2*t*t) + (4*t) - 1
class EasingBase:
def easing(self, t: float, function='linear') -> float:
if function == 'easyIn':
return easyIn(t)
elif function == 'easyOut':
return easyOut(t)
elif function == 'easyInOut':
return easyInOut(t)
else:
return t
def ease(self, start, end, t) -> float:
return end * t + start * (1 - t)

View File

@@ -0,0 +1,273 @@
import torch
from torchvision.transforms.functional import gaussian_blur
from comfy.k_diffusion.sampling import default_noise_sampler, get_ancestral_step, to_d, BrownianTreeNoiseSampler
from tqdm.auto import trange
@torch.no_grad()
def sample_euler_ancestral(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
eta=1.0,
s_noise=1.0,
noise_sampler=None,
upscale_ratio=2.0,
start_step=5,
end_step=15,
upscale_n_step=3,
unsharp_kernel_size=3,
unsharp_sigma=0.5,
unsharp_strength=0.0,
):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
# make upscale info
upscale_steps = []
step = start_step - 1
while step < end_step - 1:
upscale_steps.append(step)
step += upscale_n_step
height, width = x.shape[2:]
upscale_shapes = [
(int(height * (((upscale_ratio - 1) / i) + 1)), int(width * (((upscale_ratio - 1) / i) + 1)))
for i in reversed(range(1, len(upscale_steps) + 1))
]
upscale_info = {k: v for k, v in zip(upscale_steps, upscale_shapes)}
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
if sigmas[i + 1] > 0:
# Resize
if i in upscale_info:
x = torch.nn.functional.interpolate(x, size=upscale_info[i], mode="bicubic", align_corners=False)
if unsharp_strength > 0:
blurred = gaussian_blur(x, kernel_size=unsharp_kernel_size, sigma=unsharp_sigma)
x = x + unsharp_strength * (x - blurred)
noise_sampler = default_noise_sampler(x)
noise = noise_sampler(sigmas[i], sigmas[i + 1])
x = x + noise * sigma_up * s_noise
return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
eta=1.0,
s_noise=1.0,
noise_sampler=None,
upscale_ratio=2.0,
start_step=5,
end_step=15,
upscale_n_step=3,
unsharp_kernel_size=3,
unsharp_sigma=0.5,
unsharp_strength=0.0,
):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
# make upscale info
upscale_steps = []
step = start_step - 1
while step < end_step - 1:
upscale_steps.append(step)
step += upscale_n_step
height, width = x.shape[2:]
upscale_shapes = [
(int(height * (((upscale_ratio - 1) / i) + 1)), int(width * (((upscale_ratio - 1) / i) + 1)))
for i in reversed(range(1, len(upscale_steps) + 1))
]
upscale_info = {k: v for k, v in zip(upscale_steps, upscale_shapes)}
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
if sigma_down == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver++(2S)
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
r = 1 / 2
h = t_next - t
s = t + r * h
x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
# Noise addition
if sigmas[i + 1] > 0:
# Resize
if i in upscale_info:
x = torch.nn.functional.interpolate(x, size=upscale_info[i], mode="bicubic", align_corners=False)
if unsharp_strength > 0:
blurred = gaussian_blur(x, kernel_size=unsharp_kernel_size, sigma=unsharp_sigma)
x = x + unsharp_strength * (x - blurred)
noise_sampler = default_noise_sampler(x)
noise = noise_sampler(sigmas[i], sigmas[i + 1])
x = x + noise * sigma_up * s_noise
return x
@torch.no_grad()
def sample_dpmpp_2m_sde(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
eta=1.0,
s_noise=1.0,
noise_sampler=None,
solver_type="midpoint",
upscale_ratio=2.0,
start_step=5,
end_step=15,
upscale_n_step=3,
unsharp_kernel_size=3,
unsharp_sigma=0.5,
unsharp_strength=0.0,
):
"""DPM-Solver++(2M) SDE."""
if solver_type not in {"heun", "midpoint"}:
raise ValueError("solver_type must be 'heun' or 'midpoint'")
seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
old_denoised = None
h_last = None
h = None
# make upscale info
upscale_steps = []
step = start_step - 1
while step < end_step - 1:
upscale_steps.append(step)
step += upscale_n_step
height, width = x.shape[2:]
upscale_shapes = [
(int(height * (((upscale_ratio - 1) / i) + 1)), int(width * (((upscale_ratio - 1) / i) + 1)))
for i in reversed(range(1, len(upscale_steps) + 1))
]
upscale_info = {k: v for k, v in zip(upscale_steps, upscale_shapes)}
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
# DPM-Solver++(2M) SDE
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
eta_h = eta * h
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
if old_denoised is not None:
r = h_last / h
if solver_type == "heun":
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
elif solver_type == "midpoint":
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
if eta:
# Resize
if i in upscale_info:
x = torch.nn.functional.interpolate(x, size=upscale_info[i], mode="bicubic", align_corners=False)
if unsharp_strength > 0:
blurred = gaussian_blur(x, kernel_size=unsharp_kernel_size, sigma=unsharp_sigma)
x = x + unsharp_strength * (x - blurred)
denoised = None # 次ステップとサイズがあわないのでとりあえずNoneにしておく。
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True)
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
old_denoised = denoised
h_last = h
return x
@torch.no_grad()
def sample_lcm(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
noise_sampler=None,
eta=None,
s_noise=None,
upscale_ratio=2.0,
start_step=5,
end_step=15,
upscale_n_step=3,
unsharp_kernel_size=3,
unsharp_sigma=0.5,
unsharp_strength=0.0,
):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
# make upscale info
upscale_steps = []
step = start_step - 1
while step < end_step - 1:
upscale_steps.append(step)
step += upscale_n_step
height, width = x.shape[2:]
upscale_shapes = [
(int(height * (((upscale_ratio - 1) / i) + 1)), int(width * (((upscale_ratio - 1) / i) + 1)))
for i in reversed(range(1, len(upscale_steps) + 1))
]
upscale_info = {k: v for k, v in zip(upscale_steps, upscale_shapes)}
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
x = denoised
if sigmas[i + 1] > 0:
# Resize
if i in upscale_info:
x = torch.nn.functional.interpolate(x, size=upscale_info[i], mode="bicubic", align_corners=False)
if unsharp_strength > 0:
blurred = gaussian_blur(x, kernel_size=unsharp_kernel_size, sigma=unsharp_sigma)
x = x + unsharp_strength * (x - blurred)
noise_sampler = default_noise_sampler(x)
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
return x

View File

@@ -0,0 +1,227 @@
import os
import base64
import torch
import numpy as np
from enum import Enum
from PIL import Image
from io import BytesIO
from typing import List, Union
import folder_paths
from .utils import install_package
# PIL to Tensor
def pil2tensor(image):
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
# Tensor to PIL
def tensor2pil(image):
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
# np to Tensor
def np2tensor(img_np: Union[np.ndarray, List[np.ndarray]]) -> torch.Tensor:
if isinstance(img_np, list):
return torch.cat([np2tensor(img) for img in img_np], dim=0)
return torch.from_numpy(img_np.astype(np.float32) / 255.0).unsqueeze(0)
# Tensor to np
def tensor2np(tensor: torch.Tensor) -> List[np.ndarray]:
if len(tensor.shape) == 3: # Single image
return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8)
else: # Batch of images
return [np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor]
def pil2byte(pil_image, format='PNG'):
byte_arr = BytesIO()
pil_image.save(byte_arr, format=format)
byte_arr.seek(0)
return byte_arr
def image2base64(image_base64):
image_bytes = base64.b64decode(image_base64)
image_data = Image.open(BytesIO(image_bytes))
return image_data
# Get new bounds
def get_new_bounds(width, height, left, right, top, bottom):
"""Returns the new bounds for an image with inset crop data."""
left = 0 + left
right = width - right
top = 0 + top
bottom = height - bottom
return (left, right, top, bottom)
def RGB2RGBA(image: Image, mask: Image) -> Image:
(R, G, B) = image.convert('RGB').split()
return Image.merge('RGBA', (R, G, B, mask.convert('L')))
def image2mask(image: Image) -> torch.Tensor:
_image = image.convert('RGBA')
alpha = _image.split()[0]
bg = Image.new("L", _image.size)
_image = Image.merge('RGBA', (bg, bg, bg, alpha))
ret_mask = torch.tensor([pil2tensor(_image)[0, :, :, 3].tolist()])
return ret_mask
def mask2image(mask: torch.Tensor) -> Image:
masks = tensor2np(mask)
for m in masks:
_mask = Image.fromarray(m).convert("L")
_image = Image.new("RGBA", _mask.size, color='white')
_image = Image.composite(
_image, Image.new("RGBA", _mask.size, color='black'), _mask)
return _image
# 图像融合
class blendImage:
def g(self, x):
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
def blend_mode(self, img1, img2, mode):
if mode == "normal":
return img2
elif mode == "multiply":
return img1 * img2
elif mode == "screen":
return 1 - (1 - img1) * (1 - img2)
elif mode == "overlay":
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
elif mode == "soft_light":
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1),
img1 + (2 * img2 - 1) * (self.g(img1) - img1))
elif mode == "difference":
return img1 - img2
else:
raise ValueError(f"Unsupported blend mode: {mode}")
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str = 'normal'):
image2 = image2.to(image1.device)
if image1.shape != image2.shape:
image2 = image2.permute(0, 3, 1, 2)
image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic',
crop='center')
image2 = image2.permute(0, 2, 3, 1)
blended_image = self.blend_mode(image1, image2, blend_mode)
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
blended_image = torch.clamp(blended_image, 0, 1)
return blended_image
def empty_image(width, height, batch_size=1, color=0):
r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF)
g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF)
b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF)
return torch.cat((r, g, b), dim=-1)
class ResizeMode(Enum):
RESIZE = "Just Resize"
INNER_FIT = "Crop and Resize"
OUTER_FIT = "Resize and Fill"
def int_value(self):
if self == ResizeMode.RESIZE:
return 0
elif self == ResizeMode.INNER_FIT:
return 1
elif self == ResizeMode.OUTER_FIT:
return 2
assert False, "NOTREACHED"
# credit by https://github.com/chflame163/ComfyUI_LayerStyle/blob/main/py/imagefunc.py#L591C1-L617C22
def fit_resize_image(image: Image, target_width: int, target_height: int, fit: str, resize_sampler: str,
background_color: str = '#000000') -> Image:
image = image.convert('RGB')
orig_width, orig_height = image.size
if image is not None:
if fit == 'letterbox':
if orig_width / orig_height > target_width / target_height: # 更宽,上下留黑
fit_width = target_width
fit_height = int(target_width / orig_width * orig_height)
else: # 更瘦,左右留黑
fit_height = target_height
fit_width = int(target_height / orig_height * orig_width)
fit_image = image.resize((fit_width, fit_height), resize_sampler)
ret_image = Image.new('RGB', size=(target_width, target_height), color=background_color)
ret_image.paste(fit_image, box=((target_width - fit_width) // 2, (target_height - fit_height) // 2))
elif fit == 'crop':
if orig_width / orig_height > target_width / target_height: # 更宽,裁左右
fit_width = int(orig_height * target_width / target_height)
fit_image = image.crop(
((orig_width - fit_width) // 2, 0, (orig_width - fit_width) // 2 + fit_width, orig_height))
else: # 更瘦,裁上下
fit_height = int(orig_width * target_height / target_width)
fit_image = image.crop(
(0, (orig_height - fit_height) // 2, orig_width, (orig_height - fit_height) // 2 + fit_height))
ret_image = fit_image.resize((target_width, target_height), resize_sampler)
else:
ret_image = image.resize((target_width, target_height), resize_sampler)
return ret_image
# CLIP反推
import comfy.utils
from torchvision import transforms
Config, Interrogator = None, None
class CI_Inference:
ci_model = None
cache_path: str
def __init__(self):
self.ci_model = None
self.low_vram = False
self.cache_path = os.path.join(folder_paths.models_dir, "clip_interrogator")
def _load_model(self, model_name, low_vram=False):
if not (self.ci_model and model_name == self.ci_model.config.clip_model_name and self.low_vram == low_vram):
self.low_vram = low_vram
print(f"Load model: {model_name}")
config = Config(
device="cuda" if torch.cuda.is_available() else "cpu",
download_cache=True,
clip_model_name=model_name,
clip_model_path=self.cache_path,
cache_path=self.cache_path,
caption_model_name='blip-large'
)
if low_vram:
config.apply_low_vram_defaults()
self.ci_model = Interrogator(config)
def _interrogate(self, image, mode, caption=None):
if mode == 'best':
prompt = self.ci_model.interrogate(image, caption=caption)
elif mode == 'classic':
prompt = self.ci_model.interrogate_classic(image, caption=caption)
elif mode == 'fast':
prompt = self.ci_model.interrogate_fast(image, caption=caption)
elif mode == 'negative':
prompt = self.ci_model.interrogate_negative(image)
else:
raise Exception(f"Unknown mode {mode}")
return prompt
def image_to_prompt(self, image, mode, model_name='ViT-L-14/openai', low_vram=False):
try:
from clip_interrogator import Config, Interrogator
global Config, Interrogator
except:
install_package("clip_interrogator", "0.6.0")
from clip_interrogator import Config, Interrogator
pbar = comfy.utils.ProgressBar(len(image))
self._load_model(model_name, low_vram)
prompt = []
for i in range(len(image)):
im = image[i]
im = tensor2pil(im)
im = im.convert('RGB')
_prompt = self._interrogate(im, mode)
pbar.update(1)
prompt.append(_prompt)
return prompt
ci = CI_Inference()

View File

@@ -0,0 +1,237 @@
import math
import torch
import comfy
def extra_options_to_module_prefix(extra_options):
# extra_options = {'transformer_index': 2, 'block_index': 8, 'original_shape': [2, 4, 128, 128], 'block': ('input', 7), 'n_heads': 20, 'dim_head': 64}
# block is: [('input', 4), ('input', 5), ('input', 7), ('input', 8), ('middle', 0),
# ('output', 0), ('output', 1), ('output', 2), ('output', 3), ('output', 4), ('output', 5)]
# transformer_index is: [0, 1, 2, 3, 4, 5, 6, 7, 8], for each block
# block_index is: 0-1 or 0-9, depends on the block
# input 7 and 8, middle has 10 blocks
# make module name from extra_options
block = extra_options["block"]
block_index = extra_options["block_index"]
if block[0] == "input":
module_pfx = f"lllite_unet_input_blocks_{block[1]}_1_transformer_blocks_{block_index}"
elif block[0] == "middle":
module_pfx = f"lllite_unet_middle_block_1_transformer_blocks_{block_index}"
elif block[0] == "output":
module_pfx = f"lllite_unet_output_blocks_{block[1]}_1_transformer_blocks_{block_index}"
else:
raise Exception("invalid block name")
return module_pfx
def load_control_net_lllite_patch(path, cond_image, multiplier, num_steps, start_percent, end_percent):
# calculate start and end step
start_step = math.floor(num_steps * start_percent * 0.01) if start_percent > 0 else 0
end_step = math.floor(num_steps * end_percent * 0.01) if end_percent > 0 else num_steps
# load weights
ctrl_sd = comfy.utils.load_torch_file(path, safe_load=True)
# split each weights for each module
module_weights = {}
for key, value in ctrl_sd.items():
fragments = key.split(".")
module_name = fragments[0]
weight_name = ".".join(fragments[1:])
if module_name not in module_weights:
module_weights[module_name] = {}
module_weights[module_name][weight_name] = value
# load each module
modules = {}
for module_name, weights in module_weights.items():
# ここの自動判定を何とかしたい
if "conditioning1.4.weight" in weights:
depth = 3
elif weights["conditioning1.2.weight"].shape[-1] == 4:
depth = 2
else:
depth = 1
module = LLLiteModule(
name=module_name,
is_conv2d=weights["down.0.weight"].ndim == 4,
in_dim=weights["down.0.weight"].shape[1],
depth=depth,
cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2,
mlp_dim=weights["down.0.weight"].shape[0],
multiplier=multiplier,
num_steps=num_steps,
start_step=start_step,
end_step=end_step,
)
info = module.load_state_dict(weights)
modules[module_name] = module
if len(modules) == 1:
module.is_first = True
print(f"loaded {path} successfully, {len(modules)} modules")
# cond imageをセットする
cond_image = cond_image.permute(0, 3, 1, 2) # b,h,w,3 -> b,3,h,w
cond_image = cond_image * 2.0 - 1.0 # 0-1 -> -1-+1
for module in modules.values():
module.set_cond_image(cond_image)
class control_net_lllite_patch:
def __init__(self, modules):
self.modules = modules
def __call__(self, q, k, v, extra_options):
module_pfx = extra_options_to_module_prefix(extra_options)
is_attn1 = q.shape[-1] == k.shape[-1] # self attention
if is_attn1:
module_pfx = module_pfx + "_attn1"
else:
module_pfx = module_pfx + "_attn2"
module_pfx_to_q = module_pfx + "_to_q"
module_pfx_to_k = module_pfx + "_to_k"
module_pfx_to_v = module_pfx + "_to_v"
if module_pfx_to_q in self.modules:
q = q + self.modules[module_pfx_to_q](q)
if module_pfx_to_k in self.modules:
k = k + self.modules[module_pfx_to_k](k)
if module_pfx_to_v in self.modules:
v = v + self.modules[module_pfx_to_v](v)
return q, k, v
def to(self, device):
for d in self.modules.keys():
self.modules[d] = self.modules[d].to(device)
return self
return control_net_lllite_patch(modules)
class LLLiteModule(torch.nn.Module):
def __init__(
self,
name: str,
is_conv2d: bool,
in_dim: int,
depth: int,
cond_emb_dim: int,
mlp_dim: int,
multiplier: int,
num_steps: int,
start_step: int,
end_step: int,
):
super().__init__()
self.name = name
self.is_conv2d = is_conv2d
self.multiplier = multiplier
self.num_steps = num_steps
self.start_step = start_step
self.end_step = end_step
self.is_first = False
modules = []
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2
if depth == 1:
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
elif depth == 2:
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
elif depth == 3:
# kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
self.conditioning1 = torch.nn.Sequential(*modules)
if self.is_conv2d:
self.down = torch.nn.Sequential(
torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
torch.nn.ReLU(inplace=True),
)
self.mid = torch.nn.Sequential(
torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
torch.nn.ReLU(inplace=True),
)
self.up = torch.nn.Sequential(
torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
)
else:
self.down = torch.nn.Sequential(
torch.nn.Linear(in_dim, mlp_dim),
torch.nn.ReLU(inplace=True),
)
self.mid = torch.nn.Sequential(
torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
torch.nn.ReLU(inplace=True),
)
self.up = torch.nn.Sequential(
torch.nn.Linear(mlp_dim, in_dim),
)
self.depth = depth
self.cond_image = None
self.cond_emb = None
self.current_step = 0
# @torch.inference_mode()
def set_cond_image(self, cond_image):
# print("set_cond_image", self.name)
self.cond_image = cond_image
self.cond_emb = None
self.current_step = 0
def forward(self, x):
if self.num_steps > 0:
if self.current_step < self.start_step:
self.current_step += 1
return torch.zeros_like(x)
elif self.current_step >= self.end_step:
if self.is_first and self.current_step == self.end_step:
print(f"end LLLite: step {self.current_step}")
self.current_step += 1
if self.current_step >= self.num_steps:
self.current_step = 0 # reset
return torch.zeros_like(x)
else:
if self.is_first and self.current_step == self.start_step:
print(f"start LLLite: step {self.current_step}")
self.current_step += 1
if self.current_step >= self.num_steps:
self.current_step = 0 # reset
if self.cond_emb is None:
# print(f"cond_emb is None, {self.name}")
cx = self.conditioning1(self.cond_image.to(x.device, dtype=x.dtype))
if not self.is_conv2d:
# reshape / b,c,h,w -> b,h*w,c
n, c, h, w = cx.shape
cx = cx.view(n, c, h * w).permute(0, 2, 1)
self.cond_emb = cx
cx = self.cond_emb
# print(f"forward {self.name}, {cx.shape}, {x.shape}")
# uncond/condでxはバッチサイズが2倍
if x.shape[0] != cx.shape[0]:
if self.is_conv2d:
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1, 1)
else:
# print("x.shape[0] != cx.shape[0]", x.shape[0], cx.shape[0])
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1)
cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2)
cx = self.mid(cx)
cx = self.up(cx)
return cx * self.multiplier

View File

@@ -0,0 +1,570 @@
import re, time, os, psutil
import folder_paths
import comfy.utils
import comfy.sd
import comfy.controlnet
from comfy.model_patcher import ModelPatcher
from nodes import NODE_CLASS_MAPPINGS
from collections import defaultdict
from .log import log_node_info, log_node_error
from ..modules.dit.pixArt.loader import load_pixart
diffusion_loaders = ["easy fullLoader", "easy a1111Loader", "easy fluxLoader", "easy comfyLoader", "easy hunyuanDiTLoader", "easy zero123Loader", "easy svdLoader"]
stable_cascade_loaders = ["easy cascadeLoader"]
dit_loaders = ['easy pixArtLoader']
controlnet_loaders = ["easy controlnetLoader", "easy controlnetLoaderADV", "easy controlnetLoader++"]
instant_loaders = ["easy instantIDApply", "easy instantIDApplyADV"]
cascade_vae_node = ["easy preSamplingCascade", "easy fullCascadeKSampler"]
model_merge_node = ["easy XYInputs: ModelMergeBlocks"]
lora_widget = ["easy fullLoader", "easy a1111Loader", "easy comfyLoader", "easy fluxLoader"]
class easyLoader:
def __init__(self):
self.loaded_objects = {
"ckpt": defaultdict(tuple), # {ckpt_name: (model, ...)}
"unet": defaultdict(tuple),
"clip": defaultdict(tuple),
"clip_vision": defaultdict(tuple),
"bvae": defaultdict(tuple),
"vae": defaultdict(object),
"lora": defaultdict(dict), # {lora_name: {UID: (model_lora, clip_lora)}}
"controlnet": defaultdict(dict),
"t5": defaultdict(tuple),
"chatglm3": defaultdict(tuple),
}
self.memory_threshold = self.determine_memory_threshold(1)
self.lora_name_cache = []
def clean_values(self, values: str):
original_values = values.split("; ")
cleaned_values = []
for value in original_values:
cleaned_value = value.strip(';').strip()
if cleaned_value == "":
continue
try:
cleaned_value = int(cleaned_value)
except ValueError:
try:
cleaned_value = float(cleaned_value)
except ValueError:
pass
cleaned_values.append(cleaned_value)
return cleaned_values
def clear_unused_objects(self, desired_names: set, object_type: str):
keys = set(self.loaded_objects[object_type].keys())
for key in keys - desired_names:
del self.loaded_objects[object_type][key]
def get_input_value(self, entry, key, prompt=None):
val = entry["inputs"][key]
if isinstance(val, str):
return val
elif isinstance(val, list):
if prompt is not None and val[0]:
return prompt[val[0]]['inputs'][key]
else:
return val[0]
else:
return str(val)
def process_pipe_loader(self, entry, desired_ckpt_names, desired_vae_names, desired_lora_names, desired_lora_settings, num_loras=3, suffix=""):
for idx in range(1, num_loras + 1):
lora_name_key = f"{suffix}lora{idx}_name"
desired_lora_names.add(self.get_input_value(entry, lora_name_key))
setting = f'{self.get_input_value(entry, lora_name_key)};{entry["inputs"][f"{suffix}lora{idx}_model_strength"]};{entry["inputs"][f"{suffix}lora{idx}_clip_strength"]}'
desired_lora_settings.add(setting)
desired_ckpt_names.add(self.get_input_value(entry, f"{suffix}ckpt_name"))
desired_vae_names.add(self.get_input_value(entry, f"{suffix}vae_name"))
def update_loaded_objects(self, prompt):
desired_ckpt_names = set()
desired_unet_names = set()
desired_clip_names = set()
desired_vae_names = set()
desired_lora_names = set()
desired_lora_settings = set()
desired_controlnet_names = set()
desired_t5_names = set()
desired_glm3_names = set()
for entry in prompt.values():
class_type = entry["class_type"]
if class_type in lora_widget:
lora_name = self.get_input_value(entry, "lora_name")
desired_lora_names.add(lora_name)
setting = f'{lora_name};{entry["inputs"]["lora_model_strength"]};{entry["inputs"]["lora_clip_strength"]}'
desired_lora_settings.add(setting)
if class_type in diffusion_loaders:
desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name", prompt))
desired_vae_names.add(self.get_input_value(entry, "vae_name"))
elif class_type in ['easy kolorsLoader']:
desired_unet_names.add(self.get_input_value(entry, "unet_name"))
desired_vae_names.add(self.get_input_value(entry, "vae_name"))
desired_glm3_names.add(self.get_input_value(entry, "chatglm3_name"))
elif class_type in dit_loaders:
t5_name = self.get_input_value(entry, "mt5_name") if "mt5_name" in entry["inputs"] else None
clip_name = self.get_input_value(entry, "clip_name") if "clip_name" in entry["inputs"] else None
model_name = self.get_input_value(entry, "model_name")
ckpt_name = self.get_input_value(entry, "ckpt_name", prompt)
if t5_name:
desired_t5_names.add(t5_name)
if clip_name:
desired_clip_names.add(clip_name)
desired_ckpt_names.add(ckpt_name+'_'+model_name)
elif class_type in stable_cascade_loaders:
desired_unet_names.add(self.get_input_value(entry, "stage_c"))
desired_unet_names.add(self.get_input_value(entry, "stage_b"))
desired_clip_names.add(self.get_input_value(entry, "clip_name"))
desired_vae_names.add(self.get_input_value(entry, "stage_a"))
elif class_type in cascade_vae_node:
encode_vae_name = self.get_input_value(entry, "encode_vae_name")
decode_vae_name = self.get_input_value(entry, "decode_vae_name")
if encode_vae_name and encode_vae_name != 'None':
desired_vae_names.add(encode_vae_name)
if decode_vae_name and decode_vae_name != 'None':
desired_vae_names.add(decode_vae_name)
elif class_type in controlnet_loaders:
control_net_name = self.get_input_value(entry, "control_net_name", prompt)
scale_soft_weights = self.get_input_value(entry, "scale_soft_weights")
desired_controlnet_names.add(f'{control_net_name};{scale_soft_weights}')
elif class_type in instant_loaders:
control_net_name = self.get_input_value(entry, "control_net_name", prompt)
scale_soft_weights = self.get_input_value(entry, "cn_soft_weights")
desired_controlnet_names.add(f'{control_net_name};{scale_soft_weights}')
elif class_type in model_merge_node:
desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name_1"))
desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name_2"))
vae_use = self.get_input_value(entry, "vae_use")
if vae_use != 'Use Model 1' and vae_use != 'Use Model 2':
desired_vae_names.add(vae_use)
object_types = ["ckpt", "unet", "clip", "bvae", "vae", "lora", "controlnet", "t5"]
for object_type in object_types:
if object_type == 'unet':
desired_names = desired_unet_names
elif object_type in ["ckpt", "clip", "bvae"]:
if object_type == 'clip':
desired_names = desired_ckpt_names.union(desired_clip_names)
else:
desired_names = desired_ckpt_names
elif object_type == "vae":
desired_names = desired_vae_names
elif object_type == "controlnet":
desired_names = desired_controlnet_names
elif object_type == "t5":
desired_names = desired_t5_names
elif object_type == "chatglm3":
desired_names = desired_glm3_names
else:
desired_names = desired_lora_names
self.clear_unused_objects(desired_names, object_type)
def add_to_cache(self, obj_type, key, value):
"""
Add an item to the cache with the current timestamp.
"""
timestamped_value = (value, time.time())
self.loaded_objects[obj_type][key] = timestamped_value
def determine_memory_threshold(self, percentage=0.8):
"""
Determines the memory threshold as a percentage of the total available memory.
Args:
- percentage (float): The fraction of total memory to use as the threshold.
Should be a value between 0 and 1. Default is 0.8 (80%).
Returns:
- memory_threshold (int): Memory threshold in bytes.
"""
total_memory = psutil.virtual_memory().total
memory_threshold = total_memory * percentage
return memory_threshold
def get_memory_usage(self):
"""
Returns the memory usage of the current process in bytes.
"""
process = psutil.Process(os.getpid())
return process.memory_info().rss
def eviction_based_on_memory(self):
"""
Evicts objects from cache based on memory usage and priority.
"""
current_memory = self.get_memory_usage()
if current_memory < self.memory_threshold:
return
eviction_order = ["vae", "lora", "bvae", "clip", "ckpt", "controlnet", "unet", "t5", "chatglm3"]
for obj_type in eviction_order:
if current_memory < self.memory_threshold:
break
# Sort items based on age (using the timestamp)
items = list(self.loaded_objects[obj_type].items())
items.sort(key=lambda x: x[1][1]) # Sorting by timestamp
for item in items:
if current_memory < self.memory_threshold:
break
del self.loaded_objects[obj_type][item[0]]
current_memory = self.get_memory_usage()
def load_checkpoint(self, ckpt_name, config_name=None, load_vision=False):
cache_name = ckpt_name
if config_name not in [None, "Default"]:
cache_name = ckpt_name + "_" + config_name
if cache_name in self.loaded_objects["ckpt"]:
clip_vision = self.loaded_objects["clip_vision"][cache_name][0] if load_vision else None
clip = self.loaded_objects["clip"][cache_name][0] if not load_vision else None
return self.loaded_objects["ckpt"][cache_name][0], clip, self.loaded_objects["bvae"][cache_name][0], clip_vision
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
output_clip = False if load_vision else True
output_clipvision = True if load_vision else False
if config_name not in [None, "Default"]:
config_path = folder_paths.get_full_path("configs", config_name)
loaded_ckpt = comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
else:
model_options = {}
if re.search("nf4", ckpt_name):
from ..modules.bitsandbytes_NF4 import OPS
model_options = {"custom_operations": OPS}
loaded_ckpt = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=output_clip, output_clipvision=output_clipvision, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options)
self.add_to_cache("ckpt", cache_name, loaded_ckpt[0])
self.add_to_cache("bvae", cache_name, loaded_ckpt[2])
clip = loaded_ckpt[1]
clip_vision = loaded_ckpt[3]
if clip:
self.add_to_cache("clip", cache_name, clip)
if clip_vision:
self.add_to_cache("clip_vision", cache_name, clip_vision)
self.eviction_based_on_memory()
return loaded_ckpt[0], clip, loaded_ckpt[2], clip_vision
def load_vae(self, vae_name):
if vae_name in self.loaded_objects["vae"]:
return self.loaded_objects["vae"][vae_name][0]
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
loaded_vae = comfy.sd.VAE(sd=sd)
self.add_to_cache("vae", vae_name, loaded_vae)
self.eviction_based_on_memory()
return loaded_vae
def load_unet(self, unet_name):
if unet_name in self.loaded_objects["unet"]:
log_node_info("Load UNet", f"{unet_name} cached")
return self.loaded_objects["unet"][unet_name][0]
unet_path = folder_paths.get_full_path("unet", unet_name)
model = comfy.sd.load_unet(unet_path)
self.add_to_cache("unet", unet_name, model)
self.eviction_based_on_memory()
return model
def load_controlnet(self, control_net_name, scale_soft_weights=1, use_cache=True):
unique_id = f'{control_net_name};{str(scale_soft_weights)}'
if use_cache and unique_id in self.loaded_objects["controlnet"]:
return self.loaded_objects["controlnet"][unique_id][0]
if scale_soft_weights < 1:
if "ScaledSoftControlNetWeights" in NODE_CLASS_MAPPINGS:
soft_weight_cls = NODE_CLASS_MAPPINGS['ScaledSoftControlNetWeights']
(weights, timestep_keyframe) = soft_weight_cls().load_weights(scale_soft_weights, False)
cn_adv_cls = NODE_CLASS_MAPPINGS['ControlNetLoaderAdvanced']
control_net, = cn_adv_cls().load_controlnet(control_net_name, timestep_keyframe)
else:
raise Exception(f"[Advanced-ControlNet Not Found] you need to install 'COMFYUI-Advanced-ControlNet'")
else:
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
control_net = comfy.controlnet.load_controlnet(controlnet_path)
if use_cache:
self.add_to_cache("controlnet", unique_id, control_net)
self.eviction_based_on_memory()
return control_net
def load_clip(self, clip_name, type='stable_diffusion', load_clip=None):
if clip_name in self.loaded_objects["clip"]:
return self.loaded_objects["clip"][clip_name][0]
if type == 'stable_diffusion':
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == 'stable_cascade':
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
elif type == 'sd3':
clip_type = comfy.sd.CLIPType.SD3
elif type == 'flux':
clip_type = comfy.sd.CLIPType.FLUX
elif type == 'stable_audio':
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
clip_path = folder_paths.get_full_path("clip", clip_name)
load_clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
self.add_to_cache("clip", clip_name, load_clip)
self.eviction_based_on_memory()
return load_clip
def load_lora(self, lora, model=None, clip=None, type=None , use_cache=True):
lora_name = lora["lora_name"]
model = model if model is not None else lora["model"]
clip = clip if clip is not None else lora["clip"]
model_strength = lora["model_strength"]
clip_strength = lora["clip_strength"]
lbw = lora["lbw"] if "lbw" in lora else None
lbw_a = lora["lbw_a"] if "lbw_a" in lora else None
lbw_b = lora["lbw_b"] if "lbw_b" in lora else None
model_hash = str(model)[44:-1]
clip_hash = str(clip)[25:-1] if clip else ''
unique_id = f'{model_hash};{clip_hash};{lora_name};{model_strength};{clip_strength}'
if use_cache and unique_id in self.loaded_objects["lora"]:
log_node_info("Load LORA",f"{lora_name} cached")
return self.loaded_objects["lora"][unique_id][0]
orig_lora_name = lora_name
lora_name = self.resolve_lora_name(lora_name)
if lora_name is not None:
lora_path = folder_paths.get_full_path("loras", lora_name)
else:
lora_path = None
if lora_path is not None:
log_node_info("Load LORA",f"{lora_name}: model={model_strength:.3f}, clip={clip_strength:.3f}, LBW={lbw}, A={lbw_a}, B={lbw_b}")
if lbw:
lbw = lora["lbw"]
lbw_a = lora["lbw_a"]
lbw_b = lora["lbw_b"]
if 'LoraLoaderBlockWeight //Inspire' not in NODE_CLASS_MAPPINGS:
raise Exception('[InspirePack Not Found] you need to install ComfyUI-Inspire-Pack')
cls = NODE_CLASS_MAPPINGS['LoraLoaderBlockWeight //Inspire']
model, clip, _ = cls().doit(model, clip, lora_name, model_strength, clip_strength, False, 0,
lbw_a, lbw_b, "", lbw)
else:
_lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
keys = _lora.keys()
if "down_blocks.0.resnets.0.norm1.bias" in keys:
print('Using LORA for Resadapter')
key_map = {}
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
mapping_norm = {}
for key in keys:
if ".weight" in key:
key_name_in_ori_sd = key_map[key.replace(".weight", "")]
mapping_norm[key_name_in_ori_sd] = _lora[key]
elif ".bias" in key:
key_name_in_ori_sd = key_map[key.replace(".bias", "")]
mapping_norm[key_name_in_ori_sd.replace(".weight", ".bias")] = _lora[
key
]
else:
print("===>Unexpected key", key)
mapping_norm[key] = _lora[key]
for k in mapping_norm.keys():
if k not in model.model.state_dict():
print("===>Missing key:", k)
model.model.load_state_dict(mapping_norm, strict=False)
return (model, clip)
# PixArt
if type is not None and type == 'PixArt':
from ..modules.dit.pixArt.loader import load_pixart_lora
model = load_pixart_lora(model, _lora, lora_path, model_strength)
else:
model, clip = comfy.sd.load_lora_for_models(model, clip, _lora, model_strength, clip_strength)
if use_cache:
self.add_to_cache("lora", unique_id, (model, clip))
self.eviction_based_on_memory()
else:
log_node_error(f"LORA NOT FOUND", orig_lora_name)
return model, clip
def resolve_lora_name(self, name):
if os.path.exists(name):
return name
else:
if len(self.lora_name_cache) == 0:
loras = folder_paths.get_filename_list("loras")
self.lora_name_cache.extend(loras)
for x in self.lora_name_cache:
if x.endswith(name):
return x
# 如果刷新网页后新添加的lora走这个逻辑
log_node_info("LORA NOT IN CACHE", f"{name}")
loras = folder_paths.get_filename_list("loras")
for x in loras:
if x.endswith(name):
self.lora_name_cache.append(x)
return x
return None
def load_main(self, ckpt_name, config_name, vae_name, lora_name, lora_model_strength, lora_clip_strength, optional_lora_stack, model_override, clip_override, vae_override, prompt, nf4=False):
model: ModelPatcher | None = None
clip: comfy.sd.CLIP | None = None
vae: comfy.sd.VAE | None = None
clip_vision = None
lora_stack = []
# Check for model override
can_load_lora = True
# 判断是否存在 模型或Lora叠加xyplot, 若存在优先缓存第一个模型
# Determine whether there is a model or Lora overlapping xyplot, and if there is, prioritize caching the first model.
xy_model_id = next((x for x in prompt if str(prompt[x]["class_type"]) in ["easy XYInputs: ModelMergeBlocks",
"easy XYInputs: Checkpoint"]), None)
# This will find nodes that aren't actively connected to anything, and skip loading lora's for them.
xy_lora_id = next((x for x in prompt if str(prompt[x]["class_type"]) == "easy XYInputs: Lora"), None)
if xy_lora_id is not None:
can_load_lora = False
if xy_model_id is not None:
node = prompt[xy_model_id]
if "ckpt_name_1" in node["inputs"]:
ckpt_name_1 = node["inputs"]["ckpt_name_1"]
model, clip, vae, clip_vision = self.load_checkpoint(ckpt_name_1)
can_load_lora = False
elif model_override is not None and clip_override is not None and vae_override is not None:
model = model_override
clip = clip_override
vae = vae_override
else:
model, clip, vae, clip_vision = self.load_checkpoint(ckpt_name, config_name)
if model_override is not None:
model = model_override
if vae_override is not None:
vae = vae_override
elif clip_override is not None:
clip = clip_override
if optional_lora_stack is not None and can_load_lora:
for lora in optional_lora_stack:
# This is a subtle bit of code because it uses the model created by the last call, and passes it to the next call.
lora = {"lora_name": lora[0], "model": model, "clip": clip, "model_strength": lora[1],
"clip_strength": lora[2]}
model, clip = self.load_lora(lora)
lora['model'] = model
lora['clip'] = clip
lora_stack.append(lora)
if lora_name != "None" and can_load_lora:
lora = {"lora_name": lora_name, "model": model, "clip": clip, "model_strength": lora_model_strength,
"clip_strength": lora_clip_strength}
model, clip = self.load_lora(lora)
lora_stack.append(lora)
# Check for custom VAE
if vae_name not in ["Baked VAE", "Baked-VAE"]:
vae = self.load_vae(vae_name)
# CLIP skip
if not clip:
raise Exception("No CLIP found")
return model, clip, vae, clip_vision, lora_stack
# Kolors
def load_kolors_unet(self, unet_name):
if unet_name in self.loaded_objects["unet"]:
log_node_info("Load Kolors UNet", f"{unet_name} cached")
return self.loaded_objects["unet"][unet_name][0]
else:
from ..modules.kolors.loader import applyKolorsUnet
with applyKolorsUnet():
unet_path = folder_paths.get_full_path("unet", unet_name)
sd = comfy.utils.load_torch_file(unet_path)
model = comfy.sd.load_unet_state_dict(sd)
if model is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
self.add_to_cache("unet", unet_name, model)
self.eviction_based_on_memory()
return model
def load_chatglm3(self, chatglm3_name):
from ..modules.kolors.loader import load_chatglm3
if chatglm3_name in self.loaded_objects["chatglm3"]:
log_node_info("Load ChatGLM3", f"{chatglm3_name} cached")
return self.loaded_objects["chatglm3"][chatglm3_name][0]
chatglm_model = load_chatglm3(model_path=folder_paths.get_full_path("llm", chatglm3_name))
self.add_to_cache("chatglm3", chatglm3_name, chatglm_model)
self.eviction_based_on_memory()
return chatglm_model
# DiT
def load_dit_ckpt(self, ckpt_name, model_name, **kwargs):
if (ckpt_name+'_'+model_name) in self.loaded_objects["ckpt"]:
return self.loaded_objects["ckpt"][ckpt_name+'_'+model_name][0]
model = None
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
model_type = kwargs['model_type'] if "model_type" in kwargs else 'PixArt'
if model_type == 'PixArt':
pixart_conf = kwargs['pixart_conf']
model_conf = pixart_conf[model_name]
model = load_pixart(ckpt_path, model_conf)
if model:
self.add_to_cache("ckpt", ckpt_name + '_' + model_name, model)
self.eviction_based_on_memory()
return model
def load_t5_from_sd3_clip(self, sd3_clip, padding):
try:
from comfy.text_encoders.sd3_clip import SD3Tokenizer, SD3ClipModel
except:
from comfy.sd3_clip import SD3Tokenizer, SD3ClipModel
import copy
clip = sd3_clip.clone()
assert clip.cond_stage_model.t5xxl is not None, "CLIP must have T5 loaded!"
# remove transformer
transformer = clip.cond_stage_model.t5xxl.transformer
clip.cond_stage_model.t5xxl.transformer = None
# clone object
tmp = SD3ClipModel(clip_l=False, clip_g=False, t5=False)
tmp.t5xxl = copy.deepcopy(clip.cond_stage_model.t5xxl)
# put transformer back
clip.cond_stage_model.t5xxl.transformer = transformer
tmp.t5xxl.transformer = transformer
# override special tokens
tmp.t5xxl.special_tokens = copy.deepcopy(clip.cond_stage_model.t5xxl.special_tokens)
tmp.t5xxl.special_tokens.pop("end") # make sure empty tokens match
# tokenizer
tok = SD3Tokenizer()
tok.t5xxl.min_length = padding
clip.cond_stage_model = tmp
clip.tokenizer = tok
return clip

View File

@@ -0,0 +1,77 @@
COLORS_FG = {
'BLACK': '\33[30m',
'RED': '\33[31m',
'GREEN': '\33[32m',
'YELLOW': '\33[33m',
'BLUE': '\33[34m',
'MAGENTA': '\33[35m',
'CYAN': '\33[36m',
'WHITE': '\33[37m',
'GREY': '\33[90m',
'BRIGHT_RED': '\33[91m',
'BRIGHT_GREEN': '\33[92m',
'BRIGHT_YELLOW': '\33[93m',
'BRIGHT_BLUE': '\33[94m',
'BRIGHT_MAGENTA': '\33[95m',
'BRIGHT_CYAN': '\33[96m',
'BRIGHT_WHITE': '\33[97m',
}
COLORS_STYLE = {
'RESET': '\33[0m',
'BOLD': '\33[1m',
'NORMAL': '\33[22m',
'ITALIC': '\33[3m',
'UNDERLINE': '\33[4m',
'BLINK': '\33[5m',
'BLINK2': '\33[6m',
'SELECTED': '\33[7m',
}
COLORS_BG = {
'BLACK': '\33[40m',
'RED': '\33[41m',
'GREEN': '\33[42m',
'YELLOW': '\33[43m',
'BLUE': '\33[44m',
'MAGENTA': '\33[45m',
'CYAN': '\33[46m',
'WHITE': '\33[47m',
'GREY': '\33[100m',
'BRIGHT_RED': '\33[101m',
'BRIGHT_GREEN': '\33[102m',
'BRIGHT_YELLOW': '\33[103m',
'BRIGHT_BLUE': '\33[104m',
'BRIGHT_MAGENTA': '\33[105m',
'BRIGHT_CYAN': '\33[106m',
'BRIGHT_WHITE': '\33[107m',
}
def log_node_success(node_name, message=None):
"""Logs a success message."""
_log_node(COLORS_FG["GREEN"], node_name, message)
def log_node_info(node_name, message=None):
"""Logs an info message."""
_log_node(COLORS_FG["CYAN"], node_name, message)
def log_node_warn(node_name, message=None):
"""Logs an warn message."""
_log_node(COLORS_FG["YELLOW"], node_name, message)
def log_node_error(node_name, message=None):
"""Logs an warn message."""
_log_node(COLORS_FG["RED"], node_name, message)
def log_node(node_name, message=None):
"""Logs a message."""
_log_node(COLORS_FG["CYAN"], node_name, message)
def _log_node(color, node_name, message=None, prefix=''):
print(_get_log_msg(color, node_name, message, prefix=prefix))
def _get_log_msg(color, node_name, message=None, prefix=''):
msg = f'{COLORS_STYLE["BOLD"]}{color}{prefix}[EasyUse] {node_name.replace(" (EasyUse)", "")}'
msg += f':{COLORS_STYLE["RESET"]} {message}' if message is not None else f'{COLORS_STYLE["RESET"]}'
return msg

View File

@@ -0,0 +1,133 @@
"""
Math utility functions for formula evaluation
"""
import math
import re
def evaluate_formula(formula: str, a=0, b=0, c=0, d=0) -> float:
"""
计算字符串数学公式
支持的运算符和函数:
- 基本运算:+, -, *, /, //, %, **
- 比较运算:>, <, >=, <=, ==, !=
- 数学函数abs, pow, round, ceil, floor, sqrt, exp, log, log10
- 三角函数sin, cos, tan, asin, acos, atan
- 常量pi, e
Args:
formula: 数学公式字符串可以使用变量a、b、c、d
a: 变量a的值
b: 变量b的值
c: 变量c的值
d: 变量d的值
Returns:
计算结果
Examples:
>>> evaluate_formula("a + b", 1, 2)
3.0
>>> evaluate_formula("pow(a, 2)", 5)
25.0
>>> evaluate_formula("ceil(a / b)", 5, 2)
3.0
>>> evaluate_formula("(a>b)*b+(a<=b)*a", 5, 3)
3.0
>>> evaluate_formula("(a>b)*b+(a<=b)*a", 2, 3)
2.0
"""
# 安全的数学函数白名单
safe_dict = {
# 基本运算
'abs': abs,
'pow': pow,
'round': round,
# 数学函数
'ceil': math.ceil,
'floor': math.floor,
'sqrt': math.sqrt,
'exp': math.exp,
'log': math.log,
'log10': math.log10,
# 三角函数
'sin': math.sin,
'cos': math.cos,
'tan': math.tan,
'asin': math.asin,
'acos': math.acos,
'atan': math.atan,
# 常量
'pi': math.pi,
'e': math.e,
# 变量
'a': float(a),
'b': float(b),
'c': float(c),
'd': float(d),
}
try:
# 使用eval计算公式限制可用的函数和变量
result = eval(formula, {"__builtins__": {}}, safe_dict)
return float(result)
except Exception as e:
raise ValueError(f"公式计算错误: {str(e)}")
def ceil_value(value: float) -> int:
"""向上取整"""
return math.ceil(value)
def floor_value(value: float) -> int:
"""向下取整"""
return math.floor(value)
def round_value(value: float, decimals: int = 0) -> float:
"""
四舍五入
Args:
value: 要取整的值
decimals: 保留小数位数
Returns:
四舍五入后的值
"""
return round(value, decimals)
def power(base: float, exponent: float) -> float:
"""计算幂运算"""
return math.pow(base, exponent)
def sqrt_value(value: float) -> float:
"""计算平方根"""
if value < 0:
raise ValueError("不能对负数求平方根")
return math.sqrt(value)
def add(a: float, b: float) -> float:
"""加法"""
return a + b
def subtract(a: float, b: float) -> float:
"""减法"""
return a - b
def multiply(a: float, b: float) -> float:
"""乘法"""
return a * b
def divide(a: float, b: float) -> float:
"""除法"""
if b == 0:
raise ValueError("除数不能为零")
return a / b

View File

@@ -0,0 +1,55 @@
from server import PromptServer
from aiohttp import web
import time
import json
class MessageCancelled(Exception):
pass
class Message:
stash = {}
messages = {}
cancelled = False
@classmethod
def addMessage(cls, id, message):
if message == '__cancel__':
cls.messages = {}
cls.cancelled = True
elif message == '__start__':
cls.messages = {}
cls.stash = {}
cls.cancelled = False
else:
cls.messages[str(id)] = message
@classmethod
def waitForMessage(cls, id, period=0.1, asList=False):
sid = str(id)
while not (sid in cls.messages) and not ("-1" in cls.messages):
if cls.cancelled:
cls.cancelled = False
raise MessageCancelled()
time.sleep(period)
if cls.cancelled:
cls.cancelled = False
raise MessageCancelled()
message = cls.messages.pop(str(id), None) or cls.messages.pop("-1")
try:
if asList:
return [str(x.strip()) for x in message.split(",")]
else:
try:
return json.loads(message)
except ValueError:
return message
except ValueError:
print( f"ERROR IN MESSAGE - failed to parse '${message}' as ${'comma separated list of strings' if asList else 'string'}")
return [message] if asList else message
@PromptServer.instance.routes.post('/easyuse/message_callback')
async def message_callback(request):
post = await request.post()
Message.addMessage(post.get("id"), post.get("message"))
return web.json_response({})

View File

@@ -0,0 +1,58 @@
import json
import os
import folder_paths
import server
from .utils import find_tags
class easyModelManager:
def __init__(self):
self.img_suffixes = [".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".tiff", ".svg", ".tif", ".tiff"]
self.default_suffixes = [".ckpt", ".pt", ".bin", ".pth", ".safetensors"]
self.models_config = {
"checkpoints": {"suffix": self.default_suffixes},
"loras": {"suffix": self.default_suffixes},
"unet": {"suffix": self.default_suffixes},
}
self.model_lists = {}
def find_thumbnail(self, model_type, name):
file_no_ext = os.path.splitext(name)[0]
for ext in self.img_suffixes:
full_path = folder_paths.get_full_path(model_type, file_no_ext + ext)
if os.path.isfile(str(full_path)):
return full_path
return None
def get_model_lists(self, model_type):
if model_type not in self.models_config:
return []
filenames = folder_paths.get_filename_list(model_type)
model_lists = []
for name in filenames:
model_suffix = os.path.splitext(name)[-1]
if model_suffix not in self.models_config[model_type]["suffix"]:
continue
else:
cfg = {
"name": os.path.basename(os.path.splitext(name)[0]),
"full_name": name,
"remark": '',
"file_path": folder_paths.get_full_path(model_type, name),
"type": model_type,
"suffix": model_suffix,
"dir_tags": find_tags(name),
"cover": self.find_thumbnail(model_type, name),
"metadata": None,
"sha256": None
}
model_lists.append(cfg)
return model_lists
def get_model_info(self, model_type, model_name):
pass
# if __name__ == "__main__":
# manager = easyModelManager()
# print(manager.get_model_lists("checkpoints"))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,148 @@
import torch
import torch.nn as nn
from comfy.model_patcher import ModelPatcher
from typing import Union
T = torch.Tensor
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d
class StyleAlignedArgs:
def __init__(self, share_attn: str) -> None:
self.adain_keys = "k" in share_attn
self.adain_values = "v" in share_attn
self.adain_queries = "q" in share_attn
share_attention: bool = True
adain_queries: bool = True
adain_keys: bool = True
adain_values: bool = True
def expand_first(
feat: T,
scale=1.0,
) -> T:
"""
Expand the first element so it has the same shape as the rest of the batch.
"""
b = feat.shape[0]
feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)
if scale == 1:
feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])
else:
feat_style = feat_style.repeat(1, b // 2, 1, 1, 1)
feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1)
return feat_style.reshape(*feat.shape)
def concat_first(feat: T, dim=2, scale=1.0) -> T:
"""
concat the the feature and the style feature expanded above
"""
feat_style = expand_first(feat, scale=scale)
return torch.cat((feat, feat_style), dim=dim)
def calc_mean_std(feat, eps: float = 1e-5) -> "tuple[T, T]":
feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()
feat_mean = feat.mean(dim=-2, keepdims=True)
return feat_mean, feat_std
def adain(feat: T) -> T:
feat_mean, feat_std = calc_mean_std(feat)
feat_style_mean = expand_first(feat_mean)
feat_style_std = expand_first(feat_std)
feat = (feat - feat_mean) / feat_std
feat = feat * feat_style_std + feat_style_mean
return feat
class SharedAttentionProcessor:
def __init__(self, args: StyleAlignedArgs, scale: float):
self.args = args
self.scale = scale
def __call__(self, q, k, v, extra_options):
if self.args.adain_queries:
q = adain(q)
if self.args.adain_keys:
k = adain(k)
if self.args.adain_values:
v = adain(v)
if self.args.share_attention:
k = concat_first(k, -2, scale=self.scale)
v = concat_first(v, -2)
return q, k, v
def get_norm_layers(
layer: nn.Module,
norm_layers_: "dict[str, list[Union[nn.GroupNorm, nn.LayerNorm]]]",
share_layer_norm: bool,
share_group_norm: bool,
):
if isinstance(layer, nn.LayerNorm) and share_layer_norm:
norm_layers_["layer"].append(layer)
if isinstance(layer, nn.GroupNorm) and share_group_norm:
norm_layers_["group"].append(layer)
else:
for child_layer in layer.children():
get_norm_layers(
child_layer, norm_layers_, share_layer_norm, share_group_norm
)
def register_norm_forward(
norm_layer: Union[nn.GroupNorm, nn.LayerNorm],
) -> Union[nn.GroupNorm, nn.LayerNorm]:
if not hasattr(norm_layer, "orig_forward"):
setattr(norm_layer, "orig_forward", norm_layer.forward)
orig_forward = norm_layer.orig_forward
def forward_(hidden_states: T) -> T:
n = hidden_states.shape[-2]
hidden_states = concat_first(hidden_states, dim=-2)
hidden_states = orig_forward(hidden_states) # type: ignore
return hidden_states[..., :n, :]
norm_layer.forward = forward_ # type: ignore
return norm_layer
def register_shared_norm(
model: ModelPatcher,
share_group_norm: bool = True,
share_layer_norm: bool = True,
):
norm_layers = {"group": [], "layer": []}
get_norm_layers(model.model, norm_layers, share_layer_norm, share_group_norm)
print(
f"Patching {len(norm_layers['group'])} group norms, {len(norm_layers['layer'])} layer norms."
)
return [register_norm_forward(layer) for layer in norm_layers["group"]] + [
register_norm_forward(layer) for layer in norm_layers["layer"]
]
SHARE_NORM_OPTIONS = ["both", "group", "layer", "disabled"]
SHARE_ATTN_OPTIONS = ["q+k", "q+k+v", "disabled"]
def styleAlignBatch(model, share_norm, share_attn, scale=1.0):
m = model.clone()
share_group_norm = share_norm in ["group", "both"]
share_layer_norm = share_norm in ["layer", "both"]
register_shared_norm(model, share_group_norm, share_layer_norm)
args = StyleAlignedArgs(share_attn)
m.set_model_attn1_patch(SharedAttentionProcessor(args, scale))
return m

View File

@@ -0,0 +1,247 @@
#credit to shadowcz007 for this module
#from https://github.com/shadowcz007/comfyui-mixlab-nodes/blob/main/nodes/TextGenerateNode.py
import re
import os
import folder_paths
import comfy.utils
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from .utils import install_package
try:
from lark import Lark, Transformer, v_args
except:
print('install lark...')
install_package('lark')
from lark import Lark, Transformer, v_args
model_path = os.path.join(folder_paths.models_dir, 'prompt_generator')
zh_en_model_path = os.path.join(model_path, 'opus-mt-zh-en')
zh_en_model, zh_en_tokenizer = None, None
def correct_prompt_syntax(prompt=""):
# print("input prompt",prompt)
corrected_elements = []
# 处理成统一的英文标点
prompt = prompt.replace('', '(').replace('', ')').replace('', ',').replace(';', ',').replace('', '.').replace('',':').replace('\\',',')
# 删除多余的空格
prompt = re.sub(r'\s+', ' ', prompt).strip()
prompt = prompt.replace("< ","<").replace(" >",">").replace("( ","(").replace(" )",")").replace("[ ","[").replace(' ]',']')
# 分词
prompt_elements = prompt.split(',')
def balance_brackets(element, open_bracket, close_bracket):
open_brackets_count = element.count(open_bracket)
close_brackets_count = element.count(close_bracket)
return element + close_bracket * (open_brackets_count - close_brackets_count)
for element in prompt_elements:
element = element.strip()
# 处理空元素
if not element:
continue
# 检查并处理圆括号、方括号、尖括号
if element[0] in '([':
corrected_element = balance_brackets(element, '(', ')') if element[0] == '(' else balance_brackets(element, '[', ']')
elif element[0] == '<':
corrected_element = balance_brackets(element, '<', '>')
else:
# 删除开头的右括号或右方括号
corrected_element = element.lstrip(')]')
corrected_elements.append(corrected_element)
# 重组修正后的prompt
return ','.join(corrected_elements)
def detect_language(input_str):
# 统计中文和英文字符的数量
count_cn = count_en = 0
for char in input_str:
if '\u4e00' <= char <= '\u9fff':
count_cn += 1
elif char.isalpha():
count_en += 1
# 根据统计的字符数量判断主要语言
if count_cn > count_en:
return "cn"
elif count_en > count_cn:
return "en"
else:
return "unknow"
def has_chinese(text):
has_cn = False
_text = text
_text = re.sub(r'<.*?>', '', _text)
_text = re.sub(r'__.*?__', '', _text)
_text = re.sub(r'embedding:.*?$', '', _text)
for char in _text:
if '\u4e00' <= char <= '\u9fff':
has_cn = True
break
elif char.isalpha():
continue
return has_cn
def translate(text):
global zh_en_model_path, zh_en_model, zh_en_tokenizer
if not os.path.exists(zh_en_model_path):
zh_en_model_path = 'Helsinki-NLP/opus-mt-zh-en'
if zh_en_model is None:
zh_en_model = AutoModelForSeq2SeqLM.from_pretrained(zh_en_model_path).eval()
zh_en_tokenizer = AutoTokenizer.from_pretrained(zh_en_model_path, padding=True, truncation=True)
zh_en_model.to("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
encoded = zh_en_tokenizer([text], return_tensors="pt")
encoded.to(zh_en_model.device)
sequences = zh_en_model.generate(**encoded)
return zh_en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
@v_args(inline=True) # Decorator to flatten the tree directly into the function arguments
class ChinesePromptTranslate(Transformer):
def sentence(self, *args):
return ", ".join(args)
def phrase(self, *args):
return "".join(args)
def emphasis(self, *args):
# Reconstruct the emphasis with translated content
return "(" + "".join(args) + ")"
def weak_emphasis(self, *args):
print('weak_emphasis:', args)
return "[" + "".join(args) + "]"
def embedding(self, *args):
print('prompt embedding', args[0])
if len(args) == 1:
embedding_name = str(args[0])
return f"embedding:{embedding_name}"
elif len(args) > 1:
embedding_name, *numbers = args
if len(numbers) == 2:
return f"embedding:{embedding_name}:{numbers[0]}:{numbers[1]}"
elif len(numbers) == 1:
return f"embedding:{embedding_name}:{numbers[0]}"
else:
return f"embedding:{embedding_name}"
def lora(self, *args):
if len(args) == 1:
return f"<lora:{args[0]}>"
elif len(args) > 1:
# print('lora', args)
_, loar_name, *numbers = args
loar_name = str(loar_name).strip()
if len(numbers) == 2:
return f"<lora:{loar_name}:{numbers[0]}:{numbers[1]}>"
elif len(numbers) == 1:
return f"<lora:{loar_name}:{numbers[0]}>"
else:
return f"<lora:{loar_name}>"
def weight(self, word, number):
translated_word = translate(str(word)).rstrip('.')
return f"({translated_word}:{str(number).strip()})"
def schedule(self, *args):
print('prompt schedule', args)
data = [str(arg).strip() for arg in args]
return f"[{':'.join(data)}]"
def word(self, word):
# Translate each word using the dictionary
word = str(word)
match_cn = re.search(r'@.*?@', word)
if re.search(r'__.*?__', word):
return word.rstrip('.')
elif match_cn:
chinese = match_cn.group()
before = word.split('@', 1)
before = before[0] if len(before) > 0 else ''
before = translate(str(before)).rstrip('.') if before else ''
after = word.rsplit('@', 1)
after = after[len(after)-1] if len(after) > 1 else ''
after = translate(after).rstrip('.') if after else ''
return before + chinese.replace('@', '').rstrip('.') + after
elif detect_language(word) == "cn":
return translate(word).rstrip('.')
else:
return word.rstrip('.')
#定义Prompt文法
grammar = r"""
start: sentence
sentence: phrase ("," phrase)*
phrase: emphasis | weight | word | lora | embedding | schedule
emphasis: "(" sentence ")" -> emphasis
| "[" sentence "]" -> weak_emphasis
weight: "(" word ":" NUMBER ")"
schedule: "[" word ":" word ":" NUMBER "]"
lora: "<" WORD ":" WORD (":" NUMBER)? (":" NUMBER)? ">"
embedding: "embedding" ":" WORD (":" NUMBER)? (":" NUMBER)?
word: WORD
NUMBER: /\s*-?\d+(\.\d+)?\s*/
WORD: /[^,:\(\)\[\]<>]+/
"""
def zh_to_en(text):
global zh_en_model_path, zh_en_model, zh_en_tokenizer
# 进度条
pbar = comfy.utils.ProgressBar(len(text) + 1)
texts = [correct_prompt_syntax(t) for t in text]
install_package('sentencepiece', '0.2.0')
if not os.path.exists(zh_en_model_path):
zh_en_model_path = 'Helsinki-NLP/opus-mt-zh-en'
if zh_en_model is None:
zh_en_model = AutoModelForSeq2SeqLM.from_pretrained(zh_en_model_path).eval()
zh_en_tokenizer = AutoTokenizer.from_pretrained(zh_en_model_path, padding=True, truncation=True)
zh_en_model.to("cuda" if torch.cuda.is_available() else "cpu")
prompt_result = []
en_texts = []
for t in texts:
if t:
# translated_text = translated_word = translate(zh_en_tokenizer,zh_en_model,str(t))
parser = Lark(grammar, start="start", parser="lalr", transformer=ChinesePromptTranslate())
# print('t',t)
result = parser.parse(t).children
# print('en_result',result)
# en_text=translate(zh_en_tokenizer,zh_en_model,text_without_syntax)
en_texts.append(result[0])
zh_en_model.to('cpu')
# print("test en_text", en_texts)
# en_text.to("cuda" if torch.cuda.is_available() else "cpu")
pbar.update(1)
for t in en_texts:
prompt_result.append(t)
pbar.update(1)
# print('prompt_result', prompt_result, )
if len(prompt_result) == 0:
prompt_result = [""]
return prompt_result

View File

@@ -0,0 +1,282 @@
class AlwaysEqualProxy(str):
def __eq__(self, _):
return True
def __ne__(self, _):
return False
class TautologyStr(str):
def __ne__(self, other):
return False
class ByPassTypeTuple(tuple):
def __getitem__(self, index):
if index>0:
index=0
item = super().__getitem__(index)
if isinstance(item, str):
return TautologyStr(item)
return item
comfy_ui_revision = None
def get_comfyui_revision():
try:
import git
import os
import folder_paths
repo = git.Repo(os.path.dirname(folder_paths.__file__))
comfy_ui_revision = len(list(repo.iter_commits('HEAD')))
except:
comfy_ui_revision = "Unknown"
return comfy_ui_revision
import sys
import importlib.util
import importlib.metadata
import comfy.model_management as mm
import gc
from packaging import version
from server import PromptServer
def is_package_installed(package):
try:
module = importlib.util.find_spec(package)
return module is not None
except ImportError as e:
print(e)
return False
def install_package(package, v=None, compare=True, compare_version=None):
run_install = True
if is_package_installed(package):
try:
installed_version = importlib.metadata.version(package)
if v is not None:
if compare_version is None:
compare_version = v
if not compare or version.parse(installed_version) >= version.parse(compare_version):
run_install = False
else:
run_install = False
except:
run_install = False
if run_install:
import subprocess
package_command = package + '==' + v if v is not None else package
PromptServer.instance.send_sync("easyuse-toast", {'content': f"Installing {package_command}...", 'duration': 5000})
result = subprocess.run([sys.executable, '-s', '-m', 'pip', 'install', package_command], capture_output=True, text=True)
if result.returncode == 0:
PromptServer.instance.send_sync("easyuse-toast", {'content': f"{package} installed successfully", 'type': 'success', 'duration': 5000})
print(f"Package {package} installed successfully")
return True
else:
PromptServer.instance.send_sync("easyuse-toast", {'content': f"{package} installed failed", 'type': 'error', 'duration': 5000})
print(f"Package {package} installed failed")
return False
else:
return False
def compare_revision(num):
global comfy_ui_revision
if not comfy_ui_revision:
comfy_ui_revision = get_comfyui_revision()
return True if comfy_ui_revision == 'Unknown' or int(comfy_ui_revision) >= num else False
def find_tags(string: str, sep="/") -> list[str]:
"""
find tags from string use the sep for split
Note: string may contain the \\ or / for path separator
"""
if not string:
return []
string = string.replace("\\", "/")
while "//" in string:
string = string.replace("//", "/")
if string and sep in string:
return string.split(sep)[:-1]
return []
from comfy.model_base import BaseModel
import comfy.supported_models
import comfy.supported_models_base
def get_sd_version(model):
base: BaseModel = model.model
model_config: comfy.supported_models.supported_models_base.BASE = base.model_config
if isinstance(model_config, comfy.supported_models.SDXL):
return 'sdxl'
elif isinstance(model_config, comfy.supported_models.SDXLRefiner):
return 'sdxl_refiner'
elif isinstance(
model_config, (comfy.supported_models.SD15, comfy.supported_models.SD20)
):
return 'sd1'
elif isinstance(
model_config, (comfy.supported_models.SVD_img2vid)
):
return 'svd'
elif isinstance(model_config, comfy.supported_models.SD3):
return 'sd3'
elif isinstance(model_config, comfy.supported_models.HunyuanDiT):
return 'hydit'
elif isinstance(model_config, comfy.supported_models.Flux):
return 'flux'
elif isinstance(model_config, comfy.supported_models.GenmoMochi):
return 'mochi'
else:
return 'unknown'
def find_nearest_steps(clip_id, prompt):
"""Find the nearest KSampler or preSampling node that references the given id."""
def check_link_to_clip(node_id, clip_id, visited=None, node=None):
"""Check if a given node links directly or indirectly to a loader node."""
if visited is None:
visited = set()
if node_id in visited:
return False
visited.add(node_id)
if "pipe" in node["inputs"]:
link_ids = node["inputs"]["pipe"]
for id in link_ids:
if id != 0 and id == str(clip_id):
return True
return False
for id in prompt:
node = prompt[id]
if "Sampler" in node["class_type"] or "sampler" in node["class_type"] or "Sampling" in node["class_type"]:
# Check if this KSampler node directly or indirectly references the given CLIPTextEncode node
if check_link_to_clip(id, clip_id, None, node):
steps = node["inputs"]["steps"] if "steps" in node["inputs"] else 1
return steps
return 1
def find_wildcards_seed(clip_id, text, prompt):
""" Find easy wildcards seed value"""
def find_link_clip_id(id, seed, wildcard_id):
node = prompt[id]
if "positive" in node['inputs']:
link_ids = node["inputs"]["positive"]
if type(link_ids) == list:
for id in link_ids:
if id != 0:
if id == wildcard_id:
wildcard_node = prompt[wildcard_id]
seed = wildcard_node["inputs"]["seed"] if "seed" in wildcard_node["inputs"] else None
if seed is None:
seed = wildcard_node["inputs"]["seed_num"] if "seed_num" in wildcard_node["inputs"] else None
return seed
else:
return find_link_clip_id(id, seed, wildcard_id)
else:
return None
else:
return None
if "__" in text:
seed = None
for id in prompt:
node = prompt[id]
if "wildcards" in node["class_type"]:
wildcard_id = id
return find_link_clip_id(str(clip_id), seed, wildcard_id)
return seed
else:
return None
def is_linked_styles_selector(prompt, unique_id, prompt_type='positive'):
unique_id = unique_id.split('.')[len(unique_id.split('.')) - 1] if "." in unique_id else unique_id
inputs_values = prompt[unique_id]['inputs'][prompt_type] if prompt_type in prompt[unique_id][
'inputs'] else None
if type(inputs_values) == list and inputs_values != 'undefined' and inputs_values[0]:
return True if prompt[inputs_values[0]] and prompt[inputs_values[0]]['class_type'] == 'easy stylesSelector' else False
else:
return False
use_mirror = False
def get_local_filepath(url, dirname, local_file_name=None):
"""Get local file path when is already downloaded or download it"""
import os
from server import PromptServer
from urllib.parse import urlparse
from torch.hub import download_url_to_file
global use_mirror
if not os.path.exists(dirname):
os.makedirs(dirname)
if not local_file_name:
parsed_url = urlparse(url)
local_file_name = os.path.basename(parsed_url.path)
destination = os.path.join(dirname, local_file_name)
if not os.path.exists(destination):
try:
if use_mirror:
url = url.replace('huggingface.co', 'hf-mirror.com')
print(f'downloading {url} to {destination}')
PromptServer.instance.send_sync("easyuse-toast", {'content': f'Downloading model to {destination}, please wait...', 'duration': 10000})
download_url_to_file(url, destination)
except Exception as e:
use_mirror = True
url = url.replace('huggingface.co', 'hf-mirror.com')
print(f'Unable to download from huggingface, trying mirror: {url}')
PromptServer.instance.send_sync("easyuse-toast", {'content': f'Unable to connect to huggingface, trying mirror: {url}', 'duration': 10000})
try:
download_url_to_file(url, destination)
except Exception as err:
error_msg = str(err.args[0]) if err.args else str(err)
PromptServer.instance.send_sync("easyuse-toast",
{'content': f'Unable to download model from {url}', 'type':'error'})
raise Exception(f'Download failed. Original URL and mirror both failed.\nError: {error_msg}')
return destination
def to_lora_patch_dict(state_dict: dict) -> dict:
""" Convert raw lora state_dict to patch_dict that can be applied on
modelpatcher."""
patch_dict = {}
for k, w in state_dict.items():
model_key, patch_type, weight_index = k.split('::')
if model_key not in patch_dict:
patch_dict[model_key] = {}
if patch_type not in patch_dict[model_key]:
patch_dict[model_key][patch_type] = [None] * 16
patch_dict[model_key][patch_type][int(weight_index)] = w
patch_flat = {}
for model_key, v in patch_dict.items():
for patch_type, weight_list in v.items():
patch_flat[model_key] = (patch_type, weight_list)
return patch_flat
def easySave(images, filename_prefix, output_type, prompt=None, extra_pnginfo=None):
"""Save or Preview Image"""
from nodes import PreviewImage, SaveImage
if output_type in ["Hide", "None"]:
return list()
elif output_type in ["Preview", "Preview&Choose"]:
filename_prefix = 'easyPreview'
results = PreviewImage().save_images(images, filename_prefix, prompt, extra_pnginfo)
return results['ui']['images']
else:
results = SaveImage().save_images(images, filename_prefix, prompt, extra_pnginfo)
return results['ui']['images']
def getMetadata(filepath):
with open(filepath, "rb") as file:
# https://github.com/huggingface/safetensors#format
# 8 bytes: N, an unsigned little-endian 64-bit integer, containing the size of the header
header_size = int.from_bytes(file.read(8), "little", signed=False)
if header_size <= 0:
raise BufferError("Invalid header size")
header = file.read(header_size)
if header_size <= 0:
raise BufferError("Invalid header")
return header
def cleanGPUUsedForce():
gc.collect()
mm.unload_all_models()
mm.soft_empty_cache()

View File

@@ -0,0 +1,476 @@
import json
import os
import random
import re
from math import prod
import yaml
import folder_paths
from .log import log_node_info
easy_wildcard_dict = {}
def get_wildcard_list():
return [f"__{x}__" for x in easy_wildcard_dict.keys()]
def wildcard_normalize(x):
return x.replace("\\", "/").lower()
def read_wildcard(k, v):
if isinstance(v, list):
k = wildcard_normalize(k)
easy_wildcard_dict[k] = v
elif isinstance(v, dict):
for k2, v2 in v.items():
new_key = f"{k}/{k2}"
new_key = wildcard_normalize(new_key)
read_wildcard(new_key, v2)
def read_wildcard_dict(wildcard_path):
global easy_wildcard_dict
for root, directories, files in os.walk(wildcard_path, followlinks=True):
for file in files:
if file.endswith('.txt'):
file_path = os.path.join(root, file)
rel_path = os.path.relpath(file_path, wildcard_path)
key = os.path.splitext(rel_path)[0].replace('\\', '/').lower()
try:
with open(file_path, 'r', encoding="UTF-8", errors="ignore") as f:
lines = f.read().splitlines()
easy_wildcard_dict[key] = lines
except UnicodeDecodeError:
with open(file_path, 'r', encoding="ISO-8859-1") as f:
lines = f.read().splitlines()
easy_wildcard_dict[key] = lines
elif file.endswith('.yaml'):
file_path = os.path.join(root, file)
with open(file_path, 'r') as f:
yaml_data = yaml.load(f, Loader=yaml.FullLoader)
for k, v in yaml_data.items():
read_wildcard(k, v)
elif file.endswith('.json'):
file_path = os.path.join(root, file)
try:
with open(file_path, 'r') as f:
json_data = json.load(f)
for key, value in json_data.items():
key = wildcard_normalize(key)
easy_wildcard_dict[key] = value
except ValueError:
print('json files load error')
return easy_wildcard_dict
def process(text, seed=None):
if seed is not None:
random.seed(seed)
def replace_options(string):
replacements_found = False
def replace_option(match):
nonlocal replacements_found
options = match.group(1).split('|')
multi_select_pattern = options[0].split('$$')
select_range = None
select_sep = ' '
range_pattern = r'(\d+)(-(\d+))?'
range_pattern2 = r'-(\d+)'
if len(multi_select_pattern) > 1:
r = re.match(range_pattern, options[0])
if r is None:
r = re.match(range_pattern2, options[0])
a = '1'
b = r.group(1).strip()
else:
a = r.group(1).strip()
b = r.group(3).strip()
if r is not None:
if b is not None and is_numeric_string(a) and is_numeric_string(b):
# PATTERN: num1-num2
select_range = int(a), int(b)
elif is_numeric_string(a):
# PATTERN: num
x = int(a)
select_range = (x, x)
if select_range is not None and len(multi_select_pattern) == 2:
# PATTERN: count$$
options[0] = multi_select_pattern[1]
elif select_range is not None and len(multi_select_pattern) == 3:
# PATTERN: count$$ sep $$
select_sep = multi_select_pattern[1]
options[0] = multi_select_pattern[2]
adjusted_probabilities = []
total_prob = 0
for option in options:
parts = option.split('::', 1)
if len(parts) == 2 and is_numeric_string(parts[0].strip()):
config_value = float(parts[0].strip())
else:
config_value = 1 # Default value if no configuration is provided
adjusted_probabilities.append(config_value)
total_prob += config_value
normalized_probabilities = [prob / total_prob for prob in adjusted_probabilities]
if select_range is None:
select_count = 1
else:
select_count = random.randint(select_range[0], select_range[1])
if select_count > len(options):
selected_items = options
else:
selected_items = random.choices(options, weights=normalized_probabilities, k=select_count)
selected_items = set(selected_items)
try_count = 0
while len(selected_items) < select_count and try_count < 10:
remaining_count = select_count - len(selected_items)
additional_items = random.choices(options, weights=normalized_probabilities, k=remaining_count)
selected_items |= set(additional_items)
try_count += 1
selected_items2 = [re.sub(r'^\s*[0-9.]+::', '', x, 1) for x in selected_items]
replacement = select_sep.join(selected_items2)
if '::' in replacement:
pass
replacements_found = True
return replacement
pattern = r'{([^{}]*?)}'
replaced_string = re.sub(pattern, replace_option, string)
return replaced_string, replacements_found
def replace_wildcard(string):
global easy_wildcard_dict
pattern = r"__([\w\s.\-+/*\\]+?)__"
matches = re.findall(pattern, string)
replacements_found = False
for match in matches:
keyword = match.lower()
keyword = wildcard_normalize(keyword)
if keyword in easy_wildcard_dict:
replacement = random.choice(easy_wildcard_dict[keyword])
replacements_found = True
string = string.replace(f"__{match}__", replacement, 1)
elif '*' in keyword:
subpattern = keyword.replace('*', '.*').replace('+', r'\+')
total_patterns = []
found = False
for k, v in easy_wildcard_dict.items():
if re.match(subpattern, k) is not None:
total_patterns += v
found = True
if found:
replacement = random.choice(total_patterns)
replacements_found = True
string = string.replace(f"__{match}__", replacement, 1)
elif '/' not in keyword:
string_fallback = string.replace(f"__{match}__", f"__*/{match}__", 1)
string, replacements_found = replace_wildcard(string_fallback)
return string, replacements_found
replace_depth = 100
stop_unwrap = False
while not stop_unwrap and replace_depth > 1:
replace_depth -= 1 # prevent infinite loop
# pass1: replace options
pass1, is_replaced1 = replace_options(text)
while is_replaced1:
pass1, is_replaced1 = replace_options(pass1)
# pass2: replace wildcards
text, is_replaced2 = replace_wildcard(pass1)
stop_unwrap = not is_replaced1 and not is_replaced2
return text
def is_numeric_string(input_str):
return re.match(r'^-?\d+(\.\d+)?$', input_str) is not None
def safe_float(x):
if is_numeric_string(x):
return float(x)
else:
return 1.0
def extract_lora_values(string):
pattern = r'<lora:([^>]+)>'
matches = re.findall(pattern, string)
def touch_lbw(text):
return re.sub(r'LBW=[A-Za-z][A-Za-z0-9_-]*:', r'LBW=', text)
items = [touch_lbw(match.strip(':')) for match in matches]
added = set()
result = []
for item in items:
item = item.split(':')
lora = None
a = None
b = None
lbw = None
lbw_a = None
lbw_b = None
if len(item) > 0:
lora = item[0]
for sub_item in item[1:]:
if is_numeric_string(sub_item):
if a is None:
a = float(sub_item)
elif b is None:
b = float(sub_item)
elif sub_item.startswith("LBW="):
for lbw_item in sub_item[4:].split(';'):
if lbw_item.startswith("A="):
lbw_a = safe_float(lbw_item[2:].strip())
elif lbw_item.startswith("B="):
lbw_b = safe_float(lbw_item[2:].strip())
elif lbw_item.strip() != '':
lbw = lbw_item
if a is None:
a = 1.0
if b is None:
b = 1.0
if lora is not None and lora not in added:
result.append((lora, a, b, lbw, lbw_a, lbw_b))
added.add(lora)
return result
def remove_lora_tags(string):
pattern = r'<lora:[^>]+>'
result = re.sub(pattern, '', string)
return result
def process_with_loras(wildcard_opt, model, clip, title="Positive", seed=None, can_load_lora=True, pipe_lora_stack=[], easyCache=None):
pass1 = process(wildcard_opt, seed)
loras = extract_lora_values(pass1)
pass2 = remove_lora_tags(pass1)
has_noodle_key = True if "__" in wildcard_opt else False
has_loras = True if loras != [] else False
show_wildcard_prompt = True if has_noodle_key or has_loras else False
if can_load_lora and has_loras:
for lora_name, model_weight, clip_weight, lbw, lbw_a, lbw_b in loras:
if (lora_name.split('.')[-1]) not in folder_paths.supported_pt_extensions:
lora_name = lora_name+".safetensors"
lora = {
"lora_name": lora_name, "model": model, "clip": clip, "model_strength": model_weight,
"clip_strength": clip_weight,
"lbw_a": lbw_a,
"lbw_b": lbw_b,
"lbw": lbw
}
model, clip = easyCache.load_lora(lora)
lora["model"] = model
lora["clip"] = clip
pipe_lora_stack.append(lora)
log_node_info("easy wildcards",f"{title}: {pass2}")
if pass1 != pass2:
log_node_info("easy wildcards",f'{title}_decode: {pass1}')
return model, clip, pass2, pass1, show_wildcard_prompt, pipe_lora_stack
def expand_wildcard(keyword: str) -> tuple[str]:
"""传入文件通配符的关键词,从 easy_wildcard_dict 中获取通配符的所有选项。"""
global easy_wildcard_dict
if keyword in easy_wildcard_dict:
return tuple(easy_wildcard_dict[keyword])
elif '*' in keyword:
subpattern = keyword.replace('*', '.*').replace('+', r"\+")
total_pattern = []
for k, v in easy_wildcard_dict.items():
if re.match(subpattern, k) is not None:
total_pattern.extend(v)
if total_pattern:
return tuple(total_pattern)
elif '/' not in keyword:
return expand_wildcard(f"*/{keyword}")
def expand_options(options: str) -> tuple[str]:
"""传入去掉 {} 的选项。
展开选项通配符,返回该选项中的每一项,这里的每一项都是一个替换项。
不会对选项内容进行任何处理,即便存在空格或特殊符号,也会原样返回。"""
return tuple(options.split("|"))
def decimal_to_irregular(n, bases):
"""
将十进制数转换为不规则进制
:param n: 十进制数
:param bases: 各位置的基数列表,从低位到高位
:return: 不规则进制表示的列表,从低位到高位
"""
if n == 0:
return [0] * len(bases) if bases else [0]
digits = []
remaining = n
# 从低位到高位处理
for base in bases:
digit = remaining % base
digits.append(digit)
remaining = remaining // base
return digits
class WildcardProcessor:
"""通配符处理器
通配符格式:
+ option : {a|b}
+ wildcard: __keyword__ 通配符内容将从 Easy-Use 插件提供的 easy_wildcard_dict 中获取
"""
RE_OPTIONS = re.compile(r"{([^{}]*?)}")
RE_WILDCARD = re.compile(r"__([\w\s.\-+/*\\]+?)__")
RE_REPLACER = re.compile(r"{([^{}]*?)}|__([\w\s.\-+/*\\]+?)__")
# 将输入的提示词转化成符合 python str.format 要求格式的模板,并将 option 和 wildcard 按照顺序在模板中留下 {0}, {1} 等占位符
template: str
# option、wildcard 的替换项列表,按照在模板中出现的顺序排列,相同的替换项列表只保留第一份
replacers: dict[int, tuple[str]]
# 占位符的编号和替换项列表的索引的映射,占位符编号按照在模板中出现的顺序排列,方便减少替换项的存储占用
placeholder_mapping: dict[str, int] # placeholder_id => replacer_id
# 各替换项列表的项数,按照在模板中出现的顺序排列,提前计算,方便后续使用
placeholder_choices: dict[str, int] # placeholder_id => len(replacer)
def __init__(self, text: str):
self.__make_template(text)
self.__total = None
def random(self, seed=None) -> str:
"从所有可能性中随机获取一个"
if seed is not None:
random.seed(seed)
return self.getn(random.randint(0, self.total() - 1))
def getn(self, n: int) -> str:
"从所有可能性中获取第 n 个,以 self.total() 为周期循环"
n = n % self.total()
indice = decimal_to_irregular(n, self.placeholder_choices.values())
replacements = {
placeholder_id: self.replacers[self.placeholder_mapping[placeholder_id]][i]
for placeholder_id, i in zip(self.placeholder_mapping.keys(), indice)
}
return self.template.format(**replacements)
def getmany(self, limit: int, offset: int = 0) -> list[str]:
"""返回一组可能性组成的列表,为了避免结果太长导致内存占用超限,使用 limit 限制列表的长度,使用 offset 调整偏移。
若 limit 和 offset 的设置导致预期的结果长度超过剩下的实际长度,则会回到开头。
"""
return [self.getn(n) for n in range(offset, offset + limit)]
def total(self) -> int:
"计算可能性的数目"
if self.__total is None:
self.__total = prod(self.placeholder_choices.values())
return self.__total
def __make_template(self, text: str):
"""将输入的提示词转化成符合 python str.format 要求格式的模板,
并将 option 和 wildcard 按照顺序在模板中留下 {r0}, {r1} 等占位符,
即使遇到相同的 option 或 wildcard留下的占位符编号也不同从而使每项都独立变化。
"""
self.placeholder_mapping = {}
placeholder_id = 0
replacer_id = 0
replacers_rev = {} # replacers => id
blocks = []
# 记录所处理过的通配符末尾在文本中的位置,用于拼接完整的模板
tail = 0
for match in self.RE_REPLACER.finditer(text):
# 提取并展开通配符内容
m = match.group(0)
if m.startswith("{"):
choices = expand_options(m[1:-1])
elif m.startswith("__"):
keyword = m[2:-2].lower()
keyword = wildcard_normalize(keyword)
choices = expand_wildcard(keyword)
else:
raise ValueError(f"{m!r} is not a wildcard or option")
# 记录通配符的替换项列表和ID相同的通配符只保留第一个
if choices not in replacers_rev:
replacers_rev[choices] = replacer_id
replacer_id += 1
# 拼接通配符前方文本
start, end = match.span()
blocks.append(text[tail:start])
tail = end
# 将通配符替换为占位符,并记录占位符和替换项列表的索引的映射
blocks.append(f"{{r{placeholder_id}}}")
self.placeholder_mapping[f"r{placeholder_id}"] = replacers_rev[choices]
placeholder_id += 1
if tail < len(text):
blocks.append(text[tail:])
self.template = "".join(blocks)
self.replacers = {v: k for k, v in replacers_rev.items()}
self.placeholder_choices = {
placeholder_id: len(self.replacers[replacer_id])
for placeholder_id, replacer_id in self.placeholder_mapping.items()
}
def test_option():
text = "{|a|b|c}"
answer = ["", "a", "b", "c"]
p = WildcardProcessor(text)
assert p.total() == len(answer)
assert p.getn(0) == answer[0]
assert p.getmany(4) == answer
assert p.getmany(4, 1) == answer[1:]
def test_same():
text = "{a|b},{a|b}"
answer = ["a,a", "b,a", "a,b", "b,b"]
p = WildcardProcessor(text)
assert p.total() == len(answer)
assert p.getn(0) == answer[0]
assert p.getmany(4) == answer
assert p.getmany(4, 1) == answer[1:]

View File

@@ -0,0 +1,697 @@
import os, torch
from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
from .utils import easySave, get_sd_version
from .adv_encode import advanced_encode
from .controlnet import easyControlnet
from .log import log_node_warn
from ..modules.layer_diffuse import LayerDiffuse
from ..config import RESOURCES_DIR
from nodes import CLIPTextEncode
import pprint
try:
from comfy_extras.nodes_flux import FluxGuidance
except:
FluxGuidance = None
class easyXYPlot():
def __init__(self, xyPlotData, save_prefix, image_output, prompt, extra_pnginfo, my_unique_id, sampler, easyCache):
self.x_node_type, self.x_type = sampler.safe_split(xyPlotData.get("x_axis"), ': ')
self.y_node_type, self.y_type = sampler.safe_split(xyPlotData.get("y_axis"), ': ')
self.x_values = xyPlotData.get("x_vals") if self.x_type != "None" else []
self.y_values = xyPlotData.get("y_vals") if self.y_type != "None" else []
self.custom_font = xyPlotData.get("custom_font")
self.grid_spacing = xyPlotData.get("grid_spacing")
self.latent_id = 0
self.output_individuals = xyPlotData.get("output_individuals")
self.x_label, self.y_label = [], []
self.max_width, self.max_height = 0, 0
self.latents_plot = []
self.image_list = []
self.num_cols = len(self.x_values) if len(self.x_values) > 0 else 1
self.num_rows = len(self.y_values) if len(self.y_values) > 0 else 1
self.total = self.num_cols * self.num_rows
self.num = 0
self.save_prefix = save_prefix
self.image_output = image_output
self.prompt = prompt
self.extra_pnginfo = extra_pnginfo
self.my_unique_id = my_unique_id
self.sampler = sampler
self.easyCache = easyCache
# Helper Functions
@staticmethod
def define_variable(plot_image_vars, value_type, value, index):
plot_image_vars[value_type] = value
if value_type in ["seed", "Seeds++ Batch"]:
value_label = f"seed: {value}"
else:
value_label = f"{value_type}: {value}"
if "ControlNet" in value_type:
value_label = f"ControlNet {index + 1}"
if value_type in ['Lora', 'Checkpoint']:
arr = value.split(',')
model_name = os.path.basename(os.path.splitext(arr[0])[0])
trigger_words = ' ' + arr[3] if value_type == 'Lora' and len(arr[3]) > 2 else ''
lora_weight = float(arr[1]) if value_type == 'Lora' and len(arr) > 1 else 0
lora_weight_desc = f"({lora_weight:.2f})" if lora_weight > 0 else ''
value_label = f"{model_name[:30]}{lora_weight_desc} {trigger_words}"
if value_type in ["ModelMergeBlocks"]:
if ":" in value:
line = value.split(':')
value_label = f"{line[0]}"
elif len(value) > 16:
value_label = f"ModelMergeBlocks {index + 1}"
else:
value_label = f"MMB: {value}"
if value_type in ["Pos Condition"]:
value_label = f"pos cond {index + 1}" if index>0 else f"pos cond"
if value_type in ["Neg Condition"]:
value_label = f"neg cond {index + 1}" if index>0 else f"neg cond"
if value_type in ["Positive Prompt S/R"]:
value_label = f"pos prompt {index + 1}" if index>0 else f"pos prompt"
if value_type in ["Negative Prompt S/R"]:
value_label = f"neg prompt {index + 1}" if index>0 else f"neg prompt"
if value_type in ["steps", "cfg", "denoise", "clip_skip",
"lora_model_strength", "lora_clip_strength"]:
value_label = f"{value_type}: {value}"
if value_type == "positive":
value_label = f"pos prompt {index + 1}"
elif value_type == "negative":
value_label = f"neg prompt {index + 1}"
return plot_image_vars, value_label
@staticmethod
def get_font(font_size, font_path=None):
if font_path is None:
font_path = str(Path(os.path.join(RESOURCES_DIR, 'OpenSans-Medium.ttf')))
return ImageFont.truetype(font_path, font_size)
@staticmethod
def update_label(label, value, num_items):
if len(label) < num_items:
return [*label, value]
return label
@staticmethod
def rearrange_tensors(latent, num_cols, num_rows):
new_latent = []
for i in range(num_rows):
for j in range(num_cols):
index = j * num_rows + i
new_latent.append(latent[index])
return new_latent
def calculate_background_dimensions(self):
border_size = int((self.max_width // 8) * 1.5) if self.y_type != "None" or self.x_type != "None" else 0
bg_width = self.num_cols * (self.max_width + self.grid_spacing) - self.grid_spacing + border_size * (
self.y_type != "None")
bg_height = self.num_rows * (self.max_height + self.grid_spacing) - self.grid_spacing + border_size * (
self.x_type != "None")
# Add space at the bottom of the image for common informaiton about the image
bg_height = bg_height + (border_size*2)
# print(f"Grid Size: width = {bg_width} height = {bg_height} border_size = {border_size}")
x_offset_initial = border_size if self.y_type != "None" else 0
y_offset = border_size if self.x_type != "None" else 0
return bg_width, bg_height, x_offset_initial, y_offset
def adjust_font_size(self, text, initial_font_size, label_width):
font = self.get_font(initial_font_size, self.custom_font)
text_width = font.getbbox(text)
# pprint.pp(f"Initial font size: {initial_font_size}, text: {text}, text_width: {text_width}")
if text_width and text_width[2]:
text_width = text_width[2]
scaling_factor = 0.9
if text_width > (label_width * scaling_factor):
# print(f"Adjusting font size from {initial_font_size} to fit text width {text_width} into label width {label_width} scaling_factor {scaling_factor}")
return int(initial_font_size * (label_width / text_width) * scaling_factor)
else:
return initial_font_size
def textsize(self, d, text, font):
_, _, width, height = d.textbbox((0, 0), text=text, font=font)
return width, height
def create_label(self, img, text, initial_font_size, is_x_label=True, max_font_size=70, min_font_size=10, label_width=0, label_height=0):
# if the label_width is specified, leave it along. Otherwise do the old logic.
if label_width == 0:
label_width = img.width if is_x_label else img.height
text_lines = text.split('\n')
longest_line = max(text_lines, key=len)
# Adjust font size
font_size = self.adjust_font_size(longest_line, initial_font_size, label_width)
font_size = min(max_font_size, font_size) # Ensure font isn't too large
font_size = max(min_font_size, font_size) # Ensure font isn't too small
if label_height == 0:
label_height = int(font_size * 1.5) if is_x_label else font_size
label_bg = Image.new('RGBA', (label_width, label_height), color=(255, 255, 255, 0))
d = ImageDraw.Draw(label_bg)
font = self.get_font(font_size, self.custom_font)
# Check if text will fit, if not insert ellipsis and reduce text
if self.textsize(d, text, font=font)[0] > label_width:
while self.textsize(d, text + '...', font=font)[0] > label_width and len(text) > 0:
text = text[:-1]
text = text + '...'
# Compute text width and height for multi-line text
text_widths, text_heights = zip(*[self.textsize(d, line, font=font) for line in text_lines])
max_text_width = max(text_widths)
total_text_height = sum(text_heights)
# Compute position for each line of text
lines_positions = []
current_y = 0
for line, line_width, line_height in zip(text_lines, text_widths, text_heights):
text_x = (label_width - line_width) // 2
text_y = current_y + (label_height - total_text_height) // 2
current_y += line_height
lines_positions.append((line, (text_x, text_y)))
# Draw each line of text
for line, (text_x, text_y) in lines_positions:
d.text((text_x, text_y), line, fill='black', font=font)
return label_bg
def sample_plot_image(self, plot_image_vars, samples, preview_latent, latents_plot, image_list, disable_noise,
start_step, last_step, force_full_denoise, x_value=None, y_value=None):
model, clip, vae, positive, negative, seed, steps, cfg = None, None, None, None, None, None, None, None
sampler_name, scheduler, denoise = None, None, None
a1111_prompt_style = plot_image_vars['a1111_prompt_style'] if "a1111_prompt_style" in plot_image_vars else False
clip = clip if clip is not None else plot_image_vars["clip"]
steps = plot_image_vars['steps'] if "steps" in plot_image_vars else 1
sd_version = get_sd_version(plot_image_vars['model'])
# 高级用法
if plot_image_vars["x_node_type"] == "advanced" or plot_image_vars["y_node_type"] == "advanced":
if self.x_type == "Seeds++ Batch" or self.y_type == "Seeds++ Batch":
seed = int(x_value) if self.x_type == "Seeds++ Batch" else int(y_value)
if self.x_type == "Steps" or self.y_type == "Steps":
steps = int(x_value) if self.x_type == "Steps" else int(y_value)
if self.x_type == "StartStep" or self.y_type == "StartStep":
start_step = int(x_value) if self.x_type == "StartStep" else int(y_value)
if self.x_type == "EndStep" or self.y_type == "EndStep":
last_step = int(x_value) if self.x_type == "EndStep" else int(y_value)
if self.x_type == "CFG Scale" or self.y_type == "CFG Scale":
cfg = float(x_value) if self.x_type == "CFG Scale" else float(y_value)
if self.x_type == "Sampler" or self.y_type == "Sampler":
sampler_name = x_value if self.x_type == "Sampler" else y_value
if self.x_type == "Scheduler" or self.y_type == "Scheduler":
scheduler = x_value if self.x_type == "Scheduler" else y_value
if self.x_type == "Sampler&Scheduler" or self.y_type == "Sampler&Scheduler":
arr = x_value.split(',') if self.x_type == "Sampler&Scheduler" else y_value.split(',')
if arr[0] and arr[0]!= 'None':
sampler_name = arr[0]
if arr[1] and arr[1]!= 'None':
scheduler = arr[1]
if self.x_type == "Denoise" or self.y_type == "Denoise":
denoise = float(x_value) if self.x_type == "Denoise" else float(y_value)
if self.x_type == "Pos Condition" or self.y_type == "Pos Condition":
positive = plot_image_vars['positive_cond_stack'][int(x_value)] if self.x_type == "Pos Condition" else plot_image_vars['positive_cond_stack'][int(y_value)]
if self.x_type == "Neg Condition" or self.y_type == "Neg Condition":
negative = plot_image_vars['negative_cond_stack'][int(x_value)] if self.x_type == "Neg Condition" else plot_image_vars['negative_cond_stack'][int(y_value)]
# 模型叠加
if self.x_type == "ModelMergeBlocks" or self.y_type == "ModelMergeBlocks":
ckpt_name_1, ckpt_name_2 = plot_image_vars['models']
model1, clip1, vae1, clip_vision = self.easyCache.load_checkpoint(ckpt_name_1)
model2, clip2, vae2, clip_vision = self.easyCache.load_checkpoint(ckpt_name_2)
xy_values = x_value if self.x_type == "ModelMergeBlocks" else y_value
if ":" in xy_values:
xy_line = xy_values.split(':')
xy_values = xy_line[1]
xy_arrs = xy_values.split(',')
# ModelMergeBlocks
if len(xy_arrs) == 3:
input, middle, out = xy_arrs
kwargs = {
"input": input,
"middle": middle,
"out": out
}
elif len(xy_arrs) == 30:
kwargs = {}
kwargs["time_embed."] = xy_arrs[0]
kwargs["label_emb."] = xy_arrs[1]
for i in range(12):
kwargs["input_blocks.{}.".format(i)] = xy_arrs[2+i]
for i in range(3):
kwargs["middle_block.{}.".format(i)] = xy_arrs[14+i]
for i in range(12):
kwargs["output_blocks.{}.".format(i)] = xy_arrs[17+i]
kwargs["out."] = xy_arrs[29]
else:
raise Exception("ModelMergeBlocks weight length error")
default_ratio = next(iter(kwargs.values()))
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
ratio = float(default_ratio)
k_unet = k[len("diffusion_model."):]
last_arg_size = 0
for arg in kwargs:
if k_unet.startswith(arg) and last_arg_size < len(arg):
ratio = float(kwargs[arg])
last_arg_size = len(arg)
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
vae_use = plot_image_vars['vae_use']
clip = clip2 if vae_use == 'Use Model 2' else clip1
if vae_use == 'Use Model 2':
vae = vae2
elif vae_use == 'Use Model 1':
vae = vae1
else:
vae = self.easyCache.load_vae(vae_use)
model = m
# 如果存在lora_stack叠加lora
optional_lora_stack = plot_image_vars['lora_stack']
if optional_lora_stack is not None and optional_lora_stack != []:
for lora in optional_lora_stack:
model, clip = self.easyCache.load_lora(lora)
# 处理clip
clip = clip.clone()
if plot_image_vars['clip_skip'] != 0:
clip.clip_layer(plot_image_vars['clip_skip'])
# CheckPoint
if self.x_type == "Checkpoint" or self.y_type == "Checkpoint":
xy_values = x_value if self.x_type == "Checkpoint" else y_value
ckpt_name, clip_skip, vae_name = xy_values.split(",")
ckpt_name = ckpt_name.replace('*', ',')
vae_name = vae_name.replace('*', ',')
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(ckpt_name)
if vae_name != 'None':
vae = self.easyCache.load_vae(vae_name)
# 如果存在lora_stack叠加lora
optional_lora_stack = plot_image_vars['lora_stack']
if optional_lora_stack is not None and optional_lora_stack != []:
for lora in optional_lora_stack:
lora['model'] = model
lora['clip'] = clip
model, clip = self.easyCache.load_lora(lora)
# 处理clip
clip = clip.clone()
if clip_skip != 'None':
clip.clip_layer(int(clip_skip))
positive = plot_image_vars['positive']
negative = plot_image_vars['negative']
a1111_prompt_style = plot_image_vars['a1111_prompt_style']
steps = plot_image_vars['steps']
clip = clip if clip is not None else plot_image_vars["clip"]
positive = advanced_encode(clip, positive,
plot_image_vars['positive_token_normalization'],
plot_image_vars['positive_weight_interpretation'],
w_max=1.0,
apply_to_pooled="enable",
a1111_prompt_style=a1111_prompt_style, steps=steps)
negative = advanced_encode(clip, negative,
plot_image_vars['negative_token_normalization'],
plot_image_vars['negative_weight_interpretation'],
w_max=1.0,
apply_to_pooled="enable",
a1111_prompt_style=a1111_prompt_style, steps=steps)
if "positive_cond" in plot_image_vars:
positive = positive + plot_image_vars["positive_cond"]
if "negative_cond" in plot_image_vars:
negative = negative + plot_image_vars["negative_cond"]
# Lora
if self.x_type == "Lora" or self.y_type == "Lora":
# print(f"Lora: {x_value} {y_value}")
model = model if model is not None else plot_image_vars["model"]
clip = clip if clip is not None else plot_image_vars["clip"]
xy_values = x_value if self.x_type == "Lora" else y_value
lora_name, lora_model_strength, lora_clip_strength, _ = xy_values.split(",")
lora_stack = [{"lora_name": lora_name, "model": model, "clip" :clip, "model_strength": float(lora_model_strength), "clip_strength": float(lora_clip_strength)}]
# print(f"new_lora_stack: {new_lora_stack}")
if 'lora_stack' in plot_image_vars:
lora_stack = lora_stack + plot_image_vars['lora_stack']
if lora_stack is not None and lora_stack != []:
for lora in lora_stack:
# Each generation of the model, must use the reference to previously created model / clip objects.
lora['model'] = model
lora['clip'] = clip
model, clip = self.easyCache.load_lora(lora)
# 提示词
if "Positive" in self.x_type or "Positive" in self.y_type:
if self.x_type == 'Positive Prompt S/R' or self.y_type == 'Positive Prompt S/R':
positive = x_value if self.x_type == "Positive Prompt S/R" else y_value
if sd_version == 'flux':
positive, = CLIPTextEncode().encode(clip, positive)
else:
positive = advanced_encode(clip, positive,
plot_image_vars['positive_token_normalization'],
plot_image_vars['positive_weight_interpretation'],
w_max=1.0,
apply_to_pooled="enable", a1111_prompt_style=a1111_prompt_style, steps=steps)
# if "positive_cond" in plot_image_vars:
# positive = positive + plot_image_vars["positive_cond"]
if "Negative" in self.x_type or "Negative" in self.y_type:
if self.x_type == 'Negative Prompt S/R' or self.y_type == 'Negative Prompt S/R':
negative = x_value if self.x_type == "Negative Prompt S/R" else y_value
if sd_version == 'flux':
negative, = CLIPTextEncode().encode(clip, negative)
else:
negative = advanced_encode(clip, negative,
plot_image_vars['negative_token_normalization'],
plot_image_vars['negative_weight_interpretation'],
w_max=1.0,
apply_to_pooled="enable", a1111_prompt_style=a1111_prompt_style, steps=steps)
# if "negative_cond" in plot_image_vars:
# negative = negative + plot_image_vars["negative_cond"]
# ControlNet
if "ControlNet" in self.x_type or "ControlNet" in self.y_type:
cnet = plot_image_vars["cnet"] if "cnet" in plot_image_vars else None
positive = plot_image_vars["positive_cond"] if "positive" in plot_image_vars else None
negative = plot_image_vars["negative_cond"] if "negative" in plot_image_vars else None
if cnet:
index = x_value if "ControlNet" in self.x_type else y_value
controlnet = cnet[index]
for index, item in enumerate(controlnet):
control_net_name = item[0]
image = item[1]
strength = item[2]
start_percent = item[3]
end_percent = item[4]
provided_control_net = item[5] if len(item) > 5 else None
positive, negative = easyControlnet().apply(control_net_name, image, positive, negative, strength, start_percent, end_percent, provided_control_net, 1)
# Flux guidance
if self.x_type == "Flux Guidance" or self.y_type == "Flux Guidance":
positive = plot_image_vars["positive_cond"] if "positive" in plot_image_vars else None
flux_guidance = float(x_value) if self.x_type == "Flux Guidance" else float(y_value)
positive, = FluxGuidance().append(positive, flux_guidance)
# 简单用法
if plot_image_vars["x_node_type"] == "loader" or plot_image_vars["y_node_type"] == "loader":
if self.x_type == 'ckpt_name' or self.y_type == 'ckpt_name':
ckpt_name = x_value if self.x_type == "ckpt_name" else y_value
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(ckpt_name)
if self.x_type == 'lora_name' or self.y_type == 'lora_name':
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(plot_image_vars['ckpt_name'])
lora_name = x_value if self.x_type == "lora_name" else y_value
lora = {"lora_name": lora_name, "model": model, "clip": clip, "model_strength": 1, "clip_strength": 1}
model, clip = self.easyCache.load_lora(lora)
if self.x_type == 'lora_model_strength' or self.y_type == 'lora_model_strength':
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(plot_image_vars['ckpt_name'])
lora_model_strength = float(x_value) if self.x_type == "lora_model_strength" else float(y_value)
lora = {"lora_name": plot_image_vars['lora_name'], "model": model, "clip": clip, "model_strength": lora_model_strength, "clip_strength": plot_image_vars['lora_clip_strength']}
model, clip = self.easyCache.load_lora(lora)
if self.x_type == 'lora_clip_strength' or self.y_type == 'lora_clip_strength':
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(plot_image_vars['ckpt_name'])
lora_clip_strength = float(x_value) if self.x_type == "lora_clip_strength" else float(y_value)
lora = {"lora_name": plot_image_vars['lora_name'], "model": model, "clip": clip, "model_strength": plot_image_vars['lora_model_strength'], "clip_strength": lora_clip_strength}
model, clip = self.easyCache.load_lora(lora)
# Check for custom VAE
if self.x_type == 'vae_name' or self.y_type == 'vae_name':
vae_name = x_value if self.x_type == "vae_name" else y_value
vae = self.easyCache.load_vae(vae_name)
# CLIP skip
if not clip:
raise Exception("No CLIP found")
clip = clip.clone()
clip.clip_layer(plot_image_vars['clip_skip'])
if sd_version == 'flux':
positive, = CLIPTextEncode().encode(clip, positive)
else:
positive = advanced_encode(clip, plot_image_vars['positive'],
plot_image_vars['positive_token_normalization'],
plot_image_vars['positive_weight_interpretation'], w_max=1.0,
apply_to_pooled="enable",a1111_prompt_style=a1111_prompt_style, steps=steps)
if sd_version == 'flux':
negative, = CLIPTextEncode().encode(clip, negative)
else:
negative = advanced_encode(clip, plot_image_vars['negative'],
plot_image_vars['negative_token_normalization'],
plot_image_vars['negative_weight_interpretation'], w_max=1.0,
apply_to_pooled="enable", a1111_prompt_style=a1111_prompt_style, steps=steps)
model = model if model is not None else plot_image_vars["model"]
vae = vae if vae is not None else plot_image_vars["vae"]
positive = positive if positive is not None else plot_image_vars["positive_cond"]
negative = negative if negative is not None else plot_image_vars["negative_cond"]
seed = seed if seed is not None else plot_image_vars["seed"]
steps = steps if steps is not None else plot_image_vars["steps"]
cfg = cfg if cfg is not None else plot_image_vars["cfg"]
sampler_name = sampler_name if sampler_name is not None else plot_image_vars["sampler_name"]
scheduler = scheduler if scheduler is not None else plot_image_vars["scheduler"]
denoise = denoise if denoise is not None else plot_image_vars["denoise"]
noise_device = plot_image_vars["noise_device"] if "noise_device" in plot_image_vars else 'cpu'
# LayerDiffuse
layer_diffusion_method = plot_image_vars["layer_diffusion_method"] if "layer_diffusion_method" in plot_image_vars else None
empty_samples = plot_image_vars["empty_samples"] if "empty_samples" in plot_image_vars else None
if layer_diffusion_method:
samp_blend_samples = plot_image_vars["blend_samples"] if "blend_samples" in plot_image_vars else None
additional_cond = plot_image_vars["layer_diffusion_cond"] if "layer_diffusion_cond" in plot_image_vars else None
images = plot_image_vars["images"].movedim(-1, 1) if "images" in plot_image_vars else None
weight = plot_image_vars['layer_diffusion_weight'] if 'layer_diffusion_weight' in plot_image_vars else 1.0
model, positive, negative = LayerDiffuse().apply_layer_diffusion(model, layer_diffusion_method, weight, samples,
samp_blend_samples, positive,
negative, images, additional_cond)
samples = empty_samples if layer_diffusion_method is not None and empty_samples is not None else samples
# Sample
samples = self.sampler.common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, samples,
denoise=denoise, disable_noise=disable_noise, preview_latent=preview_latent,
start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_device=noise_device)
# Decode images and store
latent = samples["samples"]
# Add the latent tensor to the tensors list
latents_plot.append(latent)
# Decode the image
image = vae.decode(latent).cpu()
if self.output_individuals in [True, "True"]:
easySave(image, self.save_prefix, self.image_output)
# Convert the image from tensor to PIL Image and add it to the list
pil_image = self.sampler.tensor2pil(image)
image_list.append(pil_image)
# Update max dimensions
self.max_width = max(self.max_width, pil_image.width)
self.max_height = max(self.max_height, pil_image.height)
# Return the touched variables
return image_list, self.max_width, self.max_height, latents_plot
# Process Functions
def validate_xy_plot(self):
if self.x_type == 'None' and self.y_type == 'None':
log_node_warn(f'#{self.my_unique_id}','No Valid Plot Types - Reverting to default sampling...')
return False
else:
return True
def get_latent(self, samples):
# Extract the 'samples' tensor from the dictionary
latent_image_tensor = samples["samples"]
# Split the tensor into individual image tensors
image_tensors = torch.split(latent_image_tensor, 1, dim=0)
# Create a list of dictionaries containing the individual image tensors
latent_list = [{'samples': image} for image in image_tensors]
# Set latent only to the first latent of batch
if self.latent_id >= len(latent_list):
log_node_warn(f'#{self.my_unique_id}',f'The selected latent_id ({self.latent_id}) is out of range.')
log_node_warn(f'#{self.my_unique_id}', f'Automatically setting the latent_id to the last image in the list (index: {len(latent_list) - 1}).')
self.latent_id = len(latent_list) - 1
return latent_list[self.latent_id]
def get_labels_and_sample(self, plot_image_vars, latent_image, preview_latent, start_step, last_step,
force_full_denoise, disable_noise):
for x_index, x_value in enumerate(self.x_values):
plot_image_vars, x_value_label = self.define_variable(plot_image_vars, self.x_type, x_value,
x_index)
self.x_label = self.update_label(self.x_label, x_value_label, len(self.x_values))
if self.y_type != 'None':
for y_index, y_value in enumerate(self.y_values):
plot_image_vars, y_value_label = self.define_variable(plot_image_vars, self.y_type, y_value,
y_index)
self.y_label = self.update_label(self.y_label, y_value_label, len(self.y_values))
# ttNl(f'{CC.GREY}X: {x_value_label}, Y: {y_value_label}').t(
# f'Plot Values {self.num}/{self.total} ->').p()
self.image_list, self.max_width, self.max_height, self.latents_plot = self.sample_plot_image(
plot_image_vars, latent_image, preview_latent, self.latents_plot, self.image_list,
disable_noise, start_step, last_step, force_full_denoise, x_value, y_value)
self.num += 1
else:
# ttNl(f'{CC.GREY}X: {x_value_label}').t(f'Plot Values {self.num}/{self.total} ->').p()
self.image_list, self.max_width, self.max_height, self.latents_plot = self.sample_plot_image(
plot_image_vars, latent_image, preview_latent, self.latents_plot, self.image_list, disable_noise,
start_step, last_step, force_full_denoise, x_value)
self.num += 1
# Rearrange latent array to match preview image grid
self.latents_plot = self.rearrange_tensors(self.latents_plot, self.num_cols, self.num_rows)
# Concatenate the tensors along the first dimension (dim=0)
self.latents_plot = torch.cat(self.latents_plot, dim=0)
return self.latents_plot
def plot_images_and_labels(self, plot_image_vars):
bg_width, bg_height, x_offset_initial, y_offset = self.calculate_background_dimensions()
background = Image.new('RGBA', (int(bg_width), int(bg_height)), color=(255, 255, 255, 255))
output_image = []
for row_index in range(self.num_rows):
x_offset = x_offset_initial
for col_index in range(self.num_cols):
index = col_index * self.num_rows + row_index
img = self.image_list[index]
output_image.append(self.sampler.pil2tensor(img))
background.paste(img, (x_offset, y_offset))
# Handle X label
if row_index == 0 and self.x_type != "None":
label_bg = self.create_label(img, self.x_label[col_index], int(48 * img.width / 512))
label_y = (y_offset - label_bg.height) // 2
background.alpha_composite(label_bg, (x_offset, label_y))
# Handle Y label
if col_index == 0 and self.y_type != "None":
label_bg = self.create_label(img, self.y_label[row_index], int(48 * img.height / 512), False)
label_bg = label_bg.rotate(90, expand=True)
label_x = (x_offset - label_bg.width) // 2
label_y = y_offset + (img.height - label_bg.height) // 2
background.alpha_composite(label_bg, (label_x, label_y))
x_offset += img.width + self.grid_spacing
y_offset += img.height + self.grid_spacing
# lookup used models in the image
common_label = ""
# Update to add a function to do the heavy lifting. Parameters are plot_image_vars name, label to use, names of the axis,
# pprint.pp(plot_image_vars)
# We don't process LORAs here because there can be multiple of them.
labels = [
{"id": "ckpt_name", "id_desc": "ckpt", "axis_type" : "Checkpoint"},
{"id": "vae_name", "id_desc": '', "axis_type" : "vae_name"},
{"id": "sampler_name", "id_desc": "sampler", "axis_type" : "Sampler"},
{"id": "scheduler", "id_desc": '', "axis_type" : "Scheduler"},
{"id": "steps", "id_desc": '', "axis_type" : "Steps"},
{"id": "Flux Guidance", "id_desc": 'guidance', "axis_type" : "Flux Guidance"},
{"id": "seed", "id_desc": '', "axis_type" : "Seeds++ Batch"}
]
for item in labels:
# Only add the label if it's not one of the axis
# print(f"Checking item: {item['id']} axis_type {item['axis_type']} x_type: {self.x_type} y_type: {self.y_type}")
if self.x_type != item['axis_type'] and self.y_type != item['axis_type']:
common_label += self.add_common_label(item['id'], plot_image_vars, item['id_desc'])
common_label += f"\n"
if plot_image_vars['lora_stack'] is not None and plot_image_vars['lora_stack'] != []:
# print(f"lora_stack: {plot_image_vars['lora_stack']}")
for lora in plot_image_vars['lora_stack']:
lora_name = lora['lora_name']
lora_weight = lora['model_strength']
if lora_name is not None and len(lora_name) > 0 and lora_weight > 0:
common_label += f"LORA: {lora_name} weight: {lora_weight:.2f} \n"
common_label = common_label.strip()
if len(common_label) > 0:
label_height = background.height - y_offset
label_bg = self.create_label(background, common_label, int(48 * background.width / 512), label_width=background.width, label_height=label_height)
label_x = (background.width - label_bg.width) // 2
label_y = y_offset
# print(f"Adding common label: {common_label} x = {label_x} y = {label_y}")
background.alpha_composite(label_bg, (label_x, label_y))
return (self.sampler.pil2tensor(background), output_image)
def add_common_label(self, tag, plot_image_vars, description = ''):
label = ''
if description == '': description = tag
if tag in plot_image_vars and plot_image_vars[tag] is not None and plot_image_vars[tag] != 'None':
label += f"{description}: {plot_image_vars[tag]} "
# print(f"add_common_label: {tag} description: {description} label: {label}" )
return label

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,167 @@
#credit to comfyanonymous for this module
#from https://github.com/comfyanonymous/ComfyUI_bitsandbytes_NF4
import comfy.ops
import torch
import folder_paths
from ...libs.utils import install_package
try:
from bitsandbytes.nn.modules import Params4bit, QuantState
except ImportError:
Params4bit = torch.nn.Parameter
raise ImportError("Please install bitsandbytes>=0.43.3")
def functional_linear_4bits(x, weight, bias):
try:
install_package("bitsandbytes", "0.43.3", True, "0.43.3")
import bitsandbytes as bnb
except ImportError:
raise ImportError("Please install bitsandbytes>=0.43.3")
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
out = out.to(x)
return out
def copy_quant_state(state, device: torch.device = None):
if state is None:
return None
device = device or state.absmax.device
state2 = (
QuantState(
absmax=state.state2.absmax.to(device),
shape=state.state2.shape,
code=state.state2.code.to(device),
blocksize=state.state2.blocksize,
quant_type=state.state2.quant_type,
dtype=state.state2.dtype,
)
if state.nested
else None
)
return QuantState(
absmax=state.absmax.to(device),
shape=state.shape,
code=state.code.to(device),
blocksize=state.blocksize,
quant_type=state.quant_type,
dtype=state.dtype,
offset=state.offset.to(device) if state.nested else None,
state2=state2,
)
class ForgeParams4bit(Params4bit):
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == "cuda" and not self.bnb_quantized:
return self._quantize(device)
else:
n = ForgeParams4bit(
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad,
quant_state=copy_quant_state(self.quant_state, device),
blocksize=self.blocksize,
compress_statistics=self.compress_statistics,
quant_type=self.quant_type,
quant_storage=self.quant_storage,
bnb_quantized=self.bnb_quantized,
module=self.module
)
self.module.quant_state = n.quant_state
self.data = n.data
self.quant_state = n.quant_state
return n
class ForgeLoader4Bit(torch.nn.Module):
def __init__(self, *, device, dtype, quant_type, **kwargs):
super().__init__()
self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype))
self.weight = None
self.quant_state = None
self.bias = None
self.quant_type = quant_type
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
quant_state = getattr(self.weight, "quant_state", None)
if quant_state is not None:
for k, v in quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
return
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")}
if any('bitsandbytes' in k for k in quant_state_keys):
quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
self.weight = ForgeParams4bit().from_prequantized(
data=state_dict[prefix + 'weight'],
quantized_stats=quant_state_dict,
requires_grad=False,
device=self.dummy.device,
module=self
)
self.quant_state = self.weight.quant_state
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
elif hasattr(self, 'dummy'):
if prefix + 'weight' in state_dict:
self.weight = ForgeParams4bit(
state_dict[prefix + 'weight'].to(self.dummy),
requires_grad=False,
compress_statistics=True,
quant_type=self.quant_type,
quant_storage=torch.uint8,
module=self,
)
self.quant_state = self.weight.quant_state
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
else:
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
current_device = None
current_dtype = None
current_manual_cast_enabled = False
current_bnb_dtype = None
class OPS(comfy.ops.manual_cast):
class Linear(ForgeLoader4Bit):
def __init__(self, *args, device=None, dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, quant_type=current_bnb_dtype)
self.parameters_manual_cast = current_manual_cast_enabled
def forward(self, x):
self.weight.quant_state = self.quant_state
if self.bias is not None and self.bias.dtype != x.dtype:
# Maybe this can also be set to all non-bnb ops since the cost is very low.
# And it only invokes one time, and most linear does not have bias
self.bias.data = self.bias.data.to(x.dtype)
if not self.parameters_manual_cast:
return functional_linear_4bits(x, self.weight, self.bias)
elif not self.weight.bnb_quantized:
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
layer_original_device = self.weight.device
self.weight = self.weight._quantize(x.device)
bias = self.bias.to(x.device) if self.bias is not None else None
out = functional_linear_4bits(x, self.weight, bias)
self.weight = self.weight.to(layer_original_device)
return out
else:
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
with main_stream_worker(weight, bias, signal):
return functional_linear_4bits(x, weight, bias)

View File

@@ -0,0 +1,475 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
import numpy as np
class REBNCONV(nn.Module):
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
super(REBNCONV,self).__init__()
self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)
def forward(self,x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
return xout
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src,tar):
src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
return src
### RSU-7 ###
class RSU7(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
super(RSU7,self).__init__()
self.in_ch = in_ch
self.mid_ch = mid_ch
self.out_ch = out_ch
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
b, c, h, w = x.shape
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
hx6dup = _upsample_like(hx6d,hx5)
hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-6 ###
class RSU6(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU6,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx6 = self.rebnconv6(hx5)
hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-5 ###
class RSU5(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU5,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-4 ###
class RSU4(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-4F ###
class RSU4F(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4F,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
return hx1d + hxin
class myrebnconv(nn.Module):
def __init__(self, in_ch=3,
out_ch=1,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1):
super(myrebnconv,self).__init__()
self.conv = nn.Conv2d(in_ch,
out_ch,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups)
self.bn = nn.BatchNorm2d(out_ch)
self.rl = nn.ReLU(inplace=True)
def forward(self,x):
return self.rl(self.bn(self.conv(x)))
def preprocess_image(im, model_input_size: list) -> torch.Tensor:
# im = im.resize(model_input_size, Image.BILINEAR)
im_np = np.array(im)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
image = torch.divide(im_tensor,255.0)
image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
return image
def postprocess_image(result: torch.Tensor, im_size: list)-> np.ndarray:
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
ma = torch.max(result)
mi = torch.min(result)
result = (result-mi)/(ma-mi)
im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
im_array = np.squeeze(im_array)
return im_array
class BriaRMBG(nn.Module):
def __init__(self, config:dict={"in_ch":3,"out_ch":1}):
super(BriaRMBG,self).__init__()
in_ch = config["in_ch"]
out_ch = config["out_ch"]
self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage1 = RSU7(64,32,64)
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage2 = RSU6(64,32,128)
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage3 = RSU5(128,64,256)
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage4 = RSU4(256,128,512)
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage5 = RSU4F(512,256,512)
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage6 = RSU4F(512,256,512)
# decoder
self.stage5d = RSU4F(1024,256,512)
self.stage4d = RSU4(1024,128,256)
self.stage3d = RSU5(512,64,128)
self.stage2d = RSU6(256,32,64)
self.stage1d = RSU7(128,16,64)
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
def forward(self,x):
hx = x
hxin = self.conv_in(hx)
#hx = self.pool_in(hxin)
#stage 1
hx1 = self.stage1(hxin)
hx = self.pool12(hx1)
#stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
#stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
#stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
#stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
#stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6,hx5)
#-------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
#side output
d1 = self.side1(hx1d)
d1 = _upsample_like(d1,x)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2,x)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3,x)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4,x)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5,x)
d6 = self.side6(hx6)
d6 = _upsample_like(d6,x)
return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]

View File

@@ -0,0 +1,822 @@
#credit to nullquant for this module
#from https://github.com/nullquant/ComfyUI-BrushNet
import os
import types
import torch
try:
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
except:
init_empty_weights, load_checkpoint_and_dispatch = None, None
import comfy
try:
from .model import BrushNetModel, PowerPaintModel
from .model_patch import add_model_patch_option, patch_model_function_wrapper
from .powerpaint_utils import TokenizerWrapper, add_tokens
except:
BrushNetModel, PowerPaintModel = None, None
add_model_patch_option, patch_model_function_wrapper = None, None
TokenizerWrapper, add_tokens = None, None
cwd_path = os.path.dirname(os.path.realpath(__file__))
brushnet_config_file = os.path.join(cwd_path, 'config', 'brushnet.json')
brushnet_xl_config_file = os.path.join(cwd_path, 'config', 'brushnet_xl.json')
powerpaint_config_file = os.path.join(cwd_path, 'config', 'powerpaint.json')
sd15_scaling_factor = 0.18215
sdxl_scaling_factor = 0.13025
ModelsToUnload = [comfy.sd1_clip.SD1ClipModel, comfy.ldm.models.autoencoder.AutoencoderKL]
class BrushNet:
# Check models compatibility
def check_compatibilty(self, model, brushnet):
is_SDXL = False
is_PP = False
if isinstance(model.model.model_config, comfy.supported_models.SD15):
print('Base model type: SD1.5')
is_SDXL = False
if brushnet["SDXL"]:
raise Exception("Base model is SD15, but BrushNet is SDXL type")
if brushnet["PP"]:
is_PP = True
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
print('Base model type: SDXL')
is_SDXL = True
if not brushnet["SDXL"]:
raise Exception("Base model is SDXL, but BrushNet is SD15 type")
else:
print('Base model type: ', type(model.model.model_config))
raise Exception("Unsupported model type: " + str(type(model.model.model_config)))
return (is_SDXL, is_PP)
def check_image_mask(self, image, mask, name):
if len(image.shape) < 4:
# image tensor shape should be [B, H, W, C], but batch somehow is missing
image = image[None, :, :, :]
if len(mask.shape) > 3:
# mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be?
# take first mask, red channel
mask = (mask[:, :, :, 0])[:, :, :]
elif len(mask.shape) < 3:
# mask tensor shape should be [B, H, W] but batch somehow is missing
mask = mask[None, :, :]
if image.shape[0] > mask.shape[0]:
print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0]))
if mask.shape[0] == 1:
print(name, "will copy the mask to fill batch")
mask = torch.cat([mask] * image.shape[0], dim=0)
else:
print(name, "will add empty masks to fill batch")
empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]])
mask = torch.cat([mask, empty_mask], dim=0)
elif image.shape[0] < mask.shape[0]:
print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0]))
mask = mask[:image.shape[0], :, :]
return (image, mask)
# Prepare image and mask
def prepare_image(self, image, mask):
image, mask = self.check_image_mask(image, mask, 'BrushNet')
print("BrushNet image.shape =", image.shape, "mask.shape =", mask.shape)
if mask.shape[2] != image.shape[2] or mask.shape[1] != image.shape[1]:
raise Exception("Image and mask should be the same size")
# As a suggestion of inferno46n2 (https://github.com/nullquant/ComfyUI-BrushNet/issues/64)
mask = mask.round()
masked_image = image * (1.0 - mask[:, :, :, None])
return (masked_image, mask)
# Get origin of the mask
def cut_with_mask(self, mask, width, height):
iy, ix = (mask == 1).nonzero(as_tuple=True)
h0, w0 = mask.shape
if iy.numel() == 0:
x_c = w0 / 2.0
y_c = h0 / 2.0
else:
x_min = ix.min().item()
x_max = ix.max().item()
y_min = iy.min().item()
y_max = iy.max().item()
if x_max - x_min > width or y_max - y_min > height:
raise Exception("Mask is bigger than provided dimensions")
x_c = (x_min + x_max) / 2.0
y_c = (y_min + y_max) / 2.0
width2 = width / 2.0
height2 = height / 2.0
if w0 <= width:
x0 = 0
w = w0
else:
x0 = max(0, x_c - width2)
w = width
if x0 + width > w0:
x0 = w0 - width
if h0 <= height:
y0 = 0
h = h0
else:
y0 = max(0, y_c - height2)
h = height
if y0 + height > h0:
y0 = h0 - height
return (int(x0), int(y0), int(w), int(h))
# Prepare conditioning_latents
@torch.inference_mode()
def get_image_latents(self, masked_image, mask, vae, scaling_factor):
processed_image = masked_image.to(vae.device)
image_latents = vae.encode(processed_image[:, :, :, :3]) * scaling_factor
processed_mask = 1. - mask[:, None, :, :]
interpolated_mask = torch.nn.functional.interpolate(
processed_mask,
size=(
image_latents.shape[-2],
image_latents.shape[-1]
)
)
interpolated_mask = interpolated_mask.to(image_latents.device)
conditioning_latents = [image_latents, interpolated_mask]
print('BrushNet CL: image_latents shape =', image_latents.shape, 'interpolated_mask shape =',
interpolated_mask.shape)
return conditioning_latents
def brushnet_blocks(self, sd):
brushnet_down_block = 0
brushnet_mid_block = 0
brushnet_up_block = 0
for key in sd:
if 'brushnet_down_block' in key:
brushnet_down_block += 1
if 'brushnet_mid_block' in key:
brushnet_mid_block += 1
if 'brushnet_up_block' in key:
brushnet_up_block += 1
return (brushnet_down_block, brushnet_mid_block, brushnet_up_block, len(sd))
def get_model_type(self, brushnet_file):
sd = comfy.utils.load_torch_file(brushnet_file)
brushnet_down_block, brushnet_mid_block, brushnet_up_block, keys = self.brushnet_blocks(sd)
del sd
if brushnet_down_block == 24 and brushnet_mid_block == 2 and brushnet_up_block == 30:
is_SDXL = False
if keys == 322:
is_PP = False
print('BrushNet model type: SD1.5')
else:
is_PP = True
print('PowerPaint model type: SD1.5')
elif brushnet_down_block == 18 and brushnet_mid_block == 2 and brushnet_up_block == 22:
print('BrushNet model type: Loading SDXL')
is_SDXL = True
is_PP = False
else:
raise Exception("Unknown BrushNet model")
return is_SDXL, is_PP
def load_brushnet_model(self, brushnet_file, dtype='float16'):
is_SDXL, is_PP = self.get_model_type(brushnet_file)
with init_empty_weights():
if is_SDXL:
brushnet_config = BrushNetModel.load_config(brushnet_xl_config_file)
brushnet_model = BrushNetModel.from_config(brushnet_config)
elif is_PP:
brushnet_config = PowerPaintModel.load_config(powerpaint_config_file)
brushnet_model = PowerPaintModel.from_config(brushnet_config)
else:
brushnet_config = BrushNetModel.load_config(brushnet_config_file)
brushnet_model = BrushNetModel.from_config(brushnet_config)
if is_PP:
print("PowerPaint model file:", brushnet_file)
else:
print("BrushNet model file:", brushnet_file)
if dtype == 'float16':
torch_dtype = torch.float16
elif dtype == 'bfloat16':
torch_dtype = torch.bfloat16
elif dtype == 'float32':
torch_dtype = torch.float32
else:
torch_dtype = torch.float64
brushnet_model = load_checkpoint_and_dispatch(
brushnet_model,
brushnet_file,
device_map="sequential",
max_memory=None,
offload_folder=None,
offload_state_dict=False,
dtype=torch_dtype,
force_hooks=False,
)
if is_PP:
print("PowerPaint model is loaded")
elif is_SDXL:
print("BrushNet SDXL model is loaded")
else:
print("BrushNet SD1.5 model is loaded")
return ({"brushnet": brushnet_model, "SDXL": is_SDXL, "PP": is_PP, "dtype": torch_dtype},)
def brushnet_model_update(self, model, vae, image, mask, brushnet, positive, negative, scale, start_at, end_at):
is_SDXL, is_PP = self.check_compatibilty(model, brushnet)
if is_PP:
raise Exception("PowerPaint model was loaded, please use PowerPaint node")
# Make a copy of the model so that we're not patching it everywhere in the workflow.
model = model.clone()
# prepare image and mask
# no batches for original image and mask
masked_image, mask = self.prepare_image(image, mask)
batch = masked_image.shape[0]
width = masked_image.shape[2]
height = masked_image.shape[1]
if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format,
'scale_factor'):
scaling_factor = model.model.model_config.latent_format.scale_factor
elif is_SDXL:
scaling_factor = sdxl_scaling_factor
else:
scaling_factor = sd15_scaling_factor
torch_dtype = brushnet['dtype']
# prepare conditioning latents
conditioning_latents = self.get_image_latents(masked_image, mask, vae, scaling_factor)
conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
# unload vae
del vae
# for loaded_model in comfy.model_management.current_loaded_models:
# if type(loaded_model.model.model) in ModelsToUnload:
# comfy.model_management.current_loaded_models.remove(loaded_model)
# loaded_model.model_unload()
# del loaded_model
# prepare embeddings
prompt_embeds = positive[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
negative_prompt_embeds = negative[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
max_tokens = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
if prompt_embeds.shape[1] < max_tokens:
multiplier = max_tokens // 77 - prompt_embeds.shape[1] // 77
prompt_embeds = torch.concat([prompt_embeds] + [prompt_embeds[:, -77:, :]] * multiplier, dim=1)
print('BrushNet: negative prompt more than 75 tokens:', negative_prompt_embeds.shape,
'multiplying prompt_embeds')
if negative_prompt_embeds.shape[1] < max_tokens:
multiplier = max_tokens // 77 - negative_prompt_embeds.shape[1] // 77
negative_prompt_embeds = torch.concat(
[negative_prompt_embeds] + [negative_prompt_embeds[:, -77:, :]] * multiplier, dim=1)
print('BrushNet: positive prompt more than 75 tokens:', prompt_embeds.shape,
'multiplying negative_prompt_embeds')
if len(positive[0]) > 1 and 'pooled_output' in positive[0][1] and positive[0][1]['pooled_output'] is not None:
pooled_prompt_embeds = positive[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
else:
print('BrushNet: positive conditioning has not pooled_output')
if is_SDXL:
print('BrushNet will not produce correct results')
pooled_prompt_embeds = torch.empty([2, 1280], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
if len(negative[0]) > 1 and 'pooled_output' in negative[0][1] and negative[0][1]['pooled_output'] is not None:
negative_pooled_prompt_embeds = negative[0][1]['pooled_output'].to(dtype=torch_dtype).to(
brushnet['brushnet'].device)
else:
print('BrushNet: negative conditioning has not pooled_output')
if is_SDXL:
print('BrushNet will not produce correct results')
negative_pooled_prompt_embeds = torch.empty([1, pooled_prompt_embeds.shape[1]],
device=brushnet['brushnet'].device).to(dtype=torch_dtype)
time_ids = torch.FloatTensor([[height, width, 0., 0., height, width]]).to(dtype=torch_dtype).to(
brushnet['brushnet'].device)
if not is_SDXL:
pooled_prompt_embeds = None
negative_pooled_prompt_embeds = None
time_ids = None
# apply patch to model
brushnet_conditioning_scale = scale
control_guidance_start = start_at
control_guidance_end = end_at
add_brushnet_patch(model,
brushnet['brushnet'],
torch_dtype,
conditioning_latents,
(brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
prompt_embeds, negative_prompt_embeds,
pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
False)
latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]],
device=brushnet['brushnet'].device)
return (model, positive, negative, {"samples": latent},)
#powperpaint
def load_powerpaint_clip(self, base_clip_file, pp_clip_file):
pp_clip = comfy.sd.load_clip(ckpt_paths=[base_clip_file])
print('PowerPaint base CLIP file: ', base_clip_file)
pp_tokenizer = TokenizerWrapper(pp_clip.tokenizer.clip_l.tokenizer)
pp_text_encoder = pp_clip.patcher.model.clip_l.transformer
add_tokens(
tokenizer=pp_tokenizer,
text_encoder=pp_text_encoder,
placeholder_tokens=["P_ctxt", "P_shape", "P_obj"],
initialize_tokens=["a", "a", "a"],
num_vectors_per_token=10,
)
pp_text_encoder.load_state_dict(comfy.utils.load_torch_file(pp_clip_file), strict=False)
print('PowerPaint CLIP file: ', pp_clip_file)
pp_clip.tokenizer.clip_l.tokenizer = pp_tokenizer
pp_clip.patcher.model.clip_l.transformer = pp_text_encoder
return (pp_clip,)
def powerpaint_model_update(self, model, vae, image, mask, powerpaint, clip, positive, negative, fitting, function, scale, start_at, end_at, save_memory):
is_SDXL, is_PP = self.check_compatibilty(model, powerpaint)
if not is_PP:
raise Exception("BrushNet model was loaded, please use BrushNet node")
# Make a copy of the model so that we're not patching it everywhere in the workflow.
model = model.clone()
# prepare image and mask
# no batches for original image and mask
masked_image, mask = self.prepare_image(image, mask)
batch = masked_image.shape[0]
# width = masked_image.shape[2]
# height = masked_image.shape[1]
if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format,
'scale_factor'):
scaling_factor = model.model.model_config.latent_format.scale_factor
else:
scaling_factor = sd15_scaling_factor
torch_dtype = powerpaint['dtype']
# prepare conditioning latents
conditioning_latents = self.get_image_latents(masked_image, mask, vae, scaling_factor)
conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
# prepare embeddings
if function == "object removal":
promptA = "P_ctxt"
promptB = "P_ctxt"
negative_promptA = "P_obj"
negative_promptB = "P_obj"
print('You should add to positive prompt: "empty scene blur"')
# positive = positive + " empty scene blur"
elif function == "context aware":
promptA = "P_ctxt"
promptB = "P_ctxt"
negative_promptA = ""
negative_promptB = ""
# positive = positive + " empty scene"
print('You should add to positive prompt: "empty scene"')
elif function == "shape guided":
promptA = "P_shape"
promptB = "P_ctxt"
negative_promptA = "P_shape"
negative_promptB = "P_ctxt"
elif function == "image outpainting":
promptA = "P_ctxt"
promptB = "P_ctxt"
negative_promptA = "P_obj"
negative_promptB = "P_obj"
# positive = positive + " empty scene"
print('You should add to positive prompt: "empty scene"')
else:
promptA = "P_obj"
promptB = "P_obj"
negative_promptA = "P_obj"
negative_promptB = "P_obj"
tokens = clip.tokenize(promptA)
prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
tokens = clip.tokenize(negative_promptA)
negative_prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
tokens = clip.tokenize(promptB)
prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
tokens = clip.tokenize(negative_promptB)
negative_prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
prompt_embeds_pp = (prompt_embedsA * fitting + (1.0 - fitting) * prompt_embedsB).to(dtype=torch_dtype).to(
powerpaint['brushnet'].device)
negative_prompt_embeds_pp = (negative_prompt_embedsA * fitting + (1.0 - fitting) * negative_prompt_embedsB).to(
dtype=torch_dtype).to(powerpaint['brushnet'].device)
# unload vae and CLIPs
del vae
del clip
# for loaded_model in comfy.model_management.current_loaded_models:
# if type(loaded_model.model.model) in ModelsToUnload:
# comfy.model_management.current_loaded_models.remove(loaded_model)
# loaded_model.model_unload()
# del loaded_model
# apply patch to model
brushnet_conditioning_scale = scale
control_guidance_start = start_at
control_guidance_end = end_at
if save_memory != 'none':
powerpaint['brushnet'].set_attention_slice(save_memory)
add_brushnet_patch(model,
powerpaint['brushnet'],
torch_dtype,
conditioning_latents,
(brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
negative_prompt_embeds_pp, prompt_embeds_pp,
None, None, None,
False)
latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]],
device=powerpaint['brushnet'].device)
return (model, positive, negative, {"samples": latent},)
@torch.inference_mode()
def brushnet_inference(x, timesteps, transformer_options, debug):
if 'model_patch' not in transformer_options:
print('BrushNet inference: there is no model_patch key in transformer_options')
return ([], 0, [])
mp = transformer_options['model_patch']
if 'brushnet' not in mp:
print('BrushNet inference: there is no brushnet key in mdel_patch')
return ([], 0, [])
bo = mp['brushnet']
if 'model' not in bo:
print('BrushNet inference: there is no model key in brushnet')
return ([], 0, [])
brushnet = bo['model']
if not (isinstance(brushnet, BrushNetModel) or isinstance(brushnet, PowerPaintModel)):
print('BrushNet model is not a BrushNetModel class')
return ([], 0, [])
torch_dtype = bo['dtype']
cl_list = bo['latents']
brushnet_conditioning_scale, control_guidance_start, control_guidance_end = bo['controls']
pe = bo['prompt_embeds']
npe = bo['negative_prompt_embeds']
ppe, nppe, time_ids = bo['add_embeds']
#do_classifier_free_guidance = mp['free_guidance']
do_classifier_free_guidance = len(transformer_options['cond_or_uncond']) > 1
x = x.detach().clone()
x = x.to(torch_dtype).to(brushnet.device)
timesteps = timesteps.detach().clone()
timesteps = timesteps.to(torch_dtype).to(brushnet.device)
total_steps = mp['total_steps']
step = mp['step']
added_cond_kwargs = {}
if do_classifier_free_guidance and step == 0:
print('BrushNet inference: do_classifier_free_guidance is True')
sub_idx = None
if 'ad_params' in transformer_options and 'sub_idxs' in transformer_options['ad_params']:
sub_idx = transformer_options['ad_params']['sub_idxs']
# we have batch input images
batch = cl_list[0].shape[0]
# we have incoming latents
latents_incoming = x.shape[0]
# and we already got some
latents_got = bo['latent_id']
if step == 0 or batch > 1:
print('BrushNet inference, step = %d: image batch = %d, got %d latents, starting from %d' \
% (step, batch, latents_incoming, latents_got))
image_latents = []
masks = []
prompt_embeds = []
negative_prompt_embeds = []
pooled_prompt_embeds = []
negative_pooled_prompt_embeds = []
if sub_idx:
# AnimateDiff indexes detected
if step == 0:
print('BrushNet inference: AnimateDiff indexes detected and applied')
batch = len(sub_idx)
if do_classifier_free_guidance:
for i in sub_idx:
image_latents.append(cl_list[0][i][None,:,:,:])
masks.append(cl_list[1][i][None,:,:,:])
prompt_embeds.append(pe)
negative_prompt_embeds.append(npe)
pooled_prompt_embeds.append(ppe)
negative_pooled_prompt_embeds.append(nppe)
for i in sub_idx:
image_latents.append(cl_list[0][i][None,:,:,:])
masks.append(cl_list[1][i][None,:,:,:])
else:
for i in sub_idx:
image_latents.append(cl_list[0][i][None,:,:,:])
masks.append(cl_list[1][i][None,:,:,:])
prompt_embeds.append(pe)
pooled_prompt_embeds.append(ppe)
else:
# do_classifier_free_guidance = 2 passes, 1st pass is cond, 2nd is uncond
continue_batch = True
for i in range(latents_incoming):
number = latents_got + i
if number < batch:
# 1st pass, cond
image_latents.append(cl_list[0][number][None,:,:,:])
masks.append(cl_list[1][number][None,:,:,:])
prompt_embeds.append(pe)
pooled_prompt_embeds.append(ppe)
elif do_classifier_free_guidance and number < batch * 2:
# 2nd pass, uncond
image_latents.append(cl_list[0][number-batch][None,:,:,:])
masks.append(cl_list[1][number-batch][None,:,:,:])
negative_prompt_embeds.append(npe)
negative_pooled_prompt_embeds.append(nppe)
else:
# latent batch
image_latents.append(cl_list[0][0][None,:,:,:])
masks.append(cl_list[1][0][None,:,:,:])
prompt_embeds.append(pe)
pooled_prompt_embeds.append(ppe)
latents_got = -i
continue_batch = False
if continue_batch:
# we don't have full batch yet
if do_classifier_free_guidance:
if number < batch * 2 - 1:
bo['latent_id'] = number + 1
else:
bo['latent_id'] = 0
else:
if number < batch - 1:
bo['latent_id'] = number + 1
else:
bo['latent_id'] = 0
else:
bo['latent_id'] = 0
cl = []
for il, m in zip(image_latents, masks):
cl.append(torch.concat([il, m], dim=1))
cl2apply = torch.concat(cl, dim=0)
conditioning_latents = cl2apply.to(torch_dtype).to(brushnet.device)
prompt_embeds.extend(negative_prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
if ppe is not None:
added_cond_kwargs = {}
added_cond_kwargs['time_ids'] = torch.concat([time_ids] * latents_incoming, dim = 0).to(torch_dtype).to(brushnet.device)
pooled_prompt_embeds.extend(negative_pooled_prompt_embeds)
pooled_prompt_embeds = torch.concat(pooled_prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
added_cond_kwargs['text_embeds'] = pooled_prompt_embeds
else:
added_cond_kwargs = None
if x.shape[2] != conditioning_latents.shape[2] or x.shape[3] != conditioning_latents.shape[3]:
if step == 0:
print('BrushNet inference: image', conditioning_latents.shape, 'and latent', x.shape, 'have different size, resizing image')
conditioning_latents = torch.nn.functional.interpolate(
conditioning_latents, size=(
x.shape[2],
x.shape[3],
), mode='bicubic',
).to(torch_dtype).to(brushnet.device)
if step == 0:
print('BrushNet inference: sample', x.shape, ', CL', conditioning_latents.shape, 'dtype', torch_dtype)
if debug: print('BrushNet: step =', step)
if step < control_guidance_start or step > control_guidance_end:
cond_scale = 0.0
else:
cond_scale = brushnet_conditioning_scale
return brushnet(x,
encoder_hidden_states=prompt_embeds,
brushnet_cond=conditioning_latents,
timestep = timesteps,
conditioning_scale=cond_scale,
guess_mode=False,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
debug=debug,
)
def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
controls,
prompt_embeds, negative_prompt_embeds,
pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
debug):
is_SDXL = isinstance(model.model.model_config, comfy.supported_models.SDXL)
if model.model.model_config.custom_operations is None:
fp8 = model.model.model_config.optimizations.get("fp8", model.model.model_config.scaled_fp8 is not None)
operations = comfy.ops.pick_operations(model.model.model_config.unet_config.get("dtype", None), model.model.manual_cast_dtype,
fp8_optimizations=fp8, scaled_fp8=model.model.model_config.scaled_fp8)
else:
# such as gguf
operations = model.model.model_config.custom_operations
if is_SDXL:
input_blocks = [[0, operations.Conv2d],
[1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[4, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.attention.SpatialTransformer],
[6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[7, comfy.ldm.modules.attention.SpatialTransformer],
[8, comfy.ldm.modules.attention.SpatialTransformer]]
middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
output_blocks = [[0, comfy.ldm.modules.attention.SpatialTransformer],
[1, comfy.ldm.modules.attention.SpatialTransformer],
[2, comfy.ldm.modules.attention.SpatialTransformer],
[2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[3, comfy.ldm.modules.attention.SpatialTransformer],
[4, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[6, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[7, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[8, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
else:
input_blocks = [[0, operations.Conv2d],
[1, comfy.ldm.modules.attention.SpatialTransformer],
[2, comfy.ldm.modules.attention.SpatialTransformer],
[3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[4, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.attention.SpatialTransformer],
[6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[7, comfy.ldm.modules.attention.SpatialTransformer],
[8, comfy.ldm.modules.attention.SpatialTransformer],
[9, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[10, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[11, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
output_blocks = [[0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[3, comfy.ldm.modules.attention.SpatialTransformer],
[4, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[6, comfy.ldm.modules.attention.SpatialTransformer],
[7, comfy.ldm.modules.attention.SpatialTransformer],
[8, comfy.ldm.modules.attention.SpatialTransformer],
[8, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[9, comfy.ldm.modules.attention.SpatialTransformer],
[10, comfy.ldm.modules.attention.SpatialTransformer],
[11, comfy.ldm.modules.attention.SpatialTransformer]]
def last_layer_index(block, tp):
layer_list = []
for layer in block:
layer_list.append(type(layer))
layer_list.reverse()
if tp not in layer_list:
return -1, layer_list.reverse()
return len(layer_list) - 1 - layer_list.index(tp), layer_list
def brushnet_forward(model, x, timesteps, transformer_options, control):
if 'brushnet' not in transformer_options['model_patch']:
input_samples = []
mid_sample = 0
output_samples = []
else:
# brushnet inference
input_samples, mid_sample, output_samples = brushnet_inference(x, timesteps, transformer_options, debug)
# give additional samples to blocks
for i, tp in input_blocks:
idx, layer_list = last_layer_index(model.input_blocks[i], tp)
if idx < 0:
print("BrushNet can't find", tp, "layer in", i, "input block:", layer_list)
continue
model.input_blocks[i][idx].add_sample_after = input_samples.pop(0) if input_samples else 0
idx, layer_list = last_layer_index(model.middle_block, middle_block[1])
if idx < 0:
print("BrushNet can't find", middle_block[1], "layer in middle block", layer_list)
model.middle_block[idx].add_sample_after = mid_sample
for i, tp in output_blocks:
idx, layer_list = last_layer_index(model.output_blocks[i], tp)
if idx < 0:
print("BrushNet can't find", tp, "layer in", i, "outnput block:", layer_list)
continue
model.output_blocks[i][idx].add_sample_after = output_samples.pop(0) if output_samples else 0
patch_model_function_wrapper(model, brushnet_forward)
to = add_model_patch_option(model)
mp = to['model_patch']
if 'brushnet' not in mp:
mp['brushnet'] = {}
bo = mp['brushnet']
bo['model'] = brushnet
bo['dtype'] = torch_dtype
bo['latents'] = conditioning_latents
bo['controls'] = controls
bo['prompt_embeds'] = prompt_embeds
bo['negative_prompt_embeds'] = negative_prompt_embeds
bo['add_embeds'] = (pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids)
bo['latent_id'] = 0
# patch layers `forward` so we can apply brushnet
def forward_patched_by_brushnet(self, x, *args, **kwargs):
h = self.original_forward(x, *args, **kwargs)
if hasattr(self, 'add_sample_after') and type(self):
to_add = self.add_sample_after
if torch.is_tensor(to_add):
# interpolate due to RAUNet
if h.shape[2] != to_add.shape[2] or h.shape[3] != to_add.shape[3]:
to_add = torch.nn.functional.interpolate(to_add, size=(h.shape[2], h.shape[3]), mode='bicubic')
h += to_add.to(h.dtype).to(h.device)
else:
h += self.add_sample_after
self.add_sample_after = 0
return h
for i, block in enumerate(model.model.diffusion_model.input_blocks):
for j, layer in enumerate(block):
if not hasattr(layer, 'original_forward'):
layer.original_forward = layer.forward
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
layer.add_sample_after = 0
for j, layer in enumerate(model.model.diffusion_model.middle_block):
if not hasattr(layer, 'original_forward'):
layer.original_forward = layer.forward
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
layer.add_sample_after = 0
for i, block in enumerate(model.model.diffusion_model.output_blocks):
for j, layer in enumerate(block):
if not hasattr(layer, 'original_forward'):
layer.original_forward = layer.forward
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
layer.add_sample_after = 0

View File

@@ -0,0 +1,58 @@
{
"_class_name": "BrushNetModel",
"_diffusers_version": "0.27.0.dev0",
"_name_or_path": "runs/logs/brushnet_randommask/checkpoint-100000",
"act_fn": "silu",
"addition_embed_type": null,
"addition_embed_type_num_heads": 64,
"addition_time_embed_dim": null,
"attention_head_dim": 8,
"block_out_channels": [
320,
640,
1280,
1280
],
"brushnet_conditioning_channel_order": "rgb",
"class_embed_type": null,
"conditioning_channels": 5,
"conditioning_embedding_out_channels": [
16,
32,
96,
256
],
"cross_attention_dim": 768,
"down_block_types": [
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D"
],
"downsample_padding": 1,
"encoder_hid_dim": null,
"encoder_hid_dim_type": null,
"flip_sin_to_cos": true,
"freq_shift": 0,
"global_pool_conditions": false,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "MidBlock2D",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": null,
"num_class_embeds": null,
"only_cross_attention": false,
"projection_class_embeddings_input_dim": null,
"resnet_time_scale_shift": "default",
"transformer_layers_per_block": 1,
"up_block_types": [
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D"
],
"upcast_attention": false,
"use_linear_projection": false
}

View File

@@ -0,0 +1,63 @@
{
"_class_name": "BrushNetModel",
"_diffusers_version": "0.27.0.dev0",
"_name_or_path": "runs/logs/brushnetsdxl_randommask/checkpoint-80000",
"act_fn": "silu",
"addition_embed_type": "text_time",
"addition_embed_type_num_heads": 64,
"addition_time_embed_dim": 256,
"attention_head_dim": [
5,
10,
20
],
"block_out_channels": [
320,
640,
1280
],
"brushnet_conditioning_channel_order": "rgb",
"class_embed_type": null,
"conditioning_channels": 5,
"conditioning_embedding_out_channels": [
16,
32,
96,
256
],
"cross_attention_dim": 2048,
"down_block_types": [
"DownBlock2D",
"DownBlock2D",
"DownBlock2D"
],
"downsample_padding": 1,
"encoder_hid_dim": null,
"encoder_hid_dim_type": null,
"flip_sin_to_cos": true,
"freq_shift": 0,
"global_pool_conditions": false,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "MidBlock2D",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": null,
"num_class_embeds": null,
"only_cross_attention": false,
"projection_class_embeddings_input_dim": 2816,
"resnet_time_scale_shift": "default",
"transformer_layers_per_block": [
1,
2,
10
],
"up_block_types": [
"UpBlock2D",
"UpBlock2D",
"UpBlock2D"
],
"upcast_attention": null,
"use_linear_projection": true
}

View File

@@ -0,0 +1,57 @@
{
"_class_name": "BrushNetModel",
"_diffusers_version": "0.27.2",
"act_fn": "silu",
"addition_embed_type": null,
"addition_embed_type_num_heads": 64,
"addition_time_embed_dim": null,
"attention_head_dim": 8,
"block_out_channels": [
320,
640,
1280,
1280
],
"brushnet_conditioning_channel_order": "rgb",
"class_embed_type": null,
"conditioning_channels": 5,
"conditioning_embedding_out_channels": [
16,
32,
96,
256
],
"cross_attention_dim": 768,
"down_block_types": [
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D"
],
"downsample_padding": 1,
"encoder_hid_dim": null,
"encoder_hid_dim_type": null,
"flip_sin_to_cos": true,
"freq_shift": 0,
"global_pool_conditions": false,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": null,
"num_class_embeds": null,
"only_cross_attention": false,
"projection_class_embeddings_input_dim": null,
"resnet_time_scale_shift": "default",
"transformer_layers_per_block": 1,
"up_block_types": [
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D"
],
"upcast_attention": false,
"use_linear_projection": false
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,137 @@
import torch
import comfy
# Check and add 'model_patch' to model.model_options['transformer_options']
def add_model_patch_option(model):
if 'transformer_options' not in model.model_options:
model.model_options['transformer_options'] = {}
to = model.model_options['transformer_options']
if "model_patch" not in to:
to["model_patch"] = {}
return to
# Patch model with model_function_wrapper
def patch_model_function_wrapper(model, forward_patch, remove=False):
def brushnet_model_function_wrapper(apply_model_method, options_dict):
to = options_dict['c']['transformer_options']
control = None
if 'control' in options_dict['c']:
control = options_dict['c']['control']
x = options_dict['input']
timestep = options_dict['timestep']
# check if there are patches to execute
if 'model_patch' not in to or 'forward' not in to['model_patch']:
return apply_model_method(x, timestep, **options_dict['c'])
mp = to['model_patch']
unet = mp['unet']
all_sigmas = mp['all_sigmas']
sigma = to['sigmas'][0].item()
total_steps = all_sigmas.shape[0] - 1
step = torch.argmin((all_sigmas - sigma).abs()).item()
mp['step'] = step
mp['total_steps'] = total_steps
# comfy.model_base.apply_model
xc = model.model.model_sampling.calculate_input(timestep, x)
if 'c_concat' in options_dict['c'] and options_dict['c']['c_concat'] is not None:
xc = torch.cat([xc] + [options_dict['c']['c_concat']], dim=1)
t = model.model.model_sampling.timestep(timestep).float()
# execute all patches
for method in mp['forward']:
method(unet, xc, t, to, control)
return apply_model_method(x, timestep, **options_dict['c'])
if "model_function_wrapper" in model.model_options and model.model_options["model_function_wrapper"]:
print('BrushNet is going to replace existing model_function_wrapper:',
model.model_options["model_function_wrapper"])
model.set_model_unet_function_wrapper(brushnet_model_function_wrapper)
to = add_model_patch_option(model)
mp = to['model_patch']
if isinstance(model.model.model_config, comfy.supported_models.SD15):
mp['SDXL'] = False
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
mp['SDXL'] = True
else:
print('Base model type: ', type(model.model.model_config))
raise Exception("Unsupported model type: ", type(model.model.model_config))
if 'forward' not in mp:
mp['forward'] = []
if remove:
if forward_patch in mp['forward']:
mp['forward'].remove(forward_patch)
else:
mp['forward'].append(forward_patch)
mp['unet'] = model.model.diffusion_model
mp['step'] = 0
mp['total_steps'] = 1
# apply patches to code
if comfy.samplers.sample.__doc__ is None or 'BrushNet' not in comfy.samplers.sample.__doc__:
comfy.samplers.original_sample = comfy.samplers.sample
comfy.samplers.sample = modified_sample
if comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__ is None or \
'BrushNet' not in comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__:
comfy.ldm.modules.diffusionmodules.openaimodel.original_apply_control = comfy.ldm.modules.diffusionmodules.openaimodel.apply_control
comfy.ldm.modules.diffusionmodules.openaimodel.apply_control = modified_apply_control
# Model needs current step number and cfg at inference step. It is possible to write a custom KSampler but I'd like to use ComfyUI's one.
# The first versions had modified_common_ksampler, but it broke custom KSampler nodes
def modified_sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={},
latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
''' Modified by BrushNet nodes'''
cfg_guider = comfy.samplers.CFGGuider(model)
cfg_guider.set_conds(positive, negative)
cfg_guider.set_cfg(cfg)
### Modified part ######################################################################
to = add_model_patch_option(model)
to['model_patch']['all_sigmas'] = sigmas
#######################################################################################
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
# To use Controlnet with RAUNet it is much easier to modify apply_control a little
def modified_apply_control(h, control, name):
'''Modified by BrushNet nodes'''
if control is not None and name in control and len(control[name]) > 0:
ctrl = control[name].pop()
if ctrl is not None:
if h.shape[2] != ctrl.shape[2] or h.shape[3] != ctrl.shape[3]:
ctrl = torch.nn.functional.interpolate(ctrl, size=(h.shape[2], h.shape[3]), mode='bicubic').to(
h.dtype).to(h.device)
try:
h += ctrl
except:
print.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
return h
def add_model_patch(model):
to = add_model_patch_option(model)
mp = to['model_patch']
if "brushnet" in mp:
if isinstance(model.model.model_config, comfy.supported_models.SD15):
mp['SDXL'] = False
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
mp['SDXL'] = True
else:
print('Base model type: ', type(model.model.model_config))
raise Exception("Unsupported model type: ", type(model.model.model_config))
mp['unet'] = model.model.diffusion_model
mp['step'] = 0
mp['total_steps'] = 1

View File

@@ -0,0 +1,467 @@
import copy
import random
import torch
import torch.nn as nn
from transformers import CLIPTokenizer
from typing import Any, List, Optional, Union
class TokenizerWrapper:
"""Tokenizer wrapper for CLIPTokenizer. Only support CLIPTokenizer
currently. This wrapper is modified from https://github.com/huggingface/dif
fusers/blob/e51f19aee82c8dd874b715a09dbc521d88835d68/src/diffusers/loaders.
py#L358 # noqa.
Args:
from_pretrained (Union[str, os.PathLike], optional): The *model id*
of a pretrained model or a path to a *directory* containing
model weights and config. Defaults to None.
from_config (Union[str, os.PathLike], optional): The *model id*
of a pretrained model or a path to a *directory* containing
model weights and config. Defaults to None.
*args, **kwargs: If `from_pretrained` is passed, *args and **kwargs
will be passed to `from_pretrained` function. Otherwise, *args
and **kwargs will be used to initialize the model by
`self._module_cls(*args, **kwargs)`.
"""
def __init__(self, tokenizer: CLIPTokenizer):
self.wrapped = tokenizer
self.token_map = {}
def __getattr__(self, name: str) -> Any:
if name in self.__dict__:
return getattr(self, name)
# if name == "wrapped":
# return getattr(self, 'wrapped')#super().__getattr__("wrapped")
try:
return getattr(self.wrapped, name)
except AttributeError:
raise AttributeError(
"'name' cannot be found in both "
f"'{self.__class__.__name__}' and "
f"'{self.__class__.__name__}.tokenizer'."
)
def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
"""Attempt to add tokens to the tokenizer.
Args:
tokens (Union[str, List[str]]): The tokens to be added.
"""
num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
assert num_added_tokens != 0, (
f"The tokenizer already contains the token {tokens}. Please pass "
"a different `placeholder_token` that is not already in the "
"tokenizer."
)
def get_token_info(self, token: str) -> dict:
"""Get the information of a token, including its start and end index in
the current tokenizer.
Args:
token (str): The token to be queried.
Returns:
dict: The information of the token, including its start and end
index in current tokenizer.
"""
token_ids = self.__call__(token).input_ids
start, end = token_ids[1], token_ids[-2] + 1
return {"name": token, "start": start, "end": end}
def add_placeholder_token(self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs):
"""Add placeholder tokens to the tokenizer.
Args:
placeholder_token (str): The placeholder token to be added.
num_vec_per_token (int, optional): The number of vectors of
the added placeholder token.
*args, **kwargs: The arguments for `self.wrapped.add_tokens`.
"""
output = []
if num_vec_per_token == 1:
self.try_adding_tokens(placeholder_token, *args, **kwargs)
output.append(placeholder_token)
else:
output = []
for i in range(num_vec_per_token):
ith_token = placeholder_token + f"_{i}"
self.try_adding_tokens(ith_token, *args, **kwargs)
output.append(ith_token)
for token in self.token_map:
if token in placeholder_token:
raise ValueError(
f"The tokenizer already has placeholder token {token} "
f"that can get confused with {placeholder_token} "
"keep placeholder tokens independent"
)
self.token_map[placeholder_token] = output
def replace_placeholder_tokens_in_text(
self, text: Union[str, List[str]], vector_shuffle: bool = False, prop_tokens_to_load: float = 1.0
) -> Union[str, List[str]]:
"""Replace the keywords in text with placeholder tokens. This function
will be called in `self.__call__` and `self.encode`.
Args:
text (Union[str, List[str]]): The text to be processed.
vector_shuffle (bool, optional): Whether to shuffle the vectors.
Defaults to False.
prop_tokens_to_load (float, optional): The proportion of tokens to
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
Returns:
Union[str, List[str]]: The processed text.
"""
if isinstance(text, list):
output = []
for i in range(len(text)):
output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle))
return output
for placeholder_token in self.token_map:
if placeholder_token in text:
tokens = self.token_map[placeholder_token]
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
if vector_shuffle:
tokens = copy.copy(tokens)
random.shuffle(tokens)
text = text.replace(placeholder_token, " ".join(tokens))
return text
def replace_text_with_placeholder_tokens(self, text: Union[str, List[str]]) -> Union[str, List[str]]:
"""Replace the placeholder tokens in text with the original keywords.
This function will be called in `self.decode`.
Args:
text (Union[str, List[str]]): The text to be processed.
Returns:
Union[str, List[str]]: The processed text.
"""
if isinstance(text, list):
output = []
for i in range(len(text)):
output.append(self.replace_text_with_placeholder_tokens(text[i]))
return output
for placeholder_token, tokens in self.token_map.items():
merged_tokens = " ".join(tokens)
if merged_tokens in text:
text = text.replace(merged_tokens, placeholder_token)
return text
def __call__(
self,
text: Union[str, List[str]],
*args,
vector_shuffle: bool = False,
prop_tokens_to_load: float = 1.0,
**kwargs,
):
"""The call function of the wrapper.
Args:
text (Union[str, List[str]]): The text to be tokenized.
vector_shuffle (bool, optional): Whether to shuffle the vectors.
Defaults to False.
prop_tokens_to_load (float, optional): The proportion of tokens to
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
*args, **kwargs: The arguments for `self.wrapped.__call__`.
"""
replaced_text = self.replace_placeholder_tokens_in_text(
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
)
return self.wrapped.__call__(replaced_text, *args, **kwargs)
def encode(self, text: Union[str, List[str]], *args, **kwargs):
"""Encode the passed text to token index.
Args:
text (Union[str, List[str]]): The text to be encode.
*args, **kwargs: The arguments for `self.wrapped.__call__`.
"""
replaced_text = self.replace_placeholder_tokens_in_text(text)
return self.wrapped(replaced_text, *args, **kwargs)
def decode(self, token_ids, return_raw: bool = False, *args, **kwargs) -> Union[str, List[str]]:
"""Decode the token index to text.
Args:
token_ids: The token index to be decoded.
return_raw: Whether keep the placeholder token in the text.
Defaults to False.
*args, **kwargs: The arguments for `self.wrapped.decode`.
Returns:
Union[str, List[str]]: The decoded text.
"""
text = self.wrapped.decode(token_ids, *args, **kwargs)
if return_raw:
return text
replaced_text = self.replace_text_with_placeholder_tokens(text)
return replaced_text
def __repr__(self):
"""The representation of the wrapper."""
s = super().__repr__()
prefix = f"Wrapped Module Class: {self._module_cls}\n"
prefix += f"Wrapped Module Name: {self._module_name}\n"
if self._from_pretrained:
prefix += f"From Pretrained: {self._from_pretrained}\n"
s = prefix + s
return s
class EmbeddingLayerWithFixes(nn.Module):
"""The revised embedding layer to support external embeddings. This design
of this class is inspired by https://github.com/AUTOMATIC1111/stable-
diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi
jack.py#L224 # noqa.
Args:
wrapped (nn.Emebdding): The embedding layer to be wrapped.
external_embeddings (Union[dict, List[dict]], optional): The external
embeddings added to this layer. Defaults to None.
"""
def __init__(self, wrapped: nn.Embedding, external_embeddings: Optional[Union[dict, List[dict]]] = None):
super().__init__()
self.wrapped = wrapped
self.num_embeddings = wrapped.weight.shape[0]
self.external_embeddings = []
if external_embeddings:
self.add_embeddings(external_embeddings)
self.trainable_embeddings = nn.ParameterDict()
@property
def weight(self):
"""Get the weight of wrapped embedding layer."""
return self.wrapped.weight
def check_duplicate_names(self, embeddings: List[dict]):
"""Check whether duplicate names exist in list of 'external
embeddings'.
Args:
embeddings (List[dict]): A list of embedding to be check.
"""
names = [emb["name"] for emb in embeddings]
assert len(names) == len(set(names)), (
"Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'"
)
def check_ids_overlap(self, embeddings):
"""Check whether overlap exist in token ids of 'external_embeddings'.
Args:
embeddings (List[dict]): A list of embedding to be check.
"""
ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings]
ids_range.sort() # sort by 'start'
# check if 'end' has overlapping
for idx in range(len(ids_range) - 1):
name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1]
assert ids_range[idx][1] <= ids_range[idx + 1][0], (
f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'."
)
def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
"""Add external embeddings to this layer.
Use case:
Args:
embeddings (Union[dict, list[dict]]): The external embeddings to
be added. Each dict must contain the following 4 fields: 'name'
(the name of this embedding), 'embedding' (the embedding
tensor), 'start' (the start token id of this embedding), 'end'
(the end token id of this embedding). For example:
`{name: NAME, start: START, end: END, embedding: torch.Tensor}`
"""
if isinstance(embeddings, dict):
embeddings = [embeddings]
self.external_embeddings += embeddings
self.check_duplicate_names(self.external_embeddings)
self.check_ids_overlap(self.external_embeddings)
# set for trainable
added_trainable_emb_info = []
for embedding in embeddings:
trainable = embedding.get("trainable", False)
if trainable:
name = embedding["name"]
embedding["embedding"] = torch.nn.Parameter(embedding["embedding"])
self.trainable_embeddings[name] = embedding["embedding"]
added_trainable_emb_info.append(name)
added_emb_info = [emb["name"] for emb in embeddings]
added_emb_info = ", ".join(added_emb_info)
print(f"Successfully add external embeddings: {added_emb_info}.", "current")
if added_trainable_emb_info:
added_trainable_emb_info = ", ".join(added_trainable_emb_info)
print("Successfully add trainable external embeddings: " f"{added_trainable_emb_info}", "current")
def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Replace external input ids to 0.
Args:
input_ids (torch.Tensor): The input ids to be replaced.
Returns:
torch.Tensor: The replaced input ids.
"""
input_ids_fwd = input_ids.clone()
input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0
return input_ids_fwd
def replace_embeddings(
self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict
) -> torch.Tensor:
"""Replace external embedding to the embedding layer. Noted that, in
this function we use `torch.cat` to avoid inplace modification.
Args:
input_ids (torch.Tensor): The original token ids. Shape like
[LENGTH, ].
embedding (torch.Tensor): The embedding of token ids after
`replace_input_ids` function.
external_embedding (dict): The external embedding to be replaced.
Returns:
torch.Tensor: The replaced embedding.
"""
new_embedding = []
name = external_embedding["name"]
start = external_embedding["start"]
end = external_embedding["end"]
target_ids_to_replace = [i for i in range(start, end)]
ext_emb = external_embedding["embedding"].to(embedding.device)
# do not need to replace
if not (input_ids == start).any():
return embedding
# start replace
s_idx, e_idx = 0, 0
while e_idx < len(input_ids):
if input_ids[e_idx] == start:
if e_idx != 0:
# add embedding do not need to replace
new_embedding.append(embedding[s_idx:e_idx])
# check if the next embedding need to replace is valid
actually_ids_to_replace = [int(i) for i in input_ids[e_idx: e_idx + end - start]]
assert actually_ids_to_replace == target_ids_to_replace, (
f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. "
f"Expect '{target_ids_to_replace}' for embedding "
f"'{name}' but found '{actually_ids_to_replace}'."
)
new_embedding.append(ext_emb)
s_idx = e_idx + end - start
e_idx = s_idx + 1
else:
e_idx += 1
if e_idx == len(input_ids):
new_embedding.append(embedding[s_idx:e_idx])
return torch.cat(new_embedding, dim=0)
def forward(self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None, out_dtype = None):
"""The forward function.
Args:
input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or
[LENGTH, ].
external_embeddings (Optional[List[dict]]): The external
embeddings. If not passed, only `self.external_embeddings`
will be used. Defaults to None.
input_ids: shape like [bz, LENGTH] or [LENGTH].
"""
assert input_ids.ndim in [1, 2]
if input_ids.ndim == 1:
input_ids = input_ids.unsqueeze(0)
if external_embeddings is None and not self.external_embeddings:
return self.wrapped(input_ids, out_dtype=out_dtype)
input_ids_fwd = self.replace_input_ids(input_ids)
inputs_embeds = self.wrapped(input_ids_fwd)
vecs = []
if external_embeddings is None:
external_embeddings = []
elif isinstance(external_embeddings, dict):
external_embeddings = [external_embeddings]
embeddings = self.external_embeddings + external_embeddings
for input_id, embedding in zip(input_ids, inputs_embeds):
new_embedding = embedding
for external_embedding in embeddings:
new_embedding = self.replace_embeddings(input_id, new_embedding, external_embedding)
vecs.append(new_embedding)
return torch.stack(vecs).to(out_dtype)
def add_tokens(
tokenizer, text_encoder, placeholder_tokens: list, initialize_tokens: list = None,
num_vectors_per_token: int = 1
):
"""Add token for training.
# TODO: support add tokens as dict, then we can load pretrained tokens.
"""
if initialize_tokens is not None:
assert len(initialize_tokens) == len(
placeholder_tokens
), "placeholder_token should be the same length as initialize_token"
for ii in range(len(placeholder_tokens)):
tokenizer.add_placeholder_token(placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token)
# text_encoder.set_embedding_layer()
embedding_layer = text_encoder.text_model.embeddings.token_embedding
text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(embedding_layer)
embedding_layer = text_encoder.text_model.embeddings.token_embedding
assert embedding_layer is not None, (
"Do not support get embedding layer for current text encoder. " "Please check your configuration."
)
initialize_embedding = []
if initialize_tokens is not None:
for ii in range(len(placeholder_tokens)):
init_id = tokenizer(initialize_tokens[ii]).input_ids[1]
temp_embedding = embedding_layer.weight[init_id]
initialize_embedding.append(temp_embedding[None, ...].repeat(num_vectors_per_token, 1))
else:
for ii in range(len(placeholder_tokens)):
init_id = tokenizer("a").input_ids[1]
temp_embedding = embedding_layer.weight[init_id]
len_emb = temp_embedding.shape[0]
init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0
initialize_embedding.append(init_weight)
# initialize_embedding = torch.cat(initialize_embedding,dim=0)
token_info_all = []
for ii in range(len(placeholder_tokens)):
token_info = tokenizer.get_token_info(placeholder_tokens[ii])
token_info["embedding"] = initialize_embedding[ii]
token_info["trainable"] = True
token_info_all.append(token_info)
embedding_layer.add_embeddings(token_info_all)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,2 @@
#credit to city96 for this module
#from https://github.com/city96/ComfyUI_ExtraModels/

View File

@@ -0,0 +1,120 @@
"""
List of all DiT model types / settings
"""
sampling_settings = {
"beta_schedule" : "sqrt_linear",
"linear_start" : 0.0001,
"linear_end" : 0.02,
"timesteps" : 1000,
}
dit_conf = {
"XL/2": { # DiT_XL_2
"unet_config": {
"depth" : 28,
"num_heads" : 16,
"patch_size" : 2,
"hidden_size" : 1152,
},
"sampling_settings" : sampling_settings,
},
"XL/4": { # DiT_XL_4
"unet_config": {
"depth" : 28,
"num_heads" : 16,
"patch_size" : 4,
"hidden_size" : 1152,
},
"sampling_settings" : sampling_settings,
},
"XL/8": { # DiT_XL_8
"unet_config": {
"depth" : 28,
"num_heads" : 16,
"patch_size" : 8,
"hidden_size" : 1152,
},
"sampling_settings" : sampling_settings,
},
"L/2": { # DiT_L_2
"unet_config": {
"depth" : 24,
"num_heads" : 16,
"patch_size" : 2,
"hidden_size" : 1024,
},
"sampling_settings" : sampling_settings,
},
"L/4": { # DiT_L_4
"unet_config": {
"depth" : 24,
"num_heads" : 16,
"patch_size" : 4,
"hidden_size" : 1024,
},
"sampling_settings" : sampling_settings,
},
"L/8": { # DiT_L_8
"unet_config": {
"depth" : 24,
"num_heads" : 16,
"patch_size" : 8,
"hidden_size" : 1024,
},
"sampling_settings" : sampling_settings,
},
"B/2": { # DiT_B_2
"unet_config": {
"depth" : 12,
"num_heads" : 12,
"patch_size" : 2,
"hidden_size" : 768,
},
"sampling_settings" : sampling_settings,
},
"B/4": { # DiT_B_4
"unet_config": {
"depth" : 12,
"num_heads" : 12,
"patch_size" : 4,
"hidden_size" : 768,
},
"sampling_settings" : sampling_settings,
},
"B/8": { # DiT_B_8
"unet_config": {
"depth" : 12,
"num_heads" : 12,
"patch_size" : 8,
"hidden_size" : 768,
},
"sampling_settings" : sampling_settings,
},
"S/2": { # DiT_S_2
"unet_config": {
"depth" : 12,
"num_heads" : 6,
"patch_size" : 2,
"hidden_size" : 384,
},
"sampling_settings" : sampling_settings,
},
"S/4": { # DiT_S_4
"unet_config": {
"depth" : 12,
"num_heads" : 6,
"patch_size" : 4,
"hidden_size" : 384,
},
"sampling_settings" : sampling_settings,
},
"S/8": { # DiT_S_8
"unet_config": {
"depth" : 12,
"num_heads" : 6,
"patch_size" : 8,
"hidden_size" : 384,
},
"sampling_settings" : sampling_settings,
},
}

View File

@@ -0,0 +1,661 @@
GNU AFFERO GENERAL PUBLIC LICENSE
Version 3, 19 November 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU Affero General Public License is a free, copyleft license for
software and other kinds of works, specifically designed to ensure
cooperation with the community in the case of network server software.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
our General Public Licenses are intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
Developers that use our General Public Licenses protect your rights
with two steps: (1) assert copyright on the software, and (2) offer
you this License which gives you legal permission to copy, distribute
and/or modify the software.
A secondary benefit of defending all users' freedom is that
improvements made in alternate versions of the program, if they
receive widespread use, become available for other developers to
incorporate. Many developers of free software are heartened and
encouraged by the resulting cooperation. However, in the case of
software used on network servers, this result may fail to come about.
The GNU General Public License permits making a modified version and
letting the public access it on a server without ever releasing its
source code to the public.
The GNU Affero General Public License is designed specifically to
ensure that, in such cases, the modified source code becomes available
to the community. It requires the operator of a network server to
provide the source code of the modified version running there to the
users of that server. Therefore, public use of a modified version, on
a publicly accessible server, gives the public access to the source
code of the modified version.
An older license, called the Affero General Public License and
published by Affero, was designed to accomplish similar goals. This is
a different license, not a version of the Affero GPL, but Affero has
released a new version of the Affero GPL which permits relicensing under
this license.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU Affero General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Remote Network Interaction; Use with the GNU General Public License.
Notwithstanding any other provision of this License, if you modify the
Program, your modified version must prominently offer all users
interacting with it remotely through a computer network (if your version
supports such interaction) an opportunity to receive the Corresponding
Source of your version by providing access to the Corresponding Source
from a network server at no charge, through some standard or customary
means of facilitating copying of software. This Corresponding Source
shall include the Corresponding Source for any work covered by version 3
of the GNU General Public License that is incorporated pursuant to the
following paragraph.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the work with which it is combined will remain governed by version
3 of the GNU General Public License.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU Affero General Public License from time to time. Such new versions
will be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU Affero General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU Affero General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU Affero General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If your software can interact with users remotely through a computer
network, you should also make sure that it provides a way for users to
get its source. For example, if your program is a web application, its
interface could display a "Source" link that leads users to an archive
of the code. There are many ways you could offer source, and different
solutions will be better for different programs; see section 13 for the
specific requirements.
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU AGPL, see
<https://www.gnu.org/licenses/>.

View File

@@ -0,0 +1,139 @@
"""
List of all PixArt model types / settings
"""
sampling_settings = {
"beta_schedule" : "sqrt_linear",
"linear_start" : 0.0001,
"linear_end" : 0.02,
"timesteps" : 1000,
}
pixart_conf = {
"PixArtMS_XL_2": { # models/PixArtMS
"target": "PixArtMS",
"unet_config": {
"input_size" : 1024//8,
"depth" : 28,
"num_heads" : 16,
"patch_size" : 2,
"hidden_size" : 1152,
"pe_interpolation": 2,
},
"sampling_settings" : sampling_settings,
},
"PixArtMS_Sigma_XL_2": {
"target": "PixArtMSSigma",
"unet_config": {
"input_size" : 1024//8,
"token_num" : 300,
"depth" : 28,
"num_heads" : 16,
"patch_size" : 2,
"hidden_size" : 1152,
"micro_condition": False,
"pe_interpolation": 2,
"model_max_length": 300,
},
"sampling_settings" : sampling_settings,
},
"PixArtMS_Sigma_XL_2_900M": {
"target": "PixArtMSSigma",
"unet_config": {
"input_size": 1024 // 8,
"token_num": 300,
"depth": 42,
"num_heads": 16,
"patch_size": 2,
"hidden_size": 1152,
"micro_condition": False,
"pe_interpolation": 2,
"model_max_length": 300,
},
"sampling_settings": sampling_settings,
},
"PixArtMS_Sigma_XL_2_2K": {
"target": "PixArtMSSigma",
"unet_config": {
"input_size" : 2048//8,
"token_num" : 300,
"depth" : 28,
"num_heads" : 16,
"patch_size" : 2,
"hidden_size" : 1152,
"micro_condition": False,
"pe_interpolation": 4,
"model_max_length": 300,
},
"sampling_settings" : sampling_settings,
},
"PixArt_XL_2": { # models/PixArt
"target": "PixArt",
"unet_config": {
"input_size" : 512//8,
"token_num" : 120,
"depth" : 28,
"num_heads" : 16,
"patch_size" : 2,
"hidden_size" : 1152,
"pe_interpolation": 1,
},
"sampling_settings" : sampling_settings,
},
}
pixart_conf.update({ # controlnet models
"ControlPixArtHalf": {
"target": "ControlPixArtHalf",
"unet_config": pixart_conf["PixArt_XL_2"]["unet_config"],
"sampling_settings": pixart_conf["PixArt_XL_2"]["sampling_settings"],
},
"ControlPixArtMSHalf": {
"target": "ControlPixArtMSHalf",
"unet_config": pixart_conf["PixArtMS_XL_2"]["unet_config"],
"sampling_settings": pixart_conf["PixArtMS_XL_2"]["sampling_settings"],
}
})
pixart_res = {
"PixArtMS_XL_2": { # models/PixArtMS 1024x1024
'0.25': [512, 2048], '0.26': [512, 1984], '0.27': [512, 1920], '0.28': [512, 1856],
'0.32': [576, 1792], '0.33': [576, 1728], '0.35': [576, 1664], '0.40': [640, 1600],
'0.42': [640, 1536], '0.48': [704, 1472], '0.50': [704, 1408], '0.52': [704, 1344],
'0.57': [768, 1344], '0.60': [768, 1280], '0.68': [832, 1216], '0.72': [832, 1152],
'0.78': [896, 1152], '0.82': [896, 1088], '0.88': [960, 1088], '0.94': [960, 1024],
'1.00': [1024,1024], '1.07': [1024, 960], '1.13': [1088, 960], '1.21': [1088, 896],
'1.29': [1152, 896], '1.38': [1152, 832], '1.46': [1216, 832], '1.67': [1280, 768],
'1.75': [1344, 768], '2.00': [1408, 704], '2.09': [1472, 704], '2.40': [1536, 640],
'2.50': [1600, 640], '2.89': [1664, 576], '3.00': [1728, 576], '3.11': [1792, 576],
'3.62': [1856, 512], '3.75': [1920, 512], '3.88': [1984, 512], '4.00': [2048, 512],
},
"PixArt_XL_2": { # models/PixArt 512x512
'0.25': [256,1024], '0.26': [256, 992], '0.27': [256, 960], '0.28': [256, 928],
'0.32': [288, 896], '0.33': [288, 864], '0.35': [288, 832], '0.40': [320, 800],
'0.42': [320, 768], '0.48': [352, 736], '0.50': [352, 704], '0.52': [352, 672],
'0.57': [384, 672], '0.60': [384, 640], '0.68': [416, 608], '0.72': [416, 576],
'0.78': [448, 576], '0.82': [448, 544], '0.88': [480, 544], '0.94': [480, 512],
'1.00': [512, 512], '1.07': [512, 480], '1.13': [544, 480], '1.21': [544, 448],
'1.29': [576, 448], '1.38': [576, 416], '1.46': [608, 416], '1.67': [640, 384],
'1.75': [672, 384], '2.00': [704, 352], '2.09': [736, 352], '2.40': [768, 320],
'2.50': [800, 320], '2.89': [832, 288], '3.00': [864, 288], '3.11': [896, 288],
'3.62': [928, 256], '3.75': [960, 256], '3.88': [992, 256], '4.00': [1024,256]
},
"PixArtMS_Sigma_XL_2_2K": {
'0.25': [1024, 4096], '0.26': [1024, 3968], '0.27': [1024, 3840], '0.28': [1024, 3712],
'0.32': [1152, 3584], '0.33': [1152, 3456], '0.35': [1152, 3328], '0.40': [1280, 3200],
'0.42': [1280, 3072], '0.48': [1408, 2944], '0.50': [1408, 2816], '0.52': [1408, 2688],
'0.57': [1536, 2688], '0.60': [1536, 2560], '0.68': [1664, 2432], '0.72': [1664, 2304],
'0.78': [1792, 2304], '0.82': [1792, 2176], '0.88': [1920, 2176], '0.94': [1920, 2048],
'1.00': [2048, 2048], '1.07': [2048, 1920], '1.13': [2176, 1920], '1.21': [2176, 1792],
'1.29': [2304, 1792], '1.38': [2304, 1664], '1.46': [2432, 1664], '1.67': [2560, 1536],
'1.75': [2688, 1536], '2.00': [2816, 1408], '2.09': [2944, 1408], '2.40': [3072, 1280],
'2.50': [3200, 1280], '2.89': [3328, 1152], '3.00': [3456, 1152], '3.11': [3584, 1152],
'3.62': [3712, 1024], '3.75': [3840, 1024], '3.88': [3968, 1024], '4.00': [4096, 1024]
}
}
# These should be the same
pixart_res.update({
"PixArtMS_Sigma_XL_2": pixart_res["PixArtMS_XL_2"],
"PixArtMS_Sigma_XL_2_512": pixart_res["PixArt_XL_2"],
})

View File

@@ -0,0 +1,216 @@
# For using the diffusers format weights
# Based on the original ComfyUI function +
# https://github.com/PixArt-alpha/PixArt-alpha/blob/master/tools/convert_pixart_alpha_to_diffusers.py
import torch
conversion_map_ms = [ # for multi_scale_train (MS)
# Resolution
("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"),
("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"),
("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"),
("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"),
# Aspect ratio
("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"),
("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"),
("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"),
("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"),
]
def get_depth(state_dict):
return sum(key.endswith('.attn1.to_k.bias') for key in state_dict.keys())
def get_lora_depth(state_dict):
return sum(key.endswith('.attn1.to_k.lora_A.weight') for key in state_dict.keys())
def get_conversion_map(state_dict):
conversion_map = [ # main SD conversion map (PixArt reference, HF Diffusers)
# Patch embeddings
("x_embedder.proj.weight", "pos_embed.proj.weight"),
("x_embedder.proj.bias", "pos_embed.proj.bias"),
# Caption projection
("y_embedder.y_embedding", "caption_projection.y_embedding"),
("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"),
("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"),
("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"),
("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"),
# AdaLN-single LN
("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"),
("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"),
("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"),
("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"),
# Shared norm
("t_block.1.weight", "adaln_single.linear.weight"),
("t_block.1.bias", "adaln_single.linear.bias"),
# Final block
("final_layer.linear.weight", "proj_out.weight"),
("final_layer.linear.bias", "proj_out.bias"),
("final_layer.scale_shift_table", "scale_shift_table"),
]
# Add actual transformer blocks
for depth in range(get_depth(state_dict)):
# Transformer blocks
conversion_map += [
(f"blocks.{depth}.scale_shift_table", f"transformer_blocks.{depth}.scale_shift_table"),
# Projection
(f"blocks.{depth}.attn.proj.weight", f"transformer_blocks.{depth}.attn1.to_out.0.weight"),
(f"blocks.{depth}.attn.proj.bias", f"transformer_blocks.{depth}.attn1.to_out.0.bias"),
# Feed-forward
(f"blocks.{depth}.mlp.fc1.weight", f"transformer_blocks.{depth}.ff.net.0.proj.weight"),
(f"blocks.{depth}.mlp.fc1.bias", f"transformer_blocks.{depth}.ff.net.0.proj.bias"),
(f"blocks.{depth}.mlp.fc2.weight", f"transformer_blocks.{depth}.ff.net.2.weight"),
(f"blocks.{depth}.mlp.fc2.bias", f"transformer_blocks.{depth}.ff.net.2.bias"),
# Cross-attention (proj)
(f"blocks.{depth}.cross_attn.proj.weight", f"transformer_blocks.{depth}.attn2.to_out.0.weight"),
(f"blocks.{depth}.cross_attn.proj.bias", f"transformer_blocks.{depth}.attn2.to_out.0.bias"),
]
return conversion_map
def find_prefix(state_dict, target_key):
prefix = ""
for k in state_dict.keys():
if k.endswith(target_key):
prefix = k.split(target_key)[0]
break
return prefix
def convert_state_dict(state_dict):
if "adaln_single.emb.resolution_embedder.linear_1.weight" in state_dict.keys():
cmap = get_conversion_map(state_dict) + conversion_map_ms
else:
cmap = get_conversion_map(state_dict)
missing = [k for k, v in cmap if v not in state_dict]
new_state_dict = {k: state_dict[v] for k, v in cmap if k not in missing}
matched = list(v for k, v in cmap if v in state_dict.keys())
for depth in range(get_depth(state_dict)):
for wb in ["weight", "bias"]:
# Self Attention
key = lambda a: f"transformer_blocks.{depth}.attn1.to_{a}.{wb}"
new_state_dict[f"blocks.{depth}.attn.qkv.{wb}"] = torch.cat((
state_dict[key('q')], state_dict[key('k')], state_dict[key('v')]
), dim=0)
matched += [key('q'), key('k'), key('v')]
# Cross-attention (linear)
key = lambda a: f"transformer_blocks.{depth}.attn2.to_{a}.{wb}"
new_state_dict[f"blocks.{depth}.cross_attn.q_linear.{wb}"] = state_dict[key('q')]
new_state_dict[f"blocks.{depth}.cross_attn.kv_linear.{wb}"] = torch.cat((
state_dict[key('k')], state_dict[key('v')]
), dim=0)
matched += [key('q'), key('k'), key('v')]
if len(matched) < len(state_dict):
print(f"PixArt: UNET conversion has leftover keys! ({len(matched)} vs {len(state_dict)})")
print(list(set(state_dict.keys()) - set(matched)))
if len(missing) > 0:
print(f"PixArt: UNET conversion has missing keys!")
print(missing)
return new_state_dict
# Same as above but for LoRA weights:
def convert_lora_state_dict(state_dict, peft=True):
# koyha
rep_ak = lambda x: x.replace(".weight", ".lora_down.weight")
rep_bk = lambda x: x.replace(".weight", ".lora_up.weight")
rep_pk = lambda x: x.replace(".weight", ".alpha")
if peft: # peft
rep_ap = lambda x: x.replace(".weight", ".lora_A.weight")
rep_bp = lambda x: x.replace(".weight", ".lora_B.weight")
rep_pp = lambda x: x.replace(".weight", ".alpha")
prefix = find_prefix(state_dict, "adaln_single.linear.lora_A.weight")
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
else: # OneTrainer
rep_ap = lambda x: x.replace(".", "_")[:-7] + ".lora_down.weight"
rep_bp = lambda x: x.replace(".", "_")[:-7] + ".lora_up.weight"
rep_pp = lambda x: x.replace(".", "_")[:-7] + ".alpha"
prefix = "lora_transformer_"
t5_marker = "lora_te_encoder"
t5_keys = []
for key in list(state_dict.keys()):
if key.startswith(prefix):
state_dict[key[len(prefix):]] = state_dict.pop(key)
elif t5_marker in key:
t5_keys.append(state_dict.pop(key))
if len(t5_keys) > 0:
print(f"Text Encoder not supported for PixArt LoRA, ignoring {len(t5_keys)} keys")
cmap = []
cmap_unet = get_conversion_map(state_dict) + conversion_map_ms # todo: 512 model
for k, v in cmap_unet:
if v.endswith(".weight"):
cmap.append((rep_ak(k), rep_ap(v)))
cmap.append((rep_bk(k), rep_bp(v)))
if not peft:
cmap.append((rep_pk(k), rep_pp(v)))
missing = [k for k, v in cmap if v not in state_dict]
new_state_dict = {k: state_dict[v] for k, v in cmap if k not in missing}
matched = list(v for k, v in cmap if v in state_dict.keys())
lora_depth = get_lora_depth(state_dict)
for fp, fk in ((rep_ap, rep_ak), (rep_bp, rep_bk)):
for depth in range(lora_depth):
# Self Attention
key = lambda a: fp(f"transformer_blocks.{depth}.attn1.to_{a}.weight")
new_state_dict[fk(f"blocks.{depth}.attn.qkv.weight")] = torch.cat((
state_dict[key('q')], state_dict[key('k')], state_dict[key('v')]
), dim=0)
matched += [key('q'), key('k'), key('v')]
if not peft:
akey = lambda a: rep_pp(f"transformer_blocks.{depth}.attn1.to_{a}.weight")
new_state_dict[rep_pk((f"blocks.{depth}.attn.qkv.weight"))] = state_dict[akey("q")]
matched += [akey('q'), akey('k'), akey('v')]
# Self Attention projection?
key = lambda a: fp(f"transformer_blocks.{depth}.attn1.to_{a}.weight")
new_state_dict[fk(f"blocks.{depth}.attn.proj.weight")] = state_dict[key('out.0')]
matched += [key('out.0')]
# Cross-attention (linear)
key = lambda a: fp(f"transformer_blocks.{depth}.attn2.to_{a}.weight")
new_state_dict[fk(f"blocks.{depth}.cross_attn.q_linear.weight")] = state_dict[key('q')]
new_state_dict[fk(f"blocks.{depth}.cross_attn.kv_linear.weight")] = torch.cat((
state_dict[key('k')], state_dict[key('v')]
), dim=0)
matched += [key('q'), key('k'), key('v')]
if not peft:
akey = lambda a: rep_pp(f"transformer_blocks.{depth}.attn2.to_{a}.weight")
new_state_dict[rep_pk((f"blocks.{depth}.cross_attn.q_linear.weight"))] = state_dict[akey("q")]
new_state_dict[rep_pk((f"blocks.{depth}.cross_attn.kv_linear.weight"))] = state_dict[akey("k")]
matched += [akey('q'), akey('k'), akey('v')]
# Cross Attention projection?
key = lambda a: fp(f"transformer_blocks.{depth}.attn2.to_{a}.weight")
new_state_dict[fk(f"blocks.{depth}.cross_attn.proj.weight")] = state_dict[key('out.0')]
matched += [key('out.0')]
key = fp(f"transformer_blocks.{depth}.ff.net.0.proj.weight")
new_state_dict[fk(f"blocks.{depth}.mlp.fc1.weight")] = state_dict[key]
matched += [key]
key = fp(f"transformer_blocks.{depth}.ff.net.2.weight")
new_state_dict[fk(f"blocks.{depth}.mlp.fc2.weight")] = state_dict[key]
matched += [key]
if len(matched) < len(state_dict):
print(f"PixArt: LoRA conversion has leftover keys! ({len(matched)} vs {len(state_dict)})")
print(list(set(state_dict.keys()) - set(matched)))
if len(missing) > 0:
print(f"PixArt: LoRA conversion has missing keys! (probably)")
print(missing)
return new_state_dict

View File

@@ -0,0 +1,331 @@
import os
import json
import copy
import torch
import math
import comfy.supported_models_base
import comfy.latent_formats
import comfy.model_patcher
import comfy.model_base
import comfy.utils
import comfy.conds
from comfy import model_management
from .diffusers_convert import convert_state_dict, convert_lora_state_dict
# checkpointbf
class EXM_PixArt(comfy.supported_models_base.BASE):
unet_config = {}
unet_extra_config = {}
latent_format = comfy.latent_formats.SD15
def __init__(self, model_conf):
self.model_target = model_conf.get("target")
self.unet_config = model_conf.get("unet_config", {})
self.sampling_settings = model_conf.get("sampling_settings", {})
self.latent_format = self.latent_format()
# UNET is handled by extension
self.unet_config["disable_unet_model_creation"] = True
def model_type(self, state_dict, prefix=""):
return comfy.model_base.ModelType.EPS
class EXM_PixArt_Model(comfy.model_base.BaseModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
img_hw = kwargs.get("img_hw", None)
if img_hw is not None:
out["img_hw"] = comfy.conds.CONDRegular(torch.tensor(img_hw))
aspect_ratio = kwargs.get("aspect_ratio", None)
if aspect_ratio is not None:
out["aspect_ratio"] = comfy.conds.CONDRegular(torch.tensor(aspect_ratio))
cn_hint = kwargs.get("cn_hint", None)
if cn_hint is not None:
out["cn_hint"] = comfy.conds.CONDRegular(cn_hint)
return out
def load_pixart(model_path, model_conf=None):
state_dict = comfy.utils.load_torch_file(model_path)
state_dict = state_dict.get("model", state_dict)
# prefix
for prefix in ["model.diffusion_model.", ]:
if any(True for x in state_dict if x.startswith(prefix)):
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
# diffusers
if "adaln_single.linear.weight" in state_dict:
state_dict = convert_state_dict(state_dict) # Diffusers
# guess auto config
if model_conf is None:
model_conf = guess_pixart_config(state_dict)
parameters = comfy.utils.calculate_parameters(state_dict)
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.unet_offload_device()
# ignore fp8/etc and use directly for now
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype:
print(f"PixArt: falling back to {manual_cast_dtype}")
unet_dtype = manual_cast_dtype
model_conf = EXM_PixArt(model_conf) # convert to object
model = EXM_PixArt_Model( # same as comfy.model_base.BaseModel
model_conf,
model_type=comfy.model_base.ModelType.EPS,
device=model_management.get_torch_device()
)
if model_conf.model_target == "PixArtMS":
from .models.PixArtMS import PixArtMS
model.diffusion_model = PixArtMS(**model_conf.unet_config)
elif model_conf.model_target == "PixArt":
from .models.PixArt import PixArt
model.diffusion_model = PixArt(**model_conf.unet_config)
elif model_conf.model_target == "PixArtMSSigma":
from .models.PixArtMS import PixArtMS
model.diffusion_model = PixArtMS(**model_conf.unet_config)
model.latent_format = comfy.latent_formats.SDXL()
elif model_conf.model_target == "ControlPixArtMSHalf":
from .models.PixArtMS import PixArtMS
from .models.pixart_controlnet import ControlPixArtMSHalf
model.diffusion_model = PixArtMS(**model_conf.unet_config)
model.diffusion_model = ControlPixArtMSHalf(model.diffusion_model)
elif model_conf.model_target == "ControlPixArtHalf":
from .models.PixArt import PixArt
from .models.pixart_controlnet import ControlPixArtHalf
model.diffusion_model = PixArt(**model_conf.unet_config)
model.diffusion_model = ControlPixArtHalf(model.diffusion_model)
else:
raise NotImplementedError(f"Unknown model target '{model_conf.model_target}'")
m, u = model.diffusion_model.load_state_dict(state_dict, strict=False)
if len(m) > 0: print("Missing UNET keys", m)
if len(u) > 0: print("Leftover UNET keys", u)
model.diffusion_model.dtype = unet_dtype
model.diffusion_model.eval()
model.diffusion_model.to(unet_dtype)
model_patcher = comfy.model_patcher.ModelPatcher(
model,
load_device=load_device,
offload_device=offload_device,
)
return model_patcher
def guess_pixart_config(sd):
"""
Guess config based on converted state dict.
"""
# Shared settings based on DiT_XL_2 - could be enumerated
config = {
"num_heads": 16, # get from attention
"patch_size": 2, # final layer I guess?
"hidden_size": 1152, # pos_embed.shape[2]
}
config["depth"] = sum([key.endswith(".attn.proj.weight") for key in sd.keys()]) or 28
try:
# this is not present in the diffusers version for sigma?
config["model_max_length"] = sd["y_embedder.y_embedding"].shape[0]
except KeyError:
# need better logic to guess this
config["model_max_length"] = 300
if "pos_embed" in sd:
config["input_size"] = int(math.sqrt(sd["pos_embed"].shape[1])) * config["patch_size"]
config["pe_interpolation"] = config["input_size"] // (512 // 8) # dumb guess
target_arch = "PixArtMS"
if config["model_max_length"] == 300:
# Sigma
target_arch = "PixArtMSSigma"
config["micro_condition"] = False
if "input_size" not in config:
# The diffusers weights for 1K/2K are exactly the same...?
# replace patch embed logic with HyDiT?
print(f"PixArt: diffusers weights - 2K model will be broken, use manual loading!")
config["input_size"] = 1024 // 8
else:
# Alpha
if "csize_embedder.mlp.0.weight" in sd:
# MS (microconds)
target_arch = "PixArtMS"
config["micro_condition"] = True
if "input_size" not in config:
config["input_size"] = 1024 // 8
config["pe_interpolation"] = 2
else:
# PixArt
target_arch = "PixArt"
if "input_size" not in config:
config["input_size"] = 512 // 8
config["pe_interpolation"] = 1
print("PixArt guessed config:", target_arch, config)
return {
"target": target_arch,
"unet_config": config,
"sampling_settings": {
"beta_schedule": "sqrt_linear",
"linear_start": 0.0001,
"linear_end": 0.02,
"timesteps": 1000,
}
}
# lora
class EXM_PixArt_ModelPatcher(comfy.model_patcher.ModelPatcher):
def calculate_weight(self, patches, weight, key):
"""
This is almost the same as the comfy function, but stripped down to just the LoRA patch code.
The problem with the original code is the q/k/v keys being combined into one for the attention.
In the diffusers code, they're treated as separate keys, but in the reference code they're recombined (q+kv|qkv).
This means, for example, that the [1152,1152] weights become [3456,1152] in the state dict.
The issue with this is that the LoRA weights are [128,1152],[1152,128] and become [384,1162],[3456,128] instead.
This is the best thing I could think of that would fix that, but it's very fragile.
- Check key shape to determine if it needs the fallback logic
- Cut the input into parts based on the shape (undoing the torch.cat)
- Do the matrix multiplication logic
- Recombine them to match the expected shape
"""
for p in patches:
alpha = p[0]
v = p[1]
strength_model = p[2]
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key),)
if len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "lora":
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
try:
mat1 = mat1.flatten(start_dim=1)
mat2 = mat2.flatten(start_dim=1)
ch1 = mat1.shape[0] // mat2.shape[1]
ch2 = mat2.shape[0] // mat1.shape[1]
### Fallback logic for shape mismatch ###
if mat1.shape[0] != mat2.shape[1] and ch1 == ch2 and (mat1.shape[0] / mat2.shape[1]) % 1 == 0:
mat1 = mat1.chunk(ch1, dim=0)
mat2 = mat2.chunk(ch1, dim=0)
weight += torch.cat(
[alpha * torch.mm(mat1[x], mat2[x]) for x in range(ch1)],
dim=0,
).reshape(weight.shape).type(weight.dtype)
else:
weight += (alpha * torch.mm(mat1, mat2)).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
return weight
def clone(self):
n = EXM_PixArt_ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device,
weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
return n
def replace_model_patcher(model):
n = EXM_PixArt_ModelPatcher(
model=model.model,
size=model.size,
load_device=model.load_device,
offload_device=model.offload_device,
current_device=model.current_device,
weight_inplace_update=model.weight_inplace_update,
)
n.patches = {}
for k in model.patches:
n.patches[k] = model.patches[k][:]
n.object_patches = model.object_patches.copy()
n.model_options = copy.deepcopy(model.model_options)
return n
def find_peft_alpha(path):
def load_json(json_path):
with open(json_path) as f:
data = json.load(f)
alpha = data.get("lora_alpha")
alpha = alpha or data.get("alpha")
if not alpha:
print(" Found config but `lora_alpha` is missing!")
else:
print(f" Found config at {json_path} [alpha:{alpha}]")
return alpha
# For some weird reason peft doesn't include the alpha in the actual model
print("PixArt: Warning! This is a PEFT LoRA. Trying to find config...")
files = [
f"{os.path.splitext(path)[0]}.json",
f"{os.path.splitext(path)[0]}.config.json",
os.path.join(os.path.dirname(path), "adapter_config.json"),
]
for file in files:
if os.path.isfile(file):
return load_json(file)
print(" Missing config/alpha! assuming alpha of 8. Consider converting it/adding a config json to it.")
return 8.0
def load_pixart_lora(model, lora, lora_path, strength):
k_back = lambda x: x.replace(".lora_up.weight", "")
# need to convert the actual weights for this to work.
if any(True for x in lora.keys() if x.endswith("adaln_single.linear.lora_A.weight")):
lora = convert_lora_state_dict(lora, peft=True)
alpha = find_peft_alpha(lora_path)
lora.update({f"{k_back(x)}.alpha": torch.tensor(alpha) for x in lora.keys() if "lora_up" in x})
else: # OneTrainer
lora = convert_lora_state_dict(lora, peft=False)
key_map = {k_back(x): f"diffusion_model.{k_back(x)}.weight" for x in lora.keys() if "lora_up" in x} # fake
loaded = comfy.lora.load_lora(lora, key_map)
if model is not None:
# switch to custom model patcher when using LoRAs
if isinstance(model, EXM_PixArt_ModelPatcher):
new_modelpatcher = model.clone()
else:
new_modelpatcher = replace_model_patcher(model)
k = new_modelpatcher.add_patches(loaded, strength)
else:
k = ()
new_modelpatcher = None
k = set(k)
for x in loaded:
if (x not in k):
print("NOT LOADED", x)
return new_modelpatcher

View File

@@ -0,0 +1,250 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
import math
import torch
import torch.nn as nn
import os
import numpy as np
from timm.models.layers import DropPath
from timm.models.vision_transformer import PatchEmbed, Mlp
from .utils import auto_grad_checkpoint, to_2tuple
from .PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, LabelEmbedder, FinalLayer
class PixArtBlock(nn.Module):
"""
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0, input_size=None, sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = AttentionKVCompress(
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
qk_norm=qk_norm, **block_kwargs
)
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
self.sampling = sampling
self.sr_ratio = sr_ratio
def forward(self, x, y, t, mask=None, **kwargs):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
x = x + self.cross_attn(x, y, mask)
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
### Core PixArt Model ###
class PixArt(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
pred_sigma=True,
drop_path: float = 0.,
caption_channels=4096,
pe_interpolation=1.0,
pe_precision=None,
config=None,
model_max_length=120,
qk_norm=False,
kv_compress_config=None,
**kwargs,
):
super().__init__()
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.pe_interpolation = pe_interpolation
self.pe_precision = pe_precision
self.depth = depth
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size // self.patch_size
# Will use fixed sin-cos embedding:
self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
act_layer=approx_gelu, token_num=model_max_length
)
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
self.kv_compress_config = kv_compress_config
if kv_compress_config is None:
self.kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList([
PixArtBlock(
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
input_size=(input_size // patch_size, input_size // patch_size),
sampling=self.kv_compress_config['sampling'],
sr_ratio=int(
self.kv_compress_config['scale_factor']
) if i in self.kv_compress_config['kv_compress_layer'] else 1,
qk_norm=qk_norm,
)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
def forward_raw(self, x, t, y, mask=None, data_info=None):
"""
Original forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = t.to(self.dtype)
y = y.to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, 1, L, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
for block in self.blocks:
x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, y=None, **kwargs):
"""
Forward pass that adapts comfy input to original forward function
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
timesteps: (N,) tensor of diffusion timesteps
context: (N, 1, 120, C) conditioning
y: extra conditioning.
"""
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_raw(
x = x.to(self.dtype),
t = timesteps.to(self.dtype),
y = context.to(self.dtype),
)
## only return EPS
out = out.to(torch.float)
eps, rest = out[:, :self.in_channels], out[:, self.in_channels:]
return eps
def unpatchify(self, x):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return imgs
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0, base_size=16):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_size = to_2tuple(grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0]/base_size) / pe_interpolation
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1]/base_size) / pe_interpolation
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed.astype(np.float32)
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb

View File

@@ -0,0 +1,273 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
import torch
import torch.nn as nn
from tqdm import tqdm
from timm.models.layers import DropPath
from timm.models.vision_transformer import Mlp
from .utils import auto_grad_checkpoint, to_2tuple
from .PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, SizeEmbedder
from .PixArt import PixArt, get_2d_sincos_pos_embed
class PatchEmbed(nn.Module):
"""
2D Image to Patch Embedding
"""
def __init__(
self,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class PixArtMSBlock(nn.Module):
"""
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = AttentionKVCompress(
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
qk_norm=qk_norm, **block_kwargs
)
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
x = x + self.cross_attn(x, y, mask)
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
### Core PixArt Model ###
class PixArtMS(PixArt):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
learn_sigma=True,
pred_sigma=True,
drop_path: float = 0.,
caption_channels=4096,
pe_interpolation=None,
pe_precision=None,
config=None,
model_max_length=120,
micro_condition=True,
qk_norm=False,
kv_compress_config=None,
**kwargs,
):
super().__init__(
input_size=input_size,
patch_size=patch_size,
in_channels=in_channels,
hidden_size=hidden_size,
depth=depth,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
class_dropout_prob=class_dropout_prob,
learn_sigma=learn_sigma,
pred_sigma=pred_sigma,
drop_path=drop_path,
pe_interpolation=pe_interpolation,
config=config,
model_max_length=model_max_length,
qk_norm=qk_norm,
kv_compress_config=kv_compress_config,
**kwargs,
)
self.dtype = torch.get_default_dtype()
self.h = self.w = 0
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.x_embedder = PatchEmbed(patch_size, in_channels, hidden_size, bias=True)
self.y_embedder = CaptionEmbedder(in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, act_layer=approx_gelu, token_num=model_max_length)
self.micro_conditioning = micro_condition
if self.micro_conditioning:
self.csize_embedder = SizeEmbedder(hidden_size//3) # c_size embed
self.ar_embedder = SizeEmbedder(hidden_size//3) # aspect ratio embed
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
if kv_compress_config is None:
kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList([
PixArtMSBlock(
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
input_size=(input_size // patch_size, input_size // patch_size),
sampling=kv_compress_config['sampling'],
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
qk_norm=qk_norm,
)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
def forward_raw(self, x, t, y, mask=None, data_info=None, **kwargs):
"""
Original forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
bs = x.shape[0]
x = x.to(self.dtype)
timestep = t.to(self.dtype)
y = y.to(self.dtype)
pe_interpolation = self.pe_interpolation
if pe_interpolation is None or self.pe_precision is not None:
# calculate pe_interpolation on-the-fly
pe_interpolation = round((x.shape[-1]+x.shape[-2])/2.0 / (512/8.0), self.pe_precision or 0)
self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
pos_embed = torch.from_numpy(
get_2d_sincos_pos_embed(
self.pos_embed.shape[-1], (self.h, self.w), pe_interpolation=pe_interpolation,
base_size=self.base_size
)
).unsqueeze(0).to(device=x.device, dtype=self.dtype)
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep) # (N, D)
if self.micro_conditioning:
c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype)
csize = self.csize_embedder(c_size, bs) # (N, D)
ar = self.ar_embedder(ar, bs) # (N, D)
t = t + torch.cat([csize, ar], dim=1)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
for block in self.blocks:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, (self.h, self.w), **kwargs) # (N, T, D) #support grad checkpoint
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, img_hw=None, aspect_ratio=None, **kwargs):
"""
Forward pass that adapts comfy input to original forward function
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
timesteps: (N,) tensor of diffusion timesteps
context: (N, 1, 120, C) conditioning
img_hw: height|width conditioning
aspect_ratio: aspect ratio conditioning
"""
## size/ar from cond with fallback based on the latent image shape.
bs = x.shape[0]
data_info = {}
if img_hw is None:
data_info["img_hw"] = torch.tensor(
[[x.shape[2]*8, x.shape[3]*8]],
dtype=self.dtype,
device=x.device
).repeat(bs, 1)
else:
data_info["img_hw"] = img_hw.to(dtype=x.dtype, device=x.device)
if aspect_ratio is None or True:
data_info["aspect_ratio"] = torch.tensor(
[[x.shape[2]/x.shape[3]]],
dtype=self.dtype,
device=x.device
).repeat(bs, 1)
else:
data_info["aspect_ratio"] = aspect_ratio.to(dtype=x.dtype, device=x.device)
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_raw(
x = x.to(self.dtype),
t = timesteps.to(self.dtype),
y = context.to(self.dtype),
data_info=data_info,
)
## only return EPS
out = out.to(torch.float)
eps, rest = out[:, :self.in_channels], out[:, self.in_channels:]
return eps
def unpatchify(self, x):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
assert self.h * self.w == x.shape[1]
x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p))
return imgs

View File

@@ -0,0 +1,477 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.vision_transformer import Mlp, Attention as Attention_
from einops import rearrange
from comfy import model_management
if model_management.xformers_enabled():
import xformers
import xformers.ops
else:
print("""
########################################
PixArt: Not using xformers!
Expect images to be non-deterministic!
Batch sizes > 1 are most likely broken
########################################
""")
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def t2i_modulate(x, shift, scale):
return x * (1 + scale) + shift
class MultiHeadCrossAttention(nn.Module):
def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., **block_kwargs):
super(MultiHeadCrossAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.kv_linear = nn.Linear(d_model, d_model*2)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(d_model, d_model)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, cond, mask=None):
# query/value: img tokens; key: condition; mask: if padding tokens
B, N, C = x.shape
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
if model_management.xformers_enabled():
attn_bias = None
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(
q, k, v,
p=self.attn_drop.p,
attn_bias=attn_bias
)
else:
q, k, v = map(lambda t: t.permute(0, 2, 1, 3),(q, k, v),)
attn_mask = None
if mask is not None and len(mask) > 1:
# Create equivalent of xformer diagonal block mask, still only correct for square masks
# But depth doesn't matter as tensors can expand in that dimension
attn_mask_template = torch.ones(
[q.shape[2] // B, mask[0]],
dtype=torch.bool,
device=q.device
)
attn_mask = torch.block_diag(attn_mask_template)
# create a mask on the diagonal for each mask in the batch
for n in range(B - 1):
attn_mask = torch.block_diag(attn_mask, attn_mask_template)
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p
).permute(0, 2, 1, 3).contiguous()
x = x.view(B, -1, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentionKVCompress(Attention_):
"""Multi-head Attention block with KV token compression and qk norm."""
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
sampling='conv',
sr_ratio=1,
qk_norm=False,
**block_kwargs,
):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool: If True, add a learnable bias to query, key, value.
"""
super().__init__(dim, num_heads=num_heads, qkv_bias=qkv_bias, **block_kwargs)
self.sampling=sampling # ['conv', 'ave', 'uniform', 'uniform_every']
self.sr_ratio = sr_ratio
if sr_ratio > 1 and sampling == 'conv':
# Avg Conv Init.
self.sr = nn.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio)
self.sr.weight.data.fill_(1/sr_ratio**2)
self.sr.bias.data.zero_()
self.norm = nn.LayerNorm(dim)
if qk_norm:
self.q_norm = nn.LayerNorm(dim)
self.k_norm = nn.LayerNorm(dim)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
def downsample_2d(self, tensor, H, W, scale_factor, sampling=None):
if sampling is None or scale_factor == 1:
return tensor
B, N, C = tensor.shape
if sampling == 'uniform_every':
return tensor[:, ::scale_factor], int(N // scale_factor)
tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2)
new_H, new_W = int(H / scale_factor), int(W / scale_factor)
new_N = new_H * new_W
if sampling == 'ave':
tensor = F.interpolate(
tensor, scale_factor=1 / scale_factor, mode='nearest'
).permute(0, 2, 3, 1)
elif sampling == 'uniform':
tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1)
elif sampling == 'conv':
tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1)
tensor = self.norm(tensor)
else:
raise ValueError
return tensor.reshape(B, new_N, C).contiguous(), new_N
def forward(self, x, mask=None, HW=None, block_id=None):
B, N, C = x.shape # 2 4096 1152
new_N = N
if HW is None:
H = W = int(N ** 0.5)
else:
H, W = HW
qkv = self.qkv(x).reshape(B, N, 3, C)
q, k, v = qkv.unbind(2)
dtype = q.dtype
q = self.q_norm(q)
k = self.k_norm(k)
# KV compression
if self.sr_ratio > 1:
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
q = q.reshape(B, N, self.num_heads, C // self.num_heads).to(dtype)
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
attn_bias = None
if mask is not None:
attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
# Switch between torch / xformers attention
if model_management.xformers_enabled():
x = xformers.ops.memory_efficient_attention(
q, k, v,
p=self.attn_drop.p,
attn_bias=attn_bias
)
else:
q, k, v = map(lambda t: t.transpose(1, 2),(q, k, v),)
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
attn_mask=attn_bias
).transpose(1, 2).contiguous()
x = x.view(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
#################################################################################
# AMP attention with fp32 softmax to fix loss NaN problem during training #
#################################################################################
class Attention(Attention_):
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
use_fp32_attention = getattr(self, 'fp32_attention', False)
if use_fp32_attention:
q, k = q.float(), k.float()
with torch.cuda.amp.autocast(enabled=not use_fp32_attention):
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class T2IFinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
self.out_channels = out_channels
def forward(self, x, t):
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class MaskFinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True)
)
def forward(self, x, t):
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class DecoderLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, decoder_hidden_size):
super().__init__()
self.norm_decoder = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, decoder_hidden_size, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x, t):
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
x = modulate(self.norm_decoder(x), shift, scale)
x = self.linear(x)
return x
#################################################################################
# Embedding Layers for Timesteps and Class Labels #
#################################################################################
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq.to(t.dtype))
return t_emb
class SizeEmbedder(TimestepEmbedder):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
self.outdim = hidden_size
def forward(self, s, bs):
if s.ndim == 1:
s = s[:, None]
assert s.ndim == 2
if s.shape[0] != bs:
s = s.repeat(bs//s.shape[0], 1)
assert s.shape[0] == bs
b, dims = s.shape[0], s.shape[1]
s = rearrange(s, "b d -> (b d)")
s_freq = self.timestep_embedding(s, self.frequency_embedding_size)
s_emb = self.mlp(s_freq.to(s.dtype))
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
return s_emb
class LabelEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, num_classes, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
else:
drop_ids = force_drop_ids == 1
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings
class CaptionEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120):
super().__init__()
self.y_proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0)
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
self.uncond_prob = uncond_prob
def token_drop(self, caption, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
else:
drop_ids = force_drop_ids == 1
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
return caption
def forward(self, caption, train, force_drop_ids=None):
if train:
assert caption.shape[2:] == self.y_embedding.shape
use_dropout = self.uncond_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
caption = self.token_drop(caption, force_drop_ids)
caption = self.y_proj(caption)
return caption
class CaptionEmbedderDoubleBr(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120):
super().__init__()
self.proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0)
self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5)
self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5)
self.uncond_prob = uncond_prob
def token_drop(self, global_caption, caption, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob
else:
drop_ids = force_drop_ids == 1
global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption)
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
return global_caption, caption
def forward(self, caption, train, force_drop_ids=None):
assert caption.shape[2: ] == self.y_embedding.shape
global_caption = caption.mean(dim=2).squeeze()
use_dropout = self.uncond_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids)
y_embed = self.proj(global_caption)
return y_embed, caption

View File

@@ -0,0 +1,312 @@
import re
import torch
import torch.nn as nn
from copy import deepcopy
from torch import Tensor
from torch.nn import Module, Linear, init
from typing import Any, Mapping
from .PixArt import PixArt, get_2d_sincos_pos_embed
from .PixArtMS import PixArtMSBlock, PixArtMS
from .utils import auto_grad_checkpoint
# The implementation of ControlNet-Half architrecture
# https://github.com/lllyasviel/ControlNet/discussions/188
class ControlT2IDitBlockHalf(Module):
def __init__(self, base_block: PixArtMSBlock, block_index: 0) -> None:
super().__init__()
self.copied_block = deepcopy(base_block)
self.block_index = block_index
for p in self.copied_block.parameters():
p.requires_grad_(True)
self.copied_block.load_state_dict(base_block.state_dict())
self.copied_block.train()
self.hidden_size = hidden_size = base_block.hidden_size
if self.block_index == 0:
self.before_proj = Linear(hidden_size, hidden_size)
init.zeros_(self.before_proj.weight)
init.zeros_(self.before_proj.bias)
self.after_proj = Linear(hidden_size, hidden_size)
init.zeros_(self.after_proj.weight)
init.zeros_(self.after_proj.bias)
def forward(self, x, y, t, mask=None, c=None):
if self.block_index == 0:
# the first block
c = self.before_proj(c)
c = self.copied_block(x + c, y, t, mask)
c_skip = self.after_proj(c)
else:
# load from previous c and produce the c for skip connection
c = self.copied_block(c, y, t, mask)
c_skip = self.after_proj(c)
return c, c_skip
# The implementation of ControlPixArtHalf net
class ControlPixArtHalf(Module):
# only support single res model
def __init__(self, base_model: PixArt, copy_blocks_num: int = 13) -> None:
super().__init__()
self.dtype = torch.get_default_dtype()
self.base_model = base_model.eval()
self.controlnet = []
self.copy_blocks_num = copy_blocks_num
self.total_blocks_num = len(base_model.blocks)
for p in self.base_model.parameters():
p.requires_grad_(False)
# Copy first copy_blocks_num block
for i in range(copy_blocks_num):
self.controlnet.append(ControlT2IDitBlockHalf(base_model.blocks[i], i))
self.controlnet = nn.ModuleList(self.controlnet)
def __getattr__(self, name: str) -> Tensor or Module:
if name in ['forward', 'forward_with_dpmsolver', 'forward_with_cfg', 'forward_c', 'load_state_dict']:
return self.__dict__[name]
elif name in ['base_model', 'controlnet']:
return super().__getattr__(name)
else:
return getattr(self.base_model, name)
def forward_c(self, c):
self.h, self.w = c.shape[-2]//self.patch_size, c.shape[-1]//self.patch_size
pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size)).unsqueeze(0).to(c.device).to(self.dtype)
return self.x_embedder(c) + pos_embed if c is not None else c
# def forward(self, x, t, c, **kwargs):
# return self.base_model(x, t, c=self.forward_c(c), **kwargs)
def forward_raw(self, x, timestep, y, mask=None, data_info=None, c=None, **kwargs):
# modify the original PixArtMS forward function
if c is not None:
c = c.to(self.dtype)
c = self.forward_c(c)
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, 1, L, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# define the first layer
x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, y_lens, **kwargs) # (N, T, D) #support grad checkpoint
if c is not None:
# update c
for index in range(1, self.copy_blocks_num + 1):
c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, **kwargs)
x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, **kwargs)
# update x
for index in range(self.copy_blocks_num + 1, self.total_blocks_num):
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs)
else:
for index in range(1, self.total_blocks_num):
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, cn_hint=None, **kwargs):
"""
Forward pass that adapts comfy input to original forward function
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
timesteps: (N,) tensor of diffusion timesteps
context: (N, 1, 120, C) conditioning
cn_hint: controlnet hint
"""
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_raw(
x = x.to(self.dtype),
timestep = timesteps.to(self.dtype),
y = context.to(self.dtype),
c = cn_hint,
)
## only return EPS
out = out.to(torch.float)
eps, rest = out[:, :self.in_channels], out[:, self.in_channels:]
return eps
def forward_with_dpmsolver(self, x, t, y, data_info, c, **kwargs):
model_out = self.forward_raw(x, t, y, data_info=data_info, c=c, **kwargs)
return model_out.chunk(2, dim=1)[0]
# def forward_with_dpmsolver(self, x, t, y, data_info, c, **kwargs):
# return self.base_model.forward_with_dpmsolver(x, t, y, data_info=data_info, c=self.forward_c(c), **kwargs)
def forward_with_cfg(self, x, t, y, cfg_scale, data_info, c, **kwargs):
return self.base_model.forward_with_cfg(x, t, y, cfg_scale, data_info, c=self.forward_c(c), **kwargs)
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
if all((k.startswith('base_model') or k.startswith('controlnet')) for k in state_dict.keys()):
return super().load_state_dict(state_dict, strict)
else:
new_key = {}
for k in state_dict.keys():
new_key[k] = re.sub(r"(blocks\.\d+)(.*)", r"\1.base_block\2", k)
for k, v in new_key.items():
if k != v:
print(f"replace {k} to {v}")
state_dict[v] = state_dict.pop(k)
return self.base_model.load_state_dict(state_dict, strict)
def unpatchify(self, x):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
assert self.h * self.w == x.shape[1]
x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p))
return imgs
# @property
# def dtype(self):
## 返回模型参数的数据类型
# return next(self.parameters()).dtype
# The implementation for PixArtMS_Half + 1024 resolution
class ControlPixArtMSHalf(ControlPixArtHalf):
# support multi-scale res model (multi-scale model can also be applied to single reso training & inference)
def __init__(self, base_model: PixArtMS, copy_blocks_num: int = 13) -> None:
super().__init__(base_model=base_model, copy_blocks_num=copy_blocks_num)
def forward_raw(self, x, timestep, y, mask=None, data_info=None, c=None, **kwargs):
# modify the original PixArtMS forward function
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
if c is not None:
c = c.to(self.dtype)
c = self.forward_c(c)
bs = x.shape[0]
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype)
self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size)).unsqueeze(0).to(x.device).to(self.dtype)
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep) # (N, D)
csize = self.csize_embedder(c_size, bs) # (N, D)
ar = self.ar_embedder(ar, bs) # (N, D)
t = t + torch.cat([csize, ar], dim=1)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# define the first layer
x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, y_lens, **kwargs) # (N, T, D) #support grad checkpoint
if c is not None:
# update c
for index in range(1, self.copy_blocks_num + 1):
c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, **kwargs)
x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, **kwargs)
# update x
for index in range(self.copy_blocks_num + 1, self.total_blocks_num):
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs)
else:
for index in range(1, self.total_blocks_num):
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, img_hw=None, aspect_ratio=None, cn_hint=None, **kwargs):
"""
Forward pass that adapts comfy input to original forward function
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
timesteps: (N,) tensor of diffusion timesteps
context: (N, 1, 120, C) conditioning
img_hw: height|width conditioning
aspect_ratio: aspect ratio conditioning
cn_hint: controlnet hint
"""
## size/ar from cond with fallback based on the latent image shape.
bs = x.shape[0]
data_info = {}
if img_hw is None:
data_info["img_hw"] = torch.tensor(
[[x.shape[2]*8, x.shape[3]*8]],
dtype=self.dtype,
device=x.device
).repeat(bs, 1)
else:
data_info["img_hw"] = img_hw.to(x.dtype)
if aspect_ratio is None or True:
data_info["aspect_ratio"] = torch.tensor(
[[x.shape[2]/x.shape[3]]],
dtype=self.dtype,
device=x.device
).repeat(bs, 1)
else:
data_info["aspect_ratio"] = aspect_ratio.to(x.dtype)
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_raw(
x = x.to(self.dtype),
timestep = timesteps.to(self.dtype),
y = context.to(self.dtype),
c = cn_hint,
data_info=data_info,
)
## only return EPS
out = out.to(torch.float)
eps, rest = out[:, :self.in_channels], out[:, self.in_channels:]
return eps

View File

@@ -0,0 +1,122 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from collections.abc import Iterable
from itertools import repeat
def _ntuple(n):
def parse(x):
if isinstance(x, Iterable) and not isinstance(x, str):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1):
assert isinstance(model, nn.Module)
def set_attr(module):
module.grad_checkpointing = True
module.fp32_attention = use_fp32_attention
module.grad_checkpointing_step = gc_step
model.apply(set_attr)
def auto_grad_checkpoint(module, *args, **kwargs):
if getattr(module, 'grad_checkpointing', False):
if isinstance(module, Iterable):
gc_step = module[0].grad_checkpointing_step
return checkpoint_sequential(module, gc_step, *args, **kwargs)
else:
return checkpoint(module, *args, **kwargs)
return module(*args, **kwargs)
def checkpoint_sequential(functions, step, input, *args, **kwargs):
# Hack for keyword-only parameter in a python 2.7-compliant way
preserve = kwargs.pop('preserve_rng_state', True)
if kwargs:
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
def run_function(start, end, functions):
def forward(input):
for j in range(start, end + 1):
input = functions[j](input, *args)
return input
return forward
if isinstance(functions, torch.nn.Sequential):
functions = list(functions.children())
# the last chunk has to be non-volatile
end = -1
segment = len(functions) // step
for start in range(0, step * (segment - 1), step):
end = start + step - 1
input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve)
return run_function(end + 1, len(functions) - 1, functions)(input)
def get_rel_pos(q_size, k_size, rel_pos):
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)
return attn

View File

@@ -0,0 +1,38 @@
import torch
from comfy import model_management
def string_to_dtype(s="none", mode=None):
s = s.lower().strip()
if s in ["default", "as-is"]:
return None
elif s in ["auto", "auto (comfy)"]:
if mode == "vae":
return model_management.vae_device()
elif mode == "text_encoder":
return model_management.text_encoder_dtype()
elif mode == "unet":
return model_management.unet_dtype()
else:
raise NotImplementedError(f"Unknown dtype mode '{mode}'")
elif s in ["none", "auto (hf)", "auto (hf/bnb)"]:
return None
elif s in ["fp32", "float32", "float"]:
return torch.float32
elif s in ["bf16", "bfloat16"]:
return torch.bfloat16
elif s in ["fp16", "float16", "half"]:
return torch.float16
elif "fp8" in s or "float8" in s:
if "e5m2" in s:
return torch.float8_e5m2
elif "e4m3" in s:
return torch.float8_e4m3fn
else:
raise NotImplementedError(f"Unknown 8bit dtype '{s}'")
elif "bnb" in s:
assert s in ["bnb8bit", "bnb4bit"], f"Unknown bnb mode '{s}'"
return s
elif s is None:
return None
else:
raise NotImplementedError(f"Unknown dtype '{s}'")

View File

@@ -0,0 +1,139 @@
#credit to Acly for this module
#from https://github.com/Acly/comfyui-inpaint-nodes
import torch
import torch.nn.functional as F
import comfy
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from comfy.model_management import cast_to_device
from ...libs.log import log_node_warn, log_node_error, log_node_info
class InpaintHead(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device="cpu"))
def __call__(self, x):
x = F.pad(x, (1, 1, 1, 1), "replicate")
return F.conv2d(x, weight=self.head)
# injected_model_patcher_calculate_weight = False
# original_calculate_weight = None
class applyFooocusInpaint:
def calculate_weight_patched(self, patches, weight, key, intermediate_dtype=torch.float32):
remaining = []
for p in patches:
alpha = p[0]
v = p[1]
is_fooocus_patch = isinstance(v, tuple) and len(v) == 2 and v[0] == "fooocus"
if not is_fooocus_patch:
remaining.append(p)
continue
if alpha != 0.0:
v = v[1]
w1 = cast_to_device(v[0], weight.device, torch.float32)
if w1.shape == weight.shape:
w_min = cast_to_device(v[1], weight.device, torch.float32)
w_max = cast_to_device(v[2], weight.device, torch.float32)
w1 = (w1 / 255.0) * (w_max - w_min) + w_min
weight += alpha * cast_to_device(w1, weight.device, weight.dtype)
else:
print(
f"[ApplyFooocusInpaint] Shape mismatch {key}, weight not merged ({w1.shape} != {weight.shape})"
)
if len(remaining) > 0:
return self.original_calculate_weight(remaining, weight, key, intermediate_dtype)
return weight
def __enter__(self):
try:
print("[comfyui-easy-use] Injecting patched comfy.lora.calculate_weight.calculate_weight")
self.original_calculate_weight = comfy.lora.calculate_weight
comfy.lora.calculate_weight = self.calculate_weight_patched
except AttributeError:
print("[comfyui-easy-use] Injecting patched comfy.model_patcher.ModelPatcher.calculate_weight")
self.original_calculate_weight = ModelPatcher.calculate_weight
ModelPatcher.calculate_weight = self.calculate_weight_patched
def __exit__(self, exc_type, exc_value, traceback):
try:
comfy.lora.calculate_weight = self.original_calculate_weight
except:
ModelPatcher.calculate_weight = self.original_calculate_weight
# def inject_patched_calculate_weight():
# global injected_model_patcher_calculate_weight
# if not injected_model_patcher_calculate_weight:
# try:
# print("[comfyui-easy-use] Injecting patched comfy.lora.calculate_weight.calculate_weight")
# original_calculate_weight = comfy.lora.calculate_weight
# comfy.lora.original_calculate_weight = original_calculate_weight
# comfy.lora.calculate_weight = calculate_weight_patched
# except AttributeError:
# print("[comfyui-easy-use] Injecting patched comfy.model_patcher.ModelPatcher.calculate_weight")
# original_calculate_weight = ModelPatcher.calculate_weight
# ModelPatcher.original_calculate_weight = original_calculate_weight
# ModelPatcher.calculate_weight = calculate_weight_patched
# injected_model_patcher_calculate_weight = True
class InpaintWorker:
def __init__(self, node_name):
self.node_name = node_name if node_name is not None else ""
def load_fooocus_patch(self, lora: dict, to_load: dict):
patch_dict = {}
loaded_keys = set()
for key in to_load.values():
if value := lora.get(key, None):
patch_dict[key] = ("fooocus", value)
loaded_keys.add(key)
not_loaded = sum(1 for x in lora if x not in loaded_keys)
if not_loaded > 0:
log_node_info(self.node_name,
f"{len(loaded_keys)} Lora keys loaded, {not_loaded} remaining keys not found in model."
)
return patch_dict
def _input_block_patch(self, h: torch.Tensor, transformer_options: dict):
if transformer_options["block"][1] == 0:
if self._inpaint_block is None or self._inpaint_block.shape != h.shape:
assert self._inpaint_head_feature is not None
batch = h.shape[0] // self._inpaint_head_feature.shape[0]
self._inpaint_block = self._inpaint_head_feature.to(h).repeat(batch, 1, 1, 1)
h = h + self._inpaint_block
return h
def patch(self, model, latent, patch):
base_model: BaseModel = model.model
latent_pixels = base_model.process_latent_in(latent["samples"])
noise_mask = latent["noise_mask"].round()
latent_mask = F.max_pool2d(noise_mask, (8, 8)).round().to(latent_pixels)
inpaint_head_model, inpaint_lora = patch
feed = torch.cat([latent_mask, latent_pixels], dim=1)
inpaint_head_model.to(device=feed.device, dtype=feed.dtype)
self._inpaint_head_feature = inpaint_head_model(feed)
self._inpaint_block = None
lora_keys = comfy.lora.model_lora_keys_unet(model.model, {})
lora_keys.update({x: x for x in base_model.state_dict().keys()})
loaded_lora = self.load_fooocus_patch(inpaint_lora, lora_keys)
m = model.clone()
m.set_model_input_block_patch(self._input_block_patch)
patched = m.add_patches(loaded_lora, 1.0)
m.model_options['transformer_options']['fooocus'] = True
not_patched_count = sum(1 for x in loaded_lora if x not in patched)
if not_patched_count > 0:
log_node_error(self.node_name, f"Failed to patch {not_patched_count} keys")
# inject_patched_calculate_weight()
return (m,)

View File

@@ -0,0 +1,156 @@
import torch
import numpy as np
import cv2
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from .simple_extractor_dataset import SimpleFolderDataset
from .transforms import transform_logits
from tqdm import tqdm
from PIL import Image
def get_palette(num_cls):
""" Returns the color map for visualizing the segmentation mask.
Args:
num_cls: Number of classes
Returns:
The color map
"""
n = num_cls
palette = [0] * (n * 3)
for j in range(0, n):
lab = j
palette[j * 3 + 0] = 0
palette[j * 3 + 1] = 0
palette[j * 3 + 2] = 0
i = 0
while lab:
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
i += 1
lab >>= 3
return palette
def delete_irregular(logits_result):
parsing_result = np.argmax(logits_result, axis=2)
upper_cloth = np.where(parsing_result == 4, 255, 0)
contours, hierarchy = cv2.findContours(upper_cloth.astype(np.uint8),
cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
area = []
for i in range(len(contours)):
a = cv2.contourArea(contours[i], True)
area.append(abs(a))
if len(area) != 0:
top = area.index(max(area))
M = cv2.moments(contours[top])
cY = int(M["m01"] / M["m00"])
dresses = np.where(parsing_result == 7, 255, 0)
contours_dress, hierarchy_dress = cv2.findContours(dresses.astype(np.uint8),
cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
area_dress = []
for j in range(len(contours_dress)):
a_d = cv2.contourArea(contours_dress[j], True)
area_dress.append(abs(a_d))
if len(area_dress) != 0:
top_dress = area_dress.index(max(area_dress))
M_dress = cv2.moments(contours_dress[top_dress])
cY_dress = int(M_dress["m01"] / M_dress["m00"])
wear_type = "dresses"
if len(area) != 0:
if len(area_dress) != 0 and cY_dress > cY:
irregular_list = np.array([4, 5, 6])
logits_result[:, :, irregular_list] = -1
else:
irregular_list = np.array([5, 6, 7, 8, 9, 10, 12, 13])
logits_result[:cY, :, irregular_list] = -1
wear_type = "cloth_pant"
parsing_result = np.argmax(logits_result, axis=2)
# pad border
parsing_result = np.pad(parsing_result, pad_width=1, mode='constant', constant_values=0)
return parsing_result, wear_type
def hole_fill(img):
img_copy = img.copy()
mask = np.zeros((img.shape[0] + 2, img.shape[1] + 2), dtype=np.uint8)
cv2.floodFill(img, mask, (0, 0), 255)
img_inverse = cv2.bitwise_not(img)
dst = cv2.bitwise_or(img_copy, img_inverse)
return dst
def refine_mask(mask):
contours, hierarchy = cv2.findContours(mask.astype(np.uint8),
cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
area = []
for j in range(len(contours)):
a_d = cv2.contourArea(contours[j], True)
area.append(abs(a_d))
refine_mask = np.zeros_like(mask).astype(np.uint8)
if len(area) != 0:
i = area.index(max(area))
cv2.drawContours(refine_mask, contours, i, color=255, thickness=-1)
# keep large area in skin case
for j in range(len(area)):
if j != i and area[i] > 2000:
cv2.drawContours(refine_mask, contours, j, color=255, thickness=-1)
return refine_mask
def refine_hole(parsing_result_filled, parsing_result, arm_mask):
filled_hole = cv2.bitwise_and(np.where(parsing_result_filled == 4, 255, 0),
np.where(parsing_result != 4, 255, 0)) - arm_mask * 255
contours, hierarchy = cv2.findContours(filled_hole, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
refine_hole_mask = np.zeros_like(parsing_result).astype(np.uint8)
for i in range(len(contours)):
a = cv2.contourArea(contours[i], True)
# keep hole > 2000 pixels
if abs(a) > 2000:
cv2.drawContours(refine_hole_mask, contours, i, color=255, thickness=-1)
return refine_hole_mask + arm_mask
def onnx_inference(lip_session, input_dir, mask_components=[0]):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
])
input_size = [473, 473]
dataset_lip = SimpleFolderDataset(root=input_dir, input_size=input_size, transform=transform)
dataloader_lip = DataLoader(dataset_lip)
palette = get_palette(20)
with torch.no_grad():
for _, batch in enumerate(tqdm(dataloader_lip)):
image, meta = batch
c = meta['center'].numpy()[0]
s = meta['scale'].numpy()[0]
w = meta['width'].numpy()[0]
h = meta['height'].numpy()[0]
output = lip_session.run(None, {"input.1": image.numpy().astype(np.float32)})
upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True)
upsample_output = upsample(torch.from_numpy(output[1][0]).unsqueeze(0))
upsample_output = upsample_output.squeeze()
upsample_output = upsample_output.permute(1, 2, 0) # CHW -> HWC
logits_result_lip = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h,
input_size=input_size)
parsing_result = np.argmax(logits_result_lip, axis=2)
output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
output_img.putpalette(palette)
mask = np.isin(output_img, mask_components).astype(np.uint8)
mask_image = Image.fromarray(mask * 255)
mask_image = mask_image.convert("RGB")
mask_image = torch.from_numpy(np.array(mask_image).astype(np.float32) / 255.0).unsqueeze(0)
output_img = output_img.convert('RGB')
output_img = torch.from_numpy(np.array(output_img).astype(np.float32) / 255.0).unsqueeze(0)
return output_img, mask_image

View File

@@ -0,0 +1,109 @@
import numpy as np
import torch
from PIL import Image
from .parsing_api import onnx_inference
from ...libs.utils import install_package
class HumanParsing:
def __init__(self, model_path):
self.model_path = model_path
self.session = None
def __call__(self, input_image, mask_components):
if self.session is None:
install_package('onnxruntime')
import onnxruntime as ort
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
# session_options.add_session_config_entry('gpu_id', str(gpu_id))
self.session = ort.InferenceSession(self.model_path, sess_options=session_options,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
parsed_image, mask = onnx_inference(self.session, input_image, mask_components)
return parsed_image, mask
class HumanParts:
def __init__(self, model_path):
self.model_path = model_path
self.session = None
# self.classes_dict = {
# "background": 0,
# "hair": 2,
# "glasses": 4,
# "top-clothes": 5,
# "bottom-clothes": 9,
# "torso-skin": 10,
# "face": 13,
# "left-arm": 14,
# "right-arm": 15,
# "left-leg": 16,
# "right-leg": 17,
# "left-foot": 18,
# "right-foot": 19,
# },
self.classes = [0, 13, 2, 4, 5, 9, 10, 14, 15, 16, 17, 18, 19]
def __call__(self, input_image, mask_components):
if self.session is None:
install_package('onnxruntime')
import onnxruntime as ort
self.session = ort.InferenceSession(self.model_path, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'])
mask, = self.get_mask(self.session, input_image, 0, mask_components)
return mask
def get_mask(self, model, image, rotation, mask_components):
image = image.squeeze(0)
image_np = image.numpy() * 255
pil_image = Image.fromarray(image_np.astype(np.uint8))
original_size = pil_image.size # to resize the mask later
# resize to 512x512 as the model expects
pil_image = pil_image.resize((512, 512))
center = (256, 256)
if rotation != 0:
pil_image = pil_image.rotate(rotation, center=center)
# normalize the image
image_np = np.array(pil_image).astype(np.float32) / 127.5 - 1
image_np = np.expand_dims(image_np, axis=0)
# use the onnx model to get the mask
input_name = model.get_inputs()[0].name
output_name = model.get_outputs()[0].name
result = model.run([output_name], {input_name: image_np})
result = np.array(result[0]).argmax(axis=3).squeeze(0)
score: int = 0
mask = np.zeros_like(result)
for class_index in mask_components:
detected = result == self.classes[class_index]
mask[detected] = 255
score += mask.sum()
# back to the original size
mask_image = Image.fromarray(mask.astype(np.uint8), mode="L")
if rotation != 0:
mask_image = mask_image.rotate(-rotation, center=center)
mask_image = mask_image.resize(original_size)
# and back to numpy...
mask = np.array(mask_image).astype(np.float32) / 255
# add 2 dimensions to match the expected output
mask = np.expand_dims(mask, axis=0)
mask = np.expand_dims(mask, axis=0)
# ensure to return a "binary mask_image"
del image_np, result # free up memory, maybe not necessary
return (torch.from_numpy(mask.astype(np.uint8)),)

View File

@@ -0,0 +1,88 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
@Author : Peike Li
@Contact : peike.li@yahoo.com
@File : dataset.py
@Time : 8/30/19 9:12 PM
@Desc : Dataset Definition
@License : This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import os
import cv2
import numpy as np
from PIL import Image
from torch.utils import data
from .transforms import get_affine_transform
class SimpleFolderDataset(data.Dataset):
def __init__(self, root, input_size=[512, 512], transform=None):
self.root = root
self.input_size = input_size
self.transform = transform
self.aspect_ratio = input_size[1] * 1.0 / input_size[0]
self.input_size = np.asarray(input_size)
self.is_pil_image = False
if isinstance(root, Image.Image):
self.file_list = [root]
self.is_pil_image = True
elif os.path.isfile(root):
self.file_list = [os.path.basename(root)]
self.root = os.path.dirname(root)
else:
self.file_list = os.listdir(self.root)
def __len__(self):
return len(self.file_list)
def _box2cs(self, box):
x, y, w, h = box[:4]
return self._xywh2cs(x, y, w, h)
def _xywh2cs(self, x, y, w, h):
center = np.zeros((2), dtype=np.float32)
center[0] = x + w * 0.5
center[1] = y + h * 0.5
if w > self.aspect_ratio * h:
h = w * 1.0 / self.aspect_ratio
elif w < self.aspect_ratio * h:
w = h * self.aspect_ratio
scale = np.array([w, h], dtype=np.float32)
return center, scale
def __getitem__(self, index):
if self.is_pil_image:
img = np.asarray(self.file_list[index])[:, :, [2, 1, 0]]
else:
img_name = self.file_list[index]
img_path = os.path.join(self.root, img_name)
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
h, w, _ = img.shape
# Get person center and scale
person_center, s = self._box2cs([0, 0, w - 1, h - 1])
r = 0
trans = get_affine_transform(person_center, s, r, self.input_size)
input = cv2.warpAffine(
img,
trans,
(int(self.input_size[1]), int(self.input_size[0])),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(0, 0, 0))
input = self.transform(input)
meta = {
'center': person_center,
'height': h,
'width': w,
'scale': s,
'rotation': r
}
return input, meta

View File

@@ -0,0 +1,167 @@
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# ------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import cv2
import torch
class BRG2Tensor_transform(object):
def __call__(self, pic):
img = torch.from_numpy(pic.transpose((2, 0, 1)))
if isinstance(img, torch.ByteTensor):
return img.float()
else:
return img
class BGR2RGB_transform(object):
def __call__(self, tensor):
return tensor[[2,1,0],:,:]
def flip_back(output_flipped, matched_parts):
'''
ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
'''
assert output_flipped.ndim == 4,\
'output_flipped should be [batch_size, num_joints, height, width]'
output_flipped = output_flipped[:, :, :, ::-1]
for pair in matched_parts:
tmp = output_flipped[:, pair[0], :, :].copy()
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
output_flipped[:, pair[1], :, :] = tmp
return output_flipped
def fliplr_joints(joints, joints_vis, width, matched_parts):
"""
flip coords
"""
# Flip horizontal
joints[:, 0] = width - joints[:, 0] - 1
# Change left-right parts
for pair in matched_parts:
joints[pair[0], :], joints[pair[1], :] = \
joints[pair[1], :], joints[pair[0], :].copy()
joints_vis[pair[0], :], joints_vis[pair[1], :] = \
joints_vis[pair[1], :], joints_vis[pair[0], :].copy()
return joints*joints_vis, joints_vis
def transform_preds(coords, center, scale, input_size):
target_coords = np.zeros(coords.shape)
trans = get_affine_transform(center, scale, 0, input_size, inv=1)
for p in range(coords.shape[0]):
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
return target_coords
def transform_parsing(pred, center, scale, width, height, input_size):
trans = get_affine_transform(center, scale, 0, input_size, inv=1)
target_pred = cv2.warpAffine(
pred,
trans,
(int(width), int(height)), #(int(width), int(height)),
flags=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(0))
return target_pred
def transform_logits(logits, center, scale, width, height, input_size):
trans = get_affine_transform(center, scale, 0, input_size, inv=1)
channel = logits.shape[2]
target_logits = []
for i in range(channel):
target_logit = cv2.warpAffine(
logits[:,:,i],
trans,
(int(width), int(height)), #(int(width), int(height)),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(0))
target_logits.append(target_logit)
target_logits = np.stack(target_logits,axis=2)
return target_logits
def get_affine_transform(center,
scale,
rot,
output_size,
shift=np.array([0, 0], dtype=np.float32),
inv=0):
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
print(scale)
scale = np.array([scale, scale])
scale_tmp = scale
src_w = scale_tmp[0]
dst_w = output_size[1]
dst_h = output_size[0]
rot_rad = np.pi * rot / 180
src_dir = get_dir([0, src_w * -0.5], rot_rad)
dst_dir = np.array([0, (dst_w-1) * -0.5], np.float32)
src = np.zeros((3, 2), dtype=np.float32)
dst = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
dst[0, :] = [(dst_w-1) * 0.5, (dst_h-1) * 0.5]
dst[1, :] = np.array([(dst_w-1) * 0.5, (dst_h-1) * 0.5]) + dst_dir
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2]
def get_3rd_point(a, b):
direct = a - b
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
def get_dir(src_point, rot_rad):
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
src_result = [0, 0]
src_result[0] = src_point[0] * cs - src_point[1] * sn
src_result[1] = src_point[0] * sn + src_point[1] * cs
return src_result
def crop(img, center, scale, output_size, rot=0):
trans = get_affine_transform(center, scale, rot, output_size)
dst_img = cv2.warpAffine(img,
trans,
(int(output_size[1]), int(output_size[0])),
flags=cv2.INTER_LINEAR)
return dst_img

View File

@@ -0,0 +1,184 @@
#credit to huchenlei for this module
#from https://github.com/huchenlei/ComfyUI-IC-Light-Native
import torch
import numpy as np
from typing import Tuple, TypedDict, Callable
import comfy.model_management
from comfy.sd import load_unet
from comfy.ldm.models.autoencoder import AutoencoderKL
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from PIL import Image
from nodes import VAEEncode
from ...libs.image import np2tensor, pil2tensor
class UnetParams(TypedDict):
input: torch.Tensor
timestep: torch.Tensor
c: dict
cond_or_uncond: torch.Tensor
class VAEEncodeArgMax(VAEEncode):
def encode(self, vae, pixels):
assert isinstance(
vae.first_stage_model, AutoencoderKL
), "ArgMax only supported for AutoencoderKL"
original_sample_mode = vae.first_stage_model.regularization.sample
vae.first_stage_model.regularization.sample = False
ret = super().encode(vae, pixels)
vae.first_stage_model.regularization.sample = original_sample_mode
return ret
class ICLight:
@staticmethod
def apply_c_concat(params: UnetParams, concat_conds) -> UnetParams:
"""Apply c_concat on unet call."""
sample = params["input"]
params["c"]["c_concat"] = torch.cat(
(
[concat_conds.to(sample.device)]
* (sample.shape[0] // concat_conds.shape[0])
),
dim=0,
)
return params
@staticmethod
def create_custom_conv(
original_conv: torch.nn.Module,
dtype: torch.dtype,
device=torch.device,
) -> torch.nn.Module:
with torch.no_grad():
new_conv_in = torch.nn.Conv2d(
8,
original_conv.out_channels,
original_conv.kernel_size,
original_conv.stride,
original_conv.padding,
)
new_conv_in.weight.zero_()
new_conv_in.weight[:, :4, :, :].copy_(original_conv.weight)
new_conv_in.bias = original_conv.bias
return new_conv_in.to(dtype=dtype, device=device)
def generate_lighting_image(self, original_image, direction):
_, image_height, image_width, _ = original_image.shape
if direction == 'Left Light':
gradient = np.linspace(255, 0, image_width)
image = np.tile(gradient, (image_height, 1))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
return np2tensor(input_bg)
elif direction == 'Right Light':
gradient = np.linspace(0, 255, image_width)
image = np.tile(gradient, (image_height, 1))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
return np2tensor(input_bg)
elif direction == 'Top Light':
gradient = np.linspace(255, 0, image_height)[:, None]
image = np.tile(gradient, (1, image_width))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
return np2tensor(input_bg)
elif direction == 'Bottom Light':
gradient = np.linspace(0, 255, image_height)[:, None]
image = np.tile(gradient, (1, image_width))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
return np2tensor(input_bg)
elif direction == 'Circle Light':
x = np.linspace(-1, 1, image_width)
y = np.linspace(-1, 1, image_height)
x, y = np.meshgrid(x, y)
r = np.sqrt(x ** 2 + y ** 2)
r = r / r.max()
color1 = np.array([0, 0, 0])[np.newaxis, np.newaxis, :]
color2 = np.array([255, 255, 255])[np.newaxis, np.newaxis, :]
gradient = (color1 * r[..., np.newaxis] + color2 * (1 - r)[..., np.newaxis]).astype(np.uint8)
image = pil2tensor(Image.fromarray(gradient))
return image
else:
image = pil2tensor(Image.new('RGB', (1, 1), (0, 0, 0)))
return image
def generate_source_image(self, original_image, source):
batch_size, image_height, image_width, _ = original_image.shape
if source == 'Use Flipped Background Image':
if batch_size < 2:
raise ValueError('Must be at least 2 image to use flipped background image.')
original_image = [img.unsqueeze(0) for img in original_image]
image = torch.flip(original_image[1], [2])
return image
elif source == 'Ambient':
input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
return np2tensor(input_bg)
elif source == 'Left Light':
gradient = np.linspace(224, 32, image_width)
image = np.tile(gradient, (image_height, 1))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
return np2tensor(input_bg)
elif source == 'Right Light':
gradient = np.linspace(32, 224, image_width)
image = np.tile(gradient, (image_height, 1))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
return np2tensor(input_bg)
elif source == 'Top Light':
gradient = np.linspace(224, 32, image_height)[:, None]
image = np.tile(gradient, (1, image_width))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
return np2tensor(input_bg)
elif source == 'Bottom Light':
gradient = np.linspace(32, 224, image_height)[:, None]
image = np.tile(gradient, (1, image_width))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
return np2tensor(input_bg)
else:
image = pil2tensor(Image.new('RGB', (1, 1), (0, 0, 0)))
return image
def apply(self, ic_model_path, model, c_concat: dict, ic_model=None) -> Tuple[ModelPatcher]:
device = comfy.model_management.get_torch_device()
dtype = comfy.model_management.unet_dtype()
work_model = model.clone()
# Apply scale factor.
base_model: BaseModel = work_model.model
scale_factor = base_model.model_config.latent_format.scale_factor
# [B, 4, H, W]
concat_conds: torch.Tensor = c_concat["samples"] * scale_factor
# [1, 4 * B, H, W]
concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
def unet_dummy_apply(unet_apply: Callable, params: UnetParams):
"""A dummy unet apply wrapper serving as the endpoint of wrapper
chain."""
return unet_apply(x=params["input"], t=params["timestep"], **params["c"])
existing_wrapper = work_model.model_options.get(
"model_function_wrapper", unet_dummy_apply
)
def wrapper_func(unet_apply: Callable, params: UnetParams):
return existing_wrapper(unet_apply, params=self.apply_c_concat(params, concat_conds))
work_model.set_model_unet_function_wrapper(wrapper_func)
if not ic_model:
ic_model = load_unet(ic_model_path)
ic_model_state_dict = ic_model.model.diffusion_model.state_dict()
work_model.add_patches(
patches={
("diffusion_model." + key): (
'diff',
[
value.to(dtype=dtype, device=device),
{"pad_weight": key == 'input_blocks.0.0.weight'}
]
)
for key, value in ic_model_state_dict.items()
}
)
return (work_model, ic_model)

View File

@@ -0,0 +1,268 @@
#credit to shakker-labs and instantX for this module
#from https://github.com/Shakker-Labs/ComfyUI-IPAdapter-Flux
import torch
from PIL import Image
import numpy as np
from .attention_processor import IPAFluxAttnProcessor2_0
from .utils import is_model_pathched, FluxUpdateModules
from .sd3.resampler import TimeResampler
from .sd3.joinblock import JointBlockIPWrapper, IPAttnProcessor
image_proj_model = None
class MLPProjModel(torch.nn.Module):
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens
self.proj = torch.nn.Sequential(
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
torch.nn.GELU(),
torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens),
)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, id_embeds):
x = self.proj(id_embeds)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
x = self.norm(x)
return x
class InstantXFluxIpadapterApply:
def __init__(self, num_tokens=128):
self.device = None
self.dtype = torch.float16
self.num_tokens = num_tokens
self.ip_ckpt = None
self.clip_vision = None
self.image_encoder = None
self.clip_image_processor = None
# state_dict
self.state_dict = None
self.joint_attention_dim = 4096
self.hidden_size = 3072
def set_ip_adapter(self, flux_model, weight, timestep_percent_range=(0.0, 1.0)):
s = flux_model.model_sampling
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
timestep_range = (percent_to_timestep_function(timestep_percent_range[0]),
percent_to_timestep_function(timestep_percent_range[1]))
ip_attn_procs = {} # 19+38=57
dsb_count = len(flux_model.diffusion_model.double_blocks)
for i in range(dsb_count):
name = f"double_blocks.{i}"
ip_attn_procs[name] = IPAFluxAttnProcessor2_0(
hidden_size=self.hidden_size,
cross_attention_dim=self.joint_attention_dim,
num_tokens=self.num_tokens,
scale=weight,
timestep_range=timestep_range
).to(self.device, dtype=self.dtype)
ssb_count = len(flux_model.diffusion_model.single_blocks)
for i in range(ssb_count):
name = f"single_blocks.{i}"
ip_attn_procs[name] = IPAFluxAttnProcessor2_0(
hidden_size=self.hidden_size,
cross_attention_dim=self.joint_attention_dim,
num_tokens=self.num_tokens,
scale=weight,
timestep_range=timestep_range
).to(self.device, dtype=self.dtype)
return ip_attn_procs
def load_ip_adapter(self, flux_model, weight, timestep_percent_range=(0.0, 1.0)):
global image_proj_model
image_proj_model.load_state_dict(self.state_dict["image_proj"], strict=True)
ip_attn_procs = self.set_ip_adapter(flux_model, weight, timestep_percent_range)
ip_layers = torch.nn.ModuleList(ip_attn_procs.values())
ip_layers.load_state_dict(self.state_dict["ip_adapter"], strict=True)
return ip_attn_procs
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
# outputs = self.clip_vision.encode_image(pil_image)
# clip_image_embeds = outputs['image_embeds']
# clip_image_embeds = clip_image_embeds.to(self.device, dtype=self.dtype)
# image_prompt_embeds = self.image_proj_model(clip_image_embeds)
if pil_image is not None:
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(
clip_image.to(self.device, dtype=self.image_encoder.dtype)).pooler_output
clip_image_embeds = clip_image_embeds.to(dtype=self.dtype)
else:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=self.dtype)
global image_proj_model
image_prompt_embeds = image_proj_model(clip_image_embeds)
return image_prompt_embeds
def apply_ipadapter(self, model, ipadapter, image, weight, start_at, end_at, provider=None, use_tiled=False):
self.device = provider.lower()
if "clipvision" in ipadapter:
# self.clip_vision = ipadapter["clipvision"]['model']
self.image_encoder = ipadapter["clipvision"]['model']['image_encoder'].to(self.device, dtype=self.dtype)
self.clip_image_processor = ipadapter["clipvision"]['model']['clip_image_processor']
if "ipadapter" in ipadapter:
self.ip_ckpt = ipadapter["ipadapter"]['file']
self.state_dict = ipadapter["ipadapter"]['model']
# process image
pil_image = image.numpy()[0] * 255.0
pil_image = Image.fromarray(pil_image.astype(np.uint8))
# initialize ipadapter
global image_proj_model
if image_proj_model is None:
image_proj_model = MLPProjModel(
cross_attention_dim=self.joint_attention_dim, # 4096
id_embeddings_dim=1152,
num_tokens=self.num_tokens,
)
image_proj_model.to(self.device, dtype=self.dtype)
ip_attn_procs = self.load_ip_adapter(model.model, weight, (start_at, end_at))
# process control image
image_prompt_embeds = self.get_image_embeds(pil_image=pil_image, clip_image_embeds=None)
# set model
# is_patched = is_model_pathched(model.model)
bi = model.clone()
FluxUpdateModules(bi, ip_attn_procs, image_prompt_embeds)
return (bi, image)
def patch_sd3(
patcher,
ip_procs,
resampler: TimeResampler,
clip_embeds,
weight=1.0,
start=0.0,
end=1.0,
):
"""
Patches a model_sampler to add the ipadapter
"""
mmdit = patcher.model.diffusion_model
timestep_schedule_max = patcher.model.model_config.sampling_settings.get(
"timesteps", 1000
)
# hook the model's forward function
# so that when it gets called, we can grab the timestep and send it to the resampler
ip_options = {
"hidden_states": None,
"t_emb": None,
"weight": weight,
}
def ddit_wrapper(forward, args):
# this is between 0 and 1, so the adapters can calculate start_point and end_point
# actually, do we need to get the sigma value instead?
t_percent = 1 - args["timestep"].flatten()[0].cpu().item()
if start <= t_percent <= end:
batch_size = args["input"].shape[0] // len(args["cond_or_uncond"])
# if we're only doing cond or only doing uncond, only pass one of them through the resampler
embeds = clip_embeds[args["cond_or_uncond"]]
# slight efficiency optimization todo: pass the embeds through and then afterwards
# repeat to the batch size
embeds = torch.repeat_interleave(embeds, batch_size, dim=0)
# the resampler wants between 0 and MAX_STEPS
timestep = args["timestep"] * timestep_schedule_max
image_emb, t_emb = resampler(embeds, timestep, need_temb=True)
# these will need to be accessible to the IPAdapters
ip_options["hidden_states"] = image_emb
ip_options["t_emb"] = t_emb
else:
ip_options["hidden_states"] = None
ip_options["t_emb"] = None
return forward(args["input"], args["timestep"], **args["c"])
patcher.set_model_unet_function_wrapper(ddit_wrapper)
# patch each dit block
for i, block in enumerate(mmdit.joint_blocks):
wrapper = JointBlockIPWrapper(block, ip_procs[i], ip_options)
patcher.set_model_patch_replace(wrapper, "dit", "double_block", i)
class InstantXSD3IpadapterApply:
def __init__(self):
self.device = None
self.dtype = torch.float16
self.clip_image_processor = None
self.image_encoder = None
self.resampler = None
self.procs = None
@torch.inference_mode()
def encode(self, image):
clip_image = self.clip_image_processor.image_processor(image, return_tensors="pt", do_rescale=False).pixel_values
clip_image_embeds = self.image_encoder(
clip_image.to(self.device, dtype=self.image_encoder.dtype),
output_hidden_states=True,
).hidden_states[-2]
clip_image_embeds = torch.cat(
[clip_image_embeds, torch.zeros_like(clip_image_embeds)], dim=0
)
clip_image_embeds = clip_image_embeds.to(dtype=torch.float16)
return clip_image_embeds
def apply_ipadapter(self, model, ipadapter, image, weight, start_at, end_at, provider=None, use_tiled=False):
self.device = provider.lower()
if "clipvision" in ipadapter:
self.image_encoder = ipadapter["clipvision"]['model']['image_encoder'].to(self.device, dtype=self.dtype)
self.clip_image_processor = ipadapter["clipvision"]['model']['clip_image_processor']
if "ipadapter" in ipadapter:
self.ip_ckpt = ipadapter["ipadapter"]['file']
self.state_dict = ipadapter["ipadapter"]['model']
self.resampler = TimeResampler(
dim=1280,
depth=4,
dim_head=64,
heads=20,
num_queries=64,
embedding_dim=1152,
output_dim=2432,
ff_mult=4,
timestep_in_dim=320,
timestep_flip_sin_to_cos=True,
timestep_freq_shift=0,
)
self.resampler.eval()
self.resampler.to(self.device, dtype=self.dtype)
self.resampler.load_state_dict(self.state_dict["image_proj"])
# now we'll create the attention processors
# ip_adapter.keys looks like [0.proj, 0.to_k, ..., 1.proj, 1.to_k, ...]
n_procs = len(
set(x.split(".")[0] for x in self.state_dict["ip_adapter"].keys())
)
self.procs = torch.nn.ModuleList(
[
# this is hardcoded for SD3.5L
IPAttnProcessor(
hidden_size=2432,
cross_attention_dim=2432,
ip_hidden_states_dim=2432,
ip_encoder_hidden_states_dim=2432,
head_dim=64,
timesteps_emb_dim=1280,
).to(self.device, dtype=torch.float16)
for _ in range(n_procs)
]
)
self.procs.load_state_dict(self.state_dict["ip_adapter"])
work_model = model.clone()
embeds = self.encode(image)
patch_sd3(
work_model,
self.procs,
self.resampler,
embeds,
weight,
start_at,
end_at,
)
return (work_model, image)

View File

@@ -0,0 +1,87 @@
import numbers
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class RMSNorm(nn.Module):
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.weight = None
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states * self.weight
else:
hidden_states = hidden_states.to(input_dtype)
return hidden_states
class IPAFluxAttnProcessor2_0(nn.Module):
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, timestep_range=None):
super().__init__()
self.hidden_size = hidden_size # 3072
self.cross_attention_dim = cross_attention_dim # 4096
self.scale = scale
self.num_tokens = num_tokens
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.norm_added_k = RMSNorm(128, eps=1e-5, elementwise_affine=False)
self.norm_added_v = RMSNorm(128, eps=1e-5, elementwise_affine=False)
self.timestep_range = timestep_range
def __call__(
self,
num_heads,
query,
image_emb: torch.FloatTensor,
t: torch.FloatTensor
) -> torch.FloatTensor:
# only apply IPA if timestep is within range
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
return None
# `ip-adapter` projections
ip_hidden_states = image_emb
ip_hidden_states_key_proj = self.to_k_ip(ip_hidden_states)
ip_hidden_states_value_proj = self.to_v_ip(ip_hidden_states)
ip_hidden_states_key_proj = rearrange(ip_hidden_states_key_proj, 'B L (H D) -> B H L D', H=num_heads)
ip_hidden_states_value_proj = rearrange(ip_hidden_states_value_proj, 'B L (H D) -> B H L D', H=num_heads)
ip_hidden_states_key_proj = self.norm_added_k(ip_hidden_states_key_proj)
ip_hidden_states_value_proj = self.norm_added_v(ip_hidden_states_value_proj)
ip_hidden_states = F.scaled_dot_product_attention(query.to(image_emb.device).to(image_emb.dtype),
ip_hidden_states_key_proj,
ip_hidden_states_value_proj,
dropout_p=0.0, is_causal=False)
ip_hidden_states = rearrange(ip_hidden_states, "B H L D -> B L (H D)", H=num_heads)
ip_hidden_states = ip_hidden_states.to(query.dtype).to(query.device)
return self.scale * ip_hidden_states

View File

@@ -0,0 +1,153 @@
import torch
from torch import Tensor, nn
from .math import attention
from ..attention_processor import IPAFluxAttnProcessor2_0
from comfy.ldm.flux.layers import DoubleStreamBlock, SingleStreamBlock
from comfy import model_management as mm
class DoubleStreamBlockIPA(nn.Module):
def __init__(self, original_block: DoubleStreamBlock, ip_adapter, image_emb):
super().__init__()
mlp_hidden_dim = original_block.img_mlp[0].out_features
mlp_ratio = mlp_hidden_dim / original_block.hidden_size
mlp_hidden_dim = int(original_block.hidden_size * mlp_ratio)
self.num_heads = original_block.num_heads
self.hidden_size = original_block.hidden_size
self.img_mod = original_block.img_mod
self.img_norm1 = original_block.img_norm1
self.img_attn = original_block.img_attn
self.img_norm2 = original_block.img_norm2
self.img_mlp = original_block.img_mlp
self.txt_mod = original_block.txt_mod
self.txt_norm1 = original_block.txt_norm1
self.txt_attn = original_block.txt_attn
self.txt_norm2 = original_block.txt_norm2
self.txt_mlp = original_block.txt_mlp
self.flipped_img_txt = original_block.flipped_img_txt
self.ip_adapter = ip_adapter
self.image_emb = image_emb
self.device = mm.get_torch_device()
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, t: Tensor, attn_mask=None):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3,
1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3,
1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
if self.flipped_img_txt:
# run actual attention
attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
pe=pe, mask=attn_mask)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
for adapter, image in zip(self.ip_adapter, self.image_emb):
# this does a separate attention for each adapter
ip_hidden_states = adapter(self.num_heads, img_q, image, t)
if ip_hidden_states is not None:
ip_hidden_states = ip_hidden_states.to(self.device)
img_attn = img_attn + ip_hidden_states
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
class SingleStreamBlockIPA(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(self, original_block: SingleStreamBlock, ip_adapter, image_emb):
super().__init__()
self.hidden_dim = original_block.hidden_size
self.num_heads = original_block.num_heads
self.scale = original_block.scale
self.mlp_hidden_dim = original_block.mlp_hidden_dim
# qkv and mlp_in
self.linear1 = original_block.linear1
# proj and mlp_out
self.linear2 = original_block.linear2
self.norm = original_block.norm
self.hidden_size = original_block.hidden_size
self.pre_norm = original_block.pre_norm
self.mlp_act = original_block.mlp_act
self.modulation = original_block.modulation
self.ip_adapter = ip_adapter
self.image_emb = image_emb
self.device = mm.get_torch_device()
def add_adapter(self, ip_adapter: IPAFluxAttnProcessor2_0, image_emb):
self.ip_adapter.append(ip_adapter)
self.image_emb.append(image_emb)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, t: Tensor, attn_mask=None) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask)
for adapter, image in zip(self.ip_adapter, self.image_emb):
# this does a separate attention for each adapter
# maybe we want a single joint attention call for all adapters?
ip_hidden_states = adapter(self.num_heads, q, image, t)
if ip_hidden_states is not None:
ip_hidden_states = ip_hidden_states.to(self.device)
attn = attn + ip_hidden_states
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x

View File

@@ -0,0 +1,35 @@
import torch
from einops import rearrange
from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q, k = apply_rope(q, k, pe)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu():
device = torch.device("cpu")
else:
device = pos.device
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.to(dtype=torch.float32, device=pos.device)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@@ -0,0 +1,219 @@
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.mmdit import (RMSNorm, JointBlock,)
class AdaLayerNorm(nn.Module):
"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, time_embedding_dim=None, mode="normal"):
super().__init__()
self.silu = nn.SiLU()
num_params_dict = dict(
zero=6,
normal=2,
)
num_params = num_params_dict[mode]
self.linear = nn.Linear(
time_embedding_dim or embedding_dim, num_params * embedding_dim, bias=True
)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
self.mode = mode
def forward(
self,
x,
hidden_dtype=None,
emb=None,
):
emb = self.linear(self.silu(emb))
if self.mode == "normal":
shift_msa, scale_msa = emb.chunk(2, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x
elif self.mode == "zero":
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(
6, dim=1
)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class IPAttnProcessor(nn.Module):
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
ip_hidden_states_dim=None,
ip_encoder_hidden_states_dim=None,
head_dim=None,
timesteps_emb_dim=1280,
):
super().__init__()
self.norm_ip = AdaLayerNorm(
ip_hidden_states_dim, time_embedding_dim=timesteps_emb_dim
)
self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
self.norm_q = RMSNorm(head_dim, 1e-6)
self.norm_k = RMSNorm(head_dim, 1e-6)
self.norm_ip_k = RMSNorm(head_dim, 1e-6)
def forward(
self,
ip_hidden_states,
img_query,
img_key=None,
img_value=None,
t_emb=None,
n_heads=1,
):
if ip_hidden_states is None:
return None
if not hasattr(self, "to_k_ip") or not hasattr(self, "to_v_ip"):
return None
# norm ip input
norm_ip_hidden_states = self.norm_ip(ip_hidden_states, emb=t_emb)
# to k and v
ip_key = self.to_k_ip(norm_ip_hidden_states)
ip_value = self.to_v_ip(norm_ip_hidden_states)
# reshape
img_query = rearrange(img_query, "b l (h d) -> b h l d", h=n_heads)
img_key = rearrange(img_key, "b l (h d) -> b h l d", h=n_heads)
# note that the image is in a different shape: b l h d
# so we transpose to b h l d
# or do we have to transpose here?
img_value = torch.transpose(img_value, 1, 2)
ip_key = rearrange(ip_key, "b l (h d) -> b h l d", h=n_heads)
ip_value = rearrange(ip_value, "b l (h d) -> b h l d", h=n_heads)
# norm
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
ip_key = self.norm_ip_k(ip_key)
# cat img
key = torch.cat([img_key, ip_key], dim=2)
value = torch.cat([img_value, ip_value], dim=2)
#
ip_hidden_states = F.scaled_dot_product_attention(
img_query, key, value, dropout_p=0.0, is_causal=False
)
ip_hidden_states = rearrange(ip_hidden_states, "b h l d -> b l (h d)")
ip_hidden_states = ip_hidden_states.to(img_query.dtype)
return ip_hidden_states
class JointBlockIPWrapper:
"""To be used as a patch_replace with Comfy"""
def __init__(
self,
original_block: JointBlock,
adapter: IPAttnProcessor,
ip_options=None,
):
self.original_block = original_block
self.adapter = adapter
if ip_options is None:
ip_options = {}
self.ip_options = ip_options
def block_mixing(self, context, x, context_block, x_block, c):
"""
Comes from mmdit.py. Modified to add ipadapter attention.
"""
context_qkv, context_intermediates = context_block.pre_attention(context, c)
if x_block.x_block_self_attn:
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
else:
x_qkv, x_intermediates = x_block.pre_attention(x, c)
qkv = tuple(torch.cat((context_qkv[j], x_qkv[j]), dim=1) for j in range(3))
attn = optimized_attention(
qkv[0],
qkv[1],
qkv[2],
heads=x_block.attn.num_heads,
)
context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]],
attn[:, context_qkv[0].shape[1] :],
)
# if the current timestep is not in the ipadapter enabling range, then the resampler wasn't run
# and the hidden states will be None
if (
self.ip_options["hidden_states"] is not None
and self.ip_options["t_emb"] is not None
):
# IP-Adapter
ip_attn = self.adapter(
self.ip_options["hidden_states"],
*x_qkv,
self.ip_options["t_emb"],
x_block.attn.num_heads,
)
x_attn = x_attn + ip_attn * self.ip_options["weight"]
# Everything else is unchanged
if not context_block.pre_only:
context = context_block.post_attention(context_attn, *context_intermediates)
else:
context = None
if x_block.x_block_self_attn:
attn2 = optimized_attention(
x_qkv2[0],
x_qkv2[1],
x_qkv2[2],
heads=x_block.attn2.num_heads,
)
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
else:
x = x_block.post_attention(x_attn, *x_intermediates)
return context, x
def __call__(self, args, _):
# Code from mmdit.py:
# in this case, we're blocks_replace[("double_block", i)]
# note that although we're passed the original block,
# we can't actually get it from inside its wrapper
# (which would simplify the whole code...)
# ```
# def block_wrap(args):
# out = {}
# out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
# return out
# out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
# context = out["txt"]
# x = out["img"]
# ```
c, x = self.block_mixing(
args["txt"],
args["img"],
self.original_block.context_block,
self.original_block.x_block,
c=args["vec"],
)
return {"txt": c, "img": x}

View File

@@ -0,0 +1,385 @@
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
import math
import torch
import torch.nn as nn
from typing import Optional
ACTIVATION_FUNCTIONS = {
"swish": nn.SiLU(),
"silu": nn.SiLU(),
"mish": nn.Mish(),
"gelu": nn.GELU(),
"relu": nn.ReLU(),
}
def get_activation(act_fn: str) -> nn.Module:
"""Helper function to get activation function from string.
Args:
act_fn (str): Name of activation function.
Returns:
nn.Module: Activation function.
"""
act_fn = act_fn.lower()
if act_fn in ACTIVATION_FUNCTIONS:
return ACTIVATION_FUNCTIONS[act_fn]
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
Args
timesteps (torch.Tensor):
a 1-D Tensor of N indices, one per batch element. These may be fractional.
embedding_dim (int):
the dimension of the output.
flip_sin_to_cos (bool):
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
downscale_freq_shift (float):
Controls the delta between frequencies between dimensions
scale (float):
Scaling factor applied to the embeddings.
max_period (int):
Controls the maximum frequency of the embeddings
Returns
torch.Tensor: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
return t_emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents, shift=None, scale=None):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
if shift is not None and scale is not None:
latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(
-2, -1
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
*args,
**kwargs,
):
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
class TimeResampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
timestep_in_dim=320,
timestep_flip_sin_to_cos=True,
timestep_freq_shift=0,
):
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
# msa
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
# ff
FeedForward(dim=dim, mult=ff_mult),
# adaLN
nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim, bias=True)),
]
)
)
# time
self.time_proj = Timesteps(
timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift
)
self.time_embedding = TimestepEmbedding(timestep_in_dim, dim, act_fn="silu")
# adaLN
# self.adaLN_modulation = nn.Sequential(
# nn.SiLU(),
# nn.Linear(timestep_out_dim, 6 * timestep_out_dim, bias=True)
# )
def forward(self, x, timestep, need_temb=False):
timestep_emb = self.embedding_time(x, timestep) # bs, dim
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
x = x + timestep_emb[:, None]
for attn, ff, adaLN_modulation in self.layers:
shift_msa, scale_msa, shift_mlp, scale_mlp = adaLN_modulation(
timestep_emb
).chunk(4, dim=1)
latents = attn(x, latents, shift_msa, scale_msa) + latents
res = latents
for idx_ff in range(len(ff)):
layer_ff = ff[idx_ff]
latents = layer_ff(latents)
if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN
latents = latents * (
1 + scale_mlp.unsqueeze(1)
) + shift_mlp.unsqueeze(1)
latents = latents + res
# latents = ff(latents) + latents
latents = self.proj_out(latents)
latents = self.norm_out(latents)
if need_temb:
return latents, timestep_emb
else:
return latents
def embedding_time(self, sample, timestep):
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, None)
return emb

View File

@@ -0,0 +1,136 @@
import torch
from torch import Tensor
from .flux.layers import DoubleStreamBlockIPA, SingleStreamBlockIPA
from comfy.ldm.flux.layers import timestep_embedding
from types import MethodType
def FluxUpdateModules(bi, ip_attn_procs, image_emb):
flux_model = bi.model
bi.add_object_patch(f"diffusion_model.forward_orig", MethodType(forward_orig_ipa, flux_model.diffusion_model))
for i, original in enumerate(flux_model.diffusion_model.double_blocks):
patch_name = f"double_blocks.{i}"
maybe_patched_layer = bi.get_model_object(f"diffusion_model.{patch_name}")
# if there's already a patch there, collect its adapters and replace it
procs = [ip_attn_procs[patch_name]]
embs = [image_emb]
if isinstance(maybe_patched_layer, DoubleStreamBlockIPA):
procs = maybe_patched_layer.ip_adapter + procs
embs = maybe_patched_layer.image_emb + embs
# initial ipa models with image embeddings
new_layer = DoubleStreamBlockIPA(original, procs, embs)
# for example, ComfyUI internally uses model.add_patches to add loras
bi.add_object_patch(f"diffusion_model.{patch_name}", new_layer)
for i, original in enumerate(flux_model.diffusion_model.single_blocks):
patch_name = f"single_blocks.{i}"
maybe_patched_layer = bi.get_model_object(f"diffusion_model.{patch_name}")
procs = [ip_attn_procs[patch_name]]
embs = [image_emb]
if isinstance(maybe_patched_layer, SingleStreamBlockIPA):
procs = maybe_patched_layer.ip_adapter + procs
embs = maybe_patched_layer.image_emb + embs
# initial ipa models with image embeddings
new_layer = SingleStreamBlockIPA(original, procs, embs)
bi.add_object_patch(f"diffusion_model.{patch_name}", new_layer)
def is_model_pathched(model):
def test(mod):
if isinstance(mod, DoubleStreamBlockIPA):
return True
else:
for p in mod.children():
if test(p):
return True
return False
result = test(model)
return result
def forward_orig_ipa(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor|None = None,
control=None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
if isinstance(block, DoubleStreamBlockIPA): # ipadaper
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], t=args["timesteps"], attn_mask=args.get("attn_mask"))
else:
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "timesteps": timesteps, "attn_mask": attn_mask}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
if isinstance(block, DoubleStreamBlockIPA): # ipadaper
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, t=timesteps, attn_mask=attn_mask)
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
if isinstance(block, SingleStreamBlockIPA): # ipadaper
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], t=args["timesteps"], attn_mask=args.get("attn_mask"))
else:
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "timesteps": timesteps, "attn_mask": attn_mask}, {"original_block": block_wrap})
img = out["img"]
else:
if isinstance(block, SingleStreamBlockIPA): # ipadaper
img = block(img, vec=vec, pe=pe, t=timesteps, attn_mask=attn_mask)
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] :, ...] += add
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img

View File

@@ -0,0 +1,42 @@
{
"_name_or_path": "THUDM/chatglm3-6b-base",
"model_type": "chatglm",
"architectures": [
"ChatGLMModel"
],
"auto_map": {
"AutoConfig": "configuration_chatglm.ChatGLMConfig",
"AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
"AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
"AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
"AutoModelForSequenceClassification": "modeling_chatglm.ChatGLMForSequenceClassification"
},
"add_bias_linear": false,
"add_qkv_bias": true,
"apply_query_key_layer_scaling": true,
"apply_residual_connection_post_layernorm": false,
"attention_dropout": 0.0,
"attention_softmax_in_fp32": true,
"bias_dropout_fusion": true,
"ffn_hidden_size": 13696,
"fp32_residual_connection": false,
"hidden_dropout": 0.0,
"hidden_size": 4096,
"kv_channels": 128,
"layernorm_epsilon": 1e-05,
"multi_query_attention": true,
"multi_query_group_num": 2,
"num_attention_heads": 32,
"num_layers": 28,
"original_rope": true,
"padded_vocab_size": 65024,
"post_layer_norm": true,
"rmsnorm": true,
"seq_length": 32768,
"use_cache": true,
"torch_dtype": "float16",
"transformers_version": "4.30.2",
"tie_word_embeddings": false,
"eos_token_id": 2,
"pad_token_id": 0
}

View File

@@ -0,0 +1,60 @@
from transformers import PretrainedConfig
class ChatGLMConfig(PretrainedConfig):
model_type = "chatglm"
def __init__(
self,
num_layers=28,
padded_vocab_size=65024,
hidden_size=4096,
ffn_hidden_size=13696,
kv_channels=128,
num_attention_heads=32,
seq_length=2048,
hidden_dropout=0.0,
classifier_dropout=None,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
rmsnorm=True,
apply_residual_connection_post_layernorm=False,
post_layer_norm=True,
add_bias_linear=False,
add_qkv_bias=False,
bias_dropout_fusion=True,
multi_query_attention=False,
multi_query_group_num=1,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=True,
fp32_residual_connection=False,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs
):
self.num_layers = num_layers
self.vocab_size = padded_vocab_size
self.padded_vocab_size = padded_vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
self.hidden_dropout = hidden_dropout
self.classifier_dropout = classifier_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon
self.rmsnorm = rmsnorm
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.post_layer_norm = post_layer_norm
self.add_bias_linear = add_bias_linear
self.add_qkv_bias = add_qkv_bias
self.bias_dropout_fusion = bias_dropout_fusion
self.multi_query_attention = multi_query_attention
self.multi_query_group_num = multi_query_group_num
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
self.quantization_bit = quantization_bit
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
super().__init__(**kwargs)

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,300 @@
import json
import os
import re
from typing import List, Optional, Union, Dict
from sentencepiece import SentencePieceProcessor
from transformers import PreTrainedTokenizer
from transformers.utils import logging, PaddingStrategy
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
class SPTokenizer:
def __init__(self, model_path: str):
# reload tokenizer
assert os.path.isfile(model_path), model_path
self.sp_model = SentencePieceProcessor(model_file=model_path)
# BOS / EOS token IDs
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.unk_id()
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
self.special_tokens = {}
self.index_special_tokens = {}
for token in special_tokens:
self.special_tokens[token] = self.n_words
self.index_special_tokens[self.n_words] = token
self.n_words += 1
self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
def tokenize(self, s: str, encode_special_tokens=False):
if encode_special_tokens:
last_index = 0
t = []
for match in re.finditer(self.role_special_token_expression, s):
if last_index < match.start():
t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
t.append(s[match.start():match.end()])
last_index = match.end()
if last_index < len(s):
t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
return t
else:
return self.sp_model.EncodeAsPieces(s)
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
assert type(s) is str
t = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int]) -> str:
text, buffer = "", []
for token in t:
if token in self.index_special_tokens:
if buffer:
text += self.sp_model.decode(buffer)
buffer = []
text += self.index_special_tokens[token]
else:
buffer.append(token)
if buffer:
text += self.sp_model.decode(buffer)
return text
def decode_tokens(self, tokens: List[str]) -> str:
text = self.sp_model.DecodePieces(tokens)
return text
def convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
if token in self.special_tokens:
return self.special_tokens[token]
return self.sp_model.PieceToId(token)
def convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if index in self.index_special_tokens:
return self.index_special_tokens[index]
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
return ""
return self.sp_model.IdToPiece(index)
class ChatGLMTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "tokenizer.model"}
model_input_names = ["input_ids", "attention_mask", "position_ids"]
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
**kwargs):
self.name = "GLMTokenizer"
self.vocab_file = vocab_file
self.tokenizer = SPTokenizer(vocab_file)
self.special_tokens = {
"<bos>": self.tokenizer.bos_id,
"<eos>": self.tokenizer.eos_id,
"<pad>": self.tokenizer.pad_id
}
self.encode_special_tokens = encode_special_tokens
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
encode_special_tokens=encode_special_tokens,
**kwargs)
def get_command(self, token):
if token in self.special_tokens:
return self.special_tokens[token]
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
return self.tokenizer.special_tokens[token]
@property
def unk_token(self) -> str:
return "<unk>"
@property
def pad_token(self) -> str:
return "<unk>"
@property
def pad_token_id(self):
return self.get_command("<pad>")
@property
def eos_token(self) -> str:
return "</s>"
@property
def eos_token_id(self):
return self.get_command("<eos>")
@property
def vocab_size(self):
return self.tokenizer.n_words
def get_vocab(self):
""" Returns vocab as a dict """
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text, **kwargs):
return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
return self.tokenizer.convert_token_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.tokenizer.convert_id_to_token(index)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return self.tokenizer.decode_tokens(tokens)
def save_vocabulary(self, save_directory, filename_prefix=None):
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
filename_prefix (`str`, *optional*):
An optional prefix to add to the named of the saved files.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, self.vocab_files_names["vocab_file"]
)
else:
vocab_file = save_directory
with open(self.vocab_file, 'rb') as fin:
proto_str = fin.read()
with open(vocab_file, "wb") as writer:
writer.write(proto_str)
return (vocab_file,)
def get_prefix_tokens(self):
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
return prefix_tokens
def build_single_message(self, role, metadata, message):
assert role in ["system", "user", "assistant", "observation"], role
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
message_tokens = self.tokenizer.encode(message)
tokens = role_tokens + message_tokens
return tokens
def build_chat_input(self, query, history=None, role="user"):
if history is None:
history = []
input_ids = []
for item in history:
content = item["content"]
if item["role"] == "system" and "tools" in item:
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
input_ids.extend(self.build_single_message(role, "", query))
input_ids.extend([self.get_command("<|assistant|>")])
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A BERT sequence has the following format:
- single sequence: `[CLS] X [SEP]`
- pair of sequences: `[CLS] A [SEP] B [SEP]`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
prefix_tokens = self.get_prefix_tokens()
token_ids_0 = prefix_tokens + token_ids_0
if token_ids_1 is not None:
token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
return token_ids_0
def _pad(
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
**kwargs
) -> dict:
"""
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
Args:
encoded_inputs:
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
max_length: maximum length of the returned list and optionally padding length (see below).
Will truncate by taking into account the special tokens.
padding_strategy: PaddingStrategy to use for padding.
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The tokenizer padding sides are defined in self.padding_side:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
`>= 7.5` (Volta).
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
# Load from model defaults
assert self.padding_side == "left"
required_input = encoded_inputs[self.model_input_names[0]]
seq_length = len(required_input)
if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(required_input)
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
# Initialize attention mask if not present.
if "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * seq_length
if "position_ids" not in encoded_inputs:
encoded_inputs["position_ids"] = list(range(seq_length))
if needs_to_be_padded:
difference = max_length - len(required_input)
if "attention_mask" in encoded_inputs:
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
if "position_ids" in encoded_inputs:
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
return encoded_inputs

Some files were not shown because too many files have changed in this diff Show More