[Weight-adapter/Trainer] Bypass forward mode in Weight adapter system (#11958)

* Add API of bypass forward module

* bypass implementation

* add bypass fwd into nodes list/trainer
This commit is contained in:
Kohaku-Blueleaf
2026-01-25 11:56:22 +08:00
committed by GitHub
parent 635406e283
commit a97c98068f
12 changed files with 2039 additions and 101 deletions

View File

@@ -2,6 +2,7 @@ import logging
from typing import Optional
import torch
import torch.nn.functional as F
import comfy.model_management
from .base import (
WeightAdapterBase,
@@ -20,11 +21,7 @@ class LoraDiff(WeightAdapterTrainBase):
rank, in_dim = mat2.shape[0], mat2.shape[1]
if mid is not None:
convdim = mid.ndim - 2
layer = (
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d
)[convdim]
layer = (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)[convdim]
else:
layer = torch.nn.Linear
self.lora_up = layer(rank, out_dim, bias=False)
@@ -51,6 +48,78 @@ class LoraDiff(WeightAdapterTrainBase):
weight = w + scale * diff.reshape(w.shape)
return weight.to(org_dtype)
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
Additive bypass component for LoRA training: h(x) = up(down(x)) * scale
Simple implementation using the nn.Module weights directly.
No mid/dora/reshape branches (create_train doesn't create them).
Args:
x: Input tensor
base_out: Output from base forward (unused, for API consistency)
"""
# Compute scale = alpha / rank * multiplier
scale = (self.alpha / self.rank) * getattr(self, "multiplier", 1.0)
# Get module info from bypass injection
is_conv = getattr(self, "is_conv", False)
conv_dim = getattr(self, "conv_dim", 0)
kw_dict = getattr(self, "kw_dict", {})
# Get weights (keep in original dtype for numerical stability)
down_weight = self.lora_down.weight
up_weight = self.lora_up.weight
if is_conv:
# Conv path: use functional conv
# conv_dim: 1=conv1d, 2=conv2d, 3=conv3d
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
# Reshape 2D weights to conv format if needed
# down: [rank, in_features] -> [rank, in_channels, *kernel_size]
# up: [out_features, rank] -> [out_features, rank, 1, 1, ...]
if down_weight.dim() == 2:
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
in_channels = getattr(self, "in_channels", None)
if in_channels is not None:
down_weight = down_weight.view(
down_weight.shape[0], in_channels, *kernel_size
)
else:
# Fallback: assume 1x1 kernel
down_weight = down_weight.view(
*down_weight.shape, *([1] * conv_dim)
)
if up_weight.dim() == 2:
# up always uses 1x1 kernel
up_weight = up_weight.view(*up_weight.shape, *([1] * conv_dim))
# down conv uses stride/padding from module, up is 1x1
hidden = conv_fn(x, down_weight, **kw_dict)
# mid layer if exists (tucker decomposition)
if self.lora_mid is not None:
mid_weight = self.lora_mid.weight
if mid_weight.dim() == 2:
mid_weight = mid_weight.view(*mid_weight.shape, *([1] * conv_dim))
hidden = conv_fn(hidden, mid_weight)
# up conv is always 1x1 (no stride/padding)
out = conv_fn(hidden, up_weight)
else:
# Linear path: simple matmul chain
hidden = F.linear(x, down_weight)
# mid layer if exists
if self.lora_mid is not None:
mid_weight = self.lora_mid.weight
hidden = F.linear(hidden, mid_weight)
out = F.linear(hidden, up_weight)
return out * scale
def passive_memory_usage(self):
return sum(param.numel() * param.element_size() for param in self.parameters())
@@ -70,9 +139,7 @@ class LoRAAdapter(WeightAdapterBase):
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
torch.nn.init.constant_(mat2, 0.0)
return LoraDiff(
(mat1, mat2, alpha, None, None, None)
)
return LoraDiff((mat1, mat2, alpha, None, None, None))
def to_train(self):
return LoraDiff(self.weights)
@@ -210,3 +277,85 @@ class LoRAAdapter(WeightAdapterBase):
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
Additive bypass component for LoRA: h(x) = up(down(x)) * scale
Note:
Does not access original model weights - bypass mode is designed
for quantized models where weights may not be accessible.
Args:
x: Input tensor
base_out: Output from base forward (unused, for API consistency)
Reference: LyCORIS functional/locon.py bypass_forward_diff
"""
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
v = self.weights
# v[0]=up, v[1]=down, v[2]=alpha, v[3]=mid, v[4]=dora_scale, v[5]=reshape
up = v[0]
down = v[1]
alpha = v[2]
mid = v[3]
# Compute scale = alpha / rank
rank = down.shape[0]
if alpha is not None:
scale = alpha / rank
else:
scale = 1.0
scale = scale * getattr(self, "multiplier", 1.0)
# Cast dtype
up = up.to(dtype=x.dtype)
down = down.to(dtype=x.dtype)
# Use module info from bypass injection, not weight dimension
is_conv = getattr(self, "is_conv", False)
conv_dim = getattr(self, "conv_dim", 0)
kw_dict = getattr(self, "kw_dict", {})
if is_conv:
op = FUNC_LIST[
conv_dim + 2
] # conv_dim 1->conv1d(3), 2->conv2d(4), 3->conv3d(5)
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
in_channels = getattr(self, "in_channels", None)
# Reshape 2D weights to conv format using kernel_size
# down: [rank, in_channels * prod(kernel_size)] -> [rank, in_channels, *kernel_size]
# up: [out_channels, rank] -> [out_channels, rank, 1, 1, ...] (1x1 kernel)
if down.dim() == 2:
# down.shape[1] = in_channels * prod(kernel_size)
if in_channels is not None:
down = down.view(down.shape[0], in_channels, *kernel_size)
else:
# Fallback: assume 1x1 kernel if in_channels unknown
down = down.view(*down.shape, *([1] * conv_dim))
if up.dim() == 2:
# up always uses 1x1 kernel
up = up.view(*up.shape, *([1] * conv_dim))
if mid is not None:
mid = mid.to(dtype=x.dtype)
if mid.dim() == 2:
mid = mid.view(*mid.shape, *([1] * conv_dim))
else:
op = F.linear
kw_dict = {} # linear doesn't take stride/padding
# Simple chain: down -> mid (if tucker) -> up
if mid is not None:
if not is_conv:
mid = mid.to(dtype=x.dtype)
hidden = op(x, down)
hidden = op(hidden, mid, **kw_dict)
out = op(hidden, up)
else:
hidden = op(x, down, **kw_dict)
out = op(hidden, up)
return out * scale