diff --git a/modules/stage_c.py b/modules/stage_c.py index 45f8bab..8a55313 100755 --- a/modules/stage_c.py +++ b/modules/stage_c.py @@ -155,7 +155,7 @@ def gen_r_embedding(self, r, max_positions=10000): def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img): clip_txt = self.clip_txt_mapper(clip_txt) if len(clip_txt_pooled.shape) == 2: - clip_txt_pool = clip_txt_pooled.unsqueeze(1) + clip_txt_pooled = clip_txt_pooled.unsqueeze(1) if len(clip_img.shape) == 2: clip_img = clip_img.unsqueeze(1) clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)