-
Notifications
You must be signed in to change notification settings - Fork 390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding DGL and DGL-LifeSci support #796
Comments
Hi, that's pretty cool. I haven't worked with dgl yet. What could perhaps be really helpful is if you could provide a full working example (maybe even a jupyter notebook that's added to the skorch repo), including your dataloaders and dataset classes. If it's not too much effort to make it work, I would gladly add better support for dgl. The way you changed (I assume) Regarding your second problem, I would need to understand better what the issue is, which is probably best achieved by a full example mentioned above. |
Any updates @BernardoHernandezAdame |
Here is a script to reproduce. import torch
import skorch
import numpy
class model(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.lin = torch.nn.Linear(10,1)
def forward(self, x0, x1):
z = x0.dot(x1)
return self.lin(z)
M = skorch.NeuralNetRegressor(model)
X = {'x0':numpy.random.rand(1000, 25), 'x1': numpy.random.rand(25, 10)}
y = (X['x0'].dot(X['x1']) * numpy.random.rand(10) + numpy.random.rand(10)).sum(-1)
M.fit(X, y) |
I think this issue can be closed with a simple update in documentation. The behaviour described can be circumvented by using a torch Dataset object as X with some collate function passed to the iterators. This plus a bit of SliceDataset wizardry should make most scenarii soveable - made it work for Pytorch Geometric. |
Hi,
I have been currently working with Skorch for standard NN frameworks on tensors and recently started experimenting with some graph neural networks. In particular dgl and dgl-LifeSci. I have a workaround to get things to work with skorch where I add a pass for dgl.DGLGraph() objects in the utils.py check. However skorch throws an error when measuring the batch size due to the number of nodes in the graph batches vary.
Fixes are here; is DGL integration something currently being considered or is there a separate workaround recommended?
Fixing these two allow me to train dgl-lifesci models using skorch by defining my own dataloader and dataset classes.
Thanks!
The text was updated successfully, but these errors were encountered: