Skip to content

Commit febff45

Browse files
jbschlosserfacebook-github-bot
authored andcommitted
Support factory kwargs in torch.nn modules (pytorch#54508)
Summary: Continuation of pytorch#53144 Pull Request resolved: pytorch#54508 Reviewed By: albanD Differential Revision: D27939544 Pulled By: jbschlosser fbshipit-source-id: 4bf517e5f74f093e27ca38a85e732da65e44d805
1 parent 3a4344a commit febff45

File tree

24 files changed

+883
-239
lines changed

24 files changed

+883
-239
lines changed

test/run_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
'test_linalg',
6666
'test_logging',
6767
'test_mkldnn',
68+
'test_module_init',
6869
'test_multiprocessing',
6970
'test_multiprocessing_spawn',
7071
'distributed/test_nccl',

test/test_module_init.py

Lines changed: 432 additions & 0 deletions
Large diffs are not rendered by default.

torch/nn/__init__.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,44 @@
33
from .parallel import DataParallel
44
from . import init
55
from . import utils
6+
7+
8+
def factory_kwargs(kwargs):
9+
r"""
10+
Given kwargs, returns a canonicalized dict of factory kwargs that can be directly passed
11+
to factory functions like torch.empty, or errors if unrecognized kwargs are present.
12+
13+
This function makes it simple to write code like this::
14+
15+
class MyModule(nn.Module):
16+
def __init__(self, **kwargs):
17+
factory_kwargs = torch.nn.factory_kwargs(kwargs)
18+
self.weight = Parameter(torch.empty(10, **factory_kwargs))
19+
20+
Why should you use this function instead of just passing `kwargs` along directly?
21+
22+
1. This function does error validation, so if there are unexpected kwargs we will
23+
immediately report an error, instead of deferring it to the factory call
24+
2. This function supports a special `factory_kwargs` argument, which can be used to
25+
explicitly specify a kwarg to be used for factory functions, in the event one of the
26+
factory kwargs conflicts with an already existing argument in the signature (e.g.
27+
in the signature ``def f(dtype, **kwargs)``, you can specify ``dtype`` for factory
28+
functions, as distinct from the dtype argument, by saying
29+
``f(dtype1, factory_kwargs={"dtype": dtype2})``)
30+
"""
31+
if kwargs is None:
32+
return {}
33+
simple_keys = {"device", "dtype", "memory_format"}
34+
expected_keys = simple_keys | {"factory_kwargs"}
35+
if not kwargs.keys() <= expected_keys:
36+
raise TypeError(f"unexpected kwargs {kwargs.keys() - expected_keys}")
37+
38+
# guarantee no input kwargs is untouched
39+
r = dict(kwargs.get("factory_kwargs", {}))
40+
for k in simple_keys:
41+
if k in kwargs:
42+
if k in r:
43+
raise TypeError(f"{k} specified twice, in **kwargs and in factory_kwargs")
44+
r[k] = kwargs[k]
45+
46+
return r

torch/nn/modules/activation.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,8 @@ class MultiheadAttention(Module):
872872
bias_v: Optional[torch.Tensor]
873873

874874
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
875-
kdim=None, vdim=None, batch_first=False):
875+
kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
876+
factory_kwargs = {'device': device, 'dtype': dtype}
876877
super(MultiheadAttention, self).__init__()
877878
self.embed_dim = embed_dim
878879
self.kdim = kdim if kdim is not None else embed_dim
@@ -886,25 +887,25 @@ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=Fals
886887
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
887888

888889
if self._qkv_same_embed_dim is False:
889-
self.q_proj_weight = Parameter(torch.empty(embed_dim, embed_dim))
890-
self.k_proj_weight = Parameter(torch.empty(embed_dim, self.kdim))
891-
self.v_proj_weight = Parameter(torch.empty(embed_dim, self.vdim))
890+
self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
891+
self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
892+
self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
892893
self.register_parameter('in_proj_weight', None)
893894
else:
894-
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
895+
self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
895896
self.register_parameter('q_proj_weight', None)
896897
self.register_parameter('k_proj_weight', None)
897898
self.register_parameter('v_proj_weight', None)
898899

899900
if bias:
900-
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
901+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
901902
else:
902903
self.register_parameter('in_proj_bias', None)
903-
self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
904+
self.out_proj = Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
904905

905906
if add_bias_kv:
906-
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
907-
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
907+
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
908+
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
908909
else:
909910
self.bias_k = self.bias_v = None
910911

@@ -1057,10 +1058,12 @@ class PReLU(Module):
10571058
__constants__ = ['num_parameters']
10581059
num_parameters: int
10591060

1060-
def __init__(self, num_parameters: int = 1, init: float = 0.25) -> None:
1061+
def __init__(self, num_parameters: int = 1, init: float = 0.25,
1062+
device=None, dtype=None) -> None:
1063+
factory_kwargs = {'device': device, 'dtype': dtype}
10611064
self.num_parameters = num_parameters
10621065
super(PReLU, self).__init__()
1063-
self.weight = Parameter(torch.empty(num_parameters).fill_(init))
1066+
self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs).fill_(init))
10641067

10651068
def forward(self, input: Tensor) -> Tensor:
10661069
return F.prelu(input, self.weight)

torch/nn/modules/adaptive.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,11 @@ def __init__(
115115
n_classes: int,
116116
cutoffs: Sequence[int],
117117
div_value: float = 4.,
118-
head_bias: bool = False
118+
head_bias: bool = False,
119+
device=None,
120+
dtype=None
119121
) -> None:
122+
factory_kwargs = {'device': device, 'dtype': dtype}
120123
super(AdaptiveLogSoftmaxWithLoss, self).__init__()
121124

122125
cutoffs = list(cutoffs)
@@ -141,7 +144,8 @@ def __init__(
141144
self.n_clusters = len(self.cutoffs) - 1
142145
self.head_size = self.shortlist_size + self.n_clusters
143146

144-
self.head = Linear(self.in_features, self.head_size, bias=self.head_bias)
147+
self.head = Linear(self.in_features, self.head_size, bias=self.head_bias,
148+
**factory_kwargs)
145149
self.tail = ModuleList()
146150

147151
for i in range(self.n_clusters):
@@ -150,8 +154,8 @@ def __init__(
150154
osz = self.cutoffs[i + 1] - self.cutoffs[i]
151155

152156
projection = Sequential(
153-
Linear(self.in_features, hsz, bias=False),
154-
Linear(hsz, osz, bias=False)
157+
Linear(self.in_features, hsz, bias=False, **factory_kwargs),
158+
Linear(hsz, osz, bias=False, **factory_kwargs),
155159
)
156160

157161
self.tail.append(projection)

torch/nn/modules/batchnorm.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,28 @@ def __init__(
3131
momentum: float = 0.1,
3232
affine: bool = True,
3333
track_running_stats: bool = True,
34+
device=None,
35+
dtype=None
3436
) -> None:
37+
factory_kwargs = {'device': device, 'dtype': dtype}
3538
super(_NormBase, self).__init__()
3639
self.num_features = num_features
3740
self.eps = eps
3841
self.momentum = momentum
3942
self.affine = affine
4043
self.track_running_stats = track_running_stats
4144
if self.affine:
42-
self.weight = Parameter(torch.empty(num_features))
43-
self.bias = Parameter(torch.empty(num_features))
45+
self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
46+
self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
4447
else:
4548
self.register_parameter("weight", None)
4649
self.register_parameter("bias", None)
4750
if self.track_running_stats:
48-
self.register_buffer("running_mean", torch.zeros(num_features))
49-
self.register_buffer("running_var", torch.ones(num_features))
50-
self.register_buffer(
51-
"num_batches_tracked", torch.tensor(0, dtype=torch.long)
52-
)
51+
self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
52+
self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
53+
self.register_buffer('num_batches_tracked',
54+
torch.tensor(0, dtype=torch.long,
55+
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
5356
else:
5457
self.register_buffer("running_mean", None)
5558
self.register_buffer("running_var", None)
@@ -117,9 +120,12 @@ def __init__(
117120
momentum=0.1,
118121
affine=True,
119122
track_running_stats=True,
123+
device=None,
124+
dtype=None
120125
):
126+
factory_kwargs = {'device': device, 'dtype': dtype}
121127
super(_BatchNorm, self).__init__(
122-
num_features, eps, momentum, affine, track_running_stats
128+
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
123129
)
124130

125131
def forward(self, input: Tensor) -> Tensor:
@@ -178,7 +184,9 @@ class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):
178184
weight: UninitializedParameter # type: ignore[assignment]
179185
bias: UninitializedParameter # type: ignore[assignment]
180186

181-
def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
187+
def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
188+
device=None, dtype=None) -> None:
189+
factory_kwargs = {'device': device, 'dtype': dtype}
182190
super(_LazyBatchNorm, self).__init__(
183191
# affine and track_running_stats are hardcoded to False to
184192
# avoid creating tensors that will soon be overwritten.
@@ -187,16 +195,18 @@ def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True
187195
momentum,
188196
False,
189197
False,
198+
**factory_kwargs,
190199
)
191200
self.affine = affine
192201
self.track_running_stats = track_running_stats
193202
if self.affine:
194-
self.weight = UninitializedParameter()
195-
self.bias = UninitializedParameter()
203+
self.weight = UninitializedParameter(**factory_kwargs)
204+
self.bias = UninitializedParameter(**factory_kwargs)
196205
if self.track_running_stats:
197-
self.running_mean = UninitializedBuffer()
198-
self.running_var = UninitializedBuffer()
199-
self.num_batches_tracked = torch.tensor(0, dtype=torch.long)
206+
self.running_mean = UninitializedBuffer(**factory_kwargs)
207+
self.running_var = UninitializedBuffer(**factory_kwargs)
208+
self.num_batches_tracked = torch.tensor(
209+
0, dtype=torch.long, **{k: v for k, v in factory_kwargs.items() if k != 'dtype'})
200210

201211
def reset_parameters(self) -> None:
202212
if not self.has_uninitialized_params() and self.num_features != 0:
@@ -640,9 +650,12 @@ def __init__(
640650
affine: bool = True,
641651
track_running_stats: bool = True,
642652
process_group: Optional[Any] = None,
653+
device=None,
654+
dtype=None
643655
) -> None:
656+
factory_kwargs = {'device': device, 'dtype': dtype}
644657
super(SyncBatchNorm, self).__init__(
645-
num_features, eps, momentum, affine, track_running_stats
658+
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
646659
)
647660
self.process_group = process_group
648661

0 commit comments

Comments
 (0)