From f479840ce6d2222bd004b6f275494297f1f0ae91 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 2 Oct 2024 10:52:02 +0200 Subject: [PATCH] Add a seed to the flux example. (#2529) --- candle-examples/examples/flux/main.rs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/candle-examples/examples/flux/main.rs b/candle-examples/examples/flux/main.rs index 24b1fa2bc..943db1121 100644 --- a/candle-examples/examples/flux/main.rs +++ b/candle-examples/examples/flux/main.rs @@ -45,9 +45,13 @@ struct Args { #[arg(long, value_enum, default_value = "schnell")] model: Model, - /// Use the faster kernels which are buggy at the moment. + /// Use the slower kernels. #[arg(long)] - no_dmmv: bool, + use_dmmv: bool, + + /// The seed to use when generating random samples. + #[arg(long)] + seed: Option, } #[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)] @@ -91,6 +95,9 @@ fn run(args: Args) -> Result<()> { api.repo(hf_hub::Repo::model(name.to_string())) }; let device = candle_examples::device(cpu)?; + if let Some(seed) = args.seed { + device.set_seed(seed)?; + } let dtype = device.bf16_default_to_f32(); let img = match decode_only { None => { @@ -250,6 +257,6 @@ fn run(args: Args) -> Result<()> { fn main() -> Result<()> { let args = Args::parse(); #[cfg(feature = "cuda")] - candle::quantized::cuda::set_force_dmmv(!args.no_dmmv); + candle::quantized::cuda::set_force_dmmv(args.use_dmmv); run(args) }