Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make fp8 work on older GPUs #34

Merged
merged 5 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions fp8/float8_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,16 +275,38 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
prev_dims = x.shape[:-1]
x = x.view(-1, self.in_features)

# float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices!
out = torch._scaled_mm( # noqa
x,
self.float8_data.T,
scale_a=self.input_scale_reciprocal,
scale_b=self.scale_reciprocal,
bias=self.bias,
out_dtype=self.weight.dtype,
use_fast_accum=True,
)
device = x.device
if x.device.type != 'cpu' and torch.cuda.get_device_capability(x.device) >= (8, 9):
# float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices!
out = torch._scaled_mm( # noqa
x,
self.float8_data.T,
scale_a=self.input_scale_reciprocal,
scale_b=self.scale_reciprocal,
bias=self.bias,
out_dtype=self.weight.dtype,
use_fast_accum=True,
)
else:
# Plain matrix multiplication for non-ADA devices
# Assuming x is in float8 and self.float8_data is in float8 as well
# Convert to float32, perform the multiplication, and then apply scaling and bias if necessary

# Convert float8 to float32 for the multiplication
x_float32 = x.to(torch.float32)
float8_data_float32 = self.float8_data.T.to(torch.float32)

# Regular matrix multiplication
out = torch.matmul(x_float32, float8_data_float32)

# Scale the output accordingly
out = out * (self.input_scale_reciprocal * self.scale_reciprocal)

# Add bias if it exists
if self.bias is not None:
out += self.bias
out = out.to(self.weight.dtype)

if IS_TORCH_2_4:
out = out[0]
return out.view(*prev_dims, self.out_features)
Expand Down
29 changes: 26 additions & 3 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,11 @@ def base_setup(
self.falcon_processor = ViTImageProcessor.from_pretrained(FALCON_MODEL_NAME)

# need > 48 GB of ram to store all models in VRAM
self.offload = "A40" in gpu_name
total_mem = torch.cuda.get_device_properties(0).total_memory
self.offload = total_mem < 48 * 1024**3
if self.offload:
print("GPU memory is:", total_mem / 1024**3, ", offloading models")
compile_fp8 = False

device = "cuda"
max_length = 256 if self.flow_model_name == "flux-schnell" else 512
Expand All @@ -187,13 +191,32 @@ def base_setup(
flow=None, ae=self.ae, clip=self.clip, t5=self.t5, config=None
)

# fp8 only works w/compute capability >= 8.9
self.disable_fp8 = disable_fp8 or torch.cuda.get_device_capability() < (8, 9)
self.disable_fp8 = disable_fp8

if not self.disable_fp8:
if compile_fp8:
extra_args = {
"compile_whole_model": True,
"compile_extras": True,
"compile_blocks": True,
}
else:
extra_args = {
"compile_whole_model": False,
"compile_extras": False,
"compile_blocks": False,
}

if self.offload:
extra_args |= {
"offload_text_encoder": True,
"offload_vae": True,
"offload_flow": True,
}
self.fp8_pipe = FluxPipeline.load_pipeline_from_config_path(
f"fp8/configs/config-1-{flow_model_name}-h100.json",
shared_models=shared_models,
**extra_args,
)

if compile_fp8:
Expand Down
Loading