Skip to content

Commit

Permalink
Add a seed to the flux example. (#2529)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Oct 2, 2024
1 parent fd08d3d commit f479840
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions candle-examples/examples/flux/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,13 @@ struct Args {
#[arg(long, value_enum, default_value = "schnell")]
model: Model,

/// Use the faster kernels which are buggy at the moment.
/// Use the slower kernels.
#[arg(long)]
no_dmmv: bool,
use_dmmv: bool,

/// The seed to use when generating random samples.
#[arg(long)]
seed: Option<u64>,
}

#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
Expand Down Expand Up @@ -91,6 +95,9 @@ fn run(args: Args) -> Result<()> {
api.repo(hf_hub::Repo::model(name.to_string()))
};
let device = candle_examples::device(cpu)?;
if let Some(seed) = args.seed {
device.set_seed(seed)?;
}
let dtype = device.bf16_default_to_f32();
let img = match decode_only {
None => {
Expand Down Expand Up @@ -250,6 +257,6 @@ fn run(args: Args) -> Result<()> {
fn main() -> Result<()> {
let args = Args::parse();
#[cfg(feature = "cuda")]
candle::quantized::cuda::set_force_dmmv(!args.no_dmmv);
candle::quantized::cuda::set_force_dmmv(args.use_dmmv);
run(args)
}

0 comments on commit f479840

Please sign in to comment.