Skip to content

Commit

Permalink
Update schema_config.py
Browse files Browse the repository at this point in the history
[FIX] `surrounding` cases
  • Loading branch information
jzsmoreno committed Apr 10, 2024
1 parent 44e46cd commit b877c52
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 25 deletions.
2 changes: 1 addition & 1 deletion merge_by_lev/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.8
0.2.9
49 changes: 25 additions & 24 deletions merge_by_lev/schema_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json

import pandas as pd
import pyarrow
import pyarrow as pa
import yaml
Expand All @@ -8,7 +9,6 @@
from pydbsmgr.main import *
from pydbsmgr.main import DataFrame
from pydbsmgr.utils.azure_sdk import *
import pandas as pd


class StandardColumns:
Expand Down Expand Up @@ -77,8 +77,9 @@ def _sql_standards(
"""
df = (self.df).copy()
if surrounding:
df.columns = [col[1:-1] for col in df.columns]
df.columns = [
col[1:-1] if col.startswith("[") and col.endswith("]") else col for col in df.columns
]
df.columns = df.columns.str.lower()
df.columns = df.columns.str.replace("_+", " ", regex=True)
df.columns = df.columns.str.title()
Expand All @@ -87,19 +88,13 @@ def _sql_standards(
df.columns = df.columns.str.replace("\n", "_")
if snake_case:
df.columns = [self._camel_to_snake(col) for col in df.columns]
df.columns = [self._truncate(col) for col in df.columns]
df.columns = df.columns.str[:128]
if sort:
df = self._sort_columns_by_length(df)
if surrounding:
df.columns = [f"[{col}]" for col in df.columns]
return df

def _truncate(self, column_name: str) -> str:
if len(column_name) >= 128:
return column_name[:128]
else:
return column_name

def _sort_columns_by_length(self, dataframe: DataFrame) -> DataFrame:
# Get the column names and sort them by length
sorted_columns = sorted(dataframe.columns, key=len, reverse=True)
Expand Down Expand Up @@ -188,9 +183,12 @@ def create_yaml(
container_name (`str`, optional): name of the container inside the storage account. By default it is set to `""`.
overwrite (`bool`, optional): boolean variable indicating whether the file is overwritten or not. By default it is set to `True`.
"""
self.df.columns = [
c.replace(" ", "_") for c in list(self.df.columns)
] # Remove spaces from column names
df_info = self._create_info()
df_info["data type"] = [str(_type) for _type in df_info["data type"].to_list()]
df_info["sql name"] = [col_name.replace(" ", "_") for col_name in df_info["column name"]]
df_info["sql name"] = df_info["column name"]

data = {}
for col_name, data_type, sql_name in zip(
Expand Down Expand Up @@ -236,7 +234,7 @@ def _create_info(self) -> DataFrame:
return info


def recursive_correction(df_: DataFrame, input_string: str) -> Table:
def recursive_correction(df_: DataFrame, input_string: str, unchanged_names: List[str]) -> Table:
pattern = r"Conversion failed for column (\w+) with type"
match = re.search(pattern, input_string)
column_name = match.group(1)
Expand All @@ -245,17 +243,20 @@ def recursive_correction(df_: DataFrame, input_string: str) -> Table:
df_[column_name] = df_[column_name].astype(str)
try:
data_handler = DataSchema(df_)
return pa.Table.from_pandas(df_, schema=data_handler.get_schema())
return (pa.Table.from_pandas(df_, schema=data_handler.get_schema())).rename_columns(
unchanged_names
)
except (pa.lib.ArrowTypeError, pyarrow.lib.ArrowInvalid) as e:
iteration_match = re.search(pattern, (str(e).split(","))[-1])
iteration_column_name = iteration_match.group(1)
if column_name != iteration_column_name:
return recursive_correction(df_, (str(e).split(","))[-1])
return recursive_correction(df_, (str(e).split(","))[-1], unchanged_names)


class DataSchema(DataFrameToYaml):
def __init__(self, df: DataFrame):
super().__init__(df)
self.cols = df.columns.to_list()

def get_schema(
self,
Expand All @@ -266,7 +267,6 @@ def get_schema(
connection_string: str = "",
container_name: str = "",
overwrite: bool = True,
preserve_order: bool = True,
):
if format_type == "yaml":
return self.create_yaml(
Expand All @@ -279,6 +279,7 @@ def get_schema(
with open(yaml_name, "r") as file:
data = yaml.safe_load(file)
fields = []

for col_name, col_info in data["database"].items():
col_type = col_info["type"][0]
if col_type == "int64":
Expand Down Expand Up @@ -320,11 +321,10 @@ def get_schema(
if os.path.exists(schema_file_path):
os.remove(schema_file_path)

if preserve_order:
try:
schema = pa.Schema.from_pandas(self.df)
except:
None
try:
schema = pa.Schema.from_pandas(self.df)
except:
None
self.schema = schema

return schema
Expand All @@ -341,22 +341,23 @@ def get_table(self) -> Table:
msg = "It was not possible to create the table\n"
msg += "Error: {%s}" % e
print(f"{warning_type}: {msg}")
return recursive_correction(self.df, (str(e).split(","))[-1])
return table
return recursive_correction(self.df, (str(e).split(","))[-1], self.cols)

return table.rename_columns(self.cols)


if __name__ == "__main__":
# Create a DataFrame
data = {
"Name": ["Dani", "John", "Alice", "Bob"],
"Age": ["32", 25, 30, 35],
"Points": ["0", 1, 2, 3],
"Points value": ["0", 1, 2, 3],
}
df = pd.DataFrame(data)
table_name = "test_table"
data_handler = DataSchema(df)
schema = data_handler.get_schema()
table = data_handler.get_table()
column_handler = StandardColumns(df)
df = column_handler.get_frame(surrounding=False)
df = column_handler.get_frame(surrounding=False, snake_case=False)
breakpoint()

0 comments on commit b877c52

Please sign in to comment.