-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathgen_deid_sql.py
248 lines (217 loc) · 9.97 KB
/
gen_deid_sql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
''' Parse Oracle "create table" .sql file and generate sql to de-identify
CMS data.
'''
from collections import defaultdict, OrderedDict
from contextlib import contextmanager
from csv import DictReader
from re import findall, DOTALL
DOB_COLS = ['BENE_BIRTH_DT', 'EL_DOB', 'DOB_DT']
def main(open_rd_argv, get_input_path, get_col_desc_file, get_mode,
print_stderr):
print_stderr('Usage:\ngen_deid_sql.py <path/to/oracle_create.sql> '
'<path/to/table_column_desc.csv> '
'<cms_deid_sql|date_events>\n\n\n')
print ('-- cms_deid.sql: Deidentify CMS data\n'
'-- Copyright (c) 2020 University of Kansas Medical Center\n')
with open_rd_argv(get_input_path()) as fin:
tables = tables_columns(fin.read())
with open_rd_argv(get_col_desc_file()) as fin:
dr = DictReader(fin)
tdesc = defaultdict(dict)
for tcd in dr:
tdesc[tcd['table_name']].update(
dict([(tcd['column_name'], tcd['description'])]))
if get_mode() == 'cms_deid_sql':
cms_deid_sql(tables, tdesc)
elif get_mode() == 'date_events':
date_events(tables)
else:
raise NotImplementedError(get_mode())
def date_events(tables):
def mkids(t):
return(("'%(table)s' table_name,\n"
' bene_id,\n' +
(' msis_id,\n state_cd,\n' if t.startswith('maxdata')
else ' null msis_id,\n null state_cd,\n'))
% dict(table=t))
def mk_inq(t, dc):
return ('select /*+ PARALLEL(%(table)s,12) */\n'
'%(ids)s %(cols)s\nfrom %(table)s'
% dict(ids=mkids(t),
cols=',\n '.join([col for col in
[c2 for c2 in
[("'%(col)s' COL_DT" %
dict(col=dc[0])) if len(dc) == 1
else ''] +
[c + (' DT' if len(dc) == 1 else '')
for c in dc]]
if len(col.strip())]),
table=table))
sql_st = list()
for table, cols in tables.items():
sql = ''
date_cols = [col for (col, ctype) in cols if ctype == 'DATE']
if len(date_cols) > 1:
sql = ('with dates as (\n%(inq)s\n )\n'
'select * from dates\nunpivot exclude nulls(\n'
' dt for col_date in (\n %(upcols)s\n))' %
dict(inq=mk_inq(table, date_cols),
upcols=',\n '.join(["%(col)s as '%(col)s'" %
dict(col=c)
for c in date_cols])))
if len(date_cols) == 1:
sql = mk_inq(table, date_cols)
if sql:
sql_st.append(sql)
print ('\n\n'.join(['insert /*+ APPEND */ into date_events\n' +
s + ';\ncommit;'
for s in sql_st]))
def cms_deid_sql(tables, tdesc, date_skip_cols=['EXTRACT_DT'],
hipaa_age_limit=89):
for t_idx, (table, cols) in enumerate(tables.items()):
sql = ('insert /*+ APPEND */ into "&&deid_schema".%(table)s\n'
'select /*+ PARALLEL(%(table)s,12) */ \n' %
dict(table=table, cols=',\n'.join([' ' + c[0] for c in cols])))
for idx, (col, ctype) in enumerate(cols):
if ctype == 'DATE' and col not in date_skip_cols:
if table.startswith('maxdata'):
dt_sql = (' idt.%(col)s + coalesce('
'bm.date_shift_days, mp.date_shift_days) ' %
dict(col=col))
dob_yr = ('coalesce('
'extract(year from bm.birth_date), '
'extract(year from mp.birth_date))')
if col in DOB_COLS:
shift_sql = (' case\n'
' when %(dob_yr)s = 1900\n'
' then coalesce(bm.birth_date,mp.birth_date)\n'
' else %(dt_sql)s\n'
' end ') % dict(col=col,
dt_sql=dt_sql.strip(),
dob_yr=dob_yr)
else:
shift_sql = dt_sql
else:
dt_sql = (' idt.%(col)s + bm.date_shift_days ' %
dict(col=col))
dob_yr= 'extract(year from bm.birth_date)'
if col in DOB_COLS:
shift_sql = (' case\n'
' when %(dob_yr)s = 1900\n'
' then bm.birth_date\n'
' else %(dt_sql)s\n'
' end ') % dict(col=col,
dt_sql=dt_sql.strip(),
dob_yr=dob_yr)
else:
shift_sql = dt_sql
sql += (shift_sql + col)
elif col == 'BENE_ID':
sql += (' bm.BENE_ID_DEID %(col)s'
% dict(col=col))
elif col == 'MSIS_ID':
sql += (' mm.MSIS_ID_DEID %(col)s'
% dict(col=col))
elif (('ZIP' in col and 'PRVDR' not in col) or
('COUNTY' in col) or
('CNTY' in col)):
sql += ' NULL %(col)s' % dict(col=col)
elif col == 'BENE_AGE_AT_END_REF_YR':
sql += (' -- If the age at end of reference year is '
'null then leave it as-is. If we\'ve\n'
' -- found based on some date span (such as '
'extract date and date of birth)\n'
' -- that the age appears to be > 89 then shift '
'the age by the same amount we\n'
' -- moved the date of birth. If that still '
'results in an apparent age > 89\n'
' -- (presumably due to noisy data), then cap the '
'age at 89.\n')
sql += (' case\n'
' when idt.%(col)s is null then null\n'
' when extract(year from bm.birth_date) = 1900 or'
' idt.%(col)s > %(hipaa_age_limit)s then %(hipaa_age_limit)s\n'
' else idt.%(col)s\n'
' end %(col)s') % dict(
col=col, hipaa_age_limit=hipaa_age_limit)
elif col == 'BENE_AGE_CNT':
sql += (' case\n'
' when idt.%(col)s is null then null\n'
' when idt.%(col)s + round(months_between('
'idt.EXTRACT_DT, idt.ADMSN_DT)/12) > '
'%(hipaa_age_limit)s then %(hipaa_age_limit)s\n'
' else idt.%(col)s\n'
' end %(col)s') % dict(
col=col, hipaa_age_limit=hipaa_age_limit)
else:
sql += ' idt.%(col)s' % dict(col=col)
if idx != len(cols) - 1:
sql += ','
sql += ((' -- %s\n' % tdesc[table][col])
if col in tdesc[table] else '\n')
if table.startswith('maxdata'):
sql += ('from %(table)s idt \n'
'left join bene_id_mapping bm '
'on bm.bene_id = idt.bene_id\n'
'join msis_id_mapping mm '
'on mm.msis_id = idt.msis_id\n'
'join msis_person mp on mp.msis_id = idt.msis_id '
'and mp.state_cd = idt.state_cd;\n'
'commit;\n\n') % dict(table=table)
else:
sql += ('from %(table)s idt \n'
'left join bene_id_mapping bm '
'on bm.bene_id = idt.bene_id;\n'
'commit;\n\n') % dict(table=table)
if t_idx == len(tables) - 1:
print sql.strip()
else:
print sql
def tables_columns(sql):
'''
>>> sql = """
... create table outpatient_value_codes (
... BENE_ID VARCHAR2(15),
... CLM_VAL_AMT NUMBER
... );
...
... create table outpatient_base_claims (
... BENE_ID VARCHAR2(15),
... CLM_ID VARCHAR2(15)
... );
... """
>>> tables_columns(sql)
... # doctest: +NORMALIZE_WHITESPACE
{'outpatient_value_codes': [['BENE_ID', 'VARCHAR2(15)'],
['CLM_VAL_AMT', 'NUMBER']],
'outpatient_base_claims': [['BENE_ID', 'VARCHAR2(15)'],
['CLM_ID', 'VARCHAR2(15)']]}
'''
td = OrderedDict()
for table, columns in findall('create table ([a-zA-Z_]+) \((.*?)\);',
sql, flags=DOTALL):
td[table] = [c.split(' ') for c in [ct.replace(',', '').strip()
for ct in
columns.split('\n')] if c]
return td
if __name__ == '__main__':
def _tcb():
from sys import argv, stderr
def get_input_path():
return argv[1]
def get_col_desc_file():
return argv[2]
def get_mode():
return argv[3]
def print_stderr(s):
stderr.write(s)
@contextmanager
def open_rd_argv(path):
if(get_input_path() not in path and
get_col_desc_file() not in path):
raise RuntimeError('%s not in %s' % (get_input_path(), path))
with open(path, 'rb') as fin:
yield fin
main(open_rd_argv, get_input_path, get_col_desc_file, get_mode,
print_stderr)
_tcb()