Skip to content

Commit

Permalink
ensure safety of indirect dispatch
Browse files Browse the repository at this point in the history
by injecting a compute shader that validates the content of the indirect buffer
  • Loading branch information
teoxoy committed May 21, 2024
1 parent 4902e47 commit 36281af
Show file tree
Hide file tree
Showing 10 changed files with 555 additions and 3 deletions.
155 changes: 155 additions & 0 deletions tests/tests/dispatch_workgroups_indirect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
use wgpu_test::{gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext};

/// Make sure that the num_workgroups builtin works properly (it requires a workaround on D3D12).
#[gpu_test]
static NUM_WORKGROUPS_BUILTIN: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits::downlevel_defaults())
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
)
.run_async(|ctx| async move {
let num_workgroups = [1, 2, 3];
let res = run_test(&ctx, &num_workgroups, false).await;
assert_eq!(res, num_workgroups);
});

/// Make sure that we discard (don't run) the dispatch if its size exceeds the device limit.
#[gpu_test]
static DISCARD_DISPATCH: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits {
max_compute_workgroups_per_dimension: 10,
..wgpu::Limits::downlevel_defaults()
}),
)
.run_async(|ctx| async move {
let max = ctx.device.limits().max_compute_workgroups_per_dimension;
let res = run_test(&ctx, &[max + 1, 1, 1], false).await;
assert_eq!(res, [0; 3]);
});

/// Make sure that unsetting the bind group set by the validation code works properly.
#[gpu_test]
static UNSET_INTERNAL_BIND_GROUP: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits::downlevel_defaults()),
)
.run_async(|ctx| async move {
ctx.device.push_error_scope(wgpu::ErrorFilter::Validation);

let _ = run_test(&ctx, &[0, 0, 0], true).await;

let error = pollster::block_on(ctx.device.pop_error_scope());
assert!(error.map_or(false, |error| format!("{error}")
.contains("Expected bind group is missing")));
});

async fn run_test(
ctx: &TestingContext,
num_workgroups: &[u32; 3],
forget_to_set_bind_group: bool,
) -> [u32; 3] {
const SHADER_SRC: &str = "
@group(0) @binding(0)
var<storage, read_write> out: vec3<u32>;
@compute @workgroup_size(1)
fn main(@builtin(num_workgroups) num_workgroups: vec3<u32>, @builtin(workgroup_id) workgroup_id: vec3<u32>) {
if (all(workgroup_id == vec3<u32>())) {
out = num_workgroups;
}
}
";

let module = ctx
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()),
});

let pipeline = ctx
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &module,
entry_point: "main",
compilation_options: Default::default(),
cache: None,
});

let out_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 12,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});

let readback_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 12,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});

let indirect_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 12,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::INDIRECT,
mapped_at_creation: false,
});

ctx.queue
.write_buffer(&indirect_buffer, 0, bytemuck::bytes_of(num_workgroups));

let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &pipeline.get_bind_group_layout(0),
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: out_buffer.as_entire_binding(),
}],
});

let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());

{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
compute_pass.set_pipeline(&pipeline);
if !forget_to_set_bind_group {
compute_pass.set_bind_group(0, &bind_group, &[]);
}
compute_pass.dispatch_workgroups_indirect(&indirect_buffer, 0);
}

encoder.copy_buffer_to_buffer(&out_buffer, 0, &readback_buffer, 0, 12);

ctx.queue.submit(Some(encoder.finish()));

readback_buffer
.slice(..)
.map_async(wgpu::MapMode::Read, |_| {});

ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();

let view = readback_buffer.slice(..).get_mapped_range();

bytemuck::from_bytes::<[u32; 3]>(&view).to_owned()
}
1 change: 1 addition & 0 deletions tests/tests/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ mod clear_texture;
mod compute_pass_resource_ownership;
mod create_surface_error;
mod device;
mod dispatch_workgroups_indirect;
mod encoder;
mod external_texture;
mod float32_filterable;
Expand Down
14 changes: 13 additions & 1 deletion wgpu-core/src/command/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ mod compat {
diff.push(format!("Expected {expected_bgl_type} bind group layout, got {assigned_bgl_type}"))
}
} else {
diff.push("Assigned bind group layout not found (internal error)".to_owned());
diff.push("Expected bind group is missing".to_owned());
}
} else {
diff.push("Expected bind group layout not found (internal error)".to_owned());
Expand Down Expand Up @@ -191,6 +191,10 @@ mod compat {
self.make_range(index)
}

pub fn unassign(&mut self, index: usize) {
self.entries[index].assigned = None;
}

pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
self.entries
.iter()
Expand Down Expand Up @@ -358,6 +362,14 @@ impl<A: HalApi> Binder<A> {
&self.payloads[bind_range]
}

pub(super) fn unassign_group(&mut self, index: usize) {
log::trace!("\tBinding [{}] = null", index);

self.payloads[index].reset();

self.manager.unassign(index);
}

pub(super) fn list_active<'a>(&'a self) -> impl Iterator<Item = &'a Arc<BindGroup<A>>> + '_ {
let payloads = &self.payloads;
self.manager
Expand Down
Loading

0 comments on commit 36281af

Please sign in to comment.