8
8
from ssl import CERT_NONE , CERT_REQUIRED , create_default_context
9
9
from typing import List , Union
10
10
11
- import databricks .sql .auth .thrift_http_client
12
- import lz4 .frame
13
11
import pyarrow
14
12
import thrift .protocol .TBinaryProtocol
15
13
import thrift .transport .THttpClient
23
21
from databricks .sql .auth .authenticators import AuthProvider
24
22
from databricks .sql .thrift_api .TCLIService import TCLIService , ttypes
25
23
from databricks .sql import *
26
- from databricks .sql .exc import MaxRetryDurationError
27
24
from databricks .sql .thrift_api .TCLIService .TCLIService import (
28
25
Client as TCLIServiceClient ,
29
26
)
@@ -157,9 +154,13 @@ def __init__(
157
154
158
155
self .staging_allowed_local_path = staging_allowed_local_path
159
156
self ._initialize_retry_args (kwargs )
160
- self ._use_arrow_native_complex_types = kwargs .get ("_use_arrow_native_complex_types" , True )
157
+ self ._use_arrow_native_complex_types = kwargs .get (
158
+ "_use_arrow_native_complex_types" , True
159
+ )
161
160
self ._use_arrow_native_decimals = kwargs .get ("_use_arrow_native_decimals" , True )
162
- self ._use_arrow_native_timestamps = kwargs .get ("_use_arrow_native_timestamps" , True )
161
+ self ._use_arrow_native_timestamps = kwargs .get (
162
+ "_use_arrow_native_timestamps" , True
163
+ )
163
164
164
165
# Cloud fetch
165
166
self .max_download_threads = kwargs .get ("max_download_threads" , 10 )
@@ -229,7 +230,12 @@ def __init__(
229
230
** additional_transport_args , # type: ignore
230
231
)
231
232
232
- timeout = THRIFT_SOCKET_TIMEOUT or kwargs .get ("_socket_timeout" , DEFAULT_SOCKET_TIMEOUT )
233
+ timeout = THRIFT_SOCKET_TIMEOUT or kwargs .get (
234
+ "_socket_timeout" , DEFAULT_SOCKET_TIMEOUT
235
+ )
236
+ # HACK!
237
+ logger .info (f"Setting timeout HACK! to { timeout } " )
238
+
233
239
# setTimeout defaults to 15 minutes and is expected in ms
234
240
self ._transport .setTimeout (timeout and (float (timeout ) * 1000.0 ))
235
241
@@ -253,11 +259,15 @@ def _initialize_retry_args(self, kwargs):
253
259
given_or_default = type_ (kwargs .get (key , default ))
254
260
bound = _bound (min , max , given_or_default )
255
261
setattr (self , key , bound )
256
- logger .debug ("retry parameter: {} given_or_default {}" .format (key , given_or_default ))
262
+ logger .debug (
263
+ "retry parameter: {} given_or_default {}" .format (key , given_or_default )
264
+ )
257
265
if bound != given_or_default :
258
266
logger .warning (
259
267
"Override out of policy retry parameter: "
260
- + "{} given {}, restricted to {}" .format (key , given_or_default , bound )
268
+ + "{} given {}, restricted to {}" .format (
269
+ key , given_or_default , bound
270
+ )
261
271
)
262
272
263
273
# Fail on retry delay min > max; consider later adding fail on min > duration?
@@ -285,7 +295,9 @@ def _extract_error_message_from_headers(headers):
285
295
if THRIFT_ERROR_MESSAGE_HEADER in headers :
286
296
err_msg = headers [THRIFT_ERROR_MESSAGE_HEADER ]
287
297
if DATABRICKS_ERROR_OR_REDIRECT_HEADER in headers :
288
- if err_msg : # We don't expect both to be set, but log both here just in case
298
+ if (
299
+ err_msg
300
+ ): # We don't expect both to be set, but log both here just in case
289
301
err_msg = "Thriftserver error: {}, Databricks error: {}" .format (
290
302
err_msg , headers [DATABRICKS_ERROR_OR_REDIRECT_HEADER ]
291
303
)
@@ -516,7 +528,10 @@ def _check_initial_namespace(self, catalog, schema, response):
516
528
if not (catalog or schema ):
517
529
return
518
530
519
- if response .serverProtocolVersion < ttypes .TProtocolVersion .SPARK_CLI_SERVICE_PROTOCOL_V4 :
531
+ if (
532
+ response .serverProtocolVersion
533
+ < ttypes .TProtocolVersion .SPARK_CLI_SERVICE_PROTOCOL_V4
534
+ ):
520
535
raise InvalidServerResponseError (
521
536
"Setting initial namespace not supported by the DBR version, "
522
537
"Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0."
@@ -531,7 +546,10 @@ def _check_initial_namespace(self, catalog, schema, response):
531
546
532
547
def _check_session_configuration (self , session_configuration ):
533
548
# This client expects timetampsAsString to be false, so we do not allow users to modify that
534
- if session_configuration .get (TIMESTAMP_AS_STRING_CONFIG , "false" ).lower () != "false" :
549
+ if (
550
+ session_configuration .get (TIMESTAMP_AS_STRING_CONFIG , "false" ).lower ()
551
+ != "false"
552
+ ):
535
553
raise Error (
536
554
"Invalid session configuration: {} cannot be changed "
537
555
"while using the Databricks SQL connector, it must be false not {}" .format (
@@ -543,14 +561,18 @@ def _check_session_configuration(self, session_configuration):
543
561
def open_session (self , session_configuration , catalog , schema ):
544
562
try :
545
563
self ._transport .open ()
546
- session_configuration = {k : str (v ) for (k , v ) in (session_configuration or {}).items ()}
564
+ session_configuration = {
565
+ k : str (v ) for (k , v ) in (session_configuration or {}).items ()
566
+ }
547
567
self ._check_session_configuration (session_configuration )
548
568
# We want to receive proper Timestamp arrow types.
549
569
# We set it also in confOverlay in TExecuteStatementReq on a per query basic,
550
570
# but it doesn't hurt to also set for the whole session.
551
571
session_configuration [TIMESTAMP_AS_STRING_CONFIG ] = "false"
552
572
if catalog or schema :
553
- initial_namespace = ttypes .TNamespace (catalogName = catalog , schemaName = schema )
573
+ initial_namespace = ttypes .TNamespace (
574
+ catalogName = catalog , schemaName = schema
575
+ )
554
576
else :
555
577
initial_namespace = None
556
578
@@ -576,7 +598,9 @@ def close_session(self, session_handle) -> None:
576
598
finally :
577
599
self ._transport .close ()
578
600
579
- def _check_command_not_in_error_or_closed_state (self , op_handle , get_operations_resp ):
601
+ def _check_command_not_in_error_or_closed_state (
602
+ self , op_handle , get_operations_resp
603
+ ):
580
604
if get_operations_resp .operationState == ttypes .TOperationState .ERROR_STATE :
581
605
if get_operations_resp .displayMessage :
582
606
raise ServerOperationError (
@@ -621,7 +645,10 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
621
645
num_rows ,
622
646
) = convert_column_based_set_to_arrow_table (t_row_set .columns , description )
623
647
elif t_row_set .arrowBatches is not None :
624
- (arrow_table , num_rows ,) = convert_arrow_based_set_to_arrow_table (
648
+ (
649
+ arrow_table ,
650
+ num_rows ,
651
+ ) = convert_arrow_based_set_to_arrow_table (
625
652
t_row_set .arrowBatches , lz4_compressed , schema_bytes
626
653
)
627
654
else :
@@ -663,7 +690,9 @@ def map_type(t_type_entry):
663
690
else :
664
691
# Current thriftserver implementation should always return a primitiveEntry,
665
692
# even for complex types
666
- raise OperationalError ("Thrift protocol error: t_type_entry not a primitiveEntry" )
693
+ raise OperationalError (
694
+ "Thrift protocol error: t_type_entry not a primitiveEntry"
695
+ )
667
696
668
697
def convert_col (t_column_desc ):
669
698
return pyarrow .field (
@@ -681,7 +710,9 @@ def _col_to_description(col):
681
710
# Drop _TYPE suffix
682
711
cleaned_type = (name [:- 5 ] if name .endswith ("_TYPE" ) else name ).lower ()
683
712
else :
684
- raise OperationalError ("Thrift protocol error: t_type_entry not a primitiveEntry" )
713
+ raise OperationalError (
714
+ "Thrift protocol error: t_type_entry not a primitiveEntry"
715
+ )
685
716
686
717
if type_entry .primitiveEntry .type == ttypes .TTypeId .DECIMAL_TYPE :
687
718
qualifiers = type_entry .primitiveEntry .typeQualifiers .qualifiers
@@ -702,7 +733,9 @@ def _col_to_description(col):
702
733
703
734
@staticmethod
704
735
def _hive_schema_to_description (t_table_schema ):
705
- return [ThriftBackend ._col_to_description (col ) for col in t_table_schema .columns ]
736
+ return [
737
+ ThriftBackend ._col_to_description (col ) for col in t_table_schema .columns
738
+ ]
706
739
707
740
def _results_message_to_execute_response (self , resp , operation_state ):
708
741
if resp .directResults and resp .directResults .resultSetMetadata :
@@ -730,7 +763,9 @@ def _results_message_to_execute_response(self, resp, operation_state):
730
763
or (not direct_results .resultSet )
731
764
or direct_results .resultSet .hasMoreRows
732
765
)
733
- description = self ._hive_schema_to_description (t_result_set_metadata_resp .schema )
766
+ description = self ._hive_schema_to_description (
767
+ t_result_set_metadata_resp .schema
768
+ )
734
769
schema_bytes = (
735
770
t_result_set_metadata_resp .arrowSchema
736
771
or self ._hive_schema_to_arrow_schema (t_result_set_metadata_resp .schema )
@@ -771,7 +806,8 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
771
806
op_handle , initial_operation_status_resp
772
807
)
773
808
operation_state = (
774
- initial_operation_status_resp and initial_operation_status_resp .operationState
809
+ initial_operation_status_resp
810
+ and initial_operation_status_resp .operationState
775
811
)
776
812
while not operation_state or operation_state in [
777
813
ttypes .TOperationState .RUNNING_STATE ,
@@ -786,13 +822,21 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
786
822
def _check_direct_results_for_error (t_spark_direct_results ):
787
823
if t_spark_direct_results :
788
824
if t_spark_direct_results .operationStatus :
789
- ThriftBackend ._check_response_for_error (t_spark_direct_results .operationStatus )
825
+ ThriftBackend ._check_response_for_error (
826
+ t_spark_direct_results .operationStatus
827
+ )
790
828
if t_spark_direct_results .resultSetMetadata :
791
- ThriftBackend ._check_response_for_error (t_spark_direct_results .resultSetMetadata )
829
+ ThriftBackend ._check_response_for_error (
830
+ t_spark_direct_results .resultSetMetadata
831
+ )
792
832
if t_spark_direct_results .resultSet :
793
- ThriftBackend ._check_response_for_error (t_spark_direct_results .resultSet )
833
+ ThriftBackend ._check_response_for_error (
834
+ t_spark_direct_results .resultSet
835
+ )
794
836
if t_spark_direct_results .closeOperation :
795
- ThriftBackend ._check_response_for_error (t_spark_direct_results .closeOperation )
837
+ ThriftBackend ._check_response_for_error (
838
+ t_spark_direct_results .closeOperation
839
+ )
796
840
797
841
def execute_command (
798
842
self ,
@@ -819,7 +863,9 @@ def execute_command(
819
863
sessionHandle = session_handle ,
820
864
statement = operation ,
821
865
runAsync = True ,
822
- getDirectResults = ttypes .TSparkGetDirectResults (maxRows = max_rows , maxBytes = max_bytes ),
866
+ getDirectResults = ttypes .TSparkGetDirectResults (
867
+ maxRows = max_rows , maxBytes = max_bytes
868
+ ),
823
869
canReadArrowResult = True ,
824
870
canDecompressLZ4Result = lz4_compression ,
825
871
canDownloadResult = use_cloud_fetch ,
@@ -838,7 +884,9 @@ def get_catalogs(self, session_handle, max_rows, max_bytes, cursor):
838
884
839
885
req = ttypes .TGetCatalogsReq (
840
886
sessionHandle = session_handle ,
841
- getDirectResults = ttypes .TSparkGetDirectResults (maxRows = max_rows , maxBytes = max_bytes ),
887
+ getDirectResults = ttypes .TSparkGetDirectResults (
888
+ maxRows = max_rows , maxBytes = max_bytes
889
+ ),
842
890
)
843
891
resp = self .make_request (self ._client .GetCatalogs , req )
844
892
return self ._handle_execute_response (resp , cursor )
@@ -856,7 +904,9 @@ def get_schemas(
856
904
857
905
req = ttypes .TGetSchemasReq (
858
906
sessionHandle = session_handle ,
859
- getDirectResults = ttypes .TSparkGetDirectResults (maxRows = max_rows , maxBytes = max_bytes ),
907
+ getDirectResults = ttypes .TSparkGetDirectResults (
908
+ maxRows = max_rows , maxBytes = max_bytes
909
+ ),
860
910
catalogName = catalog_name ,
861
911
schemaName = schema_name ,
862
912
)
@@ -878,7 +928,9 @@ def get_tables(
878
928
879
929
req = ttypes .TGetTablesReq (
880
930
sessionHandle = session_handle ,
881
- getDirectResults = ttypes .TSparkGetDirectResults (maxRows = max_rows , maxBytes = max_bytes ),
931
+ getDirectResults = ttypes .TSparkGetDirectResults (
932
+ maxRows = max_rows , maxBytes = max_bytes
933
+ ),
882
934
catalogName = catalog_name ,
883
935
schemaName = schema_name ,
884
936
tableName = table_name ,
@@ -902,7 +954,9 @@ def get_columns(
902
954
903
955
req = ttypes .TGetColumnsReq (
904
956
sessionHandle = session_handle ,
905
- getDirectResults = ttypes .TSparkGetDirectResults (maxRows = max_rows , maxBytes = max_bytes ),
957
+ getDirectResults = ttypes .TSparkGetDirectResults (
958
+ maxRows = max_rows , maxBytes = max_bytes
959
+ ),
906
960
catalogName = catalog_name ,
907
961
schemaName = schema_name ,
908
962
tableName = table_name ,
0 commit comments