Support nested tensor denoise masks. (#11431)
This commit is contained in:
@@ -984,9 +984,6 @@ class CFGGuider:
|
|||||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||||
device = self.model_patcher.load_device
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
if denoise_mask is not None:
|
|
||||||
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
|
|
||||||
|
|
||||||
noise = noise.to(device)
|
noise = noise.to(device)
|
||||||
latent_image = latent_image.to(device)
|
latent_image = latent_image.to(device)
|
||||||
sigmas = sigmas.to(device)
|
sigmas = sigmas.to(device)
|
||||||
@@ -1013,6 +1010,24 @@ class CFGGuider:
|
|||||||
else:
|
else:
|
||||||
latent_shapes = [latent_image.shape]
|
latent_shapes = [latent_image.shape]
|
||||||
|
|
||||||
|
if denoise_mask is not None:
|
||||||
|
if denoise_mask.is_nested:
|
||||||
|
denoise_masks = denoise_mask.unbind()
|
||||||
|
denoise_masks = denoise_masks[:len(latent_shapes)]
|
||||||
|
else:
|
||||||
|
denoise_masks = [denoise_mask]
|
||||||
|
|
||||||
|
for i in range(len(denoise_masks), len(latent_shapes)):
|
||||||
|
denoise_masks.append(torch.ones(latent_shapes[i]))
|
||||||
|
|
||||||
|
for i in range(len(denoise_masks)):
|
||||||
|
denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
|
||||||
|
|
||||||
|
if len(denoise_masks) > 1:
|
||||||
|
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
|
||||||
|
else:
|
||||||
|
denoise_mask = denoise_masks[0]
|
||||||
|
|
||||||
self.conds = {}
|
self.conds = {}
|
||||||
for k in self.original_conds:
|
for k in self.original_conds:
|
||||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||||
|
|||||||
Reference in New Issue
Block a user