Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[naga wgsl-in] Allow abstract literals to be used as return values #7035

Open
wants to merge 1 commit into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ By @brodycj in [#6924](https://github.com/gfx-rs/wgpu/pull/6924).
#### Naga

- Fix some instances of functions which have a return type but don't return a value being incorrectly validated. By @jamienicol in [#7013](https://github.com/gfx-rs/wgpu/pull/7013).
- Allow abstract expressions to be used in WGSL function return statements. By @jamienicol in [#7035](https://github.com/gfx-rs/wgpu/pull/7035).

#### General

Expand Down
26 changes: 25 additions & 1 deletion naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1677,7 +1677,31 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
emitter.start(&ctx.function.expressions);

let value = value
.map(|expr| self.expression(expr, &mut ctx.as_expression(block, &mut emitter)))
.map(|expr| {
let expr = self.expression_for_abstract(
expr,
&mut ctx.as_expression(block, &mut emitter),
)?;

if let Some(result_ty) = ctx.function.result.as_ref().map(|r| r.ty) {
let mut ctx = ctx.as_expression(block, &mut emitter);
let mut expr_ty = resolve_inner!(ctx, expr);
while let crate::TypeInner::Array { base, .. } = *expr_ty {
expr_ty = &ctx.module.types[base].inner;
}
if expr_ty.scalar().is_some_and(|s| s.is_abstract()) {
Comment on lines +1687 to +1692
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jimblandy I'm avoiding calling try_automatic_conversions() here for non-abstract types, as otherwise I saw in one of the tests that this code:

@group(0) @binding(0)
var<storage> atom: atomic<u32>;

fn return_atomic() -> atomic<u32> {
  return atom;
}

used to fail with this error:

  ┌─ in.wgsl:4:1
  │
4 │ ╭ fn return_atomic() -> atomic<u32> {
5 │ │   return atom;
  │ ╰──────────────^ naga::Function [0]
  │
  = The function's given return type cannot be returned from functions

but would now fail with:

Could not parse WGSL:
error: automatic conversions cannot convert `u32` to `atomic<u32>`
  ┌─ in.wgsl:4:10
  │
4 │   return atom;
  │          ^^^^ this expression has type u32

which is less clear.

Would it make sense to have this check as part of try_automatic_conversions()? Do we ever convert from things that are not abstract?

ctx.try_automatic_conversions(
expr,
&crate::proc::TypeResolution::Handle(result_ty),
Span::default(),
)
} else {
Ok(expr)
}
} else {
Ok(expr)
}
})
.transpose()?;
block.extend(emitter.finish(&ctx.function.expressions));

Expand Down
26 changes: 26 additions & 0 deletions naga/tests/in/abstract-types-return.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
@compute @workgroup_size(1)
fn main() {}

fn return_i32_ai() -> i32 {
return 1;
}

fn return_u32_ai() -> u32 {
return 1;
}

fn return_f32_ai() -> f32 {
return 1;
}

fn return_f32_af() -> f32 {
return 1.0;
}

fn return_vec2f32_ai() -> vec2<f32> {
return vec2(1);
}

fn return_arrf32_ai() -> array<f32, 4> {
return array(1, 1, 1, 1);
}
36 changes: 36 additions & 0 deletions naga/tests/out/glsl/abstract-types-return.main.Compute.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#version 310 es

precision highp float;
precision highp int;

layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;


int return_i32_ai() {
return 1;
}

uint return_u32_ai() {
return 1u;
}

float return_f32_ai() {
return 1.0;
}

float return_f32_af() {
return 1.0;
}

vec2 return_vec2f32_ai() {
return vec2(1.0);
}

float[4] return_arrf32_ai() {
return float[4](1.0, 1.0, 1.0, 1.0);
}

void main() {
return;
}

42 changes: 42 additions & 0 deletions naga/tests/out/hlsl/abstract-types-return.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
int return_i32_ai()
{
return 1;
}

uint return_u32_ai()
{
return 1u;
}

float return_f32_ai()
{
return 1.0;
}

float return_f32_af()
{
return 1.0;
}

float2 return_vec2f32_ai()
{
return (1.0).xx;
}

typedef float ret_Constructarray4_float_[4];
ret_Constructarray4_float_ Constructarray4_float_(float arg0, float arg1, float arg2, float arg3) {
float ret[4] = { arg0, arg1, arg2, arg3 };
return ret;
}

typedef float ret_return_arrf32_ai[4];
ret_return_arrf32_ai return_arrf32_ai()
{
return Constructarray4_float_(1.0, 1.0, 1.0, 1.0);
}

[numthreads(1, 1, 1)]
void main()
{
return;
}
12 changes: 12 additions & 0 deletions naga/tests/out/hlsl/abstract-types-return.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
(
vertex:[
],
fragment:[
],
compute:[
(
entry_point:"main",
target_profile:"cs_5_1",
),
],
)
44 changes: 44 additions & 0 deletions naga/tests/out/msl/abstract-types-return.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// language: metal1.0
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;

struct type_4 {
float inner[4];
};

int return_i32_ai(
) {
return 1;
}

uint return_u32_ai(
) {
return 1u;
}

float return_f32_ai(
) {
return 1.0;
}

float return_f32_af(
) {
return 1.0;
}

metal::float2 return_vec2f32_ai(
) {
return metal::float2(1.0);
}

type_4 return_arrf32_ai(
) {
return type_4 {1.0, 1.0, 1.0, 1.0};
}

kernel void main_(
) {
return;
}
70 changes: 70 additions & 0 deletions naga/tests/out/spv/abstract-types-return.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 41
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %38 "main"
OpExecutionMode %38 LocalSize 1 1 1
OpDecorate %7 ArrayStride 4
%2 = OpTypeVoid
%3 = OpTypeInt 32 1
%4 = OpTypeInt 32 0
%5 = OpTypeFloat 32
%6 = OpTypeVector %5 2
%8 = OpConstant %4 4
%7 = OpTypeArray %5 %8
%11 = OpTypeFunction %3
%12 = OpConstant %3 1
%16 = OpTypeFunction %4
%17 = OpConstant %4 1
%21 = OpTypeFunction %5
%22 = OpConstant %5 1.0
%29 = OpTypeFunction %6
%30 = OpConstantComposite %6 %22 %22
%34 = OpTypeFunction %7
%35 = OpConstantComposite %7 %22 %22 %22 %22
%39 = OpTypeFunction %2
%10 = OpFunction %3 None %11
%9 = OpLabel
OpBranch %13
%13 = OpLabel
OpReturnValue %12
OpFunctionEnd
%15 = OpFunction %4 None %16
%14 = OpLabel
OpBranch %18
%18 = OpLabel
OpReturnValue %17
OpFunctionEnd
%20 = OpFunction %5 None %21
%19 = OpLabel
OpBranch %23
%23 = OpLabel
OpReturnValue %22
OpFunctionEnd
%25 = OpFunction %5 None %21
%24 = OpLabel
OpBranch %26
%26 = OpLabel
OpReturnValue %22
OpFunctionEnd
%28 = OpFunction %6 None %29
%27 = OpLabel
OpBranch %31
%31 = OpLabel
OpReturnValue %30
OpFunctionEnd
%33 = OpFunction %7 None %34
%32 = OpLabel
OpBranch %36
%36 = OpLabel
OpReturnValue %35
OpFunctionEnd
%38 = OpFunction %2 None %39
%37 = OpLabel
OpBranch %40
%40 = OpLabel
OpReturn
OpFunctionEnd
28 changes: 28 additions & 0 deletions naga/tests/out/wgsl/abstract-types-return.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
fn return_i32_ai() -> i32 {
return 1i;
}

fn return_u32_ai() -> u32 {
return 1u;
}

fn return_f32_ai() -> f32 {
return 1f;
}

fn return_f32_af() -> f32 {
return 1f;
}

fn return_vec2f32_ai() -> vec2<f32> {
return vec2(1f);
}

fn return_arrf32_ai() -> array<f32, 4> {
return array<f32, 4>(1f, 1f, 1f, 1f);
}

@compute @workgroup_size(1, 1, 1)
fn main() {
return;
}
4 changes: 4 additions & 0 deletions naga/tests/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,10 @@ fn convert_wgsl() {
"abstract-types-operators",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL,
),
(
"abstract-types-return",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
(
"int64",
Targets::SPIRV | Targets::HLSL | Targets::WGSL | Targets::METAL,
Expand Down