Skip to content

Commit

Permalink
ensure safety of indirect dispatch
Browse files Browse the repository at this point in the history
by injecting a compute shader that validates the content of the indirect buffer
  • Loading branch information
teoxoy committed May 21, 2024
1 parent 4902e47 commit ac3f089
Show file tree
Hide file tree
Showing 10 changed files with 559 additions and 3 deletions.
159 changes: 159 additions & 0 deletions tests/tests/dispatch_workgroups_indirect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext};

/// Make sure that we discard (don't run) the dispatch if its size exceeds the device limit.
#[gpu_test]
static DISCARD_DISPATCH: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits {
max_compute_workgroups_per_dimension: 10,
..wgpu::Limits::downlevel_defaults()
}),
)
.run_async(|ctx| async move {
let objects = setup_test(&ctx).await;

let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());

{
let mut compute_pass =
encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
compute_pass.set_pipeline(&objects.pipeline);
compute_pass.set_bind_group(0, &objects.bind_group, &[]);
compute_pass.dispatch_workgroups_indirect(&objects.indirect_buffer, 0);
}

encoder.copy_buffer_to_buffer(&objects.out_buffer, 0, &objects.readback_buffer, 0, 4);

ctx.queue.submit(Some(encoder.finish()));

objects
.readback_buffer
.slice(..)
.map_async(wgpu::MapMode::Read, |_| {});

ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();

let view = objects.readback_buffer.slice(..).get_mapped_range();

assert!(view.iter().all(|v| *v == 0));
});

/// Make sure that unsetting the bind group set by the validation code works properly.
#[gpu_test]
static UNSET_INTERNAL_BIND_GROUP: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits {
max_compute_workgroups_per_dimension: 10,
..wgpu::Limits::downlevel_defaults()
}),
)
.run_async(|ctx| async move {
let objects = setup_test(&ctx).await;

ctx.device.push_error_scope(wgpu::ErrorFilter::Validation);
let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
{
let mut compute_pass =
encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
compute_pass.set_pipeline(&objects.pipeline);
compute_pass.dispatch_workgroups_indirect(&objects.indirect_buffer, 0);
}
let _ = encoder.finish();
let error = pollster::block_on(ctx.device.pop_error_scope());
assert!(error.map_or(false, |error| format!("{error}")
.contains("Expected bind group is missing")));
});

struct CommonTestObjects {
pipeline: wgpu::ComputePipeline,
bind_group: wgpu::BindGroup,
indirect_buffer: wgpu::Buffer,
out_buffer: wgpu::Buffer,
readback_buffer: wgpu::Buffer,
}

async fn setup_test(ctx: &TestingContext) -> CommonTestObjects {
const SHADER_SRC: &str = "
@group(0) @binding(0)
var<storage, read_write> out: u32;
@compute @workgroup_size(1)
fn main() {
out = 1u;
}
";

let module = ctx
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()),
});

let pipeline = ctx
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &module,
entry_point: "main",
compilation_options: Default::default(),
cache: None,
});

let out_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 4,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});

let readback_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 4,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});

let indirect_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 12,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::INDIRECT,
mapped_at_creation: false,
});

let max = ctx.device.limits().max_compute_workgroups_per_dimension;
ctx.queue
.write_buffer(&indirect_buffer, 0, bytemuck::bytes_of(&[max + 1, 1, 1]));

let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &pipeline.get_bind_group_layout(0),
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: out_buffer.as_entire_binding(),
}],
});

CommonTestObjects {
pipeline,
bind_group,
indirect_buffer,
out_buffer,
readback_buffer,
}
}
1 change: 1 addition & 0 deletions tests/tests/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ mod clear_texture;
mod compute_pass_resource_ownership;
mod create_surface_error;
mod device;
mod dispatch_workgroups_indirect;
mod encoder;
mod external_texture;
mod float32_filterable;
Expand Down
14 changes: 13 additions & 1 deletion wgpu-core/src/command/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ mod compat {
diff.push(format!("Expected {expected_bgl_type} bind group layout, got {assigned_bgl_type}"))
}
} else {
diff.push("Assigned bind group layout not found (internal error)".to_owned());
diff.push("Expected bind group is missing".to_owned());
}
} else {
diff.push("Expected bind group layout not found (internal error)".to_owned());
Expand Down Expand Up @@ -191,6 +191,10 @@ mod compat {
self.make_range(index)
}

pub fn unassign(&mut self, index: usize) {
self.entries[index].assigned = None;
}

pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
self.entries
.iter()
Expand Down Expand Up @@ -358,6 +362,14 @@ impl<A: HalApi> Binder<A> {
&self.payloads[bind_range]
}

pub(super) fn unassign_group(&mut self, index: usize) {
log::trace!("\tBinding [{}] = null", index);

self.payloads[index].reset();

self.manager.unassign(index);
}

pub(super) fn list_active<'a>(&'a self) -> impl Iterator<Item = &'a Arc<BindGroup<A>>> + '_ {
let payloads = &self.payloads;
self.manager
Expand Down
Loading

0 comments on commit ac3f089

Please sign in to comment.