Make old scaled fp8 format use the new mixed quant ops system. (#11000)
This commit is contained in:
@@ -63,12 +63,12 @@ class FluxClipModel(torch.nn.Module):
|
||||
else:
|
||||
return self.t5xxl.load_sd(sd)
|
||||
|
||||
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
def flux_clip(dtype_t5=None, t5_quantization_metadata=None):
|
||||
class FluxClipModel_(FluxClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||
return FluxClipModel_
|
||||
|
||||
@@ -159,15 +159,13 @@ class Flux2TEModel(sd1_clip.SD1ClipModel):
|
||||
out = out.reshape(out.shape[0], out.shape[1], -1)
|
||||
return out, pooled, extra
|
||||
|
||||
def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False):
|
||||
def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
|
||||
class Flux2TEModel_(Flux2TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
if pruned:
|
||||
model_options = model_options.copy()
|
||||
|
||||
Reference in New Issue
Block a user