-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Pytorch] Swiglu implementation not aligned with jiterator version in probability #717
Comments
Some more setup information: torch version 2.1.0+cu122, TE version 1.3.0 python 3.9 |
Thanks @tylaar for reporting that. Let me take a look at that. |
I did a little bit of digging and (at least looking at the case I tried) it does not seem to be a bug, but rather an artifact of the finite precision of the computations. I modified your script a little bit to have deterministic execution and better show what happens: import torch
import transformer_engine
from transformer_engine.pytorch import cpp_extensions as tex
from transformer_engine.pytorch.constants import TE_DType
torch.manual_seed(1234)
torch.set_printoptions(precision=10)
# Adjust for different input
d1 = 256
d2 = 512
swiglu_fwd = """
template <typename T> T swiglu_fwd(T x, T y) {
return float(x) * float(y) / (1.0f + ::exp(-float(x)));
}
"""
swiglu_bwd = """
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
dy = float(x) * x_sigmoid * float(g);
}
"""
swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd)
swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd, num_outputs=2)
class MySwiglu(torch.autograd.Function):
@staticmethod
def forward(ctx, inputmat):
x1, x2 = torch.chunk(inputmat, 2, dim=-1)
ctx.save_for_backward(x1, x2)
return swiglu_fwd(x1, x2)
@staticmethod
def backward(ctx, dout):
x1, x2 = ctx.saved_tensors
return swiglu_bwd(x1, x2, dout)
myswiglu = MySwiglu.apply
input_mat1 = torch.rand(d2, d1).bfloat16().to('cuda')
input_mat2 = input_mat1.clone().detach().bfloat16().to('cuda')
r1 = myswiglu(input_mat1)
r2 = tex.swiglu(input_mat2, None, tex.FP8FwdTensors.GEMM2_INPUT, otype=TE_DType[torch.bfloat16])
print(r1)
print(r2)
mismatches = (r1-r2).nonzero()
print(mismatches)
for m in mismatches:
index = tuple(m)
print(r1[index])
print(r2[index])
in1, in2 = torch.chunk(input_mat1, 2, dim=-1)
x = in1.float()[index]
y = in2.float()[index]
out = x * y / (1 + torch.exp(-x))
print(out)
print(out.bfloat16())
temp1 = 1 / (1 + torch.exp(-x))
temp = x * temp1
out2 = temp * y
print(out2)
print(out2.bfloat16()) I tried it on H100, CUDA 12.4 and TE 1.4 (we did not change the logic of swiglu in 1.4 compared with 1.3 so it should be completely equivalent). tensor([[0.0128173828, 0.0776367188, 0.1162109375, ..., 0.3769531250,
0.0162353516, 0.0206298828],
[0.4765625000, 0.2636718750, 0.0295410156, ..., 0.1748046875,
0.1196289062, 0.1044921875],
[0.1738281250, 0.0634765625, 0.0056457520, ..., 0.0844726562,
0.0844726562, 0.0212402344],
...,
[0.0524902344, 0.0249023438, 0.6328125000, ..., 0.2734375000,
0.0625000000, 0.1972656250],
[0.1269531250, 0.0532226562, 0.0167236328, ..., 0.0145874023,
0.3593750000, 0.6250000000],
[0.1166992188, 0.0213623047, 0.0751953125, ..., 0.2236328125,
0.1250000000, 0.3027343750]], device='cuda:0', dtype=torch.bfloat16)
tensor([[0.0128173828, 0.0776367188, 0.1162109375, ..., 0.3769531250,
0.0162353516, 0.0206298828],
[0.4765625000, 0.2636718750, 0.0295410156, ..., 0.1748046875,
0.1196289062, 0.1044921875],
[0.1738281250, 0.0634765625, 0.0056457520, ..., 0.0844726562,
0.0844726562, 0.0212402344],
...,
[0.0524902344, 0.0249023438, 0.6328125000, ..., 0.2734375000,
0.0625000000, 0.1972656250],
[0.1269531250, 0.0532226562, 0.0167236328, ..., 0.0145874023,
0.3593750000, 0.6250000000],
[0.1166992188, 0.0213623047, 0.0751953125, ..., 0.2236328125,
0.1250000000, 0.3027343750]], device='cuda:0', dtype=torch.bfloat16)
tensor([[456, 127]], device='cuda:0')
tensor(0.4707031250, device='cuda:0', dtype=torch.bfloat16)
tensor(0.4687500000, device='cuda:0', dtype=torch.bfloat16)
tensor(0.4697265923, device='cuda:0')
tensor(0.4707031250, device='cuda:0', dtype=torch.bfloat16)
tensor(0.4697265625, device='cuda:0')
tensor(0.4687500000, device='cuda:0', dtype=torch.bfloat16) There is a mismatch at a place in1, in2 = torch.chunk(input_mat1, 2, dim=-1)
x = in1.float()[index]
y = in2.float()[index]
out = x * y / (1 + torch.exp(-x))
print(out)
print(out.bfloat16())
temp1 = 1 / (1 + torch.exp(-x))
temp = x * temp1
out2 = temp * y
print(out2)
print(out2.bfloat16()) which produces tensor(0.4697265923, device='cuda:0')
tensor(0.4707031250, device='cuda:0', dtype=torch.bfloat16)
tensor(0.4697265625, device='cuda:0')
tensor(0.4687500000, device='cuda:0', dtype=torch.bfloat16)
|
Thanks @ptrendx for your dedicated investigation and reply! That makes sense to me! Just adding a little which could be interesting on the compiler level, I did a little change on my forked TE on the implementation of swish part inside math.h, by letting swish not calling sigmoid template, instead making it looks like:
The forward diff will shrink down into almost 0 probability by iterating thousands of times in d1 = 1024 and d2 = 2048 I guess it's c++ level template type casting diff which causing the diff, but I haven't look at PTX level yet. However since you've explained quite thoroughly, I think this is not going to be a issue for me then. Thanks a lot! |
Hello there, sorry to bother again, it's during my investigation to issue #709, that I found some diff for implementation of swiglu been occur when the hidden_size become larger, here is a UT to reproduce:
So the thing is, I implemented a MySwiglu class, with fwd and bwd been compiled by torch.cuda.jiterator, and compared to tex version swiglu. Interesting thing is that, if you set L6 to d1 = 128, the final line diff almost didn't occur at all, while if d1 set to 256, there are some probability that some line of output is having mis-match, and when you set it to larger number for d1, let's say 1024 or 2048, diff result line appears more ...
I am not very sure if this is due to implementation bug of the jiterator's swiglu fwd, or it's something interesting inside the tex version swiglu ...
The text was updated successfully, but these errors were encountered: