From 83472572035b1c8231aae76a2a6d1efd973b922b Mon Sep 17 00:00:00 2001 From: Javier Asensio Date: Fri, 2 Jun 2023 11:18:15 +0100 Subject: [PATCH 1/3] Add timeout hack to mitigate timeouts --- src/databricks/sql/auth/thrift_http_client.py | 1 + src/databricks/sql/thrift_backend.py | 156 ++++++------------ 2 files changed, 48 insertions(+), 109 deletions(-) diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index a924ea63..7240ff01 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -3,6 +3,7 @@ import thrift +import thrift.transport.THttpClient import urllib.parse, six, base64 diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 935c7711..d01d93e4 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -1,34 +1,27 @@ -from decimal import Decimal import errno import logging import math -import time +import os import threading -import lz4.frame +import time +from decimal import Decimal from ssl import CERT_NONE, CERT_REQUIRED, create_default_context from typing import List, Union +import databricks.sql.auth.thrift_http_client +import lz4.frame import pyarrow -import thrift.transport.THttpClient import thrift.protocol.TBinaryProtocol +import thrift.transport.THttpClient import thrift.transport.TSocket import thrift.transport.TTransport - -import databricks.sql.auth.thrift_http_client +from databricks.sql import * from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.thrift_api.TCLIService import TCLIService, ttypes -from databricks.sql import * -from databricks.sql.thrift_api.TCLIService.TCLIService import ( - Client as TCLIServiceClient, -) - -from databricks.sql.utils import ( - ArrowQueue, - ExecuteResponse, - _bound, - RequestErrorInfo, - NoRetryReason, -) +from databricks.sql.thrift_api.TCLIService.TCLIService import \ + Client as TCLIServiceClient +from databricks.sql.utils import (ArrowQueue, ExecuteResponse, NoRetryReason, + RequestErrorInfo, _bound) logger = logging.getLogger(__name__) @@ -38,6 +31,9 @@ TIMESTAMP_AS_STRING_CONFIG = "spark.thriftserver.arrowBasedRowSet.timestampAsString" +# HACK! +THRIFT_SOCKET_TIMEOUT = os.getenv("THRIFT_SOCKET_TIMEOUT", None) + # see Connection.__init__ for parameter descriptions. # - Min/Max avoids unsustainable configs (sane values are far more constrained) # - 900s attempts-duration lines up w ODBC/JDBC drivers (for cluster startup > 10 mins) @@ -114,13 +110,9 @@ def __init__( self.staging_allowed_local_path = staging_allowed_local_path self._initialize_retry_args(kwargs) - self._use_arrow_native_complex_types = kwargs.get( - "_use_arrow_native_complex_types", True - ) + self._use_arrow_native_complex_types = kwargs.get("_use_arrow_native_complex_types", True) self._use_arrow_native_decimals = kwargs.get("_use_arrow_native_decimals", True) - self._use_arrow_native_timestamps = kwargs.get( - "_use_arrow_native_timestamps", True - ) + self._use_arrow_native_timestamps = kwargs.get("_use_arrow_native_timestamps", True) # Configure tls context ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file")) @@ -152,7 +144,10 @@ def __init__( ssl_context=ssl_context, ) - timeout = kwargs.get("_socket_timeout") + # HACK! + timeout = kwargs.get("_socket_timeout") or THRIFT_SOCKET_TIMEOUT + logger.info(f"Setting timeout HACK! to {timeout}") + # setTimeout defaults to None (i.e. no timeout), and is expected in ms self._transport.setTimeout(timeout and (float(timeout) * 1000.0)) @@ -175,15 +170,11 @@ def _initialize_retry_args(self, kwargs): given_or_default = type_(kwargs.get(key, default)) bound = _bound(min, max, given_or_default) setattr(self, key, bound) - logger.debug( - "retry parameter: {} given_or_default {}".format(key, given_or_default) - ) + logger.debug("retry parameter: {} given_or_default {}".format(key, given_or_default)) if bound != given_or_default: logger.warning( "Override out of policy retry parameter: " - + "{} given {}, restricted to {}".format( - key, given_or_default, bound - ) + + "{} given {}, restricted to {}".format(key, given_or_default, bound) ) # Fail on retry delay min > max; consider later adding fail on min > duration? @@ -211,9 +202,7 @@ def _extract_error_message_from_headers(headers): if THRIFT_ERROR_MESSAGE_HEADER in headers: err_msg = headers[THRIFT_ERROR_MESSAGE_HEADER] if DATABRICKS_ERROR_OR_REDIRECT_HEADER in headers: - if ( - err_msg - ): # We don't expect both to be set, but log both here just in case + if err_msg: # We don't expect both to be set, but log both here just in case err_msg = "Thriftserver error: {}, Databricks error: {}".format( err_msg, headers[DATABRICKS_ERROR_OR_REDIRECT_HEADER] ) @@ -406,10 +395,7 @@ def _check_initial_namespace(self, catalog, schema, response): if not (catalog or schema): return - if ( - response.serverProtocolVersion - < ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4 - ): + if response.serverProtocolVersion < ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4: raise InvalidServerResponseError( "Setting initial namespace not supported by the DBR version, " "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0." @@ -424,10 +410,7 @@ def _check_initial_namespace(self, catalog, schema, response): def _check_session_configuration(self, session_configuration): # This client expects timetampsAsString to be false, so we do not allow users to modify that - if ( - session_configuration.get(TIMESTAMP_AS_STRING_CONFIG, "false").lower() - != "false" - ): + if session_configuration.get(TIMESTAMP_AS_STRING_CONFIG, "false").lower() != "false": raise Error( "Invalid session configuration: {} cannot be changed " "while using the Databricks SQL connector, it must be false not {}".format( @@ -439,18 +422,14 @@ def _check_session_configuration(self, session_configuration): def open_session(self, session_configuration, catalog, schema): try: self._transport.open() - session_configuration = { - k: str(v) for (k, v) in (session_configuration or {}).items() - } + session_configuration = {k: str(v) for (k, v) in (session_configuration or {}).items()} self._check_session_configuration(session_configuration) # We want to receive proper Timestamp arrow types. # We set it also in confOverlay in TExecuteStatementReq on a per query basic, # but it doesn't hurt to also set for the whole session. session_configuration[TIMESTAMP_AS_STRING_CONFIG] = "false" if catalog or schema: - initial_namespace = ttypes.TNamespace( - catalogName=catalog, schemaName=schema - ) + initial_namespace = ttypes.TNamespace(catalogName=catalog, schemaName=schema) else: initial_namespace = None @@ -476,9 +455,7 @@ def close_session(self, session_handle) -> None: finally: self._transport.close() - def _check_command_not_in_error_or_closed_state( - self, op_handle, get_operations_resp - ): + def _check_command_not_in_error_or_closed_state(self, op_handle, get_operations_resp): if get_operations_resp.operationState == ttypes.TOperationState.ERROR_STATE: if get_operations_resp.displayMessage: raise ServerOperationError( @@ -513,17 +490,11 @@ def _poll_for_status(self, op_handle): def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, description): if t_row_set.columns is not None: - ( - arrow_table, - num_rows, - ) = ThriftBackend._convert_column_based_set_to_arrow_table( + (arrow_table, num_rows,) = ThriftBackend._convert_column_based_set_to_arrow_table( t_row_set.columns, description ) elif t_row_set.arrowBatches is not None: - ( - arrow_table, - num_rows, - ) = ThriftBackend._convert_arrow_based_set_to_arrow_table( + (arrow_table, num_rows,) = ThriftBackend._convert_arrow_based_set_to_arrow_table( t_row_set.arrowBatches, lz4_compressed, schema_bytes ) else: @@ -534,9 +505,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti def _convert_decimals_in_arrow_table(table, description): for (i, col) in enumerate(table.itercolumns()): if description[i][1] == "decimal": - decimal_col = col.to_pandas().apply( - lambda v: v if v is None else Decimal(v) - ) + decimal_col = col.to_pandas().apply(lambda v: v if v is None else Decimal(v)) precision, scale = description[i][4], description[i][5] assert scale is not None assert precision is not None @@ -549,9 +518,7 @@ def _convert_decimals_in_arrow_table(table, description): return table @staticmethod - def _convert_arrow_based_set_to_arrow_table( - arrow_batches, lz4_compressed, schema_bytes - ): + def _convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema_bytes): ba = bytearray() ba += schema_bytes n_rows = 0 @@ -597,9 +564,7 @@ def _convert_column_to_arrow_array(t_col): for field in field_name_to_arrow_type.keys(): wrapper = getattr(t_col, field) if wrapper: - return ThriftBackend._create_arrow_array( - wrapper, field_name_to_arrow_type[field] - ) + return ThriftBackend._create_arrow_array(wrapper, field_name_to_arrow_type[field]) raise OperationalError("Empty TColumn instance {}".format(t_col)) @@ -654,9 +619,7 @@ def map_type(t_type_entry): else: # Current thriftserver implementation should always return a primitiveEntry, # even for complex types - raise OperationalError( - "Thrift protocol error: t_type_entry not a primitiveEntry" - ) + raise OperationalError("Thrift protocol error: t_type_entry not a primitiveEntry") def convert_col(t_column_desc): return pyarrow.field( @@ -674,9 +637,7 @@ def _col_to_description(col): # Drop _TYPE suffix cleaned_type = (name[:-5] if name.endswith("_TYPE") else name).lower() else: - raise OperationalError( - "Thrift protocol error: t_type_entry not a primitiveEntry" - ) + raise OperationalError("Thrift protocol error: t_type_entry not a primitiveEntry") if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE: qualifiers = type_entry.primitiveEntry.typeQualifiers.qualifiers @@ -697,9 +658,7 @@ def _col_to_description(col): @staticmethod def _hive_schema_to_description(t_table_schema): - return [ - ThriftBackend._col_to_description(col) for col in t_table_schema.columns - ] + return [ThriftBackend._col_to_description(col) for col in t_table_schema.columns] def _results_message_to_execute_response(self, resp, operation_state): if resp.directResults and resp.directResults.resultSetMetadata: @@ -726,9 +685,7 @@ def _results_message_to_execute_response(self, resp, operation_state): or (not direct_results.resultSet) or direct_results.resultSet.hasMoreRows ) - description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema - ) + description = self._hive_schema_to_description(t_result_set_metadata_resp.schema) schema_bytes = ( t_result_set_metadata_resp.arrowSchema or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) @@ -768,8 +725,7 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp): op_handle, initial_operation_status_resp ) operation_state = ( - initial_operation_status_resp - and initial_operation_status_resp.operationState + initial_operation_status_resp and initial_operation_status_resp.operationState ) while not operation_state or operation_state in [ ttypes.TOperationState.RUNNING_STATE, @@ -784,21 +740,13 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp): def _check_direct_results_for_error(t_spark_direct_results): if t_spark_direct_results: if t_spark_direct_results.operationStatus: - ThriftBackend._check_response_for_error( - t_spark_direct_results.operationStatus - ) + ThriftBackend._check_response_for_error(t_spark_direct_results.operationStatus) if t_spark_direct_results.resultSetMetadata: - ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSetMetadata - ) + ThriftBackend._check_response_for_error(t_spark_direct_results.resultSetMetadata) if t_spark_direct_results.resultSet: - ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSet - ) + ThriftBackend._check_response_for_error(t_spark_direct_results.resultSet) if t_spark_direct_results.closeOperation: - ThriftBackend._check_response_for_error( - t_spark_direct_results.closeOperation - ) + ThriftBackend._check_response_for_error(t_spark_direct_results.closeOperation) def execute_command( self, operation, session_handle, max_rows, max_bytes, lz4_compression, cursor @@ -817,9 +765,7 @@ def execute_command( sessionHandle=session_handle, statement=operation, runAsync=True, - getDirectResults=ttypes.TSparkGetDirectResults( - maxRows=max_rows, maxBytes=max_bytes - ), + getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes), canReadArrowResult=True, canDecompressLZ4Result=lz4_compression, canDownloadResult=False, @@ -837,9 +783,7 @@ def get_catalogs(self, session_handle, max_rows, max_bytes, cursor): req = ttypes.TGetCatalogsReq( sessionHandle=session_handle, - getDirectResults=ttypes.TSparkGetDirectResults( - maxRows=max_rows, maxBytes=max_bytes - ), + getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes), ) resp = self.make_request(self._client.GetCatalogs, req) return self._handle_execute_response(resp, cursor) @@ -857,9 +801,7 @@ def get_schemas( req = ttypes.TGetSchemasReq( sessionHandle=session_handle, - getDirectResults=ttypes.TSparkGetDirectResults( - maxRows=max_rows, maxBytes=max_bytes - ), + getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes), catalogName=catalog_name, schemaName=schema_name, ) @@ -881,9 +823,7 @@ def get_tables( req = ttypes.TGetTablesReq( sessionHandle=session_handle, - getDirectResults=ttypes.TSparkGetDirectResults( - maxRows=max_rows, maxBytes=max_bytes - ), + getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes), catalogName=catalog_name, schemaName=schema_name, tableName=table_name, @@ -907,9 +847,7 @@ def get_columns( req = ttypes.TGetColumnsReq( sessionHandle=session_handle, - getDirectResults=ttypes.TSparkGetDirectResults( - maxRows=max_rows, maxBytes=max_bytes - ), + getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes), catalogName=catalog_name, schemaName=schema_name, tableName=table_name, From fddc9f936d4e03b8a11cdf0a51d4da8c1a942fd3 Mon Sep 17 00:00:00 2001 From: Javier Asensio Date: Fri, 2 Jun 2023 11:18:15 +0100 Subject: [PATCH 2/3] Add timeout hack to mitigate timeouts --- .github/workflows/code-quality-checks.yml | 166 ------------------ src/databricks/sql/auth/thrift_http_client.py | 1 + src/databricks/sql/thrift_backend.py | 109 ++++-------- 3 files changed, 34 insertions(+), 242 deletions(-) delete mode 100644 .github/workflows/code-quality-checks.yml diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml deleted file mode 100644 index 31c2b1fd..00000000 --- a/.github/workflows/code-quality-checks.yml +++ /dev/null @@ -1,166 +0,0 @@ -name: Code Quality Checks -on: - push: - branches: - - main - pull_request: - branches: - - main -jobs: - run-unit-tests: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: [3.7, 3.8, 3.9, "3.10", "3.11"] - steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - - name: Check out repository - uses: actions/checkout@v2 - - name: Set up python ${{ matrix.python-version }} - id: setup-python - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@v1 - with: - version: ${{ matrix.python-version == 3.7 && '1.5.1' || 'latest' }} - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@v2 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - #---------------------------------------------- - # install your root project, if required - #---------------------------------------------- - - name: Install library - run: poetry install --no-interaction - #---------------------------------------------- - # run test suite - #---------------------------------------------- - - name: Run tests - run: poetry run python -m pytest tests/unit - check-linting: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: [3.7, 3.8, 3.9, "3.10"] - steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - - name: Check out repository - uses: actions/checkout@v2 - - name: Set up python ${{ matrix.python-version }} - id: setup-python - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@v1 - with: - version: ${{ matrix.python-version == 3.7 && '1.5.1' || 'latest' }} - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@v2 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - #---------------------------------------------- - # install your root project, if required - #---------------------------------------------- - - name: Install library - run: poetry install --no-interaction - #---------------------------------------------- - # black the code - #---------------------------------------------- - - name: Black - run: poetry run black --check src - - check-types: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: [3.7, 3.8, 3.9, "3.10"] - steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - - name: Check out repository - uses: actions/checkout@v2 - - name: Set up python ${{ matrix.python-version }} - id: setup-python - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@v1 - with: - version: ${{ matrix.python-version == 3.7 && '1.5.1' || 'latest' }} - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@v2 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - #---------------------------------------------- - # install your root project, if required - #---------------------------------------------- - - name: Install library - run: poetry install --no-interaction - #---------------------------------------------- - # black the code - #---------------------------------------------- - - name: Mypy - run: poetry run mypy --install-types --non-interactive src diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index 11589258..0a3651b9 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -5,6 +5,7 @@ import six import thrift +import thrift.transport.THttpClient logger = logging.getLogger(__name__) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 4d07d671..04a719c0 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -1,16 +1,18 @@ -from decimal import Decimal import errno import logging import math +import os import time import uuid import threading from ssl import CERT_NONE, CERT_REQUIRED, create_default_context from typing import List, Union +import databricks.sql.auth.thrift_http_client +import lz4.frame import pyarrow -import thrift.transport.THttpClient import thrift.protocol.TBinaryProtocol +import thrift.transport.THttpClient import thrift.transport.TSocket import thrift.transport.TTransport @@ -54,6 +56,10 @@ DATABRICKS_REASON_HEADER = "x-databricks-reason-phrase" TIMESTAMP_AS_STRING_CONFIG = "spark.thriftserver.arrowBasedRowSet.timestampAsString" + +# HACK! +THRIFT_SOCKET_TIMEOUT = os.getenv("THRIFT_SOCKET_TIMEOUT", None) + DEFAULT_SOCKET_TIMEOUT = float(900) # see Connection.__init__ for parameter descriptions. @@ -145,13 +151,9 @@ def __init__( self.staging_allowed_local_path = staging_allowed_local_path self._initialize_retry_args(kwargs) - self._use_arrow_native_complex_types = kwargs.get( - "_use_arrow_native_complex_types", True - ) + self._use_arrow_native_complex_types = kwargs.get("_use_arrow_native_complex_types", True) self._use_arrow_native_decimals = kwargs.get("_use_arrow_native_decimals", True) - self._use_arrow_native_timestamps = kwargs.get( - "_use_arrow_native_timestamps", True - ) + self._use_arrow_native_timestamps = kwargs.get("_use_arrow_native_timestamps", True) # Cloud fetch self.max_download_threads = kwargs.get("max_download_threads", 10) @@ -204,7 +206,7 @@ def __init__( **additional_transport_args, # type: ignore ) - timeout = kwargs.get("_socket_timeout", DEFAULT_SOCKET_TIMEOUT) + timeout = THRIFT_SOCKET_TIMEOUT or kwargs.get("_socket_timeout", DEFAULT_SOCKET_TIMEOUT) # setTimeout defaults to 15 minutes and is expected in ms self._transport.setTimeout(timeout and (float(timeout) * 1000.0)) @@ -228,15 +230,11 @@ def _initialize_retry_args(self, kwargs): given_or_default = type_(kwargs.get(key, default)) bound = _bound(min, max, given_or_default) setattr(self, key, bound) - logger.debug( - "retry parameter: {} given_or_default {}".format(key, given_or_default) - ) + logger.debug("retry parameter: {} given_or_default {}".format(key, given_or_default)) if bound != given_or_default: logger.warning( "Override out of policy retry parameter: " - + "{} given {}, restricted to {}".format( - key, given_or_default, bound - ) + + "{} given {}, restricted to {}".format(key, given_or_default, bound) ) # Fail on retry delay min > max; consider later adding fail on min > duration? @@ -264,9 +262,7 @@ def _extract_error_message_from_headers(headers): if THRIFT_ERROR_MESSAGE_HEADER in headers: err_msg = headers[THRIFT_ERROR_MESSAGE_HEADER] if DATABRICKS_ERROR_OR_REDIRECT_HEADER in headers: - if ( - err_msg - ): # We don't expect both to be set, but log both here just in case + if err_msg: # We don't expect both to be set, but log both here just in case err_msg = "Thriftserver error: {}, Databricks error: {}".format( err_msg, headers[DATABRICKS_ERROR_OR_REDIRECT_HEADER] ) @@ -497,10 +493,7 @@ def _check_initial_namespace(self, catalog, schema, response): if not (catalog or schema): return - if ( - response.serverProtocolVersion - < ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4 - ): + if response.serverProtocolVersion < ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4: raise InvalidServerResponseError( "Setting initial namespace not supported by the DBR version, " "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0." @@ -515,10 +508,7 @@ def _check_initial_namespace(self, catalog, schema, response): def _check_session_configuration(self, session_configuration): # This client expects timetampsAsString to be false, so we do not allow users to modify that - if ( - session_configuration.get(TIMESTAMP_AS_STRING_CONFIG, "false").lower() - != "false" - ): + if session_configuration.get(TIMESTAMP_AS_STRING_CONFIG, "false").lower() != "false": raise Error( "Invalid session configuration: {} cannot be changed " "while using the Databricks SQL connector, it must be false not {}".format( @@ -530,18 +520,14 @@ def _check_session_configuration(self, session_configuration): def open_session(self, session_configuration, catalog, schema): try: self._transport.open() - session_configuration = { - k: str(v) for (k, v) in (session_configuration or {}).items() - } + session_configuration = {k: str(v) for (k, v) in (session_configuration or {}).items()} self._check_session_configuration(session_configuration) # We want to receive proper Timestamp arrow types. # We set it also in confOverlay in TExecuteStatementReq on a per query basic, # but it doesn't hurt to also set for the whole session. session_configuration[TIMESTAMP_AS_STRING_CONFIG] = "false" if catalog or schema: - initial_namespace = ttypes.TNamespace( - catalogName=catalog, schemaName=schema - ) + initial_namespace = ttypes.TNamespace(catalogName=catalog, schemaName=schema) else: initial_namespace = None @@ -567,9 +553,7 @@ def close_session(self, session_handle) -> None: finally: self._transport.close() - def _check_command_not_in_error_or_closed_state( - self, op_handle, get_operations_resp - ): + def _check_command_not_in_error_or_closed_state(self, op_handle, get_operations_resp): if get_operations_resp.operationState == ttypes.TOperationState.ERROR_STATE: if get_operations_resp.displayMessage: raise ServerOperationError( @@ -656,9 +640,7 @@ def map_type(t_type_entry): else: # Current thriftserver implementation should always return a primitiveEntry, # even for complex types - raise OperationalError( - "Thrift protocol error: t_type_entry not a primitiveEntry" - ) + raise OperationalError("Thrift protocol error: t_type_entry not a primitiveEntry") def convert_col(t_column_desc): return pyarrow.field( @@ -676,9 +658,7 @@ def _col_to_description(col): # Drop _TYPE suffix cleaned_type = (name[:-5] if name.endswith("_TYPE") else name).lower() else: - raise OperationalError( - "Thrift protocol error: t_type_entry not a primitiveEntry" - ) + raise OperationalError("Thrift protocol error: t_type_entry not a primitiveEntry") if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE: qualifiers = type_entry.primitiveEntry.typeQualifiers.qualifiers @@ -699,9 +679,7 @@ def _col_to_description(col): @staticmethod def _hive_schema_to_description(t_table_schema): - return [ - ThriftBackend._col_to_description(col) for col in t_table_schema.columns - ] + return [ThriftBackend._col_to_description(col) for col in t_table_schema.columns] def _results_message_to_execute_response(self, resp, operation_state): if resp.directResults and resp.directResults.resultSetMetadata: @@ -729,9 +707,7 @@ def _results_message_to_execute_response(self, resp, operation_state): or (not direct_results.resultSet) or direct_results.resultSet.hasMoreRows ) - description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema - ) + description = self._hive_schema_to_description(t_result_set_metadata_resp.schema) schema_bytes = ( t_result_set_metadata_resp.arrowSchema or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) @@ -772,8 +748,7 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp): op_handle, initial_operation_status_resp ) operation_state = ( - initial_operation_status_resp - and initial_operation_status_resp.operationState + initial_operation_status_resp and initial_operation_status_resp.operationState ) while not operation_state or operation_state in [ ttypes.TOperationState.RUNNING_STATE, @@ -788,21 +763,13 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp): def _check_direct_results_for_error(t_spark_direct_results): if t_spark_direct_results: if t_spark_direct_results.operationStatus: - ThriftBackend._check_response_for_error( - t_spark_direct_results.operationStatus - ) + ThriftBackend._check_response_for_error(t_spark_direct_results.operationStatus) if t_spark_direct_results.resultSetMetadata: - ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSetMetadata - ) + ThriftBackend._check_response_for_error(t_spark_direct_results.resultSetMetadata) if t_spark_direct_results.resultSet: - ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSet - ) + ThriftBackend._check_response_for_error(t_spark_direct_results.resultSet) if t_spark_direct_results.closeOperation: - ThriftBackend._check_response_for_error( - t_spark_direct_results.closeOperation - ) + ThriftBackend._check_response_for_error(t_spark_direct_results.closeOperation) def execute_command( self, @@ -828,9 +795,7 @@ def execute_command( sessionHandle=session_handle, statement=operation, runAsync=True, - getDirectResults=ttypes.TSparkGetDirectResults( - maxRows=max_rows, maxBytes=max_bytes - ), + getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes), canReadArrowResult=True, canDecompressLZ4Result=lz4_compression, canDownloadResult=use_cloud_fetch, @@ -848,9 +813,7 @@ def get_catalogs(self, session_handle, max_rows, max_bytes, cursor): req = ttypes.TGetCatalogsReq( sessionHandle=session_handle, - getDirectResults=ttypes.TSparkGetDirectResults( - maxRows=max_rows, maxBytes=max_bytes - ), + getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes), ) resp = self.make_request(self._client.GetCatalogs, req) return self._handle_execute_response(resp, cursor) @@ -868,9 +831,7 @@ def get_schemas( req = ttypes.TGetSchemasReq( sessionHandle=session_handle, - getDirectResults=ttypes.TSparkGetDirectResults( - maxRows=max_rows, maxBytes=max_bytes - ), + getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes), catalogName=catalog_name, schemaName=schema_name, ) @@ -892,9 +853,7 @@ def get_tables( req = ttypes.TGetTablesReq( sessionHandle=session_handle, - getDirectResults=ttypes.TSparkGetDirectResults( - maxRows=max_rows, maxBytes=max_bytes - ), + getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes), catalogName=catalog_name, schemaName=schema_name, tableName=table_name, @@ -918,9 +877,7 @@ def get_columns( req = ttypes.TGetColumnsReq( sessionHandle=session_handle, - getDirectResults=ttypes.TSparkGetDirectResults( - maxRows=max_rows, maxBytes=max_bytes - ), + getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes), catalogName=catalog_name, schemaName=schema_name, tableName=table_name, From 254083c1c4851ed1bfcb5c9ab6c136f21434687f Mon Sep 17 00:00:00 2001 From: matt-fleming Date: Mon, 29 Jan 2024 12:18:40 +0000 Subject: [PATCH 3/3] Import threading --- src/databricks/sql/thrift_backend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index c7000587..d54b618f 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -2,15 +2,14 @@ import logging import math import os +import threading import uuid import time -from decimal import Decimal from ssl import CERT_NONE, CERT_REQUIRED, create_default_context from typing import List, Union import databricks.sql.auth.thrift_http_client -import lz4.frame import pyarrow import thrift.protocol.TBinaryProtocol import thrift.transport.THttpClient