Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure safety of indirect dispatch #5714

Open
wants to merge 9 commits into
base: trunk
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216).
- Fix error message that is thrown in create_render_pass to no longer say `compute_pass`. By @matthew-wong1 [#6041](https://github.com/gfx-rs/wgpu/pull/6041)
- Add `VideoFrame` to `ExternalImageSource` enum. By @jprochazk in [#6170](https://github.com/gfx-rs/wgpu/pull/6170)
- Document `wgpu_hal` bounds-checking promises, and adapt `wgpu_core`'s lazy initialization logic to the slightly weaker-than-expected guarantees. By @jimblandy in [#6201](https://github.com/gfx-rs/wgpu/pull/6201)
- Ensure safety of indirect dispatch by injecting a compute shader that validates the content of the indirect buffer. By @teoxoy in [#5714](https://github.com/gfx-rs/wgpu/pull/5714)

#### GLES / OpenGL

Expand Down
244 changes: 244 additions & 0 deletions tests/tests/dispatch_workgroups_indirect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
use wgpu_test::{gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext};

/// Make sure that the num_workgroups builtin works properly (it requires a workaround on D3D12).
#[gpu_test]
static NUM_WORKGROUPS_BUILTIN: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.features(wgpu::Features::PUSH_CONSTANTS)
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits {
max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults()
})
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
)
.run_async(|ctx| async move {
let num_workgroups = [1, 2, 3];
let res = run_test(&ctx, &num_workgroups, false).await;
assert_eq!(res, num_workgroups);
});

/// Make sure that we discard (don't run) the dispatch if its size exceeds the device limit.
#[gpu_test]
static DISCARD_DISPATCH: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.features(wgpu::Features::PUSH_CONSTANTS)
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits {
max_compute_workgroups_per_dimension: 10,
max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults()
})
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
)
.run_async(|ctx| async move {
let max = ctx.device.limits().max_compute_workgroups_per_dimension;

let res = run_test(&ctx, &[max, max, max], false).await;
assert_eq!(res, [max; 3]);

let res = run_test(&ctx, &[max + 1, 1, 1], false).await;
assert_eq!(res, [0; 3]);
jimblandy marked this conversation as resolved.
Show resolved Hide resolved

let res = run_test(&ctx, &[1, max + 1, 1], false).await;
assert_eq!(res, [0; 3]);

let res = run_test(&ctx, &[1, 1, max + 1], false).await;
assert_eq!(res, [0; 3]);
});

/// Make sure that resetting the bind groups set by the validation code works properly.
#[gpu_test]
static RESET_BIND_GROUPS: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.features(wgpu::Features::PUSH_CONSTANTS)
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits {
max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults()
}),
)
.run_async(|ctx| async move {
ctx.device.push_error_scope(wgpu::ErrorFilter::Validation);

let _ = run_test(&ctx, &[0, 0, 0], true).await;

let error = pollster::block_on(ctx.device.pop_error_scope());
assert!(error.map_or(false, |error| {
format!("{error}").contains("The current set ComputePipeline with '' label expects a BindGroup to be set at index 0")
}));
});

async fn run_test(
ctx: &TestingContext,
num_workgroups: &[u32; 3],
forget_to_set_bind_group: bool,
) -> [u32; 3] {
const SHADER_SRC: &str = "
struct TestOffsetPc {
inner: u32,
}

// `test_offset.inner` should always be 0; we test that resetting the push constant set by the validation code works properly.
var<push_constant> test_offset: TestOffsetPc;

@group(0) @binding(0)
var<storage, read_write> out: array<u32, 3>;

@compute @workgroup_size(1)
fn main(@builtin(num_workgroups) num_workgroups: vec3<u32>, @builtin(workgroup_id) workgroup_id: vec3<u32>) {
if (all(workgroup_id == vec3<u32>())) {
out[0] = num_workgroups.x + test_offset.inner;
out[1] = num_workgroups.y + test_offset.inner;
out[2] = num_workgroups.z + test_offset.inner;
}
}
";

let module = ctx
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()),
});

let bgl = ctx
.device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &[wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgt::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}],
});

let layout = ctx
.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&bgl],
push_constant_ranges: &[wgt::PushConstantRange {
stages: wgt::ShaderStages::COMPUTE,
range: 0..4,
}],
});

let pipeline = ctx
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: Some(&layout),
module: &module,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});

let out_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 12,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});

let readback_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 12,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});

let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &pipeline.get_bind_group_layout(0),
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: out_buffer.as_entire_binding(),
}],
});

let mut res = None;

for (indirect_offset, indirect_buffer_size) in [
// internal src buffer binding size will be buffer.size
(0, 12),
(4, 4 + 12),
(4, 8 + 12),
(256 * 2 - 4 - 12, 256 * 2 - 4),
// internal src buffer binding size will be 256 * 2 + x
(0, 256 * 2 * 2 + 4),
(256, 256 * 2 * 2 + 8),
(256 + 4, 256 * 2 * 2 + 12),
(256 * 2 + 16, 256 * 2 * 2 + 16),
(256 * 2 * 2, 256 * 2 * 2 + 32),
(256 + 12, 256 * 2 * 2 + 64),
] {
let indirect_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: indirect_buffer_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::INDIRECT,
mapped_at_creation: false,
});

ctx.queue.write_buffer(
&indirect_buffer,
indirect_offset,
bytemuck::bytes_of(num_workgroups),
);

let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
{
let mut compute_pass =
encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
compute_pass.set_pipeline(&pipeline);
compute_pass.set_push_constants(0, &[0, 0, 0, 0]);
if !forget_to_set_bind_group {
compute_pass.set_bind_group(0, Some(&bind_group), &[]);
}
compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset);
}

encoder.copy_buffer_to_buffer(&out_buffer, 0, &readback_buffer, 0, 12);

ctx.queue.submit(Some(encoder.finish()));

readback_buffer
.slice(..)
.map_async(wgpu::MapMode::Read, |_| {});

ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();

let view = readback_buffer.slice(..).get_mapped_range();

let current_res = *bytemuck::from_bytes(&view);
drop(view);
readback_buffer.unmap();

if let Some(past_res) = res {
assert_eq!(past_res, current_res);
} else {
res = Some(current_res);
}
}

res.unwrap()
}
1 change: 1 addition & 0 deletions tests/tests/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod clear_texture;
mod compute_pass_ownership;
mod create_surface_error;
mod device;
mod dispatch_workgroups_indirect;
mod encoder;
mod external_texture;
mod float32_filterable;
Expand Down
4 changes: 4 additions & 0 deletions wgpu-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ renderdoc = ["hal/renderdoc"]
## to the validation carried out at public APIs in all builds.
strict_asserts = ["wgt/strict_asserts"]

## Validates indirect draw/dispatch calls. This will also enable naga's
## WGSL frontend since we use a WGSL compute shader to do the validation.
indirect-validation = ["naga/wgsl-in"]

## Enables serialization via `serde` on common wgpu types.
serde = ["dep:serde", "wgt/serde", "arrayvec/serde"]

Expand Down
Loading