11import datetime
2- import time
32import json
3+ import time
44from typing import Any , Dict , Iterator
5+
6+ from models .object_info import ObjectInfo
57from obsrv .common import ObsrvException
8+ from obsrv .connector import ConnectorContext , MetricsCollector
69from obsrv .connector .batch import ISourceConnector
7- from obsrv .connector import ConnectorContext
8- from obsrv .connector import MetricsCollector
9- from obsrv .models import ErrorData , StatusCode , ExecutionState
10+ from obsrv .models import ErrorData , ExecutionState , StatusCode
1011from obsrv .utils import LoggerController
11-
12- from pyspark .sql import SparkSession , DataFrame
12+ from provider .s3 import S3
1313from pyspark .conf import SparkConf
14+ from pyspark .sql import DataFrame , SparkSession
1415from pyspark .sql .functions import lit
15- from pyspark .sql .types import *
16-
17- from provider .s3 import S3
18- from models .object_info import ObjectInfo
1916
2017logger = LoggerController (__name__ )
2118
2219MAX_RETRY_COUNT = 10
2320
21+
2422class ObjectStoreConnector (ISourceConnector ):
2523 def __init__ (self ):
2624 self .provider = None
@@ -30,21 +28,34 @@ def __init__(self):
3028 self .error_state = StatusCode .FAILED .value
3129 self .running_state = ExecutionState .RUNNING .value
3230 self .not_running_state = ExecutionState .NOT_RUNNING .value
33- self .queued_state = ExecutionState .QUEUED .value
34-
35- def process (self , sc : SparkSession , ctx : ConnectorContext , connector_config : Dict [Any , Any ], metrics_collector : MetricsCollector ) -> Iterator [DataFrame ]:
36- if (ctx .state .get_state ("status" , default_value = self .not_running_state ) == self .running_state ):
31+ self .queued_state = ExecutionState .QUEUED .value
32+
33+ def process (
34+ self ,
35+ sc : SparkSession ,
36+ ctx : ConnectorContext ,
37+ connector_config : Dict [Any , Any ],
38+ metrics_collector : MetricsCollector ,
39+ ) -> Iterator [DataFrame ]:
40+ if (
41+ ctx .state .get_state ("status" , default_value = self .not_running_state )
42+ == self .running_state
43+ ):
3744 logger .info ("Connector is already running. Skipping processing." )
3845 return
3946
4047 ctx .state .put_state ("status" , self .running_state )
4148 ctx .state .save_state ()
42- self .max_retries = connector_config ["source" ]["max_retries" ] if "max_retries" in connector_config ["source" ] else MAX_RETRY_COUNT
49+ self .max_retries = (
50+ connector_config ["source" ]["max_retries" ]
51+ if "max_retries" in connector_config ["source" ]
52+ else MAX_RETRY_COUNT
53+ )
4354 self ._get_provider (connector_config )
4455 self ._get_objects_to_process (ctx , metrics_collector )
4556 for res in self ._process_objects (sc , ctx , metrics_collector ):
4657 yield res
47-
58+
4859 last_run_time = datetime .datetime .now ()
4960 ctx .state .put_state ("status" , self .not_running_state )
5061 ctx .state .put_state ("last_run_time" , last_run_time )
@@ -54,67 +65,99 @@ def get_spark_conf(self, connector_config) -> SparkConf:
5465 self ._get_provider (connector_config )
5566 if self .provider is not None :
5667 return self .provider .get_spark_config (connector_config )
57-
68+
5869 return SparkConf ()
5970
6071 def _get_provider (self , connector_config : Dict [Any , Any ]):
61- if connector_config ["source" ]["type" ] == "s3" :
72+ if connector_config ["source" ]["type" ] == "s3" :
6273 self .provider = S3 (connector_config )
6374 else :
64- ObsrvException (ErrorData ("INVALID_PROVIDER" , "provider not supported: {}" .format (connector_config ["source" ]["type" ])))
65-
66- def _get_objects_to_process (self , ctx : ConnectorContext , metrics_collector : MetricsCollector ) -> None :
75+ ObsrvException (
76+ ErrorData (
77+ "INVALID_PROVIDER" ,
78+ "provider not supported: {}" .format (
79+ connector_config ["source" ]["type" ]
80+ ),
81+ )
82+ )
83+
84+ def _get_objects_to_process (
85+ self , ctx : ConnectorContext , metrics_collector : MetricsCollector
86+ ) -> None :
6787 objects = ctx .state .get_state ("to_process" , list ())
6888 if ctx .building_block is not None and ctx .env is not None :
6989 self .dedupe_tag = "{}-{}" .format (ctx .building_block , ctx .env )
7090 else :
71- raise ObsrvException (ErrorData ("INVALID_CONTEXT" , "building_block or env not found in context" ))
72-
73- if not len (objects ):
74- num_files_discovered = ctx .stats .get_stat ('num_files_discovered' , 0 )
91+ raise ObsrvException (
92+ ErrorData (
93+ "INVALID_CONTEXT" , "building_block or env not found in context"
94+ )
95+ )
96+
97+ if not len (objects ):
98+ num_files_discovered = ctx .stats .get_stat ("num_files_discovered" , 0 )
7599 objects = self .provider .fetch_objects (ctx , metrics_collector )
76100 objects = self ._exclude_processed_objects (ctx , objects )
77101 metrics_collector .collect ("new_objects_discovered" , len (objects ))
78102 ctx .state .put_state ("to_process" , objects )
79103 ctx .state .save_state ()
80104 num_files_discovered += len (objects )
81- ctx .stats .put_stat ("num_files_discovered" , num_files_discovered )
105+ ctx .stats .put_stat ("num_files_discovered" , num_files_discovered )
82106 ctx .stats .save_stats ()
83107
84108 self .objects = objects
85109
86- def _process_objects (self , sc : SparkSession , ctx : ConnectorContext , metrics_collector : MetricsCollector ) -> Iterator [DataFrame ]:
87- num_files_processed = ctx .stats .get_stat ('num_files_processed' , 0 )
110+ def _process_objects (
111+ self ,
112+ sc : SparkSession ,
113+ ctx : ConnectorContext ,
114+ metrics_collector : MetricsCollector ,
115+ ) -> Iterator [DataFrame ]:
116+ num_files_processed = ctx .stats .get_stat ("num_files_processed" , 0 )
88117 for i in range (0 , len (self .objects )):
89118 obj = self .objects [i ]
90119 obj ["start_processing_time" ] = time .time ()
91- columns = StructType ([])
92- df = self .provider .read_object (obj .get ("location" ), sc = sc , metrics_collector = metrics_collector , file_format = ctx .data_format )
120+ df = self .provider .read_object (
121+ obj .get ("location" ),
122+ sc = sc ,
123+ metrics_collector = metrics_collector ,
124+ file_format = ctx .data_format ,
125+ )
93126
94127 if df is None :
95128 obj ["num_of_retries" ] += 1
96129 if obj ["num_of_retries" ] < self .max_retries :
97130 ctx .state .put_state ("to_process" , self .objects [i :])
98131 ctx .state .save_state ()
99132 else :
100- if not self .provider .update_tag (object = obj , tags = [{"key" : self .dedupe_tag , "value" : self .error_state }], metrics_collector = metrics_collector ):
133+ if not self .provider .update_tag (
134+ object = obj ,
135+ tags = [{"key" : self .dedupe_tag , "value" : self .error_state }],
136+ metrics_collector = metrics_collector ,
137+ ):
101138 break
102139 return
103140 else :
104141 df = self ._append_custom_meta (sc , df , obj )
105- obj ["download_time" ] = time .time ()- obj .get ("start_processing_time" )
106- if not self .provider .update_tag (object = obj , tags = [{"key" : self .dedupe_tag , "value" : self .success_state }], metrics_collector = metrics_collector ):
142+ obj ["download_time" ] = time .time () - obj .get ("start_processing_time" )
143+ if not self .provider .update_tag (
144+ object = obj ,
145+ tags = [{"key" : self .dedupe_tag , "value" : self .success_state }],
146+ metrics_collector = metrics_collector ,
147+ ):
107148 break
108- ctx .state .put_state ("to_process" , self .objects [i + 1 :])
149+ ctx .state .put_state ("to_process" , self .objects [i + 1 :])
109150 ctx .state .save_state ()
110151 num_files_processed += 1
111- ctx .stats .put_stat ("num_files_processed" ,num_files_processed )
152+ ctx .stats .put_stat ("num_files_processed" , num_files_processed )
112153 obj ["end_processing_time" ] = time .time ()
113154 yield df
114-
155+
115156 ctx .stats .save_stats ()
116157
117- def _append_custom_meta (self , sc : SparkSession , df : DataFrame , object : ObjectInfo ) -> DataFrame :
158+ def _append_custom_meta (
159+ self , sc : SparkSession , df : DataFrame , object : ObjectInfo
160+ ) -> DataFrame :
118161 addn_meta = {
119162 "location" : object .get ("location" ),
120163 "file_size_kb" : object .get ("file_size_kb" ),
@@ -123,7 +166,7 @@ def _append_custom_meta(self, sc: SparkSession, df: DataFrame, object: ObjectInf
123166 "end_processing_time" : object .get ("end_processing_time" ),
124167 "file_hash" : object .get ("file_hash" ),
125168 "num_of_retries" : object .get ("num_of_retries" ),
126- "in_time" : object .get ("in_time" )
169+ "in_time" : object .get ("in_time" ),
127170 }
128171 df = df .withColumn ("_addn_source_meta" , lit (json .dumps (addn_meta , default = str )))
129172 return df
@@ -134,4 +177,4 @@ def _exclude_processed_objects(self, ctx: ConnectorContext, objects):
134177 if not any (tag ["key" ] == self .dedupe_tag for tag in obj .get ("tags" )):
135178 to_be_processed .append (obj )
136179
137- return to_be_processed
180+ return to_be_processed
0 commit comments