-
Notifications
You must be signed in to change notification settings - Fork 0
/
patch_embedding.py
36 lines (24 loc) · 1.15 KB
/
patch_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from torch import nn
class PatchEmbedding(nn.Module):
def __init__(self, embed_dim, patch_size, num_patches, dropout, in_channels):
super().__init__()
# Dividing into patches
self.patcher = nn.Sequential(
nn.Conv2d(
in_channels = in_channels,
out_channels = embed_dim,
kernel_size = patch_size,
stride = patch_size
),
nn.Flatten(2))
self.cls_token = nn.Parameter(torch.randn(size = (1, in_channels, embed_dim)), requires_grad = True)
self.position_embeddings = nn.Parameter(torch.randn(size = (1, num_patches + 1, embed_dim), requires_grad = True))
self.dropout = nn.Dropout(p = dropout)
def forward(self, x):
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = self.patcher(x).permute(0, 2, 1)
x = torch.cat([cls_token, x], dim = 1) # adding cls_token to left
x = self.position_embeddings + x # adding position embeddings to patches
x = self.dropout(x)
return x