Fix chroma fp8 te being treated as fp16. (#11795)
This commit is contained in:
@@ -36,7 +36,7 @@ def te(dtype_t5=None, t5_quantization_metadata=None):
|
|||||||
if t5_quantization_metadata is not None:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
if dtype is None:
|
if dtype_t5 is not None:
|
||||||
dtype = dtype_t5
|
dtype = dtype_t5
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
return CosmosTEModel_
|
return CosmosTEModel_
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
|
|||||||
if t5_quantization_metadata is not None:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
if dtype is None:
|
if dtype_t5 is not None:
|
||||||
dtype = dtype_t5
|
dtype = dtype_t5
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
return MochiTEModel_
|
return MochiTEModel_
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
|
|||||||
if t5_quantization_metadata is not None:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
if dtype is None:
|
if dtype_t5 is not None:
|
||||||
dtype = dtype_t5
|
dtype = dtype_t5
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
return PixArtTEModel_
|
return PixArtTEModel_
|
||||||
|
|||||||
Reference in New Issue
Block a user