[Weight-adapter/Trainer] Bypass forward mode in Weight adapter system (#11958)
* Add API of bypass forward module * bypass implementation * add bypass fwd into nodes list/trainer
This commit is contained in:
@@ -3,13 +3,18 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
|
||||
from .base import (
|
||||
WeightAdapterBase,
|
||||
WeightAdapterTrainBase,
|
||||
weight_decompose,
|
||||
factorization,
|
||||
)
|
||||
|
||||
|
||||
class OFTDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
# Unpack weights tuple from LoHaAdapter
|
||||
# Unpack weights tuple from OFTAdapter
|
||||
blocks, rescale, alpha, _ = weights
|
||||
|
||||
# Create trainable parameters
|
||||
@@ -52,6 +57,78 @@ class OFTDiff(WeightAdapterTrainBase):
|
||||
weight = self.rescale * weight
|
||||
return weight.to(org_dtype)
|
||||
|
||||
def _get_orthogonal_matrix(self, device, dtype):
|
||||
"""Compute the orthogonal rotation matrix R from OFT blocks."""
|
||||
blocks = self.oft_blocks.to(device=device, dtype=dtype)
|
||||
I = torch.eye(self.block_size, device=device, dtype=dtype)
|
||||
|
||||
# Q = blocks - blocks^T (skew-symmetric)
|
||||
q = blocks - blocks.transpose(1, 2)
|
||||
normed_q = q
|
||||
|
||||
# Apply constraint if set
|
||||
if self.constraint:
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > self.constraint:
|
||||
normed_q = q * self.constraint / q_norm
|
||||
|
||||
# Cayley transform: R = (I + Q)(I - Q)^-1
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
return r.to(dtype)
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
OFT has no additive component - returns zeros matching base_out shape.
|
||||
|
||||
OFT only transforms the output via g(), it doesn't add to it.
|
||||
"""
|
||||
return torch.zeros_like(base_out)
|
||||
|
||||
def g(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Output transformation for OFT: applies orthogonal rotation.
|
||||
|
||||
OFT transforms output channels using block-diagonal orthogonal matrices.
|
||||
"""
|
||||
r = self._get_orthogonal_matrix(y.device, y.dtype)
|
||||
|
||||
# Apply multiplier to interpolate between identity and full transform
|
||||
multiplier = getattr(self, "multiplier", 1.0)
|
||||
I = torch.eye(self.block_size, device=y.device, dtype=y.dtype)
|
||||
r = r * multiplier + (1 - multiplier) * I
|
||||
|
||||
# Use module info from bypass injection
|
||||
is_conv = getattr(self, "is_conv", y.dim() > 2)
|
||||
|
||||
if is_conv:
|
||||
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
|
||||
y = y.transpose(1, -1)
|
||||
|
||||
# y now has channels in last dim
|
||||
*batch_shape, out_features = y.shape
|
||||
|
||||
# Reshape to apply block-diagonal transform
|
||||
# (*, out_features) -> (*, block_num, block_size)
|
||||
y_blocked = y.reshape(*batch_shape, self.block_num, self.block_size)
|
||||
|
||||
# Apply orthogonal transform: R @ y for each block
|
||||
# r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size)
|
||||
out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked)
|
||||
|
||||
# Reshape back: (*, block_num, block_size) -> (*, out_features)
|
||||
out = out_blocked.reshape(*batch_shape, out_features)
|
||||
|
||||
# Apply rescale if present
|
||||
if self.rescaled:
|
||||
rescale = self.rescale.to(device=y.device, dtype=y.dtype)
|
||||
out = out * rescale.view(-1)
|
||||
|
||||
if is_conv:
|
||||
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
|
||||
out = out.transpose(1, -1)
|
||||
|
||||
return out
|
||||
|
||||
def passive_memory_usage(self):
|
||||
"""Calculates memory usage of the trainable parameters."""
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
@@ -68,10 +145,10 @@ class OFTAdapter(WeightAdapterBase):
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
block_size, block_num = factorization(out_dim, rank)
|
||||
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32)
|
||||
return OFTDiff(
|
||||
(block, None, alpha, None)
|
||||
block = torch.zeros(
|
||||
block_num, block_size, block_size, device=weight.device, dtype=torch.float32
|
||||
)
|
||||
return OFTDiff((block, None, alpha, None))
|
||||
|
||||
def to_train(self):
|
||||
return OFTDiff(self.weights)
|
||||
@@ -127,9 +204,13 @@ class OFTAdapter(WeightAdapterBase):
|
||||
alpha = 0
|
||||
dora_scale = v[3]
|
||||
|
||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||
blocks = comfy.model_management.cast_to_device(
|
||||
blocks, weight.device, intermediate_dtype
|
||||
)
|
||||
if rescale is not None:
|
||||
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
||||
rescale = comfy.model_management.cast_to_device(
|
||||
rescale, weight.device, intermediate_dtype
|
||||
)
|
||||
|
||||
block_num, block_size, *_ = blocks.shape
|
||||
|
||||
@@ -139,23 +220,108 @@ class OFTAdapter(WeightAdapterBase):
|
||||
# for Q = -Q^T
|
||||
q = blocks - blocks.transpose(1, 2)
|
||||
normed_q = q
|
||||
if alpha > 0: # alpha in oft/boft is for constraint
|
||||
if alpha > 0: # alpha in oft/boft is for constraint
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > alpha:
|
||||
normed_q = q * alpha / q_norm
|
||||
# use float() to prevent unsupported type in .inverse()
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
r = r.to(weight)
|
||||
# Create I in weight's dtype for the einsum
|
||||
I_w = torch.eye(block_size, device=weight.device, dtype=weight.dtype)
|
||||
_, *shape = weight.shape
|
||||
lora_diff = torch.einsum(
|
||||
"k n m, k n ... -> k m ...",
|
||||
(r * strength) - strength * I,
|
||||
(r * strength) - strength * I_w,
|
||||
weight.view(block_num, block_size, *shape),
|
||||
).view(-1, *shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
weight = weight_decompose(
|
||||
dora_scale,
|
||||
weight,
|
||||
lora_diff,
|
||||
alpha,
|
||||
strength,
|
||||
intermediate_dtype,
|
||||
function,
|
||||
)
|
||||
else:
|
||||
weight += function((strength * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
|
||||
def _get_orthogonal_matrix(self, device, dtype):
|
||||
"""Compute the orthogonal rotation matrix R from OFT blocks."""
|
||||
v = self.weights
|
||||
blocks = v[0].to(device=device, dtype=dtype)
|
||||
alpha = v[2]
|
||||
if alpha is None:
|
||||
alpha = 0
|
||||
|
||||
block_num, block_size, _ = blocks.shape
|
||||
I = torch.eye(block_size, device=device, dtype=dtype)
|
||||
|
||||
# Q = blocks - blocks^T (skew-symmetric)
|
||||
q = blocks - blocks.transpose(1, 2)
|
||||
normed_q = q
|
||||
|
||||
# Apply constraint if alpha > 0
|
||||
if alpha > 0:
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > alpha:
|
||||
normed_q = q * alpha / q_norm
|
||||
|
||||
# Cayley transform: R = (I + Q)(I - Q)^-1
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
return r, block_num, block_size
|
||||
|
||||
def g(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Output transformation for OFT: applies orthogonal rotation to output.
|
||||
|
||||
OFT transforms the output channels using block-diagonal orthogonal matrices.
|
||||
|
||||
Reference: LyCORIS DiagOFTModule._bypass_forward
|
||||
"""
|
||||
v = self.weights
|
||||
rescale = v[1]
|
||||
|
||||
r, block_num, block_size = self._get_orthogonal_matrix(y.device, y.dtype)
|
||||
|
||||
# Apply multiplier to interpolate between identity and full transform
|
||||
multiplier = getattr(self, "multiplier", 1.0)
|
||||
I = torch.eye(block_size, device=y.device, dtype=y.dtype)
|
||||
r = r * multiplier + (1 - multiplier) * I
|
||||
|
||||
# Use module info from bypass injection to determine conv vs linear
|
||||
is_conv = getattr(self, "is_conv", y.dim() > 2)
|
||||
|
||||
if is_conv:
|
||||
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
|
||||
y = y.transpose(1, -1)
|
||||
|
||||
# y now has channels in last dim
|
||||
*batch_shape, out_features = y.shape
|
||||
|
||||
# Reshape to apply block-diagonal transform
|
||||
# (*, out_features) -> (*, block_num, block_size)
|
||||
y_blocked = y.view(*batch_shape, block_num, block_size)
|
||||
|
||||
# Apply orthogonal transform: R @ y for each block
|
||||
# r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size)
|
||||
out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked)
|
||||
|
||||
# Reshape back: (*, block_num, block_size) -> (*, out_features)
|
||||
out = out_blocked.view(*batch_shape, out_features)
|
||||
|
||||
# Apply rescale if present
|
||||
if rescale is not None:
|
||||
rescale = rescale.to(device=y.device, dtype=y.dtype)
|
||||
out = out * rescale.view(-1)
|
||||
|
||||
if is_conv:
|
||||
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
|
||||
out = out.transpose(1, -1)
|
||||
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user