@@ -26,21 +26,25 @@ def __init__(self, server_url, request_timeout, compression=True):
2626 self .local_storage = threading .local ()
2727 self .requests_session = requests .Session ()
2828 self .compression = compression if zstandard is not None else False
29- try :
30- self .input_type = (
31- "pickle"
32- if self .requests_session .get (
33- f"{ self .server_url } /meta" , params = {"is_pickle_allowed" : "" }
34- ).json ()["is_pickle_allowed" ]
35- else "msgpack"
36- if msgpack is not None
37- else "json"
38- )
39- except Exception as e :
40- self .input_type = None
29+ self ._set_input_type ()
4130
4231 self .request_timeout = request_timeout
4332
33+ def _set_input_type (self ):
34+ if self .input_type is None :
35+ try :
36+ self .input_type = (
37+ "pickle"
38+ if self .requests_session .get (
39+ f"{ self .server_url } /meta" , params = {"is_pickle_allowed" : "" }
40+ ).json ()["is_pickle_allowed" ]
41+ else "msgpack"
42+ if msgpack is not None
43+ else "json"
44+ )
45+ except Exception as e :
46+ self .input_type = None
47+
4448 @property
4549 def _compressor (self ):
4650 if self .compression is False :
@@ -79,7 +83,9 @@ def _decompressor(self):
7983
8084 def infer (self , data , unique_id = None , is_async = False ):
8185 if self .input_type is None :
82- raise ValueError ("Could not connect to server" )
86+ self ._set_input_type ()
87+ if self .input_type is None :
88+ raise ValueError ("Could not connect to server" )
8389
8490 assert isinstance (data , (list , tuple )), "Data must be of type list or tuple"
8591
0 commit comments