diff --git a/apps/example/ios/Podfile.lock b/apps/example/ios/Podfile.lock index 34836dfb3..2ff657480 100644 --- a/apps/example/ios/Podfile.lock +++ b/apps/example/ios/Podfile.lock @@ -935,9 +935,32 @@ PODS: - React-Mapbuffer (0.74.2): - glog - React-debug - - react-native-safe-area-context (4.11.0): + - react-native-safe-area-context (4.14.1): - React-Core - - react-native-wgpu (0.1.21): + - react-native-skia (1.7.2): + - DoubleConversion + - glog + - hermes-engine + - RCT-Folly (= 2024.01.01.00) + - RCTRequired + - RCTTypeSafety + - React + - React-callinvoker + - React-Codegen + - React-Core + - React-debug + - React-Fabric + - React-featureflags + - React-graphics + - React-ImageManager + - React-NativeModulesApple + - React-RCTFabric + - React-rendererdebug + - React-utils + - ReactCommon/turbomodule/bridging + - ReactCommon/turbomodule/core + - Yoga + - react-native-wgpu (0.1.22): - DoubleConversion - glog - hermes-engine @@ -1187,7 +1210,7 @@ PODS: - React-logger (= 0.74.2) - React-perflogger (= 0.74.2) - React-utils (= 0.74.2) - - ReactNativeHost (0.5.0): + - ReactNativeHost (0.5.2): - DoubleConversion - glog - hermes-engine @@ -1209,11 +1232,32 @@ PODS: - ReactCommon/turbomodule/bridging - ReactCommon/turbomodule/core - Yoga - - ReactTestApp-DevSupport (3.10.8): + - ReactTestApp-DevSupport (3.10.22): - React-Core - React-jsi - ReactTestApp-Resources (1.0.0-dev) - - RNGestureHandler (2.19.0): + - RNGestureHandler (2.21.2): + - DoubleConversion + - glog + - hermes-engine + - RCT-Folly (= 2024.01.01.00) + - RCTRequired + - RCTTypeSafety + - React-Codegen + - React-Core + - React-debug + - React-Fabric + - React-featureflags + - React-graphics + - React-ImageManager + - React-NativeModulesApple + - React-RCTFabric + - React-rendererdebug + - React-utils + - ReactCommon/turbomodule/bridging + - ReactCommon/turbomodule/core + - Yoga + - RNReanimated (3.16.5): - DoubleConversion - glog - hermes-engine @@ -1233,8 +1277,10 @@ PODS: - React-utils - ReactCommon/turbomodule/bridging - ReactCommon/turbomodule/core + - RNReanimated/reanimated (= 3.16.5) + - RNReanimated/worklets (= 3.16.5) - Yoga - - RNReanimated (3.15.2): + - RNReanimated/reanimated (3.16.5): - DoubleConversion - glog - hermes-engine @@ -1254,10 +1300,9 @@ PODS: - React-utils - ReactCommon/turbomodule/bridging - ReactCommon/turbomodule/core - - RNReanimated/reanimated (= 3.15.2) - - RNReanimated/worklets (= 3.15.2) + - RNReanimated/reanimated/apple (= 3.16.5) - Yoga - - RNReanimated/reanimated (3.15.2): + - RNReanimated/reanimated/apple (3.16.5): - DoubleConversion - glog - hermes-engine @@ -1278,7 +1323,7 @@ PODS: - ReactCommon/turbomodule/bridging - ReactCommon/turbomodule/core - Yoga - - RNReanimated/worklets (3.15.2): + - RNReanimated/worklets (3.16.5): - DoubleConversion - glog - hermes-engine @@ -1336,6 +1381,7 @@ DEPENDENCIES: - React-logger (from `../../../node_modules/react-native/ReactCommon/logger`) - React-Mapbuffer (from `../../../node_modules/react-native/ReactCommon`) - react-native-safe-area-context (from `../../../node_modules/react-native-safe-area-context`) + - "react-native-skia (from `../../../node_modules/@shopify/react-native-skia`)" - react-native-wgpu (from `../../../node_modules/react-native-wgpu`) - React-nativeconfig (from `../../../node_modules/react-native/ReactCommon`) - React-NativeModulesApple (from `../../../node_modules/react-native/ReactCommon/react/nativemodule/core/platform/ios`) @@ -1435,6 +1481,8 @@ EXTERNAL SOURCES: :path: "../../../node_modules/react-native/ReactCommon" react-native-safe-area-context: :path: "../../../node_modules/react-native-safe-area-context" + react-native-skia: + :path: "../../../node_modules/@shopify/react-native-skia" react-native-wgpu: :path: "../../../node_modules/react-native-wgpu" React-nativeconfig: @@ -1527,8 +1575,9 @@ SPEC CHECKSUMS: React-jsitracing: 0fa7f78d8fdda794667cb2e6f19c874c1cf31d7e React-logger: 29fa3e048f5f67fe396bc08af7606426d9bd7b5d React-Mapbuffer: bf56147c9775491e53122a94c423ac201417e326 - react-native-safe-area-context: 851c62c48dce80ccaa5637b6aa5991a1bc36eca9 - react-native-wgpu: 5308faeb6d85925394351968f7970cd00eead0cc + react-native-safe-area-context: 141eca0fd4e4191288dfc8b96a7c7e1c2983447a + react-native-skia: c85483f709f2b58d30a11fc005c2938ee87d6656 + react-native-wgpu: fc73fc100b757c6c89e489ffbcc927214bd270f8 React-nativeconfig: 9f223cd321823afdecf59ed00861ab2d69ee0fc1 React-NativeModulesApple: ff7efaff7098639db5631236cfd91d60abff04c0 React-perflogger: 32ed45d9cee02cf6639acae34251590dccd30994 @@ -1552,11 +1601,11 @@ SPEC CHECKSUMS: React-runtimescheduler: 56b642bf605ba5afa500d35790928fc1d51565ad React-utils: 4476b7fcbbd95cfd002f3e778616155241d86e31 ReactCommon: ecad995f26e0d1e24061f60f4e5d74782f003f12 - ReactNativeHost: 76fb17eac13a9a2200f659deffc91c054731a7e2 - ReactTestApp-DevSupport: 690d06567b7ecae4f2f98dff5e4881c8d25be8e2 + ReactNativeHost: 619621c39cdb4339c1336cea844b66cdf43c4d84 + ReactTestApp-DevSupport: 42abce6b0c88dfb47c86e80aa22831b2abcc3144 ReactTestApp-Resources: 857244f3a23f2b3157b364fa06cf3e8866deff9c - RNGestureHandler: 67e78f16895947f7e57ab91e75e914d3e9ef7239 - RNReanimated: 4c72fc2c0f4c6a9c36932e653cd68e4521b6c4ac + RNGestureHandler: 6fee3422fd8c81c5ee756fa72e3d1780e9943d9d + RNReanimated: 77bde2fb01415b61799ed173f9420010632b76e1 SocketRocket: abac6f5de4d4d62d24e11868d7a2f427e0ef940d Yoga: ae3c32c514802d30f687a04a6a35b348506d411f diff --git a/apps/example/package.json b/apps/example/package.json index 9ee8c19bf..ff8f5e47d 100644 --- a/apps/example/package.json +++ b/apps/example/package.json @@ -20,6 +20,10 @@ "@react-navigation/native": "^6.1.17", "@react-navigation/stack": "^6.4.0", "@react-three/fiber": "^8.17.6", + "@shopify/react-native-skia": "^1.7.3", + "@tensorflow/tfjs": "^4.22.0", + "@tensorflow/tfjs-backend-webgpu": "^4.22.0", + "@tensorflow/tfjs-vis": "^1.5.1", "fast-text-encoding": "^1.0.6", "react": "18.2.0", "react-native": "0.74.2", diff --git a/apps/example/src/App.tsx b/apps/example/src/App.tsx index 755ef5f25..9e5a6facc 100644 --- a/apps/example/src/App.tsx +++ b/apps/example/src/App.tsx @@ -18,6 +18,7 @@ import { RenderBundles } from "./RenderBundles"; import { ABuffer } from "./ABuffer"; import { OcclusionQuery } from "./OcclusionQuery"; import { ComputeBoids } from "./ComputeBoids"; +import { MNISTInference } from "./MNISTInference"; import { Wireframe } from "./Wireframe"; import { Resize } from "./Resize"; import { Particules } from "./Particles"; @@ -27,6 +28,7 @@ import { ReversedZ } from "./ReversedZ"; import { ThreeJS } from "./ThreeJS"; import { GradientTiles } from "./GradientTiles"; import { CanvasAPI } from "./CanvasAPI"; +import { Tensorflow } from "./Tensorflow"; // The two lines below are needed by three.js import "fast-text-encoding"; @@ -50,6 +52,7 @@ function App() { component={HelloTriangleMSAA} /> + @@ -64,6 +67,7 @@ function App() { + diff --git a/apps/example/src/ComputeBoids/ComputeBoids.tsx b/apps/example/src/ComputeBoids/ComputeBoids.tsx index 75a1b6d95..133ee1bdb 100644 --- a/apps/example/src/ComputeBoids/ComputeBoids.tsx +++ b/apps/example/src/ComputeBoids/ComputeBoids.tsx @@ -41,7 +41,7 @@ const renderBindGroupLayout = tgpu.bindGroupLayout({ const computeBindGroupLayout = tgpu.bindGroupLayout({ currentTrianglePos: { storage: TriangleDataArray }, - nextTrianglePos: { storage: TriangleDataArray, access: 'mutable' }, + nextTrianglePos: { storage: TriangleDataArray, access: "mutable" }, params: { uniform: Parameters }, }); @@ -122,7 +122,7 @@ export function ComputeBoids() { const triangleAmount = 1000; const trianglePosBuffers = Array.from({ length: 2 }, () => - root.createBuffer(TriangleDataArray(triangleAmount)).$usage("storage") + root.createBuffer(TriangleDataArray(triangleAmount)).$usage("storage"), ); randomizePositions.current = () => { @@ -234,7 +234,7 @@ export function ComputeBoids() { computePass.setPipeline(computePipeline); computePass.setBindGroup( 0, - root.unwrap(even ? computeBindGroups[0] : computeBindGroups[1]) + root.unwrap(even ? computeBindGroups[0] : computeBindGroups[1]), ); computePass.dispatchWorkgroups(triangleAmount); computePass.end(); @@ -244,7 +244,7 @@ export function ComputeBoids() { passEncoder.setVertexBuffer(0, triangleVertexBuffer.buffer); passEncoder.setBindGroup( 0, - root.unwrap(even ? renderBindGroups[1] : renderBindGroups[0]) + root.unwrap(even ? renderBindGroups[1] : renderBindGroups[0]), ); passEncoder.draw(3, triangleAmount); passEncoder.end(); diff --git a/apps/example/src/GradientTiles/GradientTiles.tsx b/apps/example/src/GradientTiles/GradientTiles.tsx index 701cd3159..b6544de91 100644 --- a/apps/example/src/GradientTiles/GradientTiles.tsx +++ b/apps/example/src/GradientTiles/GradientTiles.tsx @@ -4,7 +4,7 @@ import { Canvas, useDevice, useGPUContext } from "react-native-wgpu"; import { struct, u32 } from "typegpu/data"; import tgpu, { type TgpuBindGroup, type TgpuBuffer } from "typegpu"; -import { vertWGSL, fragWGSL } from './gradientWgsl'; +import { vertWGSL, fragWGSL } from "./gradientWgsl"; const Span = struct({ x: u32, @@ -18,7 +18,7 @@ const bindGroupLayout = tgpu.bindGroupLayout({ interface RenderingState { pipeline: GPURenderPipeline; spanBuffer: TgpuBuffer; - bindGroup: TgpuBindGroup<(typeof bindGroupLayout)['entries']>; + bindGroup: TgpuBindGroup<(typeof bindGroupLayout)["entries"]>; } function useRoot() { @@ -26,7 +26,7 @@ function useRoot() { return useMemo( () => (device ? tgpu.initFromDevice({ device }) : null), - [device] + [device], ); } diff --git a/apps/example/src/Home.tsx b/apps/example/src/Home.tsx index 99884ad07..f0901591b 100644 --- a/apps/example/src/Home.tsx +++ b/apps/example/src/Home.tsx @@ -23,6 +23,10 @@ export const examples = [ screen: "ThreeJS", title: "☘️ Three.js", }, + { + screen: "Tensorflow", + title: "🤖 tensorflow.js", + }, { screen: "Cube", title: "🧊 Cube", @@ -67,6 +71,10 @@ export const examples = [ screen: "ComputeBoids", title: "🐦‍⬛ Compute Boids", }, + { + screen: "MNISTInference", + title: "1️⃣ MNIST Inference", + }, ...(Platform.OS !== "ios" ? ([ { diff --git a/apps/example/src/MNISTInference/Lib.ts b/apps/example/src/MNISTInference/Lib.ts new file mode 100644 index 000000000..a34fa817e --- /dev/null +++ b/apps/example/src/MNISTInference/Lib.ts @@ -0,0 +1,277 @@ +import tgpu, { type TgpuBuffer, type Storage } from "typegpu"; +import { type F32, type TgpuArray, arrayOf, f32 } from "typegpu/data"; + +export const SIZE = 28; + +// Definitions for the network + +interface LayerData { + shape: readonly [number] | readonly [number, number]; + buffer: TgpuBuffer> & Storage; +} + +interface Layer { + weights: TgpuBuffer> & Storage; + biases: TgpuBuffer> & Storage; + state: TgpuBuffer> & Storage; +} + +export interface Network { + layers: Layer[]; + input: TgpuBuffer> & Storage; + output: TgpuBuffer> & Storage; + + inference(data: number[]): Promise; +} + +export const centerData = (data: Uint8Array) => { + "worklet"; + const mass = data.reduce((acc, value) => acc + value, 0); + const x = data.reduce((acc, value, i) => acc + value * (i % SIZE), 0) / mass; + const y = + data.reduce((acc, value, i) => acc + value * Math.floor(i / SIZE), 0) / + mass; + + const offsetX = Math.round(SIZE / 2 - x); + const offsetY = Math.round(SIZE / 2 - y); + + const newData = new Array(SIZE * SIZE).fill(0); + for (let i = 0; i < SIZE; i++) { + for (let j = 0; j < SIZE; j++) { + const index = i * SIZE + j; + const newIndex = (i + offsetY) * SIZE + j + offsetX; + if (newIndex >= 0 && newIndex < SIZE * SIZE) { + newData[newIndex] = data[index]; + } + } + } + + return newData; +}; + +export const createDemo = async (device: GPUDevice) => { + const root = tgpu.initFromDevice({ device }); + + // Shader code + + const layerShader = /* wgsl */ ` + @binding(0) @group(0) var input: array; + @binding(1) @group(0) var output: array; + + @binding(0) @group(1) var weights: array; + @binding(1) @group(1) var biases: array; + + fn relu(x: f32) -> f32 { + return max(0.0, x); + } + + @compute @workgroup_size(1) + fn main(@builtin(global_invocation_id) gid: vec3u) { + let inputSize = arrayLength( &input ); + + let i = gid.x; + + let weightsOffset = i * inputSize; + var sum = 0.0; + + for (var j = 0u; j < inputSize; j = j + 1) { + sum = sum + input[j] * weights[weightsOffset + j]; + } + + sum = sum + biases[i]; + output[i] = relu(sum); + } +`; + + const ReadonlyFloats = { + storage: (n: number) => arrayOf(f32, n), + access: "readonly", + } as const; + + const MutableFloats = { + storage: (n: number) => arrayOf(f32, n), + access: "mutable", + } as const; + + const ioLayout = tgpu.bindGroupLayout({ + input: ReadonlyFloats, + output: MutableFloats, + }); + + const weightsBiasesLayout = tgpu.bindGroupLayout({ + weights: ReadonlyFloats, + biases: ReadonlyFloats, + }); + + const pipeline = device.createComputePipeline({ + layout: device.createPipelineLayout({ + bindGroupLayouts: [ + root.unwrap(ioLayout), + root.unwrap(weightsBiasesLayout), + ], + }), + compute: { + module: device.createShaderModule({ + code: layerShader, + }), + }, + }); + + /** + * Creates a network from a list of pairs of weights and biases + * + * It automates the creation of state buffers that are used to store the intermediate results of the network + * as well as the input layer buffer + * + * It provides an inference function that takes an array of input data and returns an array of output data + */ + function createNetwork(layers: [LayerData, LayerData][]): Network { + const buffers = layers.map(([weights, biases]) => { + if (weights.shape[1] !== biases.shape[0]) { + throw new Error(`Shape mismatch: ${weights.shape} and ${biases.shape}`); + } + + return { + weights: weights.buffer, + biases: biases.buffer, + state: root + .createBuffer(arrayOf(f32, biases.shape[0])) + .$usage("storage"), + }; + }); + + const input = root + .createBuffer(arrayOf(f32, layers[0][0].shape[0])) + .$usage("storage"); + const output = buffers[buffers.length - 1].state; + + const ioBindGroups = buffers.map((_, i) => + ioLayout.populate({ + input: i === 0 ? input : buffers[i - 1].state, + output: buffers[i].state, + }), + ); + + const weightsBindGroups = buffers.map((layer) => + weightsBiasesLayout.populate({ + weights: layer.weights, + biases: layer.biases, + }), + ); + + async function inference(data: number[]): Promise { + // verify the length of the data matches the input layer + if (data.length !== layers[0][0].shape[0]) { + throw new Error( + `Data length ${data.length} does not match input shape ${layers[0][0].shape[0]}`, + ); + } + input.write(data); + + // Run the network + const encoder = device.createCommandEncoder(); + for (let i = 0; i < buffers.length; i++) { + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, root.unwrap(ioBindGroups[i])); + pass.setBindGroup(1, root.unwrap(weightsBindGroups[i])); + pass.dispatchWorkgroups(buffers[i].biases.dataType.elementCount); //.length + pass.end(); + } + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + + // Read the output + return await output.read(); + } + + return { + layers: buffers, + input, + output, + inference, + }; + } + + const network = createNetwork(await downloadLayers()); + + // #region Downloading weights & biases + + /** + * Create a LayerData object from a layer ArrayBuffer + * + * The function extracts the header, shape and data from the layer + * If there are any issues with the layer, an error is thrown + * + * Automatically creates appropriate buffer initialized with the data + */ + function getLayerData(layer: ArrayBuffer): LayerData { + const headerLen = new Uint16Array(layer.slice(8, 10)); + + const header = new TextDecoder().decode( + new Uint8Array(layer.slice(10, 10 + headerLen[0])), + ); + + // shape can be found in the header in the format: 'shape': (x, y) or 'shape': (x,) for bias + const shapeMatch = header.match(/'shape': \((\d+), ?(\d+)?\)/); + if (!shapeMatch) { + throw new Error("Shape not found in header"); + } + + // To accommodate .npy weirdness - if we have a 2d shape we need to switch the order + const X = Number.parseInt(shapeMatch[1], 10); + const Y = Number.parseInt(shapeMatch[2], 10); + const shape = Number.isNaN(Y) ? ([X] as const) : ([Y, X] as const); + + const data = new Float32Array(layer.slice(10 + headerLen[0])); + + // Verify the length of the data matches the shape + if (data.length !== shape[0] * (shape[1] || 1)) { + throw new Error( + `Data length ${data.length} does not match shape ${shape}`, + ); + } + + const buffer = root + .createBuffer(arrayOf(f32, data.length), [...data]) + .$usage("storage"); + + return { + shape, + buffer, + }; + } + + function downloadLayers(): Promise<[LayerData, LayerData][]> { + const downloadLayer = async (fileName: string): Promise => { + const buffer = await fetch( + `https://docs.swmansion.com/TypeGPU/assets/mnist-weights/${fileName}`, + ).then((res) => res.arrayBuffer()); + + return getLayerData(buffer); + }; + + return Promise.all( + [0, 1, 2, 3, 4, 5, 6, 7].map((layer) => + Promise.all([ + downloadLayer(`layer${layer}.weight.npy`), + downloadLayer(`layer${layer}.bias.npy`), + ]), + ), + ); + } + + // #endregion + + // #region User Interface + + // #endregion + + // #region Resource cleanup + + function onCleanup() { + root.destroy(); + } + return { network, onCleanup }; + // #endregion +}; diff --git a/apps/example/src/MNISTInference/MNISTInference.tsx b/apps/example/src/MNISTInference/MNISTInference.tsx new file mode 100644 index 000000000..a450b7a84 --- /dev/null +++ b/apps/example/src/MNISTInference/MNISTInference.tsx @@ -0,0 +1,145 @@ +import React, { useCallback, useEffect, useRef } from "react"; +import { Button, Dimensions, Platform, StyleSheet, View } from "react-native"; +import { + Canvas, + Fill, + Skia, + PaintStyle, + Path, + ColorType, + AlphaType, + matchFont, + Text, + notifyChange, + Group, +} from "@shopify/react-native-skia"; +import { Gesture, GestureDetector } from "react-native-gesture-handler"; +import type { SharedValue } from "react-native-reanimated"; +import { runOnJS, useSharedValue } from "react-native-reanimated"; +import { useDevice } from "react-native-wgpu"; + +import type { Network } from "./Lib"; +import { createDemo, centerData, SIZE } from "./Lib"; + +const { width } = Dimensions.get("window"); + +const fontFamily = Platform.select({ ios: "Helvetica", default: "serif" }); +const fontStyle = { + fontFamily, + fontSize: 200, +}; +const font = matchFont(fontStyle); + +const paint = Skia.Paint(); +paint.setColor(Skia.Color("black")); +paint.setStyle(PaintStyle.Stroke); +paint.setStrokeWidth(0.5); + +const grid = Skia.Path.Make(); +const cellSize = width / SIZE; + +grid.moveTo(0, 0); + +// Draw vertical lines +for (let i = 0; i <= SIZE; i++) { + grid.moveTo(i * cellSize, 0); + grid.lineTo(i * cellSize, width); +} + +// Draw horizontal lines +for (let i = 0; i <= SIZE; i++) { + grid.moveTo(0, i * cellSize); + grid.lineTo(width, i * cellSize); +} + +const f = 1 / cellSize; +const surface = Skia.Surface.MakeOffscreen(SIZE, SIZE)!; +const canvas = surface.getCanvas(); + +export function MNISTInference() { + const { device } = useDevice(); + const network = useRef(); + const text = useSharedValue(""); + const path = useSharedValue(Skia.Path.Make()); + const runInference = useCallback( + async (data: number[]) => { + if (network.current === undefined) { + return; + } + const certainties = await network.current.inference(data); + const max = Math.max(...certainties); + const index = certainties.indexOf(max); + text.value = `${index}`; + }, + [text], + ); + const gesture = Gesture.Pan() + .onStart((e) => { + path.value.moveTo(e.x * f, e.y * f); + }) + .onChange((e) => { + path.value.lineTo(e.x * f, e.y * f); + canvas.drawPath(path.value, paint); + const pixels = canvas.readPixels(0, 0, { + width: SIZE, + height: SIZE, + alphaType: AlphaType.Opaque, + colorType: ColorType.Alpha_8, + })!; + notifyChange(path as SharedValue); + runOnJS(runInference)( + centerData(pixels as Uint8Array).map((x) => (x / 255) * 3.24 - 0.42), + ); + }); + useEffect(() => { + (async () => { + if (device) { + const demo = await createDemo(device); + network.current = demo.network; + } + })(); + }, [device, network]); + return ( + +