Skip to content

Commit

Permalink
Fix snowflake script (#1128)
Browse files Browse the repository at this point in the history
* add snowflake support for script

* add test file
  • Loading branch information
pgrivachev authored Aug 10, 2023
1 parent b6c2e05 commit b186df0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
6 changes: 6 additions & 0 deletions e2e/projects/snowflake/models/union_relations.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- TODO: add test
select * from
{{ dbt_utils.union_relations(
relations=[ref('join_tables'), ref('join_tables')]
)
}}
24 changes: 9 additions & 15 deletions server/python/script.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import http
import os
import sys
from typing import List
from dbt.adapters.factory import FACTORY
from dbt.contracts.connection import AdapterRequiredConfig
from dbt.adapters.bigquery import BigQueryColumn
from dbt.adapters.bigquery.relation import BigQueryRelation
import google.cloud.exceptions
from dbt.adapters.base import Column as BaseColumn
from dbt.adapters.base.relation import BaseRelation
import json

# Expected arguments for this script:
Expand Down Expand Up @@ -34,7 +32,9 @@ def dbt_command(cli_args) -> None:
dbt_command(["--version"])
else:
port = sys.argv[1]
def new_get_columns_in_relation(self, relation: BigQueryRelation) -> List[BigQueryColumn]:
global_old_get_columns_in_relation = None
def new_get_columns_in_relation(self, relation: BaseRelation) -> List[BaseColumn]:
global global_old_get_columns_in_relation
db = relation.path.database if relation.path.database != None else "None"
schema = relation.path.schema if relation.path.schema != None else "None"
table = relation.path.identifier if relation.path.identifier != None else "None"
Expand All @@ -46,26 +46,20 @@ def new_get_columns_in_relation(self, relation: BigQueryRelation) -> List[BigQue
result = response.read().decode();
if result not in ["SOURCE", "NOT_FOUND"]:
data = json.loads(result)
bigquery_columns = [BigQueryColumn(column[0], BigQueryColumn.translate_type(column[1])) for column in data]

bigquery_columns = [self.Column.create(column[0], column[1]) for column in data]
return bigquery_columns

try:
table_from_bq = self.connections.get_bq_table(
database=relation.database, schema=relation.schema, identifier=relation.identifier
)
return self._get_dbt_columns_from_bq_table(table_from_bq)

except (ValueError, google.cloud.exceptions.NotFound) as e:
return []
return global_old_get_columns_in_relation(relation)

old_register_adapter = FACTORY.register_adapter.__get__(FACTORY)

def new_register_adapter(self, config: AdapterRequiredConfig) -> None:
global global_old_get_columns_in_relation
old_register_adapter(config)
credentials_type = config.credentials.type
if credentials_type in ["bigquery", "snowflake"]:
adapter = self.adapters[credentials_type]
global_old_get_columns_in_relation = adapter.get_columns_in_relation
adapter.get_columns_in_relation = new_get_columns_in_relation.__get__(adapter)

FACTORY.register_adapter = new_register_adapter.__get__(FACTORY)
Expand Down

0 comments on commit b186df0

Please sign in to comment.