Skip to content

Commit 1a01663

Browse files
capitancambiomatt-fleming
authored andcommitted
Add timeout hack to mitigate timeouts
1 parent 2950e70 commit 1a01663

File tree

2 files changed

+84
-31
lines changed

2 files changed

+84
-31
lines changed

src/databricks/sql/auth/thrift_http_client.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import six
77
import thrift
8-
import thrift.transport.THttpClient
98

109
logger = logging.getLogger(__name__)
1110

@@ -212,4 +211,4 @@ def set_retry_command_type(self, value: CommandType):
212211
else:
213212
logger.warning(
214213
"DatabricksRetryPolicy is currently bypassed. The CommandType cannot be set."
215-
)
214+
)

src/databricks/sql/thrift_backend.py

+83-29
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
99
from typing import List, Union
1010

11-
import databricks.sql.auth.thrift_http_client
12-
import lz4.frame
1311
import pyarrow
1412
import thrift.protocol.TBinaryProtocol
1513
import thrift.transport.THttpClient
@@ -23,7 +21,6 @@
2321
from databricks.sql.auth.authenticators import AuthProvider
2422
from databricks.sql.thrift_api.TCLIService import TCLIService, ttypes
2523
from databricks.sql import *
26-
from databricks.sql.exc import MaxRetryDurationError
2724
from databricks.sql.thrift_api.TCLIService.TCLIService import (
2825
Client as TCLIServiceClient,
2926
)
@@ -157,9 +154,13 @@ def __init__(
157154

158155
self.staging_allowed_local_path = staging_allowed_local_path
159156
self._initialize_retry_args(kwargs)
160-
self._use_arrow_native_complex_types = kwargs.get("_use_arrow_native_complex_types", True)
157+
self._use_arrow_native_complex_types = kwargs.get(
158+
"_use_arrow_native_complex_types", True
159+
)
161160
self._use_arrow_native_decimals = kwargs.get("_use_arrow_native_decimals", True)
162-
self._use_arrow_native_timestamps = kwargs.get("_use_arrow_native_timestamps", True)
161+
self._use_arrow_native_timestamps = kwargs.get(
162+
"_use_arrow_native_timestamps", True
163+
)
163164

164165
# Cloud fetch
165166
self.max_download_threads = kwargs.get("max_download_threads", 10)
@@ -229,7 +230,12 @@ def __init__(
229230
**additional_transport_args, # type: ignore
230231
)
231232

232-
timeout = THRIFT_SOCKET_TIMEOUT or kwargs.get("_socket_timeout", DEFAULT_SOCKET_TIMEOUT)
233+
timeout = THRIFT_SOCKET_TIMEOUT or kwargs.get(
234+
"_socket_timeout", DEFAULT_SOCKET_TIMEOUT
235+
)
236+
# HACK!
237+
logger.info(f"Setting timeout HACK! to {timeout}")
238+
233239
# setTimeout defaults to 15 minutes and is expected in ms
234240
self._transport.setTimeout(timeout and (float(timeout) * 1000.0))
235241

@@ -253,11 +259,15 @@ def _initialize_retry_args(self, kwargs):
253259
given_or_default = type_(kwargs.get(key, default))
254260
bound = _bound(min, max, given_or_default)
255261
setattr(self, key, bound)
256-
logger.debug("retry parameter: {} given_or_default {}".format(key, given_or_default))
262+
logger.debug(
263+
"retry parameter: {} given_or_default {}".format(key, given_or_default)
264+
)
257265
if bound != given_or_default:
258266
logger.warning(
259267
"Override out of policy retry parameter: "
260-
+ "{} given {}, restricted to {}".format(key, given_or_default, bound)
268+
+ "{} given {}, restricted to {}".format(
269+
key, given_or_default, bound
270+
)
261271
)
262272

263273
# Fail on retry delay min > max; consider later adding fail on min > duration?
@@ -285,7 +295,9 @@ def _extract_error_message_from_headers(headers):
285295
if THRIFT_ERROR_MESSAGE_HEADER in headers:
286296
err_msg = headers[THRIFT_ERROR_MESSAGE_HEADER]
287297
if DATABRICKS_ERROR_OR_REDIRECT_HEADER in headers:
288-
if err_msg: # We don't expect both to be set, but log both here just in case
298+
if (
299+
err_msg
300+
): # We don't expect both to be set, but log both here just in case
289301
err_msg = "Thriftserver error: {}, Databricks error: {}".format(
290302
err_msg, headers[DATABRICKS_ERROR_OR_REDIRECT_HEADER]
291303
)
@@ -516,7 +528,10 @@ def _check_initial_namespace(self, catalog, schema, response):
516528
if not (catalog or schema):
517529
return
518530

519-
if response.serverProtocolVersion < ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4:
531+
if (
532+
response.serverProtocolVersion
533+
< ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4
534+
):
520535
raise InvalidServerResponseError(
521536
"Setting initial namespace not supported by the DBR version, "
522537
"Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0."
@@ -531,7 +546,10 @@ def _check_initial_namespace(self, catalog, schema, response):
531546

532547
def _check_session_configuration(self, session_configuration):
533548
# This client expects timetampsAsString to be false, so we do not allow users to modify that
534-
if session_configuration.get(TIMESTAMP_AS_STRING_CONFIG, "false").lower() != "false":
549+
if (
550+
session_configuration.get(TIMESTAMP_AS_STRING_CONFIG, "false").lower()
551+
!= "false"
552+
):
535553
raise Error(
536554
"Invalid session configuration: {} cannot be changed "
537555
"while using the Databricks SQL connector, it must be false not {}".format(
@@ -543,14 +561,18 @@ def _check_session_configuration(self, session_configuration):
543561
def open_session(self, session_configuration, catalog, schema):
544562
try:
545563
self._transport.open()
546-
session_configuration = {k: str(v) for (k, v) in (session_configuration or {}).items()}
564+
session_configuration = {
565+
k: str(v) for (k, v) in (session_configuration or {}).items()
566+
}
547567
self._check_session_configuration(session_configuration)
548568
# We want to receive proper Timestamp arrow types.
549569
# We set it also in confOverlay in TExecuteStatementReq on a per query basic,
550570
# but it doesn't hurt to also set for the whole session.
551571
session_configuration[TIMESTAMP_AS_STRING_CONFIG] = "false"
552572
if catalog or schema:
553-
initial_namespace = ttypes.TNamespace(catalogName=catalog, schemaName=schema)
573+
initial_namespace = ttypes.TNamespace(
574+
catalogName=catalog, schemaName=schema
575+
)
554576
else:
555577
initial_namespace = None
556578

@@ -576,7 +598,9 @@ def close_session(self, session_handle) -> None:
576598
finally:
577599
self._transport.close()
578600

579-
def _check_command_not_in_error_or_closed_state(self, op_handle, get_operations_resp):
601+
def _check_command_not_in_error_or_closed_state(
602+
self, op_handle, get_operations_resp
603+
):
580604
if get_operations_resp.operationState == ttypes.TOperationState.ERROR_STATE:
581605
if get_operations_resp.displayMessage:
582606
raise ServerOperationError(
@@ -621,7 +645,10 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
621645
num_rows,
622646
) = convert_column_based_set_to_arrow_table(t_row_set.columns, description)
623647
elif t_row_set.arrowBatches is not None:
624-
(arrow_table, num_rows,) = convert_arrow_based_set_to_arrow_table(
648+
(
649+
arrow_table,
650+
num_rows,
651+
) = convert_arrow_based_set_to_arrow_table(
625652
t_row_set.arrowBatches, lz4_compressed, schema_bytes
626653
)
627654
else:
@@ -663,7 +690,9 @@ def map_type(t_type_entry):
663690
else:
664691
# Current thriftserver implementation should always return a primitiveEntry,
665692
# even for complex types
666-
raise OperationalError("Thrift protocol error: t_type_entry not a primitiveEntry")
693+
raise OperationalError(
694+
"Thrift protocol error: t_type_entry not a primitiveEntry"
695+
)
667696

668697
def convert_col(t_column_desc):
669698
return pyarrow.field(
@@ -681,7 +710,9 @@ def _col_to_description(col):
681710
# Drop _TYPE suffix
682711
cleaned_type = (name[:-5] if name.endswith("_TYPE") else name).lower()
683712
else:
684-
raise OperationalError("Thrift protocol error: t_type_entry not a primitiveEntry")
713+
raise OperationalError(
714+
"Thrift protocol error: t_type_entry not a primitiveEntry"
715+
)
685716

686717
if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE:
687718
qualifiers = type_entry.primitiveEntry.typeQualifiers.qualifiers
@@ -702,7 +733,9 @@ def _col_to_description(col):
702733

703734
@staticmethod
704735
def _hive_schema_to_description(t_table_schema):
705-
return [ThriftBackend._col_to_description(col) for col in t_table_schema.columns]
736+
return [
737+
ThriftBackend._col_to_description(col) for col in t_table_schema.columns
738+
]
706739

707740
def _results_message_to_execute_response(self, resp, operation_state):
708741
if resp.directResults and resp.directResults.resultSetMetadata:
@@ -730,7 +763,9 @@ def _results_message_to_execute_response(self, resp, operation_state):
730763
or (not direct_results.resultSet)
731764
or direct_results.resultSet.hasMoreRows
732765
)
733-
description = self._hive_schema_to_description(t_result_set_metadata_resp.schema)
766+
description = self._hive_schema_to_description(
767+
t_result_set_metadata_resp.schema
768+
)
734769
schema_bytes = (
735770
t_result_set_metadata_resp.arrowSchema
736771
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
@@ -771,7 +806,8 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
771806
op_handle, initial_operation_status_resp
772807
)
773808
operation_state = (
774-
initial_operation_status_resp and initial_operation_status_resp.operationState
809+
initial_operation_status_resp
810+
and initial_operation_status_resp.operationState
775811
)
776812
while not operation_state or operation_state in [
777813
ttypes.TOperationState.RUNNING_STATE,
@@ -786,13 +822,21 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
786822
def _check_direct_results_for_error(t_spark_direct_results):
787823
if t_spark_direct_results:
788824
if t_spark_direct_results.operationStatus:
789-
ThriftBackend._check_response_for_error(t_spark_direct_results.operationStatus)
825+
ThriftBackend._check_response_for_error(
826+
t_spark_direct_results.operationStatus
827+
)
790828
if t_spark_direct_results.resultSetMetadata:
791-
ThriftBackend._check_response_for_error(t_spark_direct_results.resultSetMetadata)
829+
ThriftBackend._check_response_for_error(
830+
t_spark_direct_results.resultSetMetadata
831+
)
792832
if t_spark_direct_results.resultSet:
793-
ThriftBackend._check_response_for_error(t_spark_direct_results.resultSet)
833+
ThriftBackend._check_response_for_error(
834+
t_spark_direct_results.resultSet
835+
)
794836
if t_spark_direct_results.closeOperation:
795-
ThriftBackend._check_response_for_error(t_spark_direct_results.closeOperation)
837+
ThriftBackend._check_response_for_error(
838+
t_spark_direct_results.closeOperation
839+
)
796840

797841
def execute_command(
798842
self,
@@ -819,7 +863,9 @@ def execute_command(
819863
sessionHandle=session_handle,
820864
statement=operation,
821865
runAsync=True,
822-
getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes),
866+
getDirectResults=ttypes.TSparkGetDirectResults(
867+
maxRows=max_rows, maxBytes=max_bytes
868+
),
823869
canReadArrowResult=True,
824870
canDecompressLZ4Result=lz4_compression,
825871
canDownloadResult=use_cloud_fetch,
@@ -838,7 +884,9 @@ def get_catalogs(self, session_handle, max_rows, max_bytes, cursor):
838884

839885
req = ttypes.TGetCatalogsReq(
840886
sessionHandle=session_handle,
841-
getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes),
887+
getDirectResults=ttypes.TSparkGetDirectResults(
888+
maxRows=max_rows, maxBytes=max_bytes
889+
),
842890
)
843891
resp = self.make_request(self._client.GetCatalogs, req)
844892
return self._handle_execute_response(resp, cursor)
@@ -856,7 +904,9 @@ def get_schemas(
856904

857905
req = ttypes.TGetSchemasReq(
858906
sessionHandle=session_handle,
859-
getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes),
907+
getDirectResults=ttypes.TSparkGetDirectResults(
908+
maxRows=max_rows, maxBytes=max_bytes
909+
),
860910
catalogName=catalog_name,
861911
schemaName=schema_name,
862912
)
@@ -878,7 +928,9 @@ def get_tables(
878928

879929
req = ttypes.TGetTablesReq(
880930
sessionHandle=session_handle,
881-
getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes),
931+
getDirectResults=ttypes.TSparkGetDirectResults(
932+
maxRows=max_rows, maxBytes=max_bytes
933+
),
882934
catalogName=catalog_name,
883935
schemaName=schema_name,
884936
tableName=table_name,
@@ -902,7 +954,9 @@ def get_columns(
902954

903955
req = ttypes.TGetColumnsReq(
904956
sessionHandle=session_handle,
905-
getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes),
957+
getDirectResults=ttypes.TSparkGetDirectResults(
958+
maxRows=max_rows, maxBytes=max_bytes
959+
),
906960
catalogName=catalog_name,
907961
schemaName=schema_name,
908962
tableName=table_name,

0 commit comments

Comments
 (0)