Skip to content

Commit f94bb05

Browse files
committed
updating lock
1 parent 7075572 commit f94bb05

File tree

3 files changed

+241
-156
lines changed

3 files changed

+241
-156
lines changed

nlp_scripts/model_prediction/llm/model_prediction.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
2+
import json
23
import redis
34
import pandas as pd
45

6+
from box import Box
57
from dataclasses import make_dataclass, field, dataclass
68
from pydantic import BaseModel, Field, create_model
79
from concurrent.futures import ThreadPoolExecutor
@@ -20,13 +22,14 @@
2022

2123
class WidgetsMappings:
2224

23-
def __init__(self, selected_widgets: pd.DataFrame):
25+
def __init__(self, selected_widgets: list):
2426

2527
self.mappings = {}
2628
self.selected_widgets = selected_widgets
2729
self.WidgetTypes = make_dataclass(
2830
'WidgetTypes',
29-
[(widget_id, str, field(default=widget_id))for widget_id in selected_widgets.widget_id]
31+
[(widget, str, field(default=widget))for widget in list(set([element.widget_id
32+
for element in self.selected_widgets]))]
3033
)
3134
self.create_mappings()
3235

@@ -75,7 +78,7 @@ def __process_widget_scale(self, key: str, properties: dict, version: int):
7578

7679
def create_mappings(self):
7780

78-
for _, d in self.selected_widgets.iterrows():
81+
for d in self.selected_widgets:
7982

8083
if d.widget_id == self.WidgetTypes.matrix1dWidget:
8184
self.__process_widget_1d(d.key, d.properties, d.version)
@@ -106,7 +109,7 @@ class Schema:
106109
properties: dict
107110
pyd_class: BaseModel
108111

109-
def __init__(self, selected_widgets: pd.DataFrame, model_family: str = "openai"):
112+
def __init__(self, selected_widgets: list, model_family: str = "openai"):
110113

111114
self.schemas = {}
112115
self.max_widget_length = 50
@@ -203,7 +206,7 @@ def __process_widget_scale(self, key: str, properties: dict, class_name: str):
203206

204207
def create_schemas(self):
205208

206-
for _, d in self.selected_widgets.iterrows():
209+
for d in self.selected_widgets:
207210

208211
if d.widget_id == self.mappings_instance.WidgetTypes.matrix1dWidget:
209212
self.__process_widget_1d(d.key, d.properties, "Pillar")
@@ -229,37 +232,46 @@ class LLMTagsPrediction:
229232
AVAILABLE_WIDGETS: list = ["matrix2dWidget", "matrix1dWidget"] # it'll be extended to all widget types
230233
AVAILABLE_FOUNDATION_MODELS: list = ["bedrock", "openai"]
231234

232-
233235
def __init__(self, analysis_framework_id: int, model_family: str = "openai"):
234236

235237
self.af_id = analysis_framework_id
236238
self.model_family = model_family
237239

238240
assert self.model_family in self.AVAILABLE_FOUNDATION_MODELS, ValueError("Selected model family not implemented")
239-
240-
self.cursor = self.__get_deep_db_connection().cursor
241+
242+
# self.cursor = self.__get_deep_db_connection().cursor
241243
self.selected_widgets = self.__get_framework_widgets()
242244
self.widgets = WidgetSchema(self.selected_widgets, self.model_family)
243245

244-
245246
def __get_deep_db_connection(self):
246247
return connect_db()
247248

248249
def __get_elasticache(self, port: int = 6379):
249-
250-
return redis.StrictRedis(
251-
host=env("REDIS_HOST"),
252-
port=port,
253-
decode_responses=True
254-
)
250+
return redis.Redis(host=env("REDIS_HOST"), port=port, decode_responses=True)
255251

256-
def __get_framework_widgets(self):
252+
def __get_framework_widgets(self, expire_time: int = 1200):
257253

258-
#self.redis = self.__get_elasticache()
259-
#afw = self.redis.get(f"af_id:{self.af_id}")
260-
self.cursor.execute(af_widget_by_id.format(self.af_id))
261-
afw = pd.DataFrame(self.cursor.fetchall(), columns=[c.name for c in self.cursor.description])
262-
afw = afw[afw.widget_id.isin(self.AVAILABLE_WIDGETS)]
254+
# let's get or save the af_id widget original data on elasticache for 20 minutes
255+
# avoiding multiple db connection and executions on the same analysis framework id.
256+
257+
self.redis = self.__get_elasticache()
258+
cached_afw = self.redis.get(f"af_id:{self.af_id}")
259+
if cached_afw:
260+
afw = [Box(element) for element in json.loads(cached_afw)]
261+
else:
262+
self.cursor = self.__get_deep_db_connection().cursor
263+
self.cursor.execute(af_widget_by_id.format(self.af_id))
264+
fetch = self.cursor.fetchall()
265+
if not fetch:
266+
raise ValueError(f"Not possible to retrieve framework widgets: {self.af_id}")
267+
else:
268+
afw = [Box(dict(zip([c.name for c in self.cursor.description], row))) for row in fetch]
269+
afw = [element for element in afw if afw.widget_id in self.AVAILABLE_WIDGETS]
270+
self.redis.set(
271+
name=f"af_id:{self.af_id}",
272+
ex=expire_time,
273+
value=json.dumps(afw)
274+
)
263275

264276
return afw
265277

0 commit comments

Comments
 (0)