-
Notifications
You must be signed in to change notification settings - Fork 6
/
arp.py
1477 lines (1277 loc) · 63 KB
/
arp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Single-file implementation of autoregressive policy. It requires pytorch and timm (`pip install timm torch`).
Running this file will trains a simple chunking causal transformer that generates Binary MNIST images.
```
python arp.py
```
Generated images are saved in `mnist_generated_arp` folder.
"""
import torch, math, random
from collections.abc import Iterable
from collections import defaultdict
from typing import List, Tuple, TypedDict, Union, Dict, Optional, Callable, FrozenSet, Any, TypeVar, Literal
import itertools
from copy import deepcopy
from torch import Tensor, nn
import torch.nn.functional as F
from timm.models.vision_transformer import Mlp
import numpy as np
import torch.distributions as D
##
#region Chunk Transformer Layer
def modulate(x, shift, scale):
""" x: (bs, L, d)
shift: (bs, L, d)
scale: (bs, L, d)
"""
return x * (1 + scale) + shift
class GELU(nn.Module):
def forward(self, x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
def clamp_dtype_min_max(v, dtype, inplace=True):
MIN, MAX = torch.finfo(dtype).min, torch.finfo(dtype).max
return v.clamp_(MIN, MAX) if inplace else v.clamp(MIN, MAX)
class Attention(nn.Module):
def __init__(self, n_embd, n_head, attn_pdrop=0.1, resid_pdrop=0.1, cross=False, clamp_attn=False):
super().__init__()
self.clamp_attn = clamp_attn
assert n_embd % n_head == 0
self.cross = cross
if cross:
self.kv_attn = nn.Linear(n_embd, 2 * n_embd)
self.q_attn = nn.Linear(n_embd, n_embd)
else:
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
self.c_proj = nn.Linear(n_embd, n_embd)
self.attn_dropout = nn.Dropout(attn_pdrop)
self.resid_dropout = nn.Dropout(resid_pdrop)
self.n_head = n_head
self.n_embd = n_embd
def attend(self, q, k):
attn = q @ (k.transpose(-2, -1) / math.sqrt(k.size(-1)))
dtype = attn.dtype
if self.clamp_attn and dtype == torch.float16:
return clamp_dtype_min_max(attn, dtype)
else:
return attn
def forward_interleave(self, xs, attn_masks, dependency_attn_mask=None):
assert not self.cross
B, T, C = xs[0].size()
dev = xs[0].device
qkvs = []
for x in xs:
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
qkvs.append([q,k,v])
(q_star, k_star, v_star), (q_hat, k_hat, v_hat) = qkvs
(mask_star, mask_hat), mask_causal = [~m for m in attn_masks], torch.tril(torch.ones(T, T, device=dev))[None, None, ...] == 0
def mlp(y):
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
y = self.resid_dropout(self.c_proj(y))
return y
def apply_mask(att, mask, val=float('-inf')):
if len(mask.shape) == 3: mask = mask[:, None, :, :]
return att.masked_fill(mask, val)
def merge_attn_logits(att_star, att_hat):
valid_pos = ~torch.isinf(att_hat)
att_star = att_star.clone()
att_star[valid_pos] = att_hat[valid_pos]
return att_star
softmax = nn.Softmax(dim=-1)
att_star = self.attend(q_star, k_star)
if dependency_attn_mask is not None:
att_star = apply_mask(att_star, ~dependency_attn_mask)
y_star = mlp(self.attn_dropout(softmax(apply_mask(att_star, mask_causal))) @ v_star)
attn_hat = merge_attn_logits(
apply_mask(self.attend(q_hat, k_star) , mask_star),
apply_mask(self.attend(q_hat, k_hat) , mask_hat)
)
if dependency_attn_mask is not None: attn_hat = apply_mask(attn_hat, ~dependency_attn_mask)
attn_hat = self.attn_dropout(softmax(attn_hat))
y_hat = apply_mask(attn_hat, mask_star, 0) @ v_star + apply_mask(attn_hat, mask_hat, 0) @ v_hat
y_hat = mlp(y_hat)
return y_star, y_hat
def forward(self, x, c=None, attn_mask=None):
"""
x: (B, T, C) input sequence
c: (B, L, C) context sequence
attn_mask: (B | 1, T, T) attention mask (True means keep, False means blocking)
"""
B, T, C = x.size()
if self.cross:
k, v = self.kv_attn(c).split(self.n_embd, dim=2)
q = self.q_attn(x)
Tc = c.size(1)
k = k.view(B, Tc, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, Tc, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
else:
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
att = self.attend(q, k)
if attn_mask is not None:
attn_mask = attn_mask[:, None, :, :]
att = att.masked_fill(~attn_mask, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
y = self.resid_dropout(self.c_proj(y))
return y
class ChunkTransformerLayer(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, mlp_dropout=0.1,
attn_kwargs={}, cond_attn_kwargs={},
conditional=False, AdaLN=False, norm_before_AdaLN=False):
super().__init__()
self.ln_attn = nn.LayerNorm(hidden_size, elementwise_affine=not AdaLN, eps=1e-6)
self.ln_mlp = nn.LayerNorm(hidden_size, elementwise_affine=not AdaLN, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu,
drop=mlp_dropout)
self.attn = Attention(hidden_size, num_heads, **attn_kwargs)
self.conditional = conditional
self.AdaLN = AdaLN
if conditional:
self.norm_cond = nn.LayerNorm(hidden_size, eps=1e-6)
self.cond_attn = Attention(hidden_size, num_heads, **cond_attn_kwargs, cross=True)
if AdaLN:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.norm_before_AdaLN = norm_before_AdaLN
if norm_before_AdaLN:
self.ln_ada = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# mask will be inversed before used with mask_filled
# therefore, false -> blocking
@staticmethod
def train_attn_masks(chunk_ids):
"""
chunk_ids: (B | 1, L) chunk ids, starting from 0, and ordered
"""
m_star = chunk_ids[:, :, None] > chunk_ids[:, None, :]
m_hat = chunk_ids[:, :, None] == chunk_ids[:, None, :]
return m_star, m_hat
@staticmethod
def eval_attn_mask(chunk_ids):
L = chunk_ids.size(1)
prompt = chunk_ids[:, -1:] # (B, 1)
m = (chunk_ids[:, :, None] == prompt[:, None, :]).repeat(1, 1, L)
m = m | (torch.tril(torch.ones(L, L, device=chunk_ids.device))[None, ...] == 1)
return m
@staticmethod
def dependency_attn_mask(tk_types, block_attn_directions):
"""
tk_types: (B, L)
block_attn_directions: list of (int, int), where each tuple is (curr, other) to block attention
"""
bs, L = tk_types.shape
mask = torch.full([bs, L, L], fill_value=True, device=tk_types.device, dtype=torch.bool)
for from_, to in block_attn_directions:
mask_ = (tk_types[:, :, None] == from_) & (tk_types[:, None, :] == to)
mask = mask & (~mask_)
return mask
def forward_train(self, xs, c, masks, dependency_attn_mask=None):
"""
xs: [(B, T, C), (B, T, C)]
masks: [(B | 1, T, T), (B | 1, T, T)]
"""
is_conditional = self.conditional and c is not None
cond_attns = [self.cond_attn(self.norm_cond(x), c) if is_conditional else 0 for x in xs]
if self.AdaLN and is_conditional:
if self.norm_before_AdaLN: cond_attns = [self.ln_ada(cond_attn) for cond_attn in cond_attns]
gates = [self.adaLN_modulation(cond_attn).chunk(6, dim=-1) for cond_attn in cond_attns]
ys = self.attn.forward_interleave([modulate(self.ln_attn(x), shift_msa, scale_msa)
for x, (shift_msa, scale_msa, _, _, _, _) in zip(xs, gates)], masks, dependency_attn_mask=dependency_attn_mask)
xs = [x + gate_msa * y for x, y, (_, _, gate_msa, _, _, _) in zip(xs, ys, gates)]
xs = [x + gate_mlp * self.mlp(modulate(self.ln_mlp(x), shift_mlp, scale_mlp))
for x, (_, _, _, shift_mlp, scale_mlp, gate_mlp) in zip(xs, gates)]
else:
xs = [x + cond_attn for x, cond_attn in zip(xs, cond_attns)]
ys = self.attn.forward_interleave([self.ln_attn(x) for x in xs], masks, dependency_attn_mask=dependency_attn_mask)
xs = [x + y for x, y in zip(xs, ys)]
xs = [x + self.mlp(self.ln_mlp(x)) for x in xs]
return xs
def forward_inference(self, x, c, mask=None):
"""
x: (B, T, C) input sequence
c: (B, L, C) context sequence
"""
is_conditional = self.conditional and c is not None
cond_attn = self.cond_attn(self.norm_cond(x), c) if is_conditional else 0
if self.AdaLN and is_conditional:
if self.norm_before_AdaLN: cond_attn = self.ln_ada(cond_attn)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(cond_attn).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.ln_attn(x), shift_msa, scale_msa), attn_mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.ln_mlp(x), shift_mlp, scale_mlp))
else:
x = x + cond_attn
x = x + self.attn(self.ln_attn(x), attn_mask=mask)
x = x + self.mlp(self.ln_mlp(x))
return x
#endregion #######################
#region Interface & Utilities
class TokenType(TypedDict):
id: int
name: str
is_control: bool
dim: int # number of dimensions
# NOTE: different tokens may need to be embedded or predicted differently
embedding: str # "discrete", "linear"
predictor: str
embedding_kwargs: Dict
predictor_kwargs: Dict
# NOTE: use for encoding
is_continuous: bool
bounds: List[float] # [min_of_dim1, max_of_dim1, min_of_dim2, max...]
dict_sizes: List[int] # [dict_size_of_dim1, dict_size_of_dim2, ...]
@staticmethod
def make(**kwargs):
return {**{
'is_control': False,
'name': f'token-{random.randint(0, 1000)}',
'dim': 1,
'is_continuous': False,
'dict_sizes': [1],
'embedding_kwargs': {},
'predictor_kwargs': {},
'embedding': 'discrete',
'predictor': 'class'
}, **kwargs}
class LayerType(TypedDict):
name: str
n_head: int
attn_kwargs: dict
cond_attn_kwargs: dict
mlp_dropout: float
mlp_ratio: float
AdaLN: Union[bool, str]
norm_before_AdaLN: bool
condition_on: str = ""
@staticmethod
def make(**kwargs):
return {
'name': "",
'n_head': 4,
'attn_kwargs': dict(attn_pdrop=0.1, resid_pdrop=0.1),
'cond_attn_kwargs': dict(attn_pdrop=0.1, resid_pdrop=0.1),
'mlp_dropout': 0.1,
'mlp_ratio': 4.0,
'norm_before_AdaLN': False,
'AdaLN': False,
'condition_on': "",
**kwargs
}
class ModelConfig:
def __init__(self, n_embd: int = 64, embd_pdrop: float = 0.1, max_seq_len: int = 1024, layer_norm_every_block: bool = True,
max_chunk_size:int = 1, layers: List[LayerType] = [], tokens: List[TokenType] = [], **kwargs):
self.n_embd: int = n_embd
self.embd_pdrop: float = embd_pdrop
self.max_chunk_size: int = max_chunk_size
self.layer_norm_every_block: bool = layer_norm_every_block
self.max_seq_len: int = max_seq_len
self.layers: List[LayerType] = layers
self.tokens: List[TokenType] = tokens
for k, v in kwargs.items():
setattr(self, k, v)
for i, token in enumerate(self.tokens):
token['id'] = i
class IncompleteToken(TypedDict):
chk_id: int
tk_id: int
tk_val: int
def _make_registry():
def chunk(name):
def func(cls):
chunk.map[name] = cls
cls.name = name
return cls
return func
chunk.map = {}
return chunk
register_token_embedding = _make_registry()
register_token_predictor = _make_registry()
T = TypeVar('T')
PerChunk = Union[Dict[Union[int, FrozenSet[int], Literal['default']], T], T]
SampleFunctionT = Callable[[List[Union[Tensor, D.Distribution]]], Tensor]
AttnDirectionsType = List[Union[Tuple[str, str], Tuple[int, int]]]
#endregion Interface ###############
#region Utility ###############
def flatten_per_chunk_dict(dct):
return {k: v for ks, v in dct.items() for k in (ks if isinstance(ks, Iterable) else [ks])}
def pad_last_dim(tensor, target_size, val=0):
if tensor.size(-1) < target_size:
out = torch.full([*tensor.shape[:-1], target_size], fill_value=val, device=tensor.device, dtype=tensor.dtype)
out[:, :, :tensor.size(-1)] = tensor
return out
else:
return tensor
def cat_uneven_blc_tensors(*tensors):
max_dim = max([t.size(-1) for t in tensors])
return torch.cat([pad_last_dim(t, max_dim) for t in tensors], dim=1)
map2 = lambda func, nested_list: list(map(lambda sublist: list(map(func, sublist)), nested_list))
#endregion ####################
#region TokenCoder
class TokenCoder(nn.Module):
def __init__(self, tokens: List[TokenType]):
super().__init__()
self.tokens: List[TokenType] = tokens
def encode(self, tks, tk_ids):
"""
tks: [*, dim], e.g., [B, T, dim]
tk_ids: [*]
return: [..., dim]
"""
tks = tks.float().clone()
tks_shape = tks.shape
for i, token in enumerate(self.tokens):
mask = tk_ids == i
tks[mask] = self.encode_ith(tks[mask], i, inplace=True).float()
return tks.reshape(*tks_shape)
def decode(self, tk_codes, tk_ids):
"""
tk_codes: [*, dim] long
tk_ids: [*]
return: [*, dim] float
"""
tk_codes = tk_codes.float()
tks_shape = tk_codes.shape
for i, token in enumerate(self.tokens):
mask = tk_ids == i
tk_codes[mask] = self.decode_ith(tk_codes[mask], i, inplace=True)
return tk_codes.reshape(*tks_shape)
def need_encoding(self, token: TokenType, is_continuous: bool = None):
if is_continuous is None:
is_continuous = token['is_continuous']
return is_continuous and register_token_embedding.map[token['embedding']].NEED_ENCODED_INPUT
def encode_ith(self, tks: Tensor, i: int, inplace=False):
"""
tks: [*, d], where d is the dimension of the i-th token type
return: [*, d]
"""
token = self.tokens[i]
if self.need_encoding(token):
out = torch.zeros_like(tks)
tks = tks[..., :token['dim']]
if not inplace: tks = tks.clone()
for j in range(token['dim']):
start, end = token['bounds'][2*j], token['bounds'][2*j+1]
tks[..., j].clamp_(start, end)
tks[..., j] -= start
resolution = (end - start) / (token['dict_sizes'][j] - 1)
tks[..., j] /= resolution
tks[..., j].round_()
out[..., :token['dim']] = tks
return out
else:
return tks
def decode_ith(self, tk_codes: Tensor, i: int, inplace=False):
token = self.tokens[i]
if self.need_encoding(token):
out = torch.zeros_like(tk_codes)
tk_codes = tk_codes[..., :token['dim']]
if not inplace: tk_codes = tk_codes.clone()
for j in range(token['dim']):
start, end = token['bounds'][2*j], token['bounds'][2*j+1]
resolution = (end - start) / (token['dict_sizes'][j] - 1)
tk_codes[..., j] = tk_codes[..., j].float() * resolution + start
out[..., :token['dim']] = tk_codes
return out
else:
return tk_codes
#endregion TokenCoder ###############
#region Embedding
class TokenEmbeddingInterface(nn.Module):
NEED_ENCODED_INPUT = False
def __init__(self, n_embd: int, token: TokenType, **kwargs):
super().__init__()
self.n_embd = n_embd
self.token = token
@register_token_embedding('zero')
class ZeroEmbedding(TokenEmbeddingInterface):
def forward(self, tk_codes, **extra_contexts):
return torch.zeros(*tk_codes.shape[:-1], self.n_embd, device=tk_codes.device)
@register_token_embedding('discrete')
class DiscreteEmbedding(TokenEmbeddingInterface):
NEED_ENCODED_INPUT = True
def __init__(self, n_embd: int, token: TokenType, embed_from: Optional[str]=None, **kwargs):
super().__init__(n_embd, token)
self.embed_from = embed_from
if embed_from:
assert token['dim'] == 1
else:
self.embed = nn.ModuleList([nn.Embedding(token['dict_sizes'][i], n_embd) for i in range(token['dim'])])
def forward(self, tk_codes, **extra_contexts):
if self.embed_from:
weight = extra_contexts[self.embed_from]
if weight.dim() == 2:
out = F.embedding(tk_codes[..., 0].long(), weight)
else:
assert weight.dim() == 3
bs = len(weight)
tk_codes = tk_codes[..., 0].long().reshape(bs, -1, 1).repeat(1, 1, weight.size(-1))
out = torch.gather(weight, 1, tk_codes).reshape(-1, self.n_embd)
else:
out = 0
for j in range(self.token['dim']):
out = out + self.embed[j](tk_codes[..., j].long())
return out
@register_token_embedding('position_1d')
class Position1DEmbedding(TokenEmbeddingInterface):
NEED_ENCODED_INPUT = False
def __init__(self, n_embd: int, token: TokenType, scale=1.0, N=10000, **kwargs):
super().__init__(n_embd, token)
assert token['dim'] == 1
self.scale = scale
self.register_buffer("div_term", torch.exp(torch.arange(0, n_embd, 2) * (-math.log(N) / n_embd))[None, :])
def forward(self, tk_codes, **extra_contexts):
tk_codes = tk_codes[:, :self.token['dim']]
x = torch.cat((
torch.sin(self.scale * tk_codes * self.div_term),
torch.cos(self.scale * tk_codes * self.div_term)), dim=1)
x = x.view(-1, self.n_embd)
return x
@register_token_embedding('position_2d')
class Position2DEmbedding(TokenEmbeddingInterface):
NEED_ENCODED_INPUT = False
def __init__(self, n_embd: int, token: TokenType, scale=1.0, N=10000, **kwargs):
super().__init__(n_embd, token)
assert token['dim'] == 2
assert n_embd % 4 == 0
self.scale = scale
n_embd = n_embd // 2
self.register_buffer("div_term", torch.exp(torch.arange(0, n_embd, 2) * (-math.log(N) / n_embd))[None, :])
def forward(self, tk_codes, **extra_contexts):
tk_codes = tk_codes[:, :self.token['dim']]
pe = torch.zeros(tk_codes.size(0), self.n_embd, device=tk_codes.device)
d_model = self.n_embd // 2
pe[:, 0:d_model:2] = torch.sin(self.scale * tk_codes[:, :1] * self.div_term)
pe[:, 1:d_model:2] = torch.cos(self.scale * tk_codes[:, :1] * self.div_term)
pe[:, d_model: :2] = torch.sin(self.scale * tk_codes[:, 1:] * self.div_term)
pe[:, d_model: :2] = torch.cos(self.scale * tk_codes[:, 1:] * self.div_term)
return pe
@register_token_embedding('linear')
class LinearEmbedding(TokenEmbeddingInterface):
NEED_ENCODED_INPUT = False
def __init__(self, n_embd: int, token: TokenType, **kwargs):
super().__init__(n_embd, token)
self.embed = nn.Linear(token['dim'], n_embd)
def forward(self, tk_codes, **extra_contexts):
return self.embed(tk_codes[..., :self.token['dim']])
@register_token_embedding('feat_grid_2d')
class FeatureGrid2DEmbedding(TokenEmbeddingInterface):
NEED_ENCODED_INPUT = False
def __init__(self, n_embd: int, token: TokenType, sampling_from: str, stride: Union[int, Tuple[int, int]] = 1,
token_format='xy', **kwargs):
super().__init__(n_embd, token)
self.sampling_from = sampling_from
self.token_format = token_format
assert self.token['dim'] == 2
if isinstance(stride, int):
stride = (stride, stride)
self.stride = stride
def forward(self, tk_codes, **extra_contexts):
"""
tk_codes: (*, 2)
"""
assert self.sampling_from in extra_contexts, f"extra context {self.sampling_from} not found"
feat_grid = extra_contexts[self.sampling_from]
grid_shape = feat_grid.shape # (B, C, H, W)
tk_codes = tk_codes[..., :self.token['dim']].clone().float()
tk_codes[..., 0] /= self.stride[0]
tk_codes[..., 1] /= self.stride[1]
tk_codes_shape = tk_codes.shape
tk_codes = tk_codes.reshape(grid_shape[0], -1, 2)
embs = self.grid_sample(tk_codes, feat_grid)
return embs.reshape(*tk_codes_shape[:-1], self.n_embd)
@staticmethod
def batched_index_select(inp, dim, index):
"""
input: B x * x ... x *
dim: scalar > 0 (not batch dim)
index: B x M
"""
views = [inp.shape[0]] + [1 if i != dim else -1 for i in range(1, len(inp.shape))]
expanse = list(inp.shape)
expanse[0] = -1
expanse[dim] = -1
index = index.view(views).expand(expanse)
return torch.gather(inp, dim, index)
def grid_sample(
self, points: torch.Tensor, feat_grid: torch.Tensor
) -> Tuple[torch.Tensor]:
"""
:param points: [B, P, 2], where P is the number of points
:param feat_grid: size [B, C, H, W]
:return: the weighted average for each point according to the hm values. the size is [nc, npt, 1].
"""
nc, nw, h, w = feat_grid.shape
npt = points.shape[1]
points_weight = torch.ones([nc, npt]).to(feat_grid.device)
# giving points outside the image zero weight
points_weight[points[:, :, 0] < 0] = 0
points_weight[points[:, :, 1] < 0] = 0
points_weight[points[:, :, 0] > (w - 1)] = 0
points_weight[points[:, :, 1] > (h - 1)] = 0
points = points.unsqueeze(2).repeat([1, 1, 4, 1])
# later used for calculating weight
points_copy = points.detach().clone()
# getting discrete grid location of pts in the camera image space
points[:, :, 0, 0] = torch.floor(points[:, :, 0, 0])
points[:, :, 0, 1] = torch.floor(points[:, :, 0, 1])
points[:, :, 1, 0] = torch.floor(points[:, :, 1, 0])
points[:, :, 1, 1] = torch.ceil(points[:, :, 1, 1])
points[:, :, 2, 0] = torch.ceil(points[:, :, 2, 0])
points[:, :, 2, 1] = torch.floor(points[:, :, 2, 1])
points[:, :, 3, 0] = torch.ceil(points[:, :, 3, 0])
points[:, :, 3, 1] = torch.ceil(points[:, :, 3, 1])
grid_points = points.long() # [nc, npt, 4, 2] (grid)
# o─────────────o
# │ │
# │ x │
# │ │
# │ │
# │ │
# │ │
# │ │
# o─────────────o
# since we are taking modulo, points at the edge, i,e at h or w will be
# mapped to 0. this will make their distance from the continous location
# large and hence they won't matter. therefore we don't need an explicit
# step to remove such points
grid_points[:, :, :, 0] = torch.fmod(grid_points[:, :, :, 0], int(w))
grid_points[:, :, :, 1] = torch.fmod(grid_points[:, :, :, 1], int(h))
grid_points[grid_points < 0] = 0
# getting normalized weight for each discrete location for pt
# weight based on distance of point from the discrete location
# [nc, npt, 4]
points_dist = 1 / (torch.sqrt(torch.sum((points_copy - grid_points) ** 2, dim=-1)) + 1e-10)
points_weight = points_weight.unsqueeze(-1) * points_dist
_points_weight = torch.sum(points_weight, dim=-1, keepdim=True)
_points_weight[_points_weight == 0.0] = 1
# cached points_wei in select_feat_from_hm_cache
points_weight = points_weight / _points_weight # [nc, npt, 4]
grid_points = grid_points.view(nc, 4 * npt, 2) # [nc, 4 * npt, 2]
# cached points in select_feat_from_hm_cache
if self.token_format == 'xy':
grid_points = (grid_points[:, :, 1] * w) + grid_points[:, :, 0] # [nc, 4 * npt]
elif self.token_format == 'hw':
grid_points = (grid_points[:, :, 0] * w) + grid_points[:, :, 1] # [nc, 4 * npt]
else:
raise ValueError("token_format should be 'xy' or 'hw'")
# transforming indices from 2D to 1D to use pytorch gather
feat_grid = feat_grid.permute(0, 2, 3, 1).view(nc, h * w, nw) # [nc, h * w, nw]
# [nc, 4 * npt, nw]
points_val = self.batched_index_select(feat_grid, dim=1, index=grid_points)
# tranforming back each discrete location of point
points_val = points_val.view(nc, -1, 4, nw)
# summing weighted contribution of each discrete location of a point
points_val = torch.sum(points_val * points_weight.unsqueeze(-1), dim=2) # [nc, npt, nw]
return points_val
class ChunkEmbedding(nn.Module):
def __init__(self, n_embd: int, max_chunk_size: int, tokens: List[TokenType]):
super().__init__()
self.chunk_embed = nn.Embedding(max_chunk_size, n_embd)
self.token_type_embed = nn.Embedding(len(tokens), n_embd)
def forward(self, chk_ids, tk_ids):
"""
> note the chunk embedding is shared across all tokens, more like a relative position embedding within each
> set of chunks
> for example, if chk_ids = [0, 0, 0, 1, 2, 2], then we have 3 sets of chunks, and we want to transform it into
# [0,1,2, 0, 0,1], and then apply embedding layer
chk_ids: (B | 1, L), chunk ids
tk_ids: (B, L), token ids
return (B, L, embs)
"""
chk_ids = chk_ids.long()
tk_id_emb = self.token_type_embed(tk_ids.long())
reg_indices = chk_ids.clone()
for i in range(chk_ids.size(0)):
reg_indices[i] = self.chk_ids_to_indices(chk_ids[i])
if reg_indices.size(0) == 1:
reg_indices = reg_indices.repeat(tk_ids.size(0), 1)
reg_emb = self.chunk_embed(reg_indices)
return reg_emb + tk_id_emb
@staticmethod
def chk_ids_to_indices(ids):
"""
ids: (L)
return: (L)
this function looks obscure, but what it does is very simple, like the example above
input: [0,0,0, 1, 2,2], output: [0,1,2, 0, 0,1]
input: [1,1,1, 2,2, 3,3,3, 4, 5,5,5,5,5], output: [0,1,2, 0,1, 0,1,2, 0, 0,1,2,3,4]
(transform chunk ids to relative indices within each set)
"""
dev, min_id = ids.device, ids.min()
counts = torch.unique(ids, return_counts=True, sorted=True)[1]
starts = torch.cat([torch.zeros(1, dtype=torch.long, device=dev), counts[:-1].cumsum(0)])
index = torch.bucketize(ids, torch.arange(min_id, min_id + len(counts), device=dev))
return torch.arange(0, len(ids), device=dev) - starts[index]
#endregion Embedding ###############
#region Predictor
class TokenPredictorInterface(nn.Module):
IS_CONTINUOUS = False
def __init__(self, n_embd: int, token: TokenType, **kwargs):
super().__init__()
self.n_embd = n_embd
self.token = token
def sample(self, predicts_of_curr_regs: List[Union[Tensor, D.Distribution]], do_sample:bool, **extra_contexts) -> Tensor:
raise NotImplementedError
def forward(self, embs, log_prob=False, **extra_contexts) -> Union[Dict[str, List[Tensor]], D.Distribution]:
pass
@register_token_predictor('gmm')
class GMMPredictor(TokenPredictorInterface):
IS_CONTINUOUS = True
def __init__(self, n_embd: int, token: TokenType, num_latents=1, low_var_eval=True, label_name='label', **kwargs):
super().__init__(n_embd, token, **kwargs)
self.num_latents = num_latents
self.low_var_eval = low_var_eval
self.label_name = label_name
if num_latents == 1:
out_features = 2 * token['dim']
self.gauss_nll_loss = nn.GaussianNLLLoss()
else:
out_features = num_latents * 2 * token['dim'] + num_latents # means, scales, logits
self.mlp = Mlp(in_features=n_embd, hidden_features=n_embd, out_features=out_features)
def forward(self, embs, log_prob=False, split_distributions=False, **extra_contexts) -> Union[Dict[str, List[Tensor]], D.Distribution]:
""" embs: (*, d)
"""
base_shape = embs.shape[:-1]
loss_dict = defaultdict(list)
if self.num_latents > 1:
out = self.mlp(embs)
logits = out[..., :self.num_latents]
means, raw_scales = out[..., self.num_latents:].chunk(2, dim=-1)
scales = torch.exp(0.5 * raw_scales)
if not self.training and self.low_var_eval:
scales[:] = (1e-5 * means).abs()
means = means.reshape(*base_shape, self.num_latents, self.token['dim'])
scales = scales.reshape(*base_shape, self.num_latents, self.token['dim'])
component_distribution = D.Normal(loc=means, scale=scales)
component_distribution = D.Independent(component_distribution, 1)
mixture_distribution = D.Categorical(logits=logits)
dist = D.MixtureSameFamily(
mixture_distribution=mixture_distribution,
component_distribution=component_distribution,
)
if self.training:
label = extra_contexts[self.label_name]
loss_dict['nll_loss'].append(- dist.log_prob(label).mean())
if log_prob:
loss_dict['log_prob'].append(dist.log_prob(label))
return loss_dict
else:
if split_distributions:
dists = []
for i in range(base_shape[-1]): # seq len
component_distribution = D.Normal(loc=means[:, i], scale=scales[:, i])
component_distribution = D.Independent(component_distribution, 1)
mixture_distribution = D.Categorical(logits=logits[:, i])
dists.append(D.MixtureSameFamily(
mixture_distribution=mixture_distribution,
component_distribution=component_distribution,
))
return dists
else:
return dist
else:
means, raw_scales = self.mlp(embs).chunk(2, dim=-1)
scales = torch.exp(0.5 * raw_scales)
dist = D.Normal(means, scales)
if self.training:
label = extra_contexts['label']
loss_dict['nll_loss'].append(self.gauss_nll_loss(means, label, scales ** 2))
if log_prob:
loss_dict['log_prob'].append(dist.log_prob(label))
return loss_dict
else:
if split_distributions:
return [D.Normal(means[:, i], scales[:, i]) for i in range(means.shape[1])]
else:
return dist
def sample(self, predicts_of_curr_regs: Union[Tensor, D.Distribution], do_sample:bool, **extra_contexts) -> Tensor:
result = []
for d in predicts_of_curr_regs:
assert isinstance(d, D.Distribution)
if do_sample:
if self.low_var_eval and self.num_latents == 1:
samples = d.mean
else:
samples = d.sample()
result.append(samples)
else:
result.append(d.mean)
return result
@register_token_predictor('class')
class ClassPredictor(TokenPredictorInterface):
IS_CONTINUOUS = False
def __init__(self, n_embd: int, token: TokenType, label_name='label', **kwargs):
super().__init__(n_embd, token, **kwargs)
self.linear = nn.Linear(n_embd, sum(token['dict_sizes']))
self.ce_loss = nn.CrossEntropyLoss()
self.ce_loss_wo_mean = nn.CrossEntropyLoss(reduction='none')
self.label_name = label_name
def forward(self, embs, log_prob=False, **extra_contexts) -> Union[Dict[str, List[Tensor]], D.Distribution]:
""" embs: (*, d)
"""
logits = self.linear(embs) # (bs, dict_sizes)
loss_dict = defaultdict(list)
max_dict_size = max(self.token['dict_sizes'])
start, outputs = 0, []
for j, size in enumerate(self.token['dict_sizes']):
logits_j = logits[..., start:start+size]
start += size
if self.training:
label = extra_contexts[self.label_name]
ce_loss = self.ce_loss(logits_j, label[..., j].long())
loss_dict['ce_loss'].append(ce_loss)
if log_prob:
loss_dict['log_prob'].append(self.ce_loss_wo_mean(logits_j, label[..., j].long()))
else:
outputs.append(pad_last_dim(logits_j, max_dict_size, val=float('-inf')).unsqueeze(-2))
if self.training:
return loss_dict
else:
return torch.cat(outputs, dim=-2)
def sample(self, predicts_of_curr_regs: List[Union[Tensor, D.Distribution]], do_sample:bool, **extra_contexts) -> Tensor:
result = []
for d in predicts_of_curr_regs:
assert isinstance(d, Tensor)
probs = F.softmax(d, dim=-1)
if do_sample:
samples = torch.multinomial(probs.flatten(0, -2), num_samples=1).reshape(*probs.shape[:-1])
else:
_, samples = torch.topk(probs, k=1, dim=-1)
samples = samples.squeeze(-1)
result.append(samples)
return result
class ConvUpsample(nn.Module):
""" from RVT2 and RAFT """
def __init__(self, in_dim, out_dim, up_ratio, up_kernel=3, mask_scale=0.1, hidden_dim_mult=2):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.up_ratio = up_ratio
self.up_kernel = up_kernel
self.mask_scale = mask_scale
assert (self.up_kernel % 2) == 1
hidden_dim = int(hidden_dim_mult * in_dim)
self.net_out = nn.Sequential(
nn.Conv2d(in_dim, hidden_dim, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, out_dim, 3, padding=1),
)
mask_dim = (self.up_ratio**2) * (self.up_kernel**2) # (14 * 14) * 9
self.net_mask = nn.Sequential(
nn.Conv2d(in_dim, hidden_dim, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, mask_dim, 1, padding=0),
)
def forward(self, x):
"""
x: (bs, in_dim, h, w)
return: (bs, out_dim, h*up_ratio, w*up_ratio)
"""
bs, c, h, w = x.shape
assert c == self.in_dim, c
out_low = self.net_out(x)
mask = self.mask_scale * self.net_mask(x)
mask = mask.view(bs, 1, self.up_kernel**2, self.up_ratio, self.up_ratio, h, w)
mask = torch.softmax(mask, dim=2) # bs, 1, 9, 14, 14, h, w
out = F.unfold(
out_low,
kernel_size=[self.up_kernel, self.up_kernel],
padding=self.up_kernel // 2,
)
out = out.view(bs, self.out_dim, self.up_kernel**2, 1, 1, h, w)
out = torch.sum(out * mask, dim=2)
out = out.permute(0, 1, 4, 2, 5, 3)
out = out.reshape(bs, self.out_dim, h * self.up_ratio, w * self.up_ratio)
return out
@register_token_predictor('upsample_from_2d_attn')
class Upsample2DAttnPredictor(TokenPredictorInterface):
IS_CONTINUOUS = False
def __init__(self, n_embd: int, token: TokenType, attn_with: Union[str, Tuple[int, ...]], upscale_ratio: int, token_format='xy', label_name='label',
corr_dim=-1, hidden_dim_mult=2, **kwargs):
super().__init__(n_embd, token, **kwargs)
if corr_dim < 0:
corr_dim = n_embd
else:
self.corr_proj = nn.Sequential(
nn.Conv2d(n_embd, corr_dim, 1),
nn.BatchNorm2d(corr_dim),
nn.ELU(),
nn.Conv2d(corr_dim, corr_dim, 5, padding=2, groups=corr_dim),
nn.BatchNorm2d(corr_dim),
nn.ELU(),
nn.Conv2d(corr_dim, corr_dim, 1),
nn.BatchNorm2d(corr_dim),
nn.ELU())
self.upsample = ConvUpsample(corr_dim, 1, upscale_ratio, hidden_dim_mult=hidden_dim_mult)
self.corr_dim = corr_dim
if isinstance(attn_with, str):
self.attn_with = attn_with
else:
assert len(attn_with) == 3
self.register_parameter('attn_with', nn.Parameter(0.02 * torch.randn(1, *attn_with)))
self.token_format = token_format
self.label_name = label_name
self._cross_entropy_loss = nn.CrossEntropyLoss()
self._cross_entropy_loss_wo_mean = nn.CrossEntropyLoss(reduction='none')
def forward(self, embs, log_prob=False, **extra_contexts):
""" embs: (*, d)
"""
if isinstance(self.attn_with, str):
feats = extra_contexts[self.attn_with] # (B, self.n_embd, H, W)
else:
feats = self.attn_with # (1, self.n_embd, H, W)
embs = embs.reshape(-1, self.n_embd, 1, 1)
corr = embs * feats
if hasattr(self, 'corr_proj'): corr = self.corr_proj(corr)
spatial_logits_map = self.upsample(corr) # (B, 1, H, W)
if self.training:
label = extra_contexts[self.label_name]
if label.numel() == spatial_logits_map.numel():
label = label.flatten(1)
else:
_, _, h, w = spatial_logits_map.shape
if self.token_format == 'xy':
label = (label[..., 1] * w) + label[..., 0]
elif self.token_format == 'hw':
label = (label[..., 0] * w) + label[..., 1]
else:
raise ValueError("token_format should be 'xy' or 'hw'")
result = {'2d_ce_loss': [self._cross_entropy_loss(spatial_logits_map.flatten(1), label)]}
if log_prob:
result['log_prob'] = [self._cross_entropy_loss_wo_mean(spatial_logits_map.flatten(1), label)]
return result
else:
return spatial_logits_map
def sample(self, predicts_of_curr_regs: List[Union[Tensor, D.Distribution]], do_sample:bool, **extra_contexts) -> Tensor:
result = []
for d in predicts_of_curr_regs:
assert isinstance(d, Tensor)
_, h, w = d.shape
d = d.flatten(1)
probs = F.softmax(d, dim=-1)
if do_sample:
samples = torch.multinomial(probs, num_samples=1)
else:
_, samples = torch.topk(probs, k=1, dim=-1)
pred_h, pred_w = samples.flatten() // w, samples.flatten() % w