diff --git a/modules/modelSetup/BaseChromaSetup.py b/modules/modelSetup/BaseChromaSetup.py index 183ea3a2f..e3ab10f81 100644 --- a/modules/modelSetup/BaseChromaSetup.py +++ b/modules/modelSetup/BaseChromaSetup.py @@ -228,7 +228,7 @@ def predict( image_seq_len = packed_latent_input.shape[1] image_attention_mask = torch.full((packed_latent_input.shape[0], image_seq_len), True, dtype=torch.bool, device=text_attention_mask.device) - attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) + attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) if not torch.all(text_attention_mask) else None packed_predicted_flow = model.transformer( hidden_states=packed_latent_input.to(dtype=model.train_dtype.torch_dtype()), diff --git a/modules/modelSetup/BaseQwenSetup.py b/modules/modelSetup/BaseQwenSetup.py index 93984fd95..6aeab5c20 100644 --- a/modules/modelSetup/BaseQwenSetup.py +++ b/modules/modelSetup/BaseQwenSetup.py @@ -147,7 +147,7 @@ def predict( #FIXME bug workaround for https://github.com/huggingface/diffusers/issues/12294 image_attention_mask=torch.ones((packed_latent_input.shape[0], packed_latent_input.shape[1]), dtype=torch.bool, device=latent_image.device) attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) - attention_mask_2d = attention_mask[:, None, None, :] * attention_mask[:, None, :, None] + attention_mask_2d = attention_mask[:, None, None, :] if not torch.all(text_attention_mask) else None packed_predicted_flow = model.transformer( hidden_states=packed_latent_input.to(dtype=model.train_dtype.torch_dtype()),