You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
最近用了一下DBN在CUDA上面跑代码,发现一个运行的小问题
dbn.py 里 52行 s = (torch.rand(p.size())< p).float().to(self.dvc)
会报错
Traceback (most recent call last):
File "mnist_cls.py", line 55, in
model.run(e=3, pre_e=3)
File "../core/epoch.py", line 98, in run
self.pre_batch_training(pre_e, b)
File "../core/pre_module.py", line 60, in pre_batch_training
module.batch_training(i)
File "../model/dbn.py", line 99, in batch_training
v0,h0,vk,hk = self.forward(data)
File "../model/dbn.py", line 65, in forward
ph0, h0 = self.transfrom(v0,'v2h')
File "../model/dbn.py", line 52, in transfrom
s = (torch.rand(p.size())< p).float().to(self.dvc)
RuntimeError: expected device cpu but got device cuda:0
最近用了一下DBN在CUDA上面跑代码,发现一个运行的小问题
dbn.py 里 52行
s = (torch.rand(p.size())< p).float().to(self.dvc)
会报错
Traceback (most recent call last):
File "mnist_cls.py", line 55, in
model.run(e=3, pre_e=3)
File "../core/epoch.py", line 98, in run
self.pre_batch_training(pre_e, b)
File "../core/pre_module.py", line 60, in pre_batch_training
module.batch_training(i)
File "../model/dbn.py", line 99, in batch_training
v0,h0,vk,hk = self.forward(data)
File "../model/dbn.py", line 65, in forward
ph0, h0 = self.transfrom(v0,'v2h')
File "../model/dbn.py", line 52, in transfrom
s = (torch.rand(p.size())< p).float().to(self.dvc)
RuntimeError: expected device cpu but got device cuda:0
修改为
s = (torch.rand(p.size())< p.cpu()).float().to(self.dvc)
,正常运行The text was updated successfully, but these errors were encountered: