We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I am trying to apply koila lazy eval on a Unet3D.
# defining the model import torch import torch.nn as nn import torch.nn.functional as F def conv3(in_channels, out_channels, stride, norm='BatchNorm3d', act='GELU'): return nn.Sequential( nn.Conv3d(in_channels, out_channels, 3, 1, 1), getattr(nn, norm)(out_channels), getattr(nn, act)()) def double_conv3(in_channels, out_channels, stride): return nn.Sequential(conv3(in_channels, out_channels, 1), conv3(out_channels, out_channels, stride)) def merge_skip(x, skip): x = F.upsample(x, size=skip.shape[-3:], mode='trilinear', align_corners=True) return torch.cat((x,skip),dim=1) class Unet3D(nn.Module): def __init__(self, in_channels, out_channels, num_layers=4, base=16): super().__init__() enc_channels = [in_channels]+[base * 2**i for i in range(num_layers)] dec_channels = [base * 2**i for i in range(num_layers-1,-1,-1)]+[out_channels] self.encoders = nn.ModuleList() for i in range(len(enc_channels)-1): cin = enc_channels[i] cout = enc_channels[i+1] enc = double_conv3(cin, cout, 2) self.encoders.append(enc) self.decoders = nn.ModuleList() for i in range(len(dec_channels)-1): cin_skip = enc_channels[-i-2] cin_up = dec_channels[i] cin = cin_skip + cin_up cout = dec_channels[i+1] dec = double_conv3(cin, cout, 1) self.decoders.append(dec) def forward(self, x, return_all=False): out = [x] for encoder in self.encoders: x = encoder(x) out.append(x) n = len(out) for i, decoder in enumerate(self.decoders): skip = out[n - 2 - i] x = merge_skip(out[-1], skip) x = decoder(x) out.append(x) if return_all: return out else: return out[-1] # test of koila on unet def test_lazy(): net = Unet3D(1,3) net.cuda() s = 64 b,c,d,h,w = 2,1,s,s,s x = torch.randn(b,c,d,h,w).cuda() t = torch.randint(0,3, (b,d,h,w)).cuda() loss_fn = nn.CrossEntropyLoss() net.zero_grad() lazy_x, lazy_t = lazy(x, t, batch=0) lazy_out = net(lazy_x) lazy_loss = loss_fn(lazy_out, lazy_t) assert isinstance(lazy_loss, LazyTensor), type(lazy_loss) lazy_loss.backward() # This fails test_lazy()
This fails and outputs:
tensors = (tensor([[[[[-8.9936e-02, -7.9037e-02, -1.5048e-02, ..., 2.9969e-01, 2.9774e-01, -1.0489e-01], ...]]], device='cuda:0', grad_fn=<UpsampleTrilinear3DBackward1>), <koila.lazy.LazyTensor object at 0x7fa21bf99880>) dim = 1, args = (), kwargs = {}, shapes = [torch.Size([2, 128, 64, 64, 64]), (2, 64, 64, 64, 64)] no_dim = [torch.Size([2, 64, 64, 64]), (2, 64, 64, 64)], result_size = torch.Size([2, 64, 64, 64]) size = (2, 64, 64, 64) def cat( tensors: Sequence[TensorLike], dim: int = 0, *args: Any, **kwargs: Any ) -> PrePass: mute_unused_args(*args, **kwargs) if len(tensors) == 0: raise ValueError("Expected a sequence of tensors. Got empty sequence.") shapes = [t.size() for t in tensors] no_dim = [t[:dim] + t[dim + 1 :] for t in shapes] result_size = no_dim[0] for size in no_dim[1:]: if result_size != size: raise ValueError( f"Dimension should be equal outside dim {dim}. Got {shapes}." ) if len(set(interfaces.bat(t) for t in tensors)) != 1: > raise UnsupportedError E koila.errors.UnsupportedError ../miniconda3/envs/snakes/lib/python3.9/site-packages/koila/prepasses.py:423: UnsupportedError
The text was updated successfully, but these errors were encountered:
Hi, that means the batch sizes don't match, and the library doesn't know how to deal with that situation.
Since PyTorch's broadcasting rules are extensive, not all rules are supported yet.
I'll see what I can do about it in the upcoming changes in #18 with a much more modular implementation.
Sorry, something went wrong.
No branches or pull requests
I am trying to apply koila lazy eval on a Unet3D.
This fails and outputs:
The text was updated successfully, but these errors were encountered: