Skip to content

Commit 209aa47

Browse files
chore: work in progress
1 parent e9df3b7 commit 209aa47

File tree

4 files changed

+624
-77
lines changed

4 files changed

+624
-77
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,6 @@ Available scripts:
403403
- `spark-tests` - run the Spark test suite.
404404
- `coverage` or `cov` - run the test suite with coverage.
405405
"""
406-
path = ".venv"
407406
python = "3.12"
408407
template = "default"
409408
features = [

src/koheesio/spark/transformations/camel_to_snake.py

Lines changed: 22 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -2,89 +2,35 @@
22
Class for converting DataFrame column names from camel case to snake case.
33
"""
44

5-
from typing import Optional
65
import re
76

8-
from koheesio.models import Field, ListOfColumns
9-
from koheesio.spark.transformations import ColumnsTransformation
10-
from koheesio.spark.utils import SPARK_MINOR_VERSION
7+
from collections.abc import Callable
118

12-
camel_to_snake_re = re.compile("([a-z0-9])([A-Z])")
9+
from pydantic import Field
1310

11+
from koheesio.spark.transformations.rename_columns import RenameColumns
1412

15-
def convert_camel_to_snake(name: str) -> str:
16-
"""
17-
Converts a string from camelCase to snake_case.
18-
19-
Parameters:
20-
----------
21-
name : str
22-
The string to be converted.
23-
24-
Returns:
25-
--------
26-
str
27-
The converted string in snake_case.
28-
"""
29-
return camel_to_snake_re.sub(r"\1_\2", name).lower()
30-
31-
32-
class CamelToSnakeTransformation(ColumnsTransformation):
33-
"""
34-
Converts column names from camel case to snake cases
35-
36-
Parameters
37-
----------
38-
columns : Optional[ListOfColumns], optional, default=None
39-
The column or columns to convert. If no columns are specified, all columns will be converted. A list of columns
40-
or a single column can be specified.
41-
For example: `["column1", "column2"]` or `"column1"`
42-
43-
Example
44-
-------
45-
__input_df:__
4613

47-
| camelCaseColumn | snake_case_column |
48-
|--------------------|-------------------|
49-
| ... | ... |
5014

51-
```python
52-
output_df = CamelToSnakeTransformation(
53-
column="camelCaseColumn"
54-
).transform(input_df)
55-
```
15+
def camel_to_snake(name: str) -> str:
16+
"""Convert a camelCase string to snake_case.
5617
57-
__output_df:__
58-
59-
| camel_case_column | snake_case_column |
60-
|-------------------|-------------------|
61-
| ... | ... |
62-
63-
In this example, the column `camelCaseColumn` is converted to `camel_case_column`.
64-
65-
> Note: the data in the columns is not changed, only the column names.
18+
Args:
19+
name: The camelCase string to be converted.
6620
21+
Returns:
22+
str: The converted snake_case string.
6723
"""
68-
69-
def execute(self) -> ColumnsTransformation.Output:
70-
_df = self.df
71-
72-
# Prepare columns input:
73-
columns = list(self.get_columns())
74-
75-
if SPARK_MINOR_VERSION < 3.4:
76-
for column in columns:
77-
_df = _df.withColumnRenamed(column, convert_camel_to_snake(column))
78-
79-
else:
80-
# Rename columns using toDF for Spark versions >= 3.4
81-
# Note: toDF requires all column names to be specified
82-
new_column_names = []
83-
for column in _df.columns:
84-
if column in columns:
85-
new_column_names.append(convert_camel_to_snake(column))
86-
continue
87-
new_column_names.append(column)
88-
_df = _df.toDF(*new_column_names)
89-
90-
self.output.df = _df
24+
# Replace any lowercase letter or digit followed by an uppercase
25+
# letter with the same characters separated by an underscore
26+
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
27+
# Replace any lowercase letter or digit followed by an uppercase letter
28+
# with the same characters separated by an underscore and convert to lowercase
29+
s2 = re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
30+
# Remove any double underscores
31+
res = re.sub("_+", "_", s2)
32+
return res
33+
34+
35+
class CamelToSnakeTransformation(RenameColumns):
36+
rename_func:Callable[[str], str] | None = Field(default=camel_to_snake, description="Function to convert camelCase to snake_case") # type: ignore
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from __future__ import annotations
2+
3+
from typing import Callable
4+
5+
from pydantic import Field
6+
7+
from pyspark.sql import functions as F
8+
from pyspark.sql.types import ArrayType, StructField, StructType
9+
10+
from koheesio.spark import DataFrame
11+
from koheesio.spark.transformations import ColumnsTransformation
12+
13+
14+
class RenameColumns(ColumnsTransformation):
15+
rename_func:Callable[[str], str] = Field(..., description="Function to rename columns")
16+
17+
def rename_schema(self, schema: StructType):
18+
"""Renames the fields of a given schema using a specified renaming function.
19+
Args:
20+
schema: The schema whose fields need to be renamed.
21+
Returns:
22+
StructType: A new schema with renamed fields.
23+
Notes:
24+
- If the renaming function is not provided, it defaults to `RenameColumns.camel_to_snake`.
25+
- The function handles nested StructTypes and ArrayTypes containing StructTypes.
26+
27+
Steps:
28+
1. Initialize an empty list to hold the new fields.
29+
2. Determine the renaming function to use, defaulting to `RenameColumns.camel_to_snake` if not provided.
30+
3. Iterate over each field in the provided schema.
31+
4. Rename the field using the renaming function.
32+
5. Check if the field's data type is a StructType:
33+
- Recursively rename the schema of the nested StructType.
34+
6. Check if the field's data type is an ArrayType containing a StructType:
35+
- Recursively rename the schema of the nested StructType within the ArrayType.
36+
7. For other data types, simply rename the field.
37+
8. Append the newly created field to the list of new fields.
38+
9. Return a new StructType constructed from the list of new fields.
39+
"""
40+
new_fields = []
41+
_columns= list(self.get_columns())
42+
43+
if self.rename_func is None:
44+
raise ValueError("rename_func must be provided")
45+
46+
for field in schema.fields:
47+
if field.name in _columns:
48+
new_name = self.rename_func(field.name)
49+
50+
if isinstance(field.dataType, StructType):
51+
new_field = StructField(new_name, self.rename_schema(field.dataType), field.nullable)
52+
elif isinstance(field.dataType, ArrayType) and isinstance(field.dataType.elementType, StructType):
53+
new_field = StructField(
54+
new_name, ArrayType(self.rename_schema(field.dataType.elementType)), field.nullable
55+
)
56+
else:
57+
new_field = StructField(new_name, field.dataType, field.nullable)
58+
new_fields.append(new_field)
59+
else:
60+
# If the field is not in the columns to be renamed, keep it as is
61+
new_fields.append(field)
62+
63+
return StructType(new_fields)
64+
65+
def execute(self):
66+
self.df: DataFrame
67+
new_schema = self.rename_schema(self.df.schema) # pylint: disable=E1102
68+
_columns= list(self.get_columns())
69+
_not_renamed= set(self.df.columns) - set(_columns)
70+
renamed_select=[ F.col(c).cast(new_schema[self.rename_func(c)].dataType).alias(self.rename_func(c)) # pylint: disable=E1102
71+
for c in _columns]
72+
not_renamed_select= [ F.col(c).alias(c) for c in _not_renamed]
73+
74+
# Apply the new schema by casting each column to the new type
75+
self.output.df = self.df.select(renamed_select+not_renamed_select)

0 commit comments

Comments
 (0)