@@ -65,8 +65,9 @@ Transformer <- torch::nn_module(
65
65
headNorm = torch :: nn_layer_norm ,
66
66
attNorm = torch :: nn_layer_norm ,
67
67
dimHidden ){
68
- self $ embedding <- Embedding(catFeatures , dimToken )
68
+ self $ embedding <- Embedding(catFeatures + 1 , dimToken ) # + 1 for padding idx
69
69
dimToken <- dimToken + numFeatures # because I concatenate numerical features to embedding
70
+ self $ classToken <- ClassToken(dimToken )
70
71
71
72
self $ layers <- torch :: nn_module_list(lapply(1 : numBlocks ,
72
73
function (x ) {
@@ -93,24 +94,33 @@ Transformer <- torch::nn_module(
93
94
},
94
95
forward = function (x_num , x_cat ){
95
96
x_cat <- self $ embedding(x_cat )
97
+ if (! is.null(x_num )) {
96
98
x <- torch :: torch_cat(list (x_cat , x_num ), dim = 2L )
99
+ } else {
100
+ x <- x_cat
101
+ }
102
+ x <- self $ classToken(x )
97
103
for (i in 1 : length(self $ layers )) {
98
104
layer <- self $ layers [[i ]]
99
105
xResidual <- self $ startResidual(layer , ' attention' , x )
100
106
101
107
if (i == length(self $ layers )) {
102
- xResidual <- layer $ attention(xResidual [,- 1 ], xResidual ) # in final layer take only attention on CLS token
103
- x <- x [,- 1 ]
108
+ dims <- xResidual $ shape
109
+ # in final layer take only attention on CLS token
110
+ xResidual <- layer $ attention(xResidual [,- 1 ]$ view(c(dims [1 ], 1 , dims [3 ])),
111
+ xResidual , xResidual )
112
+ xResidual <- xResidual [[1 ]]
113
+ x <- x [,- 1 ]$ view(c(dims [1 ], 1 , dims [3 ]))
104
114
} else {
105
115
xResidual <- layer $ attention(xResidual , xResidual )
106
116
}
107
117
x <- self $ endResidual(layer , ' attention' , x , xResidual )
108
118
109
- xResidual <- self $ startResidual(layer , ' ffn' , x , xResidual )
119
+ xResidual <- self $ startResidual(layer , ' ffn' , x )
110
120
xResidual <- layer $ ffn(xResidual )
111
121
x <- self $ endResidual(layer , ' ffn' , x , xResidual )
112
122
}
113
- x <- self $ head(x )
123
+ x <- self $ head(x )[, 1 ] # remove singleton dimension
114
124
return (x )
115
125
},
116
126
startResidual = function (layer , stage , x ) {
@@ -123,7 +133,7 @@ Transformer <- torch::nn_module(
123
133
},
124
134
endResidual = function (layer , stage , x , xResidual ) {
125
135
dropoutKey <- paste0(stage , ' ResDropout' )
126
- xResidual <- layer $ dropoutKey(xResidual )
136
+ xResidual <- layer [[ dropoutKey ]] (xResidual )
127
137
x <- x + xResidual
128
138
return (x )
129
139
}
@@ -167,11 +177,28 @@ Head <- torch::nn_module(
167
177
Embedding <- torch :: nn_module(
168
178
name = ' Embedding' ,
169
179
initialize = function (numEmbeddings , embeddingDim ) {
170
- self $ embedding <- torch :: nn_embedding(numEmbeddings , embeddingDim )
171
- categoryOffsets <- torch :: torch_arange(1 , numEmbeddings )
180
+ self $ embedding <- torch :: nn_embedding(numEmbeddings , embeddingDim , padding_idx = 1 )
181
+ categoryOffsets <- torch :: torch_arange(1 , numEmbeddings , dtype = torch :: torch_long() )
172
182
self $ register_buffer(' categoryOffsets' , categoryOffsets , persistent = FALSE )
173
183
},
174
184
forward = function (x_cat ) {
175
- x <- self $ embedding(x_cat * self $ categoryOffsets )
185
+ x <- self $ embedding(x_cat * self $ categoryOffsets + 1L )
176
186
}
177
187
)
188
+
189
+ # adds a class token embedding to embeddings
190
+ ClassToken <- torch :: nn_module(
191
+ name = ' ClassToken' ,
192
+ initialize = function (dimToken ) {
193
+ self $ weight <- torch :: nn_parameter(torch :: torch_empty(dimToken ,1 ))
194
+ torch :: nn_init_kaiming_uniform_(self $ weight , a = sqrt(5 ))
195
+ },
196
+ expand = function (dims ) {
197
+ newDims <- vector(" integer" , length(dims ) - 1 ) + 1
198
+ return (self $ weight $ view(c(newDims ,- 1 ))$ expand(c(dims , - 1 )))
199
+
200
+ },
201
+ forward = function (x ) {
202
+ return (torch :: torch_cat(c(x , self $ expand(c(dim(x )[[1 ]], 1 ))), dim = 2 ))
203
+ }
204
+ )
0 commit comments