-
Notifications
You must be signed in to change notification settings - Fork 0
/
oom_vektor_add.py
39 lines (30 loc) · 1.34 KB
/
oom_vektor_add.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from kp import Manager, Tensor, OpTensorSyncDevice, OpTensorSyncLocal, OpAlgoDispatch
from pyshader import python2shader, ivec3, f32, Array
mgr = Manager()
# Can be initialized with List[] or np.Array
tensor_in_a = mgr.tensor([2, 2, 2])
tensor_in_b = mgr.tensor([1, 2, 3])
tensor_out = mgr.tensor([0, 0, 0])
sq = mgr.sequence()
sq.eval(OpTensorSyncDevice([tensor_in_a, tensor_in_b, tensor_out]))
# Define the function via PyShader or directly as glsl string or spirv bytes
@python2shader
def compute_shader_multiply(index=("input", "GlobalInvocationId", ivec3),
data1=("buffer", 0, Array(f32)),
data2=("buffer", 1, Array(f32)),
data3=("buffer", 2, Array(f32))):
i = index.x
data3[i] = data1[i] * data2[i]
with open("addVectors.spv", "rb") as f:
b = f.read()
#
algo = mgr.algorithm([tensor_in_a, tensor_in_b, tensor_out], b)
# algo = mgr.algorithm([tensor_in_a, tensor_in_b, tensor_out], compute_shader_multiply.to_spirv())
# with open("proba.spv", "wb") as f:
# f.write(compute_shader_multiply.to_spirv())
# algo = mgr.algorithm([tensor_in_a, tensor_in_b, tensor_out], b)
# Run shader operation synchronously
sq.eval(OpAlgoDispatch(algo))
sq.eval(OpTensorSyncLocal([tensor_out]))
print(tensor_out.data().tolist())
assert tensor_out.data().tolist() == [2.0, 4.0, 6.0]