Skip to content

Commit

Permalink
Tweak some metal tests. (#2528)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Oct 2, 2024
1 parent a2bcc22 commit fd08d3d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 62 deletions.
5 changes: 0 additions & 5 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2372,16 +2372,11 @@ pub fn call_const_fill(
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();

encoder.set_compute_pipeline_state(&pipeline);

set_params!(encoder, (output, v, length));

let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);

encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);

Ok(())
}

Expand Down
80 changes: 23 additions & 57 deletions candle-metal-kernels/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2309,66 +2309,32 @@ fn conv_transpose1d_u32() {
assert_eq!(results, expected);
}

fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
let dev = device();
let kernels = Kernels::new();
let command_queue = dev.new_command_queue();
let command_buffer = command_queue.new_command_buffer();

let buffer = dev.new_buffer(
(len * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModePrivate,
);

call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap();

command_buffer.commit();
command_buffer.wait_until_completed();

read_to_vec::<T>(&buffer, len)
}

#[test]
fn const_fill() {
let fills = [
"fill_u8",
"fill_u32",
"fill_i64",
"fill_f16",
"fill_bf16",
"fill_f32",
];

for name in fills {
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
let dev = device();
let kernels = Kernels::new();
let command_queue = dev.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let buffer = dev.new_buffer(
(len * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModePrivate,
);
call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec::<T>(&buffer, len)
}
fn test<T: Clone + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(name: &'static str, f: F) {
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
let value = rand::thread_rng().gen_range(1. ..19.);

match name {
"fill_u8" => {
let v = constant_fill::<u8>(name, len, value);
assert_eq!(v, vec![value as u8; len])
}
"fill_u32" => {
let v = constant_fill::<u32>(name, len, value);
assert_eq!(v, vec![value as u32; len])
}
"fill_i64" => {
let v = constant_fill::<i64>(name, len, value);
assert_eq!(v, vec![value as i64; len])
}
"fill_f16" => {
let v = constant_fill::<f16>(name, len, value);
assert_eq!(v, vec![f16::from_f32(value); len])
}
"fill_bf16" => {
let v = constant_fill::<bf16>(name, len, value);
assert_eq!(v, vec![bf16::from_f32(value); len])
}
"fill_f32" => {
let v = constant_fill::<f32>(name, len, value);
assert_eq!(v, vec![value; len])
}
_ => unimplemented!(),
};
let v = constant_fill::<T>(name, len, value);
assert_eq!(v, vec![f(value); len])
}
test::<u8, _>("fill_u8", |v| v as u8);
test::<u32, _>("fill_u32", |v| v as u32);
test::<i64, _>("fill_i64", |v| v as i64);
test::<f16, _>("fill_f16", f16::from_f32);
test::<bf16, _>("fill_bf16", bf16::from_f32);
test::<f32, _>("fill_f32", |v| v);
}

0 comments on commit fd08d3d

Please sign in to comment.