@@ -41,7 +41,6 @@ def get_weight(m):
4141 constructor_args = (10 , 8 ),
4242 input_size = (4 , 10 ),
4343 reference_fn = lambda i , p : torch .mm (i , p [0 ].t ()) + p [1 ].view (1 , - 1 ).expand (4 , 8 ),
44- test_cuda = (not TEST_WITH_ROCM )
4544 ),
4645 dict (
4746 module_name = 'Linear' ,
@@ -103,35 +102,30 @@ def get_weight(m):
103102 constructor_args = (1 ,),
104103 input_size = (10 , 20 ),
105104 reference_fn = lambda i , _ : torch .exp (i ).div (torch .exp (i ).sum (1 , True ).expand (10 , 20 )),
106- test_cuda = (not TEST_WITH_ROCM )
107105 ),
108106 dict (
109107 module_name = 'Softmax2d' ,
110108 input_size = (1 , 3 , 10 , 20 ),
111109 reference_fn = lambda i , _ : torch .exp (i ).div (torch .exp (i ).sum (1 , False )),
112- test_cuda = (not TEST_WITH_ROCM )
113110 ),
114111 dict (
115112 module_name = 'LogSoftmax' ,
116113 constructor_args = (1 ,),
117114 input_size = (10 , 20 ),
118115 reference_fn = lambda i , _ : torch .exp (i ).div_ (torch .exp (i ).sum (1 , True ).expand (10 , 20 )).log_ (),
119- test_cuda = (not TEST_WITH_ROCM )
120116 ),
121117 dict (
122118 module_name = 'LogSoftmax' ,
123119 constructor_args = (1 ,),
124120 input_size = (1 , 3 , 10 , 20 ),
125121 reference_fn = lambda i , _ : torch .exp (i ).div_ (torch .exp (i ).sum (1 , False )).log_ (),
126122 desc = 'multiparam' ,
127- test_cuda = (not TEST_WITH_ROCM )
128123 ),
129124 dict (
130125 module_name = 'ELU' ,
131126 constructor_args = (2. ,),
132127 input_size = (3 , 2 , 5 ),
133128 reference_fn = lambda x , _ : torch .where (x >= 0 , x , 2 * (x .exp () - 1 )),
134- test_cuda = (not TEST_WITH_ROCM ),
135129 ),
136130 # TODO: reference function
137131 dict (
@@ -204,7 +198,6 @@ def get_weight(m):
204198 input_size = (2 , 3 , 4 ),
205199 desc = '1d_multiparam' ,
206200 reference_fn = lambda i , p : torch .clamp (i , min = 0 ) + torch .clamp (i , max = 0 ) * p [0 ][0 ],
207- test_cuda = (not TEST_WITH_ROCM )
208201 ),
209202 dict (
210203 module_name = 'PReLU' ,
@@ -218,7 +211,6 @@ def get_weight(m):
218211 input_size = (2 , 3 , 4 , 5 ),
219212 desc = '2d_multiparam' ,
220213 reference_fn = lambda i , p : torch .clamp (i , min = 0 ) + torch .clamp (i , max = 0 ) * p [0 ][0 ],
221- test_cuda = (not TEST_WITH_ROCM )
222214 ),
223215 dict (
224216 module_name = 'PReLU' ,
@@ -232,31 +224,26 @@ def get_weight(m):
232224 input_size = (2 , 3 , 4 , 5 , 6 ),
233225 desc = '3d_multiparam' ,
234226 reference_fn = lambda i , p : torch .clamp (i , min = 0 ) + torch .clamp (i , max = 0 ) * p [0 ][0 ],
235- test_cuda = (not TEST_WITH_ROCM )
236227 ),
237228 dict (
238229 module_name = 'Softsign' ,
239230 input_size = (3 , 2 , 5 ),
240231 reference_fn = lambda i , _ : i .div (1 + torch .abs (i )),
241- test_cuda = (not TEST_WITH_ROCM )
242232 ),
243233 dict (
244234 module_name = 'Softmin' ,
245235 constructor_args = (1 ,),
246236 input_size = (10 , 20 ),
247- test_cuda = (not TEST_WITH_ROCM )
248237 ),
249238 dict (
250239 module_name = 'Softmin' ,
251240 constructor_args = (1 ,),
252241 input_size = (2 , 3 , 5 , 10 ),
253242 desc = 'multidim' ,
254- test_cuda = (not TEST_WITH_ROCM )
255243 ),
256244 dict (
257245 module_name = 'Tanhshrink' ,
258246 input_size = (2 , 3 , 4 , 5 ),
259- test_cuda = (not TEST_WITH_ROCM )
260247 ),
261248]
262249
@@ -573,7 +560,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
573560 reference_fn = lambda i , t , m :
574561 kldivloss_reference (i , t , get_reduction (m )),
575562 check_sum_reduction = True ,
576- test_cuda = (not TEST_WITH_ROCM )
577563 ),
578564 dict (
579565 module_name = 'MSELoss' ,
@@ -590,7 +576,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
590576 reference_fn = lambda i , t , m : - (t * i .log () + (1 - t ) * (1 - i ).log ()).sum () /
591577 (i .numel () if get_reduction (m ) else 1 ),
592578 check_gradgrad = False ,
593- test_cuda = (not TEST_WITH_ROCM )
594579 ),
595580 dict (
596581 module_name = 'BCELoss' ,
@@ -601,7 +586,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
601586 (i .numel () if get_reduction (m ) else 1 ),
602587 desc = 'weights' ,
603588 check_gradgrad = False ,
604- test_cuda = (not TEST_WITH_ROCM )
605589 ),
606590 dict (
607591 module_name = 'CrossEntropyLoss' ,
@@ -660,7 +644,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
660644 target_fn = lambda : torch .rand (5 , 10 ).mul (2 ).floor (),
661645 reference_fn = lambda i , t , m : - (t * i .sigmoid ().log () + (1 - t ) * (- i ).sigmoid ().log ()).sum () / i .numel (),
662646 check_gradgrad = False ,
663- test_cuda = (not TEST_WITH_ROCM )
664647 ),
665648 dict (
666649 module_name = 'MultiMarginLoss' ,
@@ -759,7 +742,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
759742 reference_fn = lambda i , t , m :
760743 marginrankingloss_reference (i [0 ], i [1 ], t , reduction = get_reduction (m )),
761744 check_sum_reduction = True ,
762- test_cuda = (not TEST_WITH_ROCM )
763745 ),
764746 dict (
765747 module_name = 'MarginRankingLoss' ,
@@ -770,7 +752,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
770752 marginrankingloss_reference (i [0 ], i [1 ], t , margin = 0.5 , reduction = get_reduction (m )),
771753 desc = 'margin' ,
772754 check_sum_reduction = True ,
773- test_cuda = (not TEST_WITH_ROCM )
774755 ),
775756]
776757
0 commit comments