Skip to content

Commit 12ddb4f

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 5db6aa2 commit 12ddb4f

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

src/multigrate/module/_multivae_torch.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -455,9 +455,7 @@ def generative(
455455
rs = [self._h_to_x(z, mod) for mod, z in enumerate(zs)]
456456
return {"rs": rs}
457457

458-
def loss(
459-
self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0
460-
) -> Tuple[
458+
def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0) -> Tuple[
461459
torch.FloatTensor,
462460
Dict[str, torch.FloatTensor],
463461
torch.FloatTensor,

0 commit comments

Comments
 (0)