-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
INTERNAL_ASSERT: Unable to find mapped root/logical domain #3607
Labels
Comments
This is the fusion. If we should avoid patterns, I would appreciate a hint: import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, C1, = saved_for_backward
clear_mutable_collection(saved_for_backward)
del saved_for_backward
t0, = cotangents
clear_mutable_collection(cotangents)
del cotangents
x, t_layer_norm_weight, input, = C0
clear_mutable_collection(C0)
del C0
i67, i66, = C1
clear_mutable_collection(C1)
del C1
[t118, t121, t152] = nvFusion0(t0, i67, i66, t_layer_norm_weight, input, x)
# t112 = prims.mul(1.0, t0) # t112: "cuda:0 f32[2, 24000]"
# t114 = prims.mul(2.0, t112) # t114: "cuda:0 f32[2, 24000]"
# t60 = prims.uniform_philox((2, 24000), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i66, offset=i67) # t60: "cuda:0 f32[2, 24000]"
# t61 = prims.lt(t60, 0.5) # t61: "cuda:0 b8[2, 24000]"
# t62 = prims.convert_element_type(t61, dtypes.float32) # t62: "cuda:0 f32[2, 24000]"
# t115 = prims.mul(t62, t114) # t115: "cuda:0 f32[2, 24000]"
# t118 = prims.sum(t115, (0,)) # t118: "cuda:0 f32[24000]"
# t56 = prims.broadcast_in_dim(t_layer_norm_weight, (2, 24000), (1,)) # t56: "cuda:0 f32[2, 24000]"
# t119 = prims.mul(t56, t115) # t119: "cuda:0 f32[2, 24000]"
# (t46, t47) = prims.var_mean(input, (1,), correction=0)
# t49 = prims.broadcast_in_dim(t47, [2, 1], [0]) # t49: "cuda:0 f32[2, 1]"
# t52 = prims.broadcast_in_dim(t49, (2, 24000), (0, 1)) # t52: "cuda:0 f32[2, 24000]"
# t53 = prims.sub(input, t52) # t53: "cuda:0 f32[2, 24000]"
# t48 = prims.broadcast_in_dim(t46, [2, 1], [0]) # t48: "cuda:0 f32[2, 1]"
# t50 = prims.add(t48, 1e-05) # t50: "cuda:0 f32[2, 1]"
# t51 = prims.rsqrt(t50) # t51: "cuda:0 f32[2, 1]"
# t54 = prims.broadcast_in_dim(t51, (2, 24000), (0, 1)) # t54: "cuda:0 f32[2, 24000]"
# t55 = prims.mul(t53, t54) # t55: "cuda:0 f32[2, 24000]"
# t120 = prims.mul(t55, t115) # t120: "cuda:0 f32[2, 24000]"
# t121 = prims.sum(t120, (0,)) # t121: "cuda:0 f32[24000]"
# t122 = prims.mul(t54, t119) # t122: "cuda:0 f32[2, 24000]"
# t123 = prims.mul(t53, t119) # t123: "cuda:0 f32[2, 24000]"
# t124 = prims.sum(t123, (1,)) # t124: "cuda:0 f32[2]"
# t125 = prims.broadcast_in_dim(t124, [2, 1], [0]) # t125: "cuda:0 f32[2, 1]"
# t126 = prims.neg(t122) # t126: "cuda:0 f32[2, 24000]"
# t127 = prims.sum(t126, (1,)) # t127: "cuda:0 f32[2]"
# t128 = prims.broadcast_in_dim(t127, [2, 1], [0]) # t128: "cuda:0 f32[2, 1]"
# t129 = prims.mul(-0.5, t125) # t129: "cuda:0 f32[2, 1]"
# t130 = prims.pow(t51, 3.0) # t130: "cuda:0 f32[2, 1]"
# t131 = prims.mul(t129, t130) # t131: "cuda:0 f32[2, 1]"
# t133 = prims.sum(t128, (1,)) # t133: "cuda:0 f32[2]"
# t134 = prims.sum(t131, (1,)) # t134: "cuda:0 f32[2]"
# t137 = prims.broadcast_in_dim(t133, [2, 1], [0]) # t137: "cuda:0 f32[2, 1]"
# t138 = prims.broadcast_in_dim(t137, (2, 24000), (0, 1)) # t138: "cuda:0 f32[2, 24000]"
# t139 = prims.mul(4.1666666666666665e-05, t138) # t139: "cuda:0 f32[2, 24000]"
# t140 = prims.broadcast_in_dim(t134, [2, 1], [0]) # t140: "cuda:0 f32[2, 1]"
# t141 = prims.broadcast_in_dim(t140, (2, 24000), (0, 1)) # t141: "cuda:0 f32[2, 24000]"
# t145 = prims.mul(2.0, t141) # t145: "cuda:0 f32[2, 24000]"
# t147 = prims.mul(t145, t53) # t147: "cuda:0 f32[2, 24000]"
# t148 = prims.div(t147, 24000.0) # t148: "cuda:0 f32[2, 24000]"
# t149 = prims.add(t139, t148) # t149: "cuda:0 f32[2, 24000]"
# t150 = prims.add(t122, t149) # t150: "cuda:0 f32[2, 24000]"
# t42 = prims.gt(x, 0.0) # t42: "cuda:0 b8[2, 24000]"
# t152 = prims.where(t42, t150, 0.0) # t152: "cuda:0 f32[2, 24000]"
del t0, i67, i66, t_layer_norm_weight, input, x
return (t152, t118, t121) |
IIUC, this seems to be a simple mistake we have in the transpose scheduler. I'm surprised we haven't had this error before. #3619 |
Merged
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
While working on Lightning-AI/lightning-thunder#1560 the randomness tests seem to fail with the following repro.
Traceback:
The text was updated successfully, but these errors were encountered: