Skip to content

Commit b74341f

Browse files
committed
worked on transformer, new embedding layer
1 parent d1c0293 commit b74341f

File tree

4 files changed

+66
-22
lines changed

4 files changed

+66
-22
lines changed

R/Dataset.R

+14-6
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,21 @@ Dataset <- torch::dataset(
2929
dataCat <- data[, !numericalIndex]
3030
self$cat <- torch::torch_tensor(as.matrix(dataCat), dtype=torch::torch_long())
3131

32+
# for (i in dim(data)[[1]])
33+
#
3234
# comment out the sparse matrix for now, is really slow need to find
3335
# a better solution for converting it to dense before feeding to model
3436
# matrix <- as(dataCat, 'dgTMatrix') # convert to triplet sparse format
3537
# sparseIndices <- torch::torch_tensor(matrix(c(matrix@i + 1, matrix@j + 1), ncol=2), dtype = torch::torch_long())
3638
# values <- torch::torch_tensor(matrix(c(matrix@x)), dtype = torch::torch_float32())
37-
# self$cat <- torch::torch_sparse_coo_tensor(indices=sparseIndices$t(),
38-
# values=values$squeeze(),
39-
# dtype=torch::torch_float32())$coalesce()
39+
# self$cat <- torch::torch_sparse_coo_tensor(indices=sparseIndices$t(),
40+
# values=values$squeeze(),
41+
# dtype=torch::torch_float32())$coalesce()
42+
if (sum(numericalIndex) == 0) {
43+
self$num <- NULL
44+
} else {
4045
self$num <- torch::torch_tensor(as.matrix(data[,numericalIndex, drop = F]), dtype=torch::torch_float32())
46+
}
4147
},
4248

4349
.getNumericalIndex = function() {
@@ -53,9 +59,11 @@ Dataset <- torch::dataset(
5359
},
5460

5561
numNumFeatures = function() {
56-
return (
57-
self$num$shape[2]
58-
)
62+
if (!is.null(self$num)) {
63+
return(self$num$shape[2])
64+
} else {
65+
return(0)
66+
}
5967
},
6068

6169
.getbatch = function(item) {

R/Estimator.R

+4
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,11 @@ Estimator <- R6::R6Class(
475475
coro::loop(for (b in batchIndex) {
476476
self$optimizer$zero_grad()
477477
cat <- dataset[b]$cat$to(device=self$device)
478+
if (!is.null(dataset[b]$num)) {
478479
num <- dataset[b]$num$to(device=self$device)
480+
} else {
481+
num <- dataset[b]$num
482+
}
479483
target <- dataset[b]$target$to(device=self$device)
480484
out <- self$model(num, cat)
481485
loss <- self$criterion(out, target)

R/Transformer.R

+36-9
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ Transformer <- torch::nn_module(
6565
headNorm=torch::nn_layer_norm,
6666
attNorm=torch::nn_layer_norm,
6767
dimHidden){
68-
self$embedding <- Embedding(catFeatures, dimToken)
68+
self$embedding <- Embedding(catFeatures + 1, dimToken) # + 1 for padding idx
6969
dimToken <- dimToken + numFeatures # because I concatenate numerical features to embedding
70+
self$classToken <- ClassToken(dimToken)
7071

7172
self$layers <- torch::nn_module_list(lapply(1:numBlocks,
7273
function(x) {
@@ -93,24 +94,33 @@ Transformer <- torch::nn_module(
9394
},
9495
forward = function(x_num, x_cat){
9596
x_cat <- self$embedding(x_cat)
97+
if (!is.null(x_num)) {
9698
x <- torch::torch_cat(list(x_cat, x_num), dim=2L)
99+
} else {
100+
x <- x_cat
101+
}
102+
x <- self$classToken(x)
97103
for (i in 1:length(self$layers)) {
98104
layer <- self$layers[[i]]
99105
xResidual <- self$startResidual(layer, 'attention', x)
100106

101107
if (i==length(self$layers)) {
102-
xResidual <- layer$attention(xResidual[,-1], xResidual) # in final layer take only attention on CLS token
103-
x <- x[,-1]
108+
dims <- xResidual$shape
109+
# in final layer take only attention on CLS token
110+
xResidual <- layer$attention(xResidual[,-1]$view(c(dims[1], 1, dims[3])),
111+
xResidual, xResidual)
112+
xResidual <- xResidual[[1]]
113+
x <- x[,-1]$view(c(dims[1], 1, dims[3]))
104114
} else {
105115
xResidual <- layer$attention(xResidual, xResidual)
106116
}
107117
x <- self$endResidual(layer, 'attention', x, xResidual)
108118

109-
xResidual <- self$startResidual(layer, 'ffn', x, xResidual)
119+
xResidual <- self$startResidual(layer, 'ffn', x)
110120
xResidual <- layer$ffn(xResidual)
111121
x <- self$endResidual(layer, 'ffn', x, xResidual)
112122
}
113-
x <- self$head(x)
123+
x <- self$head(x)[,1] # remove singleton dimension
114124
return(x)
115125
},
116126
startResidual = function(layer, stage, x) {
@@ -123,7 +133,7 @@ Transformer <- torch::nn_module(
123133
},
124134
endResidual = function(layer, stage, x, xResidual) {
125135
dropoutKey <- paste0(stage, 'ResDropout')
126-
xResidual <-layer$dropoutKey(xResidual)
136+
xResidual <-layer[[dropoutKey]](xResidual)
127137
x <- x + xResidual
128138
return(x)
129139
}
@@ -167,11 +177,28 @@ Head <- torch::nn_module(
167177
Embedding <- torch::nn_module(
168178
name='Embedding',
169179
initialize = function(numEmbeddings, embeddingDim) {
170-
self$embedding <- torch::nn_embedding(numEmbeddings, embeddingDim)
171-
categoryOffsets <- torch::torch_arange(1, numEmbeddings)
180+
self$embedding <- torch::nn_embedding(numEmbeddings, embeddingDim, padding_idx = 1)
181+
categoryOffsets <- torch::torch_arange(1, numEmbeddings, dtype=torch::torch_long())
172182
self$register_buffer('categoryOffsets', categoryOffsets, persistent=FALSE)
173183
},
174184
forward = function(x_cat) {
175-
x <- self$embedding(x_cat * self$categoryOffsets)
185+
x <- self$embedding(x_cat * self$categoryOffsets + 1L)
176186
}
177187
)
188+
189+
# adds a class token embedding to embeddings
190+
ClassToken <- torch::nn_module(
191+
name='ClassToken',
192+
initialize = function(dimToken) {
193+
self$weight <- torch::nn_parameter(torch::torch_empty(dimToken,1))
194+
torch::nn_init_kaiming_uniform_(self$weight, a=sqrt(5))
195+
},
196+
expand = function(dims) {
197+
newDims <- vector("integer", length(dims) - 1) + 1
198+
return (self$weight$view(c(newDims,-1))$expand(c(dims, -1)))
199+
200+
},
201+
forward = function(x) {
202+
return(torch::torch_cat(c(x, self$expand(c(dim(x)[[1]], 1))), dim=2))
203+
}
204+
)

extras/example.R

+12-7
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,23 @@ plpData <- simulatePlpData(
1111
n = sampleSize
1212
)
1313

14-
1514
populationSet <- PatientLevelPrediction::createStudyPopulationSettings(
1615
requireTimeAtRisk = F,
1716
riskWindowStart = 1,
1817
riskWindowEnd = 365)
1918

20-
21-
modelSettings <- setResNet(numLayers = 2, sizeHidden = 64, hiddenFactor = 1,
22-
residualDropout = 0, hiddenDropout = 0.2, normalization = 'BatchNorm',
23-
activation = 'RelU', sizeEmbedding = 64, weightDecay = 1e-6,
24-
learningRate = 3e-4, seed = 42, hyperParamSearch = 'random',
25-
randomSample = 1, device = 'cuda:0',batchSize = 32,epochs = 1)
19+
#
20+
# modelSettings <- setResNet(numLayers = 2, sizeHidden = 64, hiddenFactor = 1,
21+
# residualDropout = 0, hiddenDropout = 0.2, normalization = 'BatchNorm',
22+
# activation = 'RelU', sizeEmbedding = 64, weightDecay = 1e-6,
23+
# learningRate = 3e-4, seed = 42, hyperParamSearch = 'random',
24+
# randomSample = 1, device = 'cuda:0',batchSize = 32,epochs = 1)
25+
26+
modelSettings <- setTransformer(numBlocks=1, dimToken = 12, dimOut = 1, numHeads = 1,
27+
attDropout = 0.2, ffnDropout = 0.2, resDropout = 0,
28+
dimHidden = 8, batchSize = 4, hyperParamSearch = 'random',
29+
weightDecay = 1e-6, learningRate = 3e-4, epochs = 5,
30+
device = 'cuda:0', randomSamples = 1, seed = 42)
2631

2732
res2 <- PatientLevelPrediction::runPlp(
2833
plpData = plpData,

0 commit comments

Comments
 (0)