Skip to content

Commit

Permalink
Ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
yorickvP committed Nov 1, 2024
1 parent 51e7ee4 commit 0aac308
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def base_setup(
total_mem = torch.cuda.get_device_properties(0).total_memory
self.offload = total_mem < 48 * 1024**3
if self.offload:
print("GPU memory is:", total_mem / 1024 ** 3, ", offloading models")
print("GPU memory is:", total_mem / 1024**3, ", offloading models")

device = "cuda"
max_length = 256 if self.flow_model_name == "flux-schnell" else 512
Expand Down Expand Up @@ -197,25 +197,25 @@ def base_setup(
extra_args = {
"compile_whole_model": True,
"compile_extras": True,
"compile_blocks": True
"compile_blocks": True,
}
else:
extra_args = {
"compile_whole_model": False,
"compile_extras": False,
"compile_blocks": False
"compile_blocks": False,
}

if self.offload:
extra_args |= {
"offload_text_encoder": True,
"offload_vae": True,
"offload_flow": True
"offload_flow": True,
}
self.fp8_pipe = FluxPipeline.load_pipeline_from_config_path(
f"fp8/configs/config-1-{flow_model_name}-h100.json",
shared_models=shared_models,
**extra_args
**extra_args,
)

if compile_fp8:
Expand Down

0 comments on commit 0aac308

Please sign in to comment.