Skip to content

Commit

Permalink
[FEATURE] Populate account from url if not provided in SnowflakeBaseM…
Browse files Browse the repository at this point in the history
…odel (#117)
  • Loading branch information
mikita-sakalouski authored Nov 24, 2024
1 parent c34abbe commit 7fb1b27
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
13 changes: 12 additions & 1 deletion src/koheesio/integrations/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from abc import ABC
from contextlib import contextmanager
from types import ModuleType
from urllib.parse import urlparse

from koheesio import Step
from koheesio.logger import warn
Expand Down Expand Up @@ -281,7 +282,7 @@ class SnowflakeRunQueryPython(SnowflakeStep):
"""

query: str = Field(default=..., description="The query to run", alias="sql", serialization_alias="query")
account: str = Field(default=..., description="Snowflake Account Name", alias="account")
account: Optional[str] = Field(default=None, description="Snowflake Account Name", alias="account")

# for internal use
_snowflake_connector: Optional[ModuleType] = PrivateAttr(default_factory=safe_import_snowflake_connector)
Expand All @@ -291,6 +292,16 @@ class Output(SnowflakeStep.Output):

results: List = Field(default_factory=list, description="The results of the query")

@model_validator(mode="before")
def _validate_account(cls, values: Dict) -> Dict:
"""Populate account from URL if not provided"""
if not values.get("account"):
parsed_url = urlparse(values["url"])
base_url = parsed_url.hostname or parsed_url.path
values["account"] = base_url.split(".")[0]

return values

@field_validator("query")
def validate_query(cls, query: str) -> str:
"""Replace escape characters, strip whitespace, ensure it is not empty"""
Expand Down
32 changes: 23 additions & 9 deletions tests/snowflake/test_snowflake.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# flake8: noqa: F811
from copy import deepcopy
from unittest import mock

from pydantic_core._pydantic_core import ValidationError
import pytest

from pydantic import ValidationError

from koheesio.integrations.snowflake import (
GrantPrivilegesOnObject,
GrantPrivilegesOnTable,
Expand All @@ -15,8 +17,10 @@
)
from koheesio.integrations.snowflake.test_utils import mock_query

mock_query = mock_query

COMMON_OPTIONS = {
"url": "url",
"url": "hostname.com",
"user": "user",
"password": "password",
"database": "db",
Expand Down Expand Up @@ -120,13 +124,23 @@ def test_get_options(self):
"password": "password",
"role": "role",
"schema": "schema",
"url": "url",
"url": "hostname.com",
"user": "user",
"warehouse": "warehouse",
}
assert actual_options == expected_options
assert query_in_options["query"] == expected_query, "query should be returned regardless of the input"

def test_account_populated_from_url(self):
kls = SnowflakeRunQueryPython(**COMMON_OPTIONS, sql="SELECT * FROM table")
assert kls.account == "hostname"

def test_account_populated_from_url2(self):
common_options = deepcopy(COMMON_OPTIONS)
common_options["url"] = "https://host2.host1.snowflakecomputing.com"
kls = SnowflakeRunQueryPython(**common_options, sql="SELECT * FROM table")
assert kls.account == "host2"

def test_execute(self, mock_query):
# Arrange
query = "SELECT * FROM two_row_table"
Expand Down Expand Up @@ -161,7 +175,7 @@ class TestSnowflakeBaseModel:
def test_get_options_using_alias(self):
"""Test that the options are correctly generated using alias"""
k = SnowflakeBaseModel(
sfURL="url",
sfURL="hostname.com",
sfUser="user",
sfPassword="password",
sfDatabase="database",
Expand All @@ -170,7 +184,7 @@ def test_get_options_using_alias(self):
schema="schema",
)
options = k.get_options() # alias should be used by default
assert options["sfURL"] == "url"
assert options["sfURL"] == "hostname.com"
assert options["sfUser"] == "user"
assert options["sfDatabase"] == "database"
assert options["sfRole"] == "role"
Expand All @@ -180,7 +194,7 @@ def test_get_options_using_alias(self):
def test_get_options(self):
"""Test that the options are correctly generated not using alias"""
k = SnowflakeBaseModel(
url="url",
url="hostname.com",
user="user",
password="password",
database="database",
Expand All @@ -189,7 +203,7 @@ def test_get_options(self):
schema="schema",
)
options = k.get_options(by_alias=False)
assert options["url"] == "url"
assert options["url"] == "hostname.com"
assert options["user"] == "user"
assert options["database"] == "database"
assert options["role"] == "role"
Expand All @@ -203,7 +217,7 @@ def test_get_options(self):
def test_get_options_include(self):
"""Test that the options are correctly generated using include"""
k = SnowflakeBaseModel(
url="url",
url="hostname.com",
user="user",
password="password",
database="database",
Expand All @@ -215,7 +229,7 @@ def test_get_options_include(self):
options = k.get_options(include={"url", "user", "description", "options"}, by_alias=False)

# should be present
assert options["url"] == "url"
assert options["url"] == "hostname.com"
assert options["user"] == "user"
assert "description" in options

Expand Down

0 comments on commit 7fb1b27

Please sign in to comment.