-
Notifications
You must be signed in to change notification settings - Fork 0
/
annotate.py
368 lines (309 loc) · 15.6 KB
/
annotate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
import jsonlines
import os
import argparse
import tempfile
import subprocess
from collections.abc import Mapping
from typing import List, Tuple
import tree_sitter_glsl as tsglsl
from tqdm.auto import tqdm
from tree_sitter import Language, Parser
from licensedcode.detection import detect_licenses
from wgpu_shadertoy.api import shader_args_from_json, _download_media_channels
from wgpu_shadertoy import BufferRenderPass, Shadertoy
from download import read_ids
GLSL_LANGUAGE = Language(tsglsl.language())
PARSER = Parser(GLSL_LANGUAGE)
argument_parser = argparse.ArgumentParser()
argument_parser.add_argument("--input", type=str, required=False, default="./data/raw/", help="the path of raw shadertoy api returns .jsonl")
argument_parser.add_argument("--output", type=str, required=False, default="./data/annotated/", help="the path of where to store annotated shaders as .jsonl")
argument_parser.add_argument("--mode", type=str, default="update", help="mode `update` will load shaders already in the output folder and overwrite specified columns; mode `redo` will overwrite the whole file")
argument_parser.add_argument("--columns", type=str, required=True, help="comma separated list of columns to annotate: all, license, functions, test; if empty will simply faltten the nested structure")
argument_parser.add_argument("--ids", type=str, required=False, default="", help="command seperated list or path to a .txt file of ids to update. Will do all in the output dir if left empty")
# TODO: is --mode "update" --columns "all" is the same as --mode "redo"?
def annotate_shader(shader_data: dict, columns: list, access: str = "api") -> dict:
"""
Functions calls a bunch of smaller functions to annotate and flatten a instance of a shader_data json respose
Returns a flattened dict that is a dataset insanace
"""
if "Shader" in shader_data:
shader_data = shader_data["Shader"]
out_dict = flatten_shader_data(shader_data)
out_dict["thumbnail"] = (
f"https://www.shadertoy.com/media/shaders/{shader_data['info']['id']}.jpg"
)
out_dict["access"] = access # api, shaders20k, ?
# overwrite to update?
out_dict = update_shader(out_dict, columns=columns)
return out_dict
def update_shader(flattened_shader: dict, columns: list) -> dict:
updated_shader = flattened_shader.copy() # do we need that?
cols_to_update = columns.copy() #seems redundant
if "all" in columns:
cols_to_update = list(COLUMN_MAP.keys())
for col in cols_to_update:
col_func = COLUMN_MAP[col]
updated_shader.update({col: col_func(flattened_shader)})
# TODO: set None for cols not mentioned?
return updated_shader
def flatten_shader_data(shader_data: dict) -> dict:
"""
Falttens all renderpasses into a single depth dict.
"""
if "Shader" in shader_data:
shader_data = shader_data["Shader"]
out_dict = {}
# we lift some information out of the "info" dict that are useful
out_dict["id"] = shader_data["info"]["id"]
out_dict["name"] = shader_data["info"]["name"]
out_dict["author"] = shader_data["info"]["username"]
out_dict["description"] = shader_data["info"]["description"]
out_dict["tags"] = shader_data["info"]["tags"]
out_dict["likes"] = shader_data["info"]["likes"]
out_dict["viewed"] = shader_data["info"]["viewed"]
out_dict["published"] = shader_data["info"]["published"] # download uses {0: "Private", 1: "Public", 2: "Unlisted", 3: "Public API", 4: "Anonymous"}
out_dict["date"] = shader_data["info"]["date"] # maybe format into a readable format or at least int?
# this one is added by us wiht the download.py script
out_dict["time_retrieved"] = shader_data["time_retrieved"]
pass_names = [
"Image",
"Common",
"Sound",
"Buffer A",
"Buffer B",
"Buffer C",
"Buffer D",
"Cube A",
]
for rp in shader_data["renderpass"]:
# if there is just one pass, it has to be the image Pass.
if len(shader_data["renderpass"]) == 1:
rp["name"] = "Image"
# remove the pass name from the list
try:
pass_names.remove(rp["name"])
except ValueError:
# TODO: find a solution for some of these unknown names: ('', 'Buffer @', 'none', 'Buf C', 'Text Lib', 'Buf A', 'Buf B', 'Buf D')
print(f"Pass name not standard: {rp['name']=}, skipping...")
continue
out_dict[f"{rp['name'].replace(' ', '_').lower()}_code"] = rp.get("code", "")
out_dict[f"{rp['name'].replace(' ', '_').lower()}_inputs"] = rp.get(
"inputs", []
)
for name in pass_names:
out_dict[f"{name.replace(' ', '_').lower()}_code"] = ""
out_dict[f"{name.replace(' ', '_').lower()}_inputs"] = []
del out_dict["common_inputs"] # this never exists
return out_dict
def check_license(code_or_shader) -> str:
"""
Returns the license mentioned if the first node is a comment.
if none is found, or no comment, returns "CC-BY-NC-SA-3.0" as the base case.
"""
if isinstance(code_or_shader, dict):
code = code_or_shader["image_code"]
elif isinstance(code_or_shader, str):
code = code_or_shader
else:
raise TypeError(f" function doesn't support {type(code_or_shader)}")
tree = PARSER.parse(bytes(code, encoding="utf-8"))
comment_bytes = b""
cursor = tree.walk()
cursor.goto_first_child()
# is a while node really a good idea?
while cursor.node.type == "comment":
comment_bytes += cursor.node.text
cursor.goto_next_sibling()
if comment_bytes:
detections = [x.matches[0] for x in detect_licenses(query_string=comment_bytes.decode(encoding="utf-8"))]
if len(detections) >= 1:
return detections[0].to_dict().get("license_expression", None)
# base case is capitalized for downstream analysis
return "CC-BY-NC-SA-3.0"
def parse_functions(code_or_shader) -> List[Tuple[int,int,int,int,int]]:
"""
parses the code using tree-parser-glsl
returns the **byte-indecies** for before_comment, start header, end header, end docstring, end_function.
returns a list 5-tupel. If before_comment or docstring aren't found, the indiecies will coinside with the next one.
"""
# TODO: dry and maybe have it code_or_tree?
if isinstance(code_or_shader, dict):
code = code_or_shader["image_code"]
elif isinstance(code_or_shader, str):
code = code_or_shader
else:
raise TypeError(f" function doesn't support {type(code_or_shader)}")
tree = PARSER.parse(bytes(code, encoding="utf-8"))
root_node = tree.root_node
funcs = []
# lazy init
start_comment = start_header = end_header = end_docstring = end_function = None
comment_line = -2 #init with a different number?
for child in root_node.children:
if (child.type == "comment"):
if ((comment_line + 1) != child.start_point[0]): # and child.start_point[1] == 0 # and not child.start_point[1] # so we only get whole line comments, nothing inline. but tabs or indentation might be an issue?
start_comment = child.start_byte
comment_line = child.end_point[0]
elif child.type == "function_definition" and not child.has_error: #TODO: is this .has_error check causing false negatives?
start_header = child.start_byte
if ((comment_line + 1) != child.start_point[0]): # so we can also get multi line comments at the start (but inline comments?)
start_comment = start_header
end_function = child.end_byte
end_header = child.children[-1].children[0].end_byte
# inside the function body, past the "{"
for sub_child in child.children[-1].children[1:]:
if sub_child.type == "comment":
end_docstring = sub_child.end_byte
else:
if not end_docstring:
end_docstring = end_header
break #which part does this break out of? early stopping somehow...
funcs.append(tuple([start_comment, start_header, end_header, end_docstring, end_function])) #jsonlines turns this into a list again?
start_comment = start_header = end_header = end_docstring = end_function = None
comment_line = -2 # so out empty check can work again
return funcs
def run_shader(shader_or_code, timeouts=10):
"""
Tests a shader by running it in wgpu-shadertoy. Returns one of the following disjunct classes:
"ok" - shader ran without error
"incomplete" - not yet fully supported in wgpu-shadertoy
"error" - wgpu-shadertoy threw and error (is likely still valid on the website)
"timedout" - if after 5 seconds we don't get to error or okay.
# not implemented: "panic" - worst case scenario. a rust panic in wgpu. This can cause the python process to terminate without recovery.
"""
# return "untested" #placeholder to avoid empty columns for later analysis
if isinstance(shader_or_code, str):
# case 1 we only get the only a string of code
shader_args = {"shader_code": shader_or_code}
elif isinstance(shader_or_code, Mapping):
# case 2 we get a dict, always unpack this "Shader" level
if "Shader" in shader_or_code:
shader_data = shader_or_code["Shader"]
else:
shader_data = shader_or_code
# case 2.a if we get a default "raw" return?
if "renderpass" in shader_data:
shader_args = shader_args_from_json(shader_data)
# case 2.b we get a flattened json
elif "image_code" in shader_data: #really lazy check.
buffers = {}
for buf in "abcd":
if shader_data[f"buffer_{buf}_code"]:
buffers[buf] = BufferRenderPass(buf, code=shader_data[f"buffer_{buf}_code"], inputs=_download_media_channels(shader_data[f"buffer_{buf}_inputs"])[0])
else:
# because we don't handle empty code for Buffers internally.
buffers[buf] = ""
shader_args = {
"shader_code": shader_data["image_code"],
"inputs": _download_media_channels(shader_data["image_inputs"])[0],
"common": shader_data["common_code"],
"buffers": buffers,
}
shader_args["shader_type"] = "glsl"
sub_run = run_shader_in_subprocess(shader_args["shader_code"], timeout=timeouts)
return sub_run # this later part seems redundant right now. should speed things up a bit...
if sub_run == "ok":
try:
shader = Shadertoy(**shader_args, offscreen=True)
if not shader.complete:
return "incomplete"
else:
return "ok"
except Exception as e:
return "error" # other errors have a .message like wgpu ones.
return sub_run
# this is minimal code to try a single pass shader in a subprocess (no inputs)
# dual snapshot is required since first one doesn't crash it seems.
file_template = """
from wgpu_shadertoy import Shadertoy
shader_code = '''{}'''
shader = Shadertoy(shader_code, shader_type="glsl", offscreen=True)
if __name__ == "__main__":
snap1 = shader.snapshot(12.34)
snap2 = shader.snapshot(56.78)
# shader.show()
"""
# same implementation in the metric, check for inconsistencies (and newer commits):
# https://huggingface.co/spaces/Vipitis/shadermatch/blob/c569c78182dc618b36b0883b7d66621481ca2933/shadermatch.py#L302
# the root cause for this is that some shadercode causes rust panics, which crash the python process too... there is no good solution: https://github.com/pygfx/wgpu-py/pull/603
def run_shader_in_subprocess(shader_code, timeout=10):
"""be ver careful with this function, it runs user submitted code, and it can easily be escaped and exploited!"""
status = "ok" # default case
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f:
f.write(file_template.format(shader_code))
f.flush()
try:
p = subprocess.run(["python", f.name], capture_output=True, timeout=timeout) # this might not work as expect on Linux ...
except subprocess.SubprocessError as e:
if isinstance(e, subprocess.TimeoutExpired):
status = "timeout"
else:
status = "error"
# cleanup temp file, delete_on_close was only added in Python 3.12?
os.remove(f.name)
if status == "ok":
if p.returncode != 0:
status = "error"
return status
# gloablly map all columns to the function that calculate them. might need to REGISTER more?
COLUMN_MAP = {"license": check_license, "functions": parse_functions, "test": run_shader}
if __name__ == "__main__":
args = argument_parser.parse_args()
print(f"{args=}")
input_dir = args.input
output_dir = args.output
columns = [col.strip() for col in args.columns.split(",")] #if col in list(COLUMN_MAP.values()) + ["all"]]
print(f"{columns=}")
if args.mode == "redo":
print(f"annotating all .jsonlines files in {input_dir}")
for file in tqdm(os.listdir(input_dir)):
if not file.endswith(".jsonl"):
tqdm.write(f"Skipping file {file}")
continue
source = "api" #default?
if file.startswith("20k"): #should we do api_ prefix for the others?
source = "shaders20k"
tqdm.write(f"Annotating {file}")
with jsonlines.open(os.path.join(input_dir, file), "r") as reader:
shaders = list(reader)
annotated_shaders = []
for shader in tqdm(shaders):
annotated_shaders.append(annotate_shader(shader, columns=columns, access=source))
output_path = os.path.join(output_dir, file)
with jsonlines.open(output_path, mode="w") as writer:
for shader in annotated_shaders:
writer.write(shader)
tqdm.write(f"Annotated {file} to {output_path}")
elif args.mode == "update":
if args.ids == "":
ids = None
print(f"updating all .jsonlines files in {output_dir}")
elif args.ids.endswith(".txt"):
ids = read_ids(args.ids)
else:
isinstance(args.ids, str)
ids = args.ids.split(",")
print(f"updating {len(ids)} shaders in {output_dir}")
for file in tqdm(os.listdir(output_dir)):
if not file.endswith(".jsonl"):
tqdm.write(f"Skipping file {file}")
continue
with jsonlines.open(os.path.join(output_dir, file), "r") as reader:
old_annotations = list(reader)
new_annotations = []
for annotation in tqdm(old_annotations):
# we still run through all of them just to find the one id we want?
if ids is None or annotation["id"] in ids:
# ids.remove(annotation["id"]) # be careful to never reach an empty list?
# TODO: use empty list as an early exit?
new_annotations.append(update_shader(annotation, columns=columns))
else:
new_annotations.append(annotation)
# TODO: DRY - don't repeat yourself?
output_path = os.path.join(output_dir, file)
with jsonlines.open(output_path, mode="w") as writer:
for shader in new_annotations:
writer.write(shader)
tqdm.write(f"Annotated {file} to {output_path}")
else:
print(f"unrecognized mode {args.mode}, please chose either `update` or `redo`")