Initial ops changes to use comfy_kitchen: Initial nvfp4 checkpoint support. (#11635)

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
This commit is contained in:
comfyanonymous
2026-01-05 18:48:58 -08:00
committed by GitHub
parent 4f3f9e72a9
commit 6da00dd899
8 changed files with 223 additions and 795 deletions

View File

@@ -103,18 +103,18 @@ class TestMixedPrecisionOps(unittest.TestCase):
# Verify weights are wrapped in QuantizedTensor
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")
self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout")
# Layer 2 should NOT be quantized
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
# Layer 3 should be quantized
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")
self.assertEqual(model.layer3.weight._layout_cls, "TensorCoreFP8E4M3Layout")
# Verify scales were loaded
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5)
self.assertEqual(model.layer1.weight._params.scale.item(), 2.0)
self.assertEqual(model.layer3.weight._params.scale.item(), 1.5)
# Forward pass
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
@@ -154,8 +154,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
# Verify layer1.weight is a QuantizedTensor with scale preserved
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")
self.assertEqual(state_dict2["layer1.weight"]._params.scale.item(), 3.0)
self.assertEqual(state_dict2["layer1.weight"]._layout_cls, "TensorCoreFP8E4M3Layout")
# Verify non-quantized layers are standard tensors
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)