11import os
2+ import json
23import redis
34import pandas as pd
45
6+ from box import Box
57from dataclasses import make_dataclass , field , dataclass
68from pydantic import BaseModel , Field , create_model
79from concurrent .futures import ThreadPoolExecutor
2022
2123class WidgetsMappings :
2224
23- def __init__ (self , selected_widgets : pd . DataFrame ):
25+ def __init__ (self , selected_widgets : list ):
2426
2527 self .mappings = {}
2628 self .selected_widgets = selected_widgets
2729 self .WidgetTypes = make_dataclass (
2830 'WidgetTypes' ,
29- [(widget_id , str , field (default = widget_id ))for widget_id in selected_widgets .widget_id ]
31+ [(widget , str , field (default = widget ))for widget in list (set ([element .widget_id
32+ for element in self .selected_widgets ]))]
3033 )
3134 self .create_mappings ()
3235
@@ -75,7 +78,7 @@ def __process_widget_scale(self, key: str, properties: dict, version: int):
7578
7679 def create_mappings (self ):
7780
78- for _ , d in self .selected_widgets . iterrows () :
81+ for d in self .selected_widgets :
7982
8083 if d .widget_id == self .WidgetTypes .matrix1dWidget :
8184 self .__process_widget_1d (d .key , d .properties , d .version )
@@ -106,7 +109,7 @@ class Schema:
106109 properties : dict
107110 pyd_class : BaseModel
108111
109- def __init__ (self , selected_widgets : pd . DataFrame , model_family : str = "openai" ):
112+ def __init__ (self , selected_widgets : list , model_family : str = "openai" ):
110113
111114 self .schemas = {}
112115 self .max_widget_length = 50
@@ -203,7 +206,7 @@ def __process_widget_scale(self, key: str, properties: dict, class_name: str):
203206
204207 def create_schemas (self ):
205208
206- for _ , d in self .selected_widgets . iterrows () :
209+ for d in self .selected_widgets :
207210
208211 if d .widget_id == self .mappings_instance .WidgetTypes .matrix1dWidget :
209212 self .__process_widget_1d (d .key , d .properties , "Pillar" )
@@ -229,37 +232,46 @@ class LLMTagsPrediction:
229232 AVAILABLE_WIDGETS : list = ["matrix2dWidget" , "matrix1dWidget" ] # it'll be extended to all widget types
230233 AVAILABLE_FOUNDATION_MODELS : list = ["bedrock" , "openai" ]
231234
232-
233235 def __init__ (self , analysis_framework_id : int , model_family : str = "openai" ):
234236
235237 self .af_id = analysis_framework_id
236238 self .model_family = model_family
237239
238240 assert self .model_family in self .AVAILABLE_FOUNDATION_MODELS , ValueError ("Selected model family not implemented" )
239-
240- self .cursor = self .__get_deep_db_connection ().cursor
241+
242+ # self.cursor = self.__get_deep_db_connection().cursor
241243 self .selected_widgets = self .__get_framework_widgets ()
242244 self .widgets = WidgetSchema (self .selected_widgets , self .model_family )
243245
244-
245246 def __get_deep_db_connection (self ):
246247 return connect_db ()
247248
248249 def __get_elasticache (self , port : int = 6379 ):
249-
250- return redis .StrictRedis (
251- host = env ("REDIS_HOST" ),
252- port = port ,
253- decode_responses = True
254- )
250+ return redis .Redis (host = env ("REDIS_HOST" ), port = port , decode_responses = True )
255251
256- def __get_framework_widgets (self ):
252+ def __get_framework_widgets (self , expire_time : int = 1200 ):
257253
258- #self.redis = self.__get_elasticache()
259- #afw = self.redis.get(f"af_id:{self.af_id}")
260- self .cursor .execute (af_widget_by_id .format (self .af_id ))
261- afw = pd .DataFrame (self .cursor .fetchall (), columns = [c .name for c in self .cursor .description ])
262- afw = afw [afw .widget_id .isin (self .AVAILABLE_WIDGETS )]
254+ # let's get or save the af_id widget original data on elasticache for 20 minutes
255+ # avoiding multiple db connection and executions on the same analysis framework id.
256+
257+ self .redis = self .__get_elasticache ()
258+ cached_afw = self .redis .get (f"af_id:{ self .af_id } " )
259+ if cached_afw :
260+ afw = [Box (element ) for element in json .loads (cached_afw )]
261+ else :
262+ self .cursor = self .__get_deep_db_connection ().cursor
263+ self .cursor .execute (af_widget_by_id .format (self .af_id ))
264+ fetch = self .cursor .fetchall ()
265+ if not fetch :
266+ raise ValueError (f"Not possible to retrieve framework widgets: { self .af_id } " )
267+ else :
268+ afw = [Box (dict (zip ([c .name for c in self .cursor .description ], row ))) for row in fetch ]
269+ afw = [element for element in afw if afw .widget_id in self .AVAILABLE_WIDGETS ]
270+ self .redis .set (
271+ name = f"af_id:{ self .af_id } " ,
272+ ex = expire_time ,
273+ value = json .dumps (afw )
274+ )
263275
264276 return afw
265277
0 commit comments