-
Notifications
You must be signed in to change notification settings - Fork 100
Open
Description
from torch.utils.serialization import load_lua doesnt work in current pytorch versions
here is a possible fix with torchfile
class pytorch_lua_wrapper:
def __init__(self, lua_path):
self.lua_model = torchfile.load(lua_path)
def get(self, idx):
return self.lua_model._obj.modules[idx]._obj
Now you can relace this line:
vgg1 = load_lua(args.vgg1)
with
vgg1 = pytorch_lua_wrapper(args.vgg1)
and this line
self.conv1.weight = torch.nn.Parameter(vgg1.get(0).weight.float())
with
self.conv1.weight =torch.nn.Parameter(torch.from_numpy(vgg1.get(0).weight).float())
Zhou2019, cedesu, WeichenFan and ziye3001diaodeyi
Metadata
Metadata
Assignees
Labels
No labels