Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support blockwise fp8 matmul kernel #3267

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

yizhang2077
Copy link
Collaborator

Motivation

support fp8 blockwise kernel (Currently only supports scale_a block shapes of 1x128 and scale_b block shapes of 128x128 for deepseek v3), mainly from vllm
correctness
python3 tests/test_fp8_blockwise_gemm.py
benchmark
python3 benchmark/bench_fp8_blockwise_gemm.py --models meta-llama/Llama-3.1-8B-Instruct
TODO (update benchmark results)

Modifications

Checklist

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yizhang2077 Delete all of these, we should use the version directly from 3rdparty CUTLASS instead.
BTW @BBuf will integrate a higher-performance version than this baseline, this PR won't be used directly.

@yizhang2077 yizhang2077 force-pushed the support-fp8-blockwise branch from 0cee1ef to 7b9ee19 Compare February 3, 2025 10:40
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)


WEIGHT_SHAPES = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think DeepSeek V3 is sufficient, other models use per-tensor FP8, so we don't need a benchmark in that form.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants