@@ -41,12 +41,16 @@ async def message_receiver(self, consumer: AIOKafkaConsumer, timeout: int = 3):
41
41
msg = await asyncio .wait_for (consumer .getone (), timeout = timeout )
42
42
yield msg .value
43
43
except asyncio .TimeoutError :
44
- print_text (f"No message received within the timeout { timeout } seconds" )
44
+ print_text (
45
+ f"No message received within the timeout { timeout } seconds"
46
+ )
45
47
break
46
48
finally :
47
49
await consumer .stop ()
48
50
49
- async def message_sender (self , producer : AIOKafkaProducer , data : Iterable , topic : str ):
51
+ async def message_sender (
52
+ self , producer : AIOKafkaProducer , data : Iterable , topic : str
53
+ ):
50
54
await producer .start ()
51
55
try :
52
56
for record in data :
@@ -70,20 +74,22 @@ async def get_data_batch(self, batch_size: Optional[int]) -> InternalDataFrame:
70
74
consumer = AIOKafkaConsumer (
71
75
self .kafka_input_topic ,
72
76
bootstrap_servers = self .kafka_bootstrap_servers ,
73
- value_deserializer = lambda v : json .loads (v .decode (' utf-8' )),
74
- auto_offset_reset = ' earliest' ,
75
- group_id = ' adala-consumer-group' # TODO: make it configurable based on the environment
77
+ value_deserializer = lambda v : json .loads (v .decode (" utf-8" )),
78
+ auto_offset_reset = " earliest" ,
79
+ group_id = " adala-consumer-group" , # TODO: make it configurable based on the environment
76
80
)
77
81
78
82
data_stream = self .message_receiver (consumer )
79
83
batch = await self .get_next_batch (data_stream , batch_size )
80
- logger .info (f"Received a batch of { len (batch )} records from Kafka topic { self .kafka_input_topic } " )
84
+ logger .info (
85
+ f"Received a batch of { len (batch )} records from Kafka topic { self .kafka_input_topic } "
86
+ )
81
87
return InternalDataFrame (batch )
82
88
83
89
async def set_predictions (self , predictions : InternalDataFrame ):
84
90
producer = AIOKafkaProducer (
85
91
bootstrap_servers = self .kafka_bootstrap_servers ,
86
- value_serializer = lambda v : json .dumps (v ).encode (' utf-8' )
92
+ value_serializer = lambda v : json .dumps (v ).encode (" utf-8" ),
87
93
)
88
94
predictions_iter = (r .to_dict () for _ , r in predictions .iterrows ())
89
95
await self .message_sender (producer , predictions_iter , self .kafka_output_topic )
@@ -109,7 +115,7 @@ def _iter_csv_local(self, csv_file_path):
109
115
Read data from the CSV file and push it to the kafka topic.
110
116
"""
111
117
112
- with open (csv_file_path , 'r' ) as csv_file :
118
+ with open (csv_file_path , "r" ) as csv_file :
113
119
csv_reader = DictReader (csv_file )
114
120
for row in csv_reader :
115
121
yield row
@@ -120,9 +126,9 @@ def _iter_csv_s3(self, s3_uri):
120
126
"""
121
127
# Assuming s3_uri format is "s3://bucket-name/path/to/file.csv"
122
128
bucket_name , key = s3_uri .replace ("s3://" , "" ).split ("/" , 1 )
123
- s3 = boto3 .client ('s3' )
129
+ s3 = boto3 .client ("s3" )
124
130
obj = s3 .get_object (Bucket = bucket_name , Key = key )
125
- data = obj [' Body' ].read ().decode (' utf-8' )
131
+ data = obj [" Body" ].read ().decode (" utf-8" )
126
132
csv_reader = DictReader (StringIO (data ))
127
133
for row in csv_reader :
128
134
yield row
@@ -140,7 +146,7 @@ async def initialize(self):
140
146
141
147
producer = AIOKafkaProducer (
142
148
bootstrap_servers = self .kafka_bootstrap_servers ,
143
- value_serializer = lambda v : json .dumps (v ).encode (' utf-8' )
149
+ value_serializer = lambda v : json .dumps (v ).encode (" utf-8" ),
144
150
)
145
151
146
152
await self .message_sender (producer , csv_reader , self .kafka_input_topic )
@@ -153,53 +159,79 @@ async def finalize(self):
153
159
consumer = AIOKafkaConsumer (
154
160
self .kafka_output_topic ,
155
161
bootstrap_servers = self .kafka_bootstrap_servers ,
156
- value_deserializer = lambda v : json .loads (v .decode (' utf-8' )),
157
- auto_offset_reset = ' earliest' ,
158
- group_id = ' consumer-group-output-topic' # TODO: make it configurable based on the environment
162
+ value_deserializer = lambda v : json .loads (v .decode (" utf-8" )),
163
+ auto_offset_reset = " earliest" ,
164
+ group_id = " consumer-group-output-topic" , # TODO: make it configurable based on the environment
159
165
)
160
166
161
167
data_stream = self .message_receiver (consumer )
162
168
163
169
if self .output_file .startswith ("s3://" ):
164
- await self ._write_to_s3 (self .output_file , self .error_file , data_stream , self .pass_through_columns )
170
+ await self ._write_to_s3 (
171
+ self .output_file ,
172
+ self .error_file ,
173
+ data_stream ,
174
+ self .pass_through_columns ,
175
+ )
165
176
else :
166
- await self ._write_to_local (self .output_file , self .error_file , data_stream , self .pass_through_columns )
167
-
168
- async def _write_to_csv_fileobj (self , fileobj , error_fileobj , data_stream , column_names ):
177
+ await self ._write_to_local (
178
+ self .output_file ,
179
+ self .error_file ,
180
+ data_stream ,
181
+ self .pass_through_columns ,
182
+ )
183
+
184
+ async def _write_to_csv_fileobj (
185
+ self , fileobj , error_fileobj , data_stream , column_names
186
+ ):
169
187
csv_writer , error_csv_writer = None , None
170
- error_columns = [' index' , ' message' , ' details' ]
188
+ error_columns = [" index" , " message" , " details" ]
171
189
while True :
172
190
try :
173
191
record = await anext (data_stream )
174
- if record .get (' error' ) == True :
192
+ if record .get (" error" ) == True :
175
193
logger .error (f"Error occurred while processing record: { record } " )
176
194
if error_csv_writer is None :
177
- error_csv_writer = DictWriter (error_fileobj , fieldnames = error_columns )
195
+ error_csv_writer = DictWriter (
196
+ error_fileobj , fieldnames = error_columns
197
+ )
178
198
error_csv_writer .writeheader ()
179
- error_csv_writer .writerow ({k : record .get (k , '' ) for k in error_columns })
199
+ error_csv_writer .writerow (
200
+ {k : record .get (k , "" ) for k in error_columns }
201
+ )
180
202
else :
181
203
if csv_writer is None :
182
204
if column_names is None :
183
205
column_names = list (record .keys ())
184
206
csv_writer = DictWriter (fileobj , fieldnames = column_names )
185
207
csv_writer .writeheader ()
186
- csv_writer .writerow ({k : record .get (k , '' ) for k in column_names })
208
+ csv_writer .writerow ({k : record .get (k , "" ) for k in column_names })
187
209
except StopAsyncIteration :
188
210
break
189
211
190
- async def _write_to_local (self , file_path : str , error_file_path : str , data_stream , column_names ):
191
- with open (file_path , 'w' ) as csv_file , open (error_file_path , 'w' ) as error_file :
192
- await self ._write_to_csv_fileobj (csv_file , error_file , data_stream , column_names )
193
-
194
- async def _write_to_s3 (self , s3_uri : str , s3_uri_errors : str , data_stream , column_names ):
212
+ async def _write_to_local (
213
+ self , file_path : str , error_file_path : str , data_stream , column_names
214
+ ):
215
+ with open (file_path , "w" ) as csv_file , open (error_file_path , "w" ) as error_file :
216
+ await self ._write_to_csv_fileobj (
217
+ csv_file , error_file , data_stream , column_names
218
+ )
219
+
220
+ async def _write_to_s3 (
221
+ self , s3_uri : str , s3_uri_errors : str , data_stream , column_names
222
+ ):
195
223
# Assuming s3_uri format is "s3://bucket-name/path/to/file.csv"
196
224
bucket_name , key = s3_uri .replace ("s3://" , "" ).split ("/" , 1 )
197
225
error_bucket_name , error_key = s3_uri_errors .replace ("s3://" , "" ).split ("/" , 1 )
198
- s3 = boto3 .client ('s3' )
226
+ s3 = boto3 .client ("s3" )
199
227
with StringIO () as csv_file , StringIO () as error_file :
200
- await self ._write_to_csv_fileobj (csv_file , error_file , data_stream , column_names )
228
+ await self ._write_to_csv_fileobj (
229
+ csv_file , error_file , data_stream , column_names
230
+ )
201
231
s3 .put_object (Bucket = bucket_name , Key = key , Body = csv_file .getvalue ())
202
- s3 .put_object (Bucket = error_bucket_name , Key = error_key , Body = error_file .getvalue ())
232
+ s3 .put_object (
233
+ Bucket = error_bucket_name , Key = error_key , Body = error_file .getvalue ()
234
+ )
203
235
204
236
async def get_feedback (
205
237
self ,
@@ -214,4 +246,3 @@ async def restore(self):
214
246
215
247
async def save (self ):
216
248
raise NotImplementedError ("Save is not supported in Kafka environment" )
217
-
0 commit comments