forked from zhen8838/handson-polyhedral
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmlir_utility.py
63 lines (54 loc) · 2.07 KB
/
mlir_utility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from typing import Callable
from mlir.ir import Module, AffineMap, Context, Operation, Block, Region, Value, OpView
class IrVisitor(object):
def __init__(self, beforeBlock: Callable[[Block], bool] = None, beforeOperation: Callable[[OpView], bool] = None, afterBlock: Callable[[Block], bool] = None, afterOperation: Callable[[OpView], bool] = None) -> None:
self.beforeBlockCallable = beforeBlock
self.beforeOperationCallable = beforeOperation
self.afterBlockCallable = afterBlock
self.afterOperationCallable = afterOperation
def visit(self, any):
if isinstance(any, Block):
return self.visitBlock(any)
elif isinstance(any, Operation):
return self.visitOperation(any.opview)
elif isinstance(any, OpView):
return self.visitOperation(any)
elif isinstance(any, Module):
return self.visitOperation(any.operation.opview)
else:
raise NotImplementedError()
def visitBlock(self, block: Block):
if not self.runBeforeBlock(block):
return False
for op in block.operations:
if not self.visitOperation(op):
return False
if not self.runAfterBlock(block):
return False
return True
def visitOperation(self, op: OpView):
if not self.runBeforeOperation(op):
return False
for region in op.regions:
for block in region.blocks:
if not self.visitBlock(block):
return False
if not self.runAfterOperation(op):
return False
return True
def runBeforeBlock(self, block: Block) -> bool:
if self.beforeBlockCallable is not None:
return self.beforeBlockCallable(block)
return True
def runBeforeOperation(self, op: OpView) -> bool:
if self.beforeOperationCallable is not None:
return self.beforeOperationCallable(op)
return True
def runAfterBlock(self, block: Block) -> bool:
if self.afterBlockCallable is not None:
return self.afterBlockCallable(block)
return True
def runAfterOperation(self, op: OpView) -> bool:
if self.afterOperationCallable is not None:
return self.afterOperationCallable(op)
return True