Skip to content

Commit

Permalink
Change the string append logic to use StringIO
Browse files Browse the repository at this point in the history
Change the string append logic to use StringIO
  • Loading branch information
[email protected] committed May 31, 2023
1 parent a797c91 commit 4e5fe0c
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 35 deletions.
19 changes: 14 additions & 5 deletions src/SimpleReplay/audit_logs_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This module parses various auditlogs
"""

from io import StringIO
logger = None


Expand All @@ -16,7 +16,16 @@ def __init__(self):
self.database_name = ""
self.pid = ""
self.xid = ""
self.text = ""
self.text = StringIO()

def clear_and_set_text(self, new_value):
# Better to create a new instance, rather than truncate and seek - because it’s faster
self.text.close()
self.text = StringIO()
self.text.write(new_value)

def append_text(self, value):
self.text.write(value)

def get_filename(self):
base_name = (
Expand Down Expand Up @@ -44,7 +53,7 @@ def __str__(self):
self.database_name,
self.pid,
self.xid,
self.text,
self.text.getvalue(),
)
)

Expand All @@ -58,11 +67,11 @@ def __eq__(self, other):
and self.database_name == other.database_name
and self.pid == other.pid
and self.xid == other.xid
and self.text == other.text
and self.text.getvalue() == other.text.getvalue()
)

def __hash__(self):
return hash((str(self.pid), str(self.xid), self.text.strip("\n")))
return hash((str(self.pid), str(self.xid), self.text.getvalue().strip("\n")))


class ConnectionLog:
Expand Down
20 changes: 10 additions & 10 deletions src/SimpleReplay/extract/extractor/extract_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def _parse_user_activity_log(file, logs, databases, start_time, end_time):
if filename in logs:
# Check if duplicate. This happens with JDBC connections.
prev_query = logs[filename][-1]
if not is_duplicate(prev_query.text, user_activity_log.text):
if not is_duplicate(prev_query.text.getvalue(), user_activity_log.text.getvalue()):
if fetch_pattern.search(
prev_query.text
) and fetch_pattern.search(user_activity_log.text):
user_activity_log.text = f"--{user_activity_log.text}"
prev_query.text.getvalue()
) and fetch_pattern.search(user_activity_log.text.getvalue()):
user_activity_log.clear_and_set_text(f"--{user_activity_log.text.getvalue()}")
logs[filename].append(user_activity_log)
else:
logs[filename].append(user_activity_log)
Expand All @@ -87,9 +87,9 @@ def _parse_user_activity_log(file, logs, databases, start_time, end_time):
user_activity_log.database_name = query_information[3][3:]
user_activity_log.pid = query_information[5][4:]
user_activity_log.xid = query_information[7][4:]
user_activity_log.text = line_split[1]
user_activity_log.clear_and_set_text(line_split[1])
else:
user_activity_log.text += line
user_activity_log.append_text(line)


def _parse_start_node_log(file, logs, databases, start_time, end_time):
Expand All @@ -107,7 +107,7 @@ def _parse_start_node_log(file, logs, databases, start_time, end_time):
if filename in logs:
# Check if duplicate. This happens with JDBC connections.
prev_query = logs[filename][-1]
if not is_duplicate(prev_query.text, start_node_log.text):
if not is_duplicate(prev_query.text.getvalue(), start_node_log.text.getvalue()):
logs[filename].append(start_node_log)
else:
logs[filename] = [start_node_log]
Expand All @@ -132,14 +132,14 @@ def _parse_start_node_log(file, logs, databases, start_time, end_time):
start_node_log.username = query_information[4][3:].split(":")[0]
start_node_log.pid = query_information[5][4:]
start_node_log.xid = query_information[7][4:]
start_node_log.text = line_split[1].strip()
start_node_log.clear_and_set_text(line_split[1].strip())
else:
start_node_log.text += line
start_node_log.append_text(line)


def _parse_connection_log(file, connections, last_connections, start_time, end_time):
for line in file.readlines():

line = line.decode("utf-8")

connection_information = line.split("|")
Expand Down
30 changes: 15 additions & 15 deletions src/SimpleReplay/extract/extractor/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,33 +200,33 @@ def get_sql_connections_replacements(self, last_connections, log_items):
)
continue

query.text = remove_line_comments(query.text).strip()
query.clear_and_set_text(remove_line_comments(query.text.getvalue()).strip())

if "copy " in query.text.lower() and "from 's3:" in query.text.lower():
if "copy " in query.text.getvalue().lower() and "from 's3:" in query.text.getvalue().lower():
bucket = re.search(
r"from 's3:\/\/[^']*", query.text, re.IGNORECASE
r"from 's3:\/\/[^']*", query.text.getvalue(), re.IGNORECASE
).group()[6:]
replacements.add(bucket)
query.text = re.sub(
query.clear_and_set_text(re.sub(
r"IAM_ROLE 'arn:aws:iam::\d+:role/\S+'",
f" IAM_ROLE ''",
query.text,
query.text.getvalue(),
flags=re.IGNORECASE,
)
if "unload" in query.text.lower() and "to 's3:" in query.text.lower():
query.text = re.sub(
))
if "unload" in query.text.getvalue().lower() and "to 's3:" in query.text.getvalue().lower():
query.clear_and_set_text(re.sub(
r"IAM_ROLE 'arn:aws:iam::\d+:role/\S+'",
f" IAM_ROLE ''",
query.text,
query.text.getvalue(),
flags=re.IGNORECASE,
)
))

query.text = f"{query.text.strip()}"
if not len(query.text) == 0:
if not query.text.endswith(";"):
query.text += ";"
query.clear_and_set_text(f"{query.text.getvalue().strip()}")
if not len(query.text.getvalue()) == 0:
if not query.text.getvalue().endswith(";"):
query.append_text(";")

query_info["text"] = query.text
query_info["text"] = query.text.getvalue()
sql_json["transactions"][query.xid]["queries"].append(query_info)

if not hash((query.database_name, query.username, query.pid)) in last_connections:
Expand Down
10 changes: 5 additions & 5 deletions src/SimpleReplay/log_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,18 @@ def is_valid_log(log, start_time, end_time):
if end_time and log.record_time > end_time:
return False

if any(word in log.text for word in problem_keywords):
if any(word in log.text.getvalue() for word in problem_keywords):
return False

if any(word in log.text for word in potential_problem_keywords) and not any(word in log.text for word in not_problem_keywords):
if any(word in log.text.getvalue() for word in potential_problem_keywords) and not any(word in log.text.getvalue() for word in not_problem_keywords):
return False

# filter internal statement rewrites with parameter markers
if re.search('\$\d',log.text):
if re.search('\$\d',log.text.getvalue()):
# remove \$\d in string literals ( select '$1' ) or comment blocks ( */ $1 */ )
text_without_valid_parameter_markers = re.sub("""'.*\\$\\d.*'|\\/\\*.*\\$\\d.*\\*\\/""",'',log.text,flags=re.DOTALL)
text_without_valid_parameter_markers = re.sub("""'.*\\$\\d.*'|\\/\\*.*\\$\\d.*\\*\\/""",'',log.text.getvalue(),flags=re.DOTALL)
# remove \$\d in single line quotes ( -- $1 )
if '--' in log.text:
if '--' in log.text.getvalue():
text_without_valid_parameter_markers = re.sub('^\s*--.*\$\d','',text_without_valid_parameter_markers)
# if there are still parameter markers remaining after removing them from valid cases, the query text is invalid
if re.search('\$\d',text_without_valid_parameter_markers):
Expand Down

0 comments on commit 4e5fe0c

Please sign in to comment.