From 62bd2bac66418a64476d3b07a63e2f0e37559f91 Mon Sep 17 00:00:00 2001 From: egillax Date: Mon, 25 Nov 2024 12:48:54 +0100 Subject: [PATCH] make resnet and mlp responsible to ensure embeddings are correct dims --- inst/python/MultiLayerPerceptron.py | 2 ++ inst/python/ResNet.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/inst/python/MultiLayerPerceptron.py b/inst/python/MultiLayerPerceptron.py index 46bb86f..0eb0eae 100644 --- a/inst/python/MultiLayerPerceptron.py +++ b/inst/python/MultiLayerPerceptron.py @@ -53,6 +53,8 @@ def __init__( def forward(self, input): x_cat = input["cat"] x_cat = self.embedding(x_cat) + if x_cat.dim() == 3: + x_cat = x_cat.mean(dim=1) if "num" in input.keys() and self.num_embedding is not None: x_num = input["num"] x = (x_cat + self.num_embedding(x_num).mean(dim=1)) / 2 diff --git a/inst/python/ResNet.py b/inst/python/ResNet.py index 6a4d7e7..cf381cc 100644 --- a/inst/python/ResNet.py +++ b/inst/python/ResNet.py @@ -66,6 +66,8 @@ def __init__( def forward(self, x): x_cat = x["cat"] x_cat = self.embedding(x_cat) + if x_cat.dim() == 3: + x_cat = x_cat.mean(dim=1) if ( "num" in x.keys() and x["num"] is not None