forked from ARCH-commons/i2p-transform
-
Notifications
You must be signed in to change notification settings - Fork 5
/
sql_syntax.py
238 lines (177 loc) · 6.79 KB
/
sql_syntax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
'''sql_syntax - break SQL scripts into statements, etc.
'''
from datetime import datetime
from typing import Dict, Iterable, List, Optional, Text, Tuple, Union
import re
Name = Text
SQL = Text
Environment = Dict[Name, Text]
Params = Dict[str, Union[str, int, datetime]]
Line = int
Comment = Text
StatementInContext = Tuple[Line, Comment, SQL]
def iter_statement(txt: SQL) -> Iterable[StatementInContext]:
r'''Iterate over SQL statements in a script.
>>> list(iter_statement("drop table foo; create table foo"))
[(1, '', 'drop table foo'), (1, '', 'create table foo')]
>>> list(iter_statement("-- blah blah\ndrop table foo"))
[(2, '-- blah blah\n', 'drop table foo')]
>>> list(iter_statement("drop /* blah blah */ table foo"))
[(1, '', 'drop table foo')]
>>> list(iter_statement("select '[^;]+' from dual"))
[(1, '', "select '[^;]+' from dual")]
>>> list(iter_statement('select "x--y" from z;'))
[(1, '', 'select "x--y" from z')]
>>> list(iter_statement("select 'x--y' from z;"))
[(1, '', "select 'x--y' from z")]
'''
statement = comment = ''
line = 1
sline = None # type: Optional[int]
def save(txt: SQL) -> Tuple[SQL, Optional[int]]:
return (statement + txt, sline or (line if txt else None))
while 1:
m = SQL_SEPARATORS.search(txt)
if not m:
statement, sline = save(txt)
break
pfx, match, txt = (txt[:m.start()],
txt[m.start():m.end()],
txt[m.end():])
if pfx:
statement, sline = save(pfx)
if m.group('sep'):
if sline:
yield sline, comment, statement
statement = comment = ''
sline = None
elif [n for n in ('lit', 'hint', 'sym')
if m.group(n)]:
statement, sline = save(match)
elif (m.group('space') and statement):
statement, sline = save(match)
elif ((m.group('comment') and not statement) or
(m.group('space') and comment)):
comment += match
line += (pfx + match).count("\n")
if sline and (comment or statement):
yield sline, comment, statement
# Check for hint before comment since a hint looks like a comment
SQL_SEPARATORS = re.compile(
r'(?P<space>^\s+)'
r'|(?P<hint>/\*\+.*?\*/)'
r'|(?P<comment>(--[^\n]*(?:\n|$))|(?:/\*([^\*]|(\*(?!/)))*\*/))'
r'|(?P<sym>"[^\"]*")'
r"|(?P<lit>'[^\']*')"
r'|(?P<sep>;)')
def _test_iter_statement() -> None:
r'''
>>> list(iter_statement("/* blah blah */ drop table foo"))
[(1, '/* blah blah */ ', 'drop table foo')]
>>> [l for (l, c, s) in iter_statement('s1;\n/**********************\n'
... '* Medication concepts \n'
... '**********************/\ns2')]
[1, 5]
>>> list(iter_statement(""))
[]
>>> list(iter_statement("/*...*/ "))
[]
>>> list(iter_statement("drop table foo; "))
[(1, '', 'drop table foo')]
>>> list(iter_statement("drop table foo"))
[(1, '', 'drop table foo')]
>>> list(iter_statement("/* *** -- *** */ drop table foo"))
[(1, '/* *** -- *** */ ', 'drop table foo')]
>>> list(iter_statement("/* *** -- *** */ drop table x /* ... */"))
[(1, '/* *** -- *** */ ', 'drop table x ')]
>>> list(iter_statement('select /*+ index(ix1 ix2) */ * from some_table'))
[(1, '', 'select /*+ index(ix1 ix2) */ * from some_table')]
>>> len(list(iter_statement(""" /* + no space before + in hints*/
... /*+ index(observation_fact fact_cnpt_pat_enct_idx) */
... select * from some_table; """))[0][2])
79
>>> list(iter_statement('select 1+1; /* nothing else */; '))
[(1, '', 'select 1+1')]
'''
pass # pragma: nocover
def substitute(sql: SQL, variables: Optional[Environment]) -> SQL:
'''Evaluate substitution variables in the style of Oracle sqlplus.
>>> substitute('select &¬_bound from dual', {})
Traceback (most recent call last):
KeyError: 'not_bound'
'''
if variables is None:
return sql
sql_esc = sql.replace('%', '%%') # escape %, which we use specially
return re.sub('&&(\w+)', r'%(\1)s', sql_esc) % variables
def params_used(params: Params, statement: SQL) -> Params:
return dict((k, v) for (k, v) in params.items()
if k in param_names(statement))
def param_names(s: SQL) -> List[Name]:
'''
>>> param_names('select 1+1 from dual')
[]
>>> param_names('select 1+:y from dual')
['y']
'''
return [expr[1:]
for expr in re.findall(r':\w+', s)]
def first_cursor(statement: SQL) -> SQL:
'''Find argument of first obvious call to `cursor()`.
>>> first_cursor('select * from table(f(cursor(select * from there)))')
'select * from there'
'''
return statement.split('cursor(')[1].split(')')[0]
class ObjectId(object):
kind = ''
def __init__(self, name: str) -> None:
self.name = name
def __repr__(self) -> str:
return '{} {}'.format(self.kind, self.name)
def __hash__(self) -> int:
return hash((self.kind, self.name))
def __eq__(self, other: object) -> bool:
if not isinstance(other, ObjectId):
return False
return (self.kind, self.name) == ((other.kind, other.name))
def __lt__(self, other: 'ObjectId') -> bool:
return (self.kind, self.name) < ((other.kind, other.name))
class TableId(ObjectId):
kind = 'table'
class ViewId(ObjectId):
kind = 'view'
def created_objects(statement: SQL) -> List[ObjectId]:
r'''
>>> created_objects('create table t as ...')
[table t]
>>> created_objects('create or replace view x\nas ...')
[view x]
'''
m = re.search('(?i)^create or replace view (\S+)', statement.strip())
views = [ViewId(m.group(1))] if m else [] # type: List[ObjectId]
m = re.search('(?i)^create table (\w+)', statement.strip())
tables = [TableId(m.group(1))] if m else [] # type: List[ObjectId]
return tables + views
def inserted_tables(statement: SQL) -> List[Name]:
r'''
>>> inserted_tables('create table t as ...')
[]
>>> inserted_tables('insert into t (...) ...')
['t']
'''
if not statement.startswith('insert'):
return []
m = re.search('into\s+(\S+)', statement.strip())
return [m.group(1)] if m else []
def insert_append_table(statement: SQL) -> Optional[Name]:
if '/*+ append' in statement:
[t] = inserted_tables(statement)
return t
return None
def iter_blocks(module: SQL,
separator: str =r'\r?\n/\r?\n') -> Iterable[StatementInContext]:
line = 1
for block in re.split(separator, module):
if block.strip():
yield (line, '...no comment handling...', block)
line += len((block + separator).split('\n')[:-1])