Make old scaled fp8 format use the new mixed quant ops system. (#11000)
This commit is contained in:
@@ -2,6 +2,7 @@ import unittest
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
|
||||
# Add comfy to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
@@ -15,6 +16,7 @@ if not has_gpu():
|
||||
|
||||
from comfy import ops
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
import comfy.utils
|
||||
|
||||
|
||||
class SimpleModel(torch.nn.Module):
|
||||
@@ -94,8 +96,9 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
|
||||
}
|
||||
|
||||
state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||
# Create model and load state dict (strict=False because custom loading pops keys)
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Verify weights are wrapped in QuantizedTensor
|
||||
@@ -115,7 +118,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
|
||||
# Forward pass
|
||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||
output = model(input_tensor)
|
||||
with torch.inference_mode():
|
||||
output = model(input_tensor)
|
||||
|
||||
self.assertEqual(output.shape, (5, 40))
|
||||
|
||||
@@ -141,7 +145,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
}
|
||||
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||
state_dict1, _ = comfy.utils.convert_old_quants(state_dict1, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||
model.load_state_dict(state_dict1, strict=False)
|
||||
|
||||
# Save state dict
|
||||
@@ -178,7 +183,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
}
|
||||
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||
state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Add a weight function (simulating LoRA)
|
||||
@@ -215,8 +221,10 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
}
|
||||
|
||||
state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||
|
||||
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||
with self.assertRaises(KeyError):
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user