From 27181bd8a800f5a66dd1ff14c154f98ba8beabc6 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Thu, 10 Oct 2024 17:06:54 +0200 Subject: [PATCH 1/5] fp8: fall back to float32 matmul on cuda capability < 8.9 This re-enables the use of fp8 on older GPUs, which can be useful to save vram. --- fp8/float8_quantize.py | 42 ++++++++++++++++++++++++++++++++---------- predict.py | 3 +-- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/fp8/float8_quantize.py b/fp8/float8_quantize.py index 3e48e91..400edc0 100644 --- a/fp8/float8_quantize.py +++ b/fp8/float8_quantize.py @@ -275,16 +275,38 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: prev_dims = x.shape[:-1] x = x.view(-1, self.in_features) - # float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices! - out = torch._scaled_mm( # noqa - x, - self.float8_data.T, - scale_a=self.input_scale_reciprocal, - scale_b=self.scale_reciprocal, - bias=self.bias, - out_dtype=self.weight.dtype, - use_fast_accum=True, - ) + device = x.device + if x.device.type != 'cpu' and torch.cuda.get_device_capability(x.device) >= (8, 9): + # float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices! + out = torch._scaled_mm( # noqa + x, + self.float8_data.T, + scale_a=self.input_scale_reciprocal, + scale_b=self.scale_reciprocal, + bias=self.bias, + out_dtype=self.weight.dtype, + use_fast_accum=True, + ) + else: + # Plain matrix multiplication for non-ADA devices + # Assuming x is in float8 and self.float8_data is in float8 as well + # Convert to float32, perform the multiplication, and then apply scaling and bias if necessary + + # Convert float8 to float32 for the multiplication + x_float32 = x.to(torch.float32) + float8_data_float32 = self.float8_data.T.to(torch.float32) + + # Regular matrix multiplication + out = torch.matmul(x_float32, float8_data_float32) + + # Scale the output accordingly + out = out * (self.input_scale_reciprocal * self.scale_reciprocal) + + # Add bias if it exists + if self.bias is not None: + out += self.bias + out = out.to(self.weight.dtype) + if IS_TORCH_2_4: out = out[0] return out.view(*prev_dims, self.out_features) diff --git a/predict.py b/predict.py index 4132dba..8910423 100644 --- a/predict.py +++ b/predict.py @@ -187,8 +187,7 @@ def base_setup( flow=None, ae=self.ae, clip=self.clip, t5=self.t5, config=None ) - # fp8 only works w/compute capability >= 8.9 - self.disable_fp8 = disable_fp8 or torch.cuda.get_device_capability() < (8, 9) + self.disable_fp8 = disable_fp8 if not self.disable_fp8: self.fp8_pipe = FluxPipeline.load_pipeline_from_config_path( From f06f157649f74b8057ecbdf35af5c6b1a4ce64e8 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Thu, 10 Oct 2024 17:25:05 +0200 Subject: [PATCH 2/5] fp8: override offload/compile config based on what we're doing --- predict.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/predict.py b/predict.py index 8910423..6396154 100644 --- a/predict.py +++ b/predict.py @@ -190,9 +190,29 @@ def base_setup( self.disable_fp8 = disable_fp8 if not self.disable_fp8: + if compile_fp8: + extra_args = { + "compile_whole_model": True, + "compile_extras": True, + "compile_blocks": True + } + else: + extra_args = { + "compile_whole_model": False, + "compile_extras": False, + "compile_blocks": False + } + + if self.offload: + extra_args |= { + "offload_text_encoder": True, + "offload_vae": True, + "offload_flow": True + } self.fp8_pipe = FluxPipeline.load_pipeline_from_config_path( f"fp8/configs/config-1-{flow_model_name}-h100.json", shared_models=shared_models, + **extra_args ) if compile_fp8: From 51e7ee4b61c12d73bbca869e938fca748c492ef2 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Thu, 10 Oct 2024 17:25:28 +0200 Subject: [PATCH 3/5] Check video memory to decide when to offload --- predict.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/predict.py b/predict.py index 6396154..60e29ef 100644 --- a/predict.py +++ b/predict.py @@ -165,7 +165,10 @@ def base_setup( self.falcon_processor = ViTImageProcessor.from_pretrained(FALCON_MODEL_NAME) # need > 48 GB of ram to store all models in VRAM - self.offload = "A40" in gpu_name + total_mem = torch.cuda.get_device_properties(0).total_memory + self.offload = total_mem < 48 * 1024**3 + if self.offload: + print("GPU memory is:", total_mem / 1024 ** 3, ", offloading models") device = "cuda" max_length = 256 if self.flow_model_name == "flux-schnell" else 512 From 0aac3088e302202de32463c8b280a2d9b4d13a55 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Thu, 10 Oct 2024 17:30:21 +0200 Subject: [PATCH 4/5] Ruff format --- predict.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/predict.py b/predict.py index 60e29ef..4dcac0a 100644 --- a/predict.py +++ b/predict.py @@ -168,7 +168,7 @@ def base_setup( total_mem = torch.cuda.get_device_properties(0).total_memory self.offload = total_mem < 48 * 1024**3 if self.offload: - print("GPU memory is:", total_mem / 1024 ** 3, ", offloading models") + print("GPU memory is:", total_mem / 1024**3, ", offloading models") device = "cuda" max_length = 256 if self.flow_model_name == "flux-schnell" else 512 @@ -197,25 +197,25 @@ def base_setup( extra_args = { "compile_whole_model": True, "compile_extras": True, - "compile_blocks": True + "compile_blocks": True, } else: extra_args = { "compile_whole_model": False, "compile_extras": False, - "compile_blocks": False + "compile_blocks": False, } if self.offload: extra_args |= { "offload_text_encoder": True, "offload_vae": True, - "offload_flow": True + "offload_flow": True, } self.fp8_pipe = FluxPipeline.load_pipeline_from_config_path( f"fp8/configs/config-1-{flow_model_name}-h100.json", shared_models=shared_models, - **extra_args + **extra_args, ) if compile_fp8: From 4fcda4d71b7e539eac85c78e7dcb81f4fa56f2a9 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Fri, 1 Nov 2024 21:15:07 +0100 Subject: [PATCH 5/5] Don't compile fp8 when offloaded, it's going to be slow anyways --- predict.py | 1 + 1 file changed, 1 insertion(+) diff --git a/predict.py b/predict.py index 4dcac0a..d0d8088 100644 --- a/predict.py +++ b/predict.py @@ -169,6 +169,7 @@ def base_setup( self.offload = total_mem < 48 * 1024**3 if self.offload: print("GPU memory is:", total_mem / 1024**3, ", offloading models") + compile_fp8 = False device = "cuda" max_length = 256 if self.flow_model_name == "flux-schnell" else 512