dequantization offload accounting (fixes Flux2 OOMs - incl TEs) (#11171)

* make setattr safe for non existent attributes

Handle the case where the attribute doesnt exist by returning a static
sentinel (distinct from None). If the sentinel is passed in as the set
value, del the attr.

* Account for dequantization and type-casts in offload costs

When measuring the cost of offload, identify weights that need a type
change or dequantization and add the size of the conversion result
to the offload cost.

This is mutually exclusive with lowvram patches which already has
a large conservative estimate and wont overlap the dequant cost so\
dont double count.

* Set the compute type on CLIP MPs

So that the loader can know the size of weights for dequant accounting.
This commit is contained in:
rattus
2025-12-09 14:21:31 +10:00
committed by GitHub
parent d50f342c90
commit e136b6dbb0
3 changed files with 22 additions and 8 deletions

View File

@@ -35,6 +35,7 @@ import comfy.model_management
import comfy.patcher_extension
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@@ -665,12 +666,18 @@ class ModelPatcher:
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if weight_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
if bias_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
def check_module_offload_mem(key):
if key in self.patches:
return low_vram_patch_estimate_vram(self.model, key)
model_dtype = getattr(self.model, "manual_cast_dtype", None)
weight, _, _ = get_key_weight(self.model, key)
if model_dtype is None or weight is None:
return 0
if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
return weight.numel() * model_dtype.itemsize
return 0
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
loading.append((module_offload_mem, module_mem, n, m, params))
return loading