Skip to content

Commit 8d77a40

Browse files
iwoplazawcandillon
andauthored
chore: Update to TypeGPU 0.2 (#178)
--------- Co-authored-by: William Candillon <[email protected]>
1 parent ed6b68e commit 8d77a40

File tree

5 files changed

+108
-116
lines changed

5 files changed

+108
-116
lines changed

apps/example/package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
"react-native-wgpu": "*",
2929
"teapot": "^1.0.0",
3030
"three": "0.168.0",
31-
"typegpu": "^0.1.2",
31+
"typegpu": "^0.2.0",
3232
"wgpu-matrix": "^3.0.2"
3333
},
3434
"devDependencies": {

apps/example/src/ComputeBoids/ComputeBoids.tsx

+56-73
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,33 @@ type BoidsOptions = {
1818
cohesionStrength: number;
1919
};
2020

21+
const Parameters = struct({
22+
separationDistance: f32,
23+
separationStrength: f32,
24+
alignmentDistance: f32,
25+
alignmentStrength: f32,
26+
cohesionDistance: f32,
27+
cohesionStrength: f32,
28+
});
29+
30+
const TriangleData = struct({
31+
position: vec2f,
32+
velocity: vec2f,
33+
});
34+
35+
const TriangleDataArray = (n: number) => arrayOf(TriangleData, n);
36+
37+
const renderBindGroupLayout = tgpu.bindGroupLayout({
38+
trianglePos: { storage: TriangleDataArray },
39+
colorPalette: { uniform: vec3f },
40+
});
41+
42+
const computeBindGroupLayout = tgpu.bindGroupLayout({
43+
currentTrianglePos: { storage: TriangleDataArray },
44+
nextTrianglePos: { storage: TriangleDataArray, access: 'mutable' },
45+
params: { uniform: Parameters },
46+
});
47+
2148
const colorPresets = {
2249
plumTree: vec3f(1.0, 2.0, 1.0),
2350
jeans: vec3f(2.0, 1.5, 1.0),
@@ -69,28 +96,20 @@ export function ComputeBoids() {
6996
);
7097

7198
const ref = useWebGPU(({ context, device, presentationFormat }) => {
99+
const root = tgpu.initFromDevice({ device });
100+
72101
context.configure({
73102
device,
74103
format: presentationFormat,
75104
alphaMode: "premultiplied",
76105
});
77106

78-
const params = struct({
79-
separationDistance: f32,
80-
separationStrength: f32,
81-
alignmentDistance: f32,
82-
alignmentStrength: f32,
83-
cohesionDistance: f32,
84-
cohesionStrength: f32,
85-
});
86-
87-
const paramsBuffer = tgpu
88-
.createBuffer(params, presets.default)
89-
.$device(device)
90-
.$usage(tgpu.Storage);
107+
const paramsBuffer = root
108+
.createBuffer(Parameters, presets.default)
109+
.$usage("uniform");
91110

92111
const triangleSize = 0.03;
93-
const triangleVertexBuffer = tgpu
112+
const triangleVertexBuffer = root
94113
.createBuffer(arrayOf(f32, 6), [
95114
0.0,
96115
triangleSize,
@@ -99,42 +118,33 @@ export function ComputeBoids() {
99118
triangleSize / 2,
100119
-triangleSize / 2,
101120
])
102-
.$device(device)
103-
.$usage(tgpu.Vertex);
121+
.$usage("vertex");
104122

105123
const triangleAmount = 1000;
106-
const triangleInfoStruct = struct({
107-
position: vec2f,
108-
velocity: vec2f,
109-
});
110124
const trianglePosBuffers = Array.from({ length: 2 }, () =>
111-
tgpu
112-
.createBuffer(arrayOf(triangleInfoStruct, triangleAmount))
113-
.$device(device)
114-
.$usage(tgpu.Storage, tgpu.Uniform),
125+
root.createBuffer(TriangleDataArray(triangleAmount)).$usage("storage")
115126
);
116127

117128
randomizePositions.current = () => {
118129
const positions = Array.from({ length: triangleAmount }, () => ({
119130
position: vec2f(Math.random() * 2 - 1, Math.random() * 2 - 1),
120131
velocity: vec2f(Math.random() * 0.1 - 0.05, Math.random() * 0.1 - 0.05),
121132
}));
122-
tgpu.write(trianglePosBuffers[0], positions);
123-
tgpu.write(trianglePosBuffers[1], positions);
133+
trianglePosBuffers[0].write(positions);
134+
trianglePosBuffers[1].write(positions);
124135
};
125136
randomizePositions.current();
126137

127-
const colorPaletteBuffer = tgpu
138+
const colorPaletteBuffer = root
128139
.createBuffer(vec3f, colorPresets.plumTree)
129-
.$device(device)
130-
.$usage(tgpu.Uniform);
140+
.$usage("uniform");
131141

132142
updateColorPreset.current = (newColorPreset: ColorPresets) => {
133-
tgpu.write(colorPaletteBuffer, colorPresets[newColorPreset]);
143+
colorPaletteBuffer.write(colorPresets[newColorPreset]);
134144
};
135145

136146
updateParams.current = (newOptions: BoidsOptions) => {
137-
tgpu.write(paramsBuffer, newOptions);
147+
paramsBuffer.write(newOptions);
138148
};
139149

140150
const renderModule = device.createShaderModule({
@@ -146,7 +156,9 @@ export function ComputeBoids() {
146156
});
147157

148158
const pipeline = device.createRenderPipeline({
149-
layout: "auto",
159+
layout: device.createPipelineLayout({
160+
bindGroupLayouts: [root.unwrap(renderBindGroupLayout)],
161+
}),
150162
vertex: {
151163
module: renderModule,
152164
buffers: [
@@ -176,55 +188,26 @@ export function ComputeBoids() {
176188
});
177189

178190
const computePipeline = device.createComputePipeline({
179-
layout: "auto",
191+
layout: device.createPipelineLayout({
192+
bindGroupLayouts: [root.unwrap(computeBindGroupLayout)],
193+
}),
180194
compute: {
181195
module: computeModule,
182196
},
183197
});
184198

185199
const renderBindGroups = [0, 1].map((idx) =>
186-
device.createBindGroup({
187-
layout: pipeline.getBindGroupLayout(0),
188-
entries: [
189-
{
190-
binding: 0,
191-
resource: {
192-
buffer: trianglePosBuffers[idx].buffer,
193-
},
194-
},
195-
{
196-
binding: 1,
197-
resource: {
198-
buffer: colorPaletteBuffer.buffer,
199-
},
200-
},
201-
],
200+
renderBindGroupLayout.populate({
201+
trianglePos: trianglePosBuffers[idx],
202+
colorPalette: colorPaletteBuffer,
202203
}),
203204
);
204205

205206
const computeBindGroups = [0, 1].map((idx) =>
206-
device.createBindGroup({
207-
layout: computePipeline.getBindGroupLayout(0),
208-
entries: [
209-
{
210-
binding: 0,
211-
resource: {
212-
buffer: trianglePosBuffers[idx].buffer,
213-
},
214-
},
215-
{
216-
binding: 1,
217-
resource: {
218-
buffer: trianglePosBuffers[1 - idx].buffer,
219-
},
220-
},
221-
{
222-
binding: 2,
223-
resource: {
224-
buffer: paramsBuffer.buffer,
225-
},
226-
},
227-
],
207+
computeBindGroupLayout.populate({
208+
currentTrianglePos: trianglePosBuffers[idx],
209+
nextTrianglePos: trianglePosBuffers[1 - idx],
210+
params: paramsBuffer,
228211
}),
229212
);
230213

@@ -251,7 +234,7 @@ export function ComputeBoids() {
251234
computePass.setPipeline(computePipeline);
252235
computePass.setBindGroup(
253236
0,
254-
even ? computeBindGroups[0] : computeBindGroups[1],
237+
root.unwrap(even ? computeBindGroups[0] : computeBindGroups[1])
255238
);
256239
computePass.dispatchWorkgroups(triangleAmount);
257240
computePass.end();
@@ -261,7 +244,7 @@ export function ComputeBoids() {
261244
passEncoder.setVertexBuffer(0, triangleVertexBuffer.buffer);
262245
passEncoder.setBindGroup(
263246
0,
264-
even ? renderBindGroups[1] : renderBindGroups[0],
247+
root.unwrap(even ? renderBindGroups[1] : renderBindGroups[0])
265248
);
266249
passEncoder.draw(3, triangleAmount);
267250
passEncoder.end();

apps/example/src/ComputeBoids/Shaders.ts

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
const triangleAmount = 1000;
21
const triangleSize = 0.03;
32

43
export const renderCode = /* wgsl */ `
@@ -24,7 +23,7 @@ export const renderCode = /* wgsl */ `
2423
@location(1) color : vec4f,
2524
};
2625
27-
@binding(0) @group(0) var<uniform> trianglePos : array<TriangleData, ${triangleAmount}>;
26+
@binding(0) @group(0) var<storage> trianglePos : array<TriangleData>;
2827
@binding(1) @group(0) var<uniform> colorPalette : vec3f;
2928
3029
@vertex
@@ -67,9 +66,9 @@ export const computeCode = /* wgsl */ `
6766
cohesion_strength : f32,
6867
};
6968
70-
@binding(0) @group(0) var<uniform> currentTrianglePos : array<TriangleData, ${triangleAmount}>;
69+
@binding(0) @group(0) var<storage> currentTrianglePos : array<TriangleData>;
7170
@binding(1) @group(0) var<storage, read_write> nextTrianglePos : array<TriangleData>;
72-
@binding(2) @group(0) var<storage> params : Parameters;
71+
@binding(2) @group(0) var<uniform> params : Parameters;
7372
7473
@compute @workgroup_size(1)
7574
fn mainCompute(@builtin(global_invocation_id) gid: vec3u) {
@@ -80,7 +79,7 @@ export const computeCode = /* wgsl */ `
8079
var alignmentCount = 0u;
8180
var cohesion = vec2(0.0, 0.0);
8281
var cohesionCount = 0u;
83-
for (var i = 0u; i < ${triangleAmount}; i = i + 1) {
82+
for (var i = 0u; i < arrayLength(&currentTrianglePos); i = i + 1) {
8483
if (i == index) {
8584
continue;
8685
}

apps/example/src/GradientTiles/GradientTiles.tsx

+42-32
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,49 @@
1-
import { useEffect, useState } from "react";
1+
import { useEffect, useMemo, useState } from "react";
22
import { Button, PixelRatio, StyleSheet, Text, View } from "react-native";
33
import { Canvas, useDevice, useGPUContext } from "react-native-wgpu";
44
import { struct, u32 } from "typegpu/data";
5-
import tgpu from "typegpu";
5+
import tgpu, { type TgpuBindGroup, type TgpuBuffer } from "typegpu";
66

7-
import { vertWGSL, fragWGSL } from "./gradientWgsl";
7+
import { vertWGSL, fragWGSL } from './gradientWgsl';
8+
9+
const Span = struct({
10+
x: u32,
11+
y: u32,
12+
});
13+
14+
const bindGroupLayout = tgpu.bindGroupLayout({
15+
span: { uniform: Span },
16+
});
817

918
interface RenderingState {
1019
pipeline: GPURenderPipeline;
11-
// eslint-disable-next-line @typescript-eslint/no-explicit-any
12-
spanBuffer: any;
13-
bindGroup: GPUBindGroup;
20+
spanBuffer: TgpuBuffer<typeof Span>;
21+
bindGroup: TgpuBindGroup<(typeof bindGroupLayout)['entries']>;
22+
}
23+
24+
function useRoot() {
25+
const { device } = useDevice();
26+
27+
return useMemo(
28+
() => (device ? tgpu.initFromDevice({ device }) : null),
29+
[device]
30+
);
1431
}
1532

1633
export function GradientTiles() {
1734
const presentationFormat = navigator.gpu.getPreferredCanvasFormat();
1835
const [state, setState] = useState<null | RenderingState>(null);
1936
const [spanX, setSpanX] = useState(4);
2037
const [spanY, setSpanY] = useState(4);
21-
const { device } = useDevice();
38+
const root = useRoot();
39+
const { device = null } = root ?? {};
2240
const { ref, context } = useGPUContext();
41+
2342
useEffect(() => {
24-
if (!device || !context || state !== null) {
43+
if (!device || !root || !context || state !== null) {
2544
return;
2645
}
46+
2747
const canvas = context.canvas as HTMLCanvasElement;
2848
canvas.width = canvas.clientWidth * PixelRatio.get();
2949
canvas.height = canvas.clientHeight * PixelRatio.get();
@@ -32,18 +52,14 @@ export function GradientTiles() {
3252
format: presentationFormat,
3353
});
3454

35-
const Span = struct({
36-
x: u32,
37-
y: u32,
38-
});
39-
40-
const spanBuffer = tgpu
55+
const spanBuffer = root
4156
.createBuffer(Span, { x: 10, y: 10 })
42-
.$device(device)
43-
.$usage(tgpu.Uniform);
57+
.$usage("uniform");
4458

4559
const pipeline = device.createRenderPipeline({
46-
layout: "auto",
60+
layout: device.createPipelineLayout({
61+
bindGroupLayouts: [root.unwrap(bindGroupLayout)],
62+
}),
4763
vertex: {
4864
module: device.createShaderModule({
4965
code: vertWGSL,
@@ -64,24 +80,18 @@ export function GradientTiles() {
6480
},
6581
});
6682

67-
const bindGroup = device.createBindGroup({
68-
layout: pipeline.getBindGroupLayout(0),
69-
entries: [
70-
{
71-
binding: 0,
72-
resource: {
73-
buffer: spanBuffer.buffer,
74-
},
75-
},
76-
],
83+
const bindGroup = bindGroupLayout.populate({
84+
span: spanBuffer,
7785
});
86+
7887
setState({ bindGroup, pipeline, spanBuffer });
79-
}, [context, device, presentationFormat, state]);
88+
}, [context, device, root, presentationFormat, state]);
8089

8190
useEffect(() => {
82-
if (!context || !device || !state) {
91+
if (!context || !device || !root || !state) {
8392
return;
8493
}
94+
8595
const { bindGroup, pipeline, spanBuffer } = state;
8696
const textureView = context.getCurrentTexture().createView();
8797
const renderPassDescriptor: GPURenderPassDescriptor = {
@@ -95,18 +105,18 @@ export function GradientTiles() {
95105
],
96106
};
97107

98-
tgpu.write(spanBuffer, { x: spanX, y: spanY });
108+
spanBuffer.write({ x: spanX, y: spanY });
99109

100110
const commandEncoder = device.createCommandEncoder();
101111
const passEncoder = commandEncoder.beginRenderPass(renderPassDescriptor);
102112
passEncoder.setPipeline(pipeline);
103-
passEncoder.setBindGroup(0, bindGroup);
113+
passEncoder.setBindGroup(0, root.unwrap(bindGroup));
104114
passEncoder.draw(4);
105115
passEncoder.end();
106116

107117
device.queue.submit([commandEncoder.finish()]);
108118
context.present();
109-
}, [context, device, spanX, spanY, state]);
119+
}, [context, device, root, spanX, spanY, state]);
110120

111121
return (
112122
<View style={style.container}>

0 commit comments

Comments
 (0)