Resolution bucketing and Trainer implementation refactoring (#11117)
This commit is contained in:
@@ -1125,6 +1125,99 @@ class MergeTextListsNode(TextProcessingNode):
|
||||
# ========== Training Dataset Nodes ==========
|
||||
|
||||
|
||||
class ResolutionBucket(io.ComfyNode):
|
||||
"""Bucket latents and conditions by resolution for efficient batch training."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ResolutionBucket",
|
||||
display_name="Resolution Bucket",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
io.Latent.Input(
|
||||
"latents",
|
||||
tooltip="List of latent dicts to bucket by resolution.",
|
||||
),
|
||||
io.Conditioning.Input(
|
||||
"conditioning",
|
||||
tooltip="List of conditioning lists (must match latents length).",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(
|
||||
display_name="latents",
|
||||
is_output_list=True,
|
||||
tooltip="List of batched latent dicts, one per resolution bucket.",
|
||||
),
|
||||
io.Conditioning.Output(
|
||||
display_name="conditioning",
|
||||
is_output_list=True,
|
||||
tooltip="List of condition lists, one per resolution bucket.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, latents, conditioning):
|
||||
# latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1
|
||||
# conditioning: list[list[cond]]
|
||||
|
||||
# Validate lengths match
|
||||
if len(latents) != len(conditioning):
|
||||
raise ValueError(
|
||||
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
|
||||
)
|
||||
|
||||
# Flatten latents and conditions to individual samples
|
||||
flat_latents = [] # list of (C, H, W) tensors
|
||||
flat_conditions = [] # list of condition lists
|
||||
|
||||
for latent_dict, cond in zip(latents, conditioning):
|
||||
samples = latent_dict["samples"] # (B, C, H, W)
|
||||
batch_size = samples.shape[0]
|
||||
|
||||
# cond is a list of conditions with length == batch_size
|
||||
for i in range(batch_size):
|
||||
flat_latents.append(samples[i]) # (C, H, W)
|
||||
flat_conditions.append(cond[i]) # single condition
|
||||
|
||||
# Group by resolution (H, W)
|
||||
buckets = {} # (H, W) -> {"latents": list, "conditions": list}
|
||||
|
||||
for latent, cond in zip(flat_latents, flat_conditions):
|
||||
# latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W)
|
||||
h, w = latent.shape[-2], latent.shape[-1]
|
||||
key = (h, w)
|
||||
|
||||
if key not in buckets:
|
||||
buckets[key] = {"latents": [], "conditions": []}
|
||||
|
||||
buckets[key]["latents"].append(latent)
|
||||
buckets[key]["conditions"].append(cond)
|
||||
|
||||
# Convert buckets to output format
|
||||
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W)
|
||||
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
|
||||
|
||||
for (h, w), bucket_data in buckets.items():
|
||||
# Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W)
|
||||
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
|
||||
output_latents.append({"samples": stacked_latents})
|
||||
|
||||
# Conditions stay as list of condition lists
|
||||
output_conditions.append(bucket_data["conditions"])
|
||||
|
||||
logging.info(
|
||||
f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples"
|
||||
)
|
||||
|
||||
logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples")
|
||||
return io.NodeOutput(output_latents, output_conditions)
|
||||
|
||||
|
||||
class MakeTrainingDataset(io.ComfyNode):
|
||||
"""Encode images with VAE and texts with CLIP to create a training dataset."""
|
||||
|
||||
@@ -1373,7 +1466,7 @@ class LoadTrainingDataset(io.ComfyNode):
|
||||
shard_path = os.path.join(dataset_dir, shard_file)
|
||||
|
||||
with open(shard_path, "rb") as f:
|
||||
shard_data = torch.load(f, weights_only=True)
|
||||
shard_data = torch.load(f)
|
||||
|
||||
all_latents.extend(shard_data["latents"])
|
||||
all_conditioning.extend(shard_data["conditioning"])
|
||||
@@ -1425,6 +1518,7 @@ class DatasetExtension(ComfyExtension):
|
||||
MakeTrainingDataset,
|
||||
SaveTrainingDataset,
|
||||
LoadTrainingDataset,
|
||||
ResolutionBucket,
|
||||
]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user