From 9ba8c353d4eb4a879a65cd5987b273cb3b1e8454 Mon Sep 17 00:00:00 2001 From: Arman Aminian Date: Fri, 11 Aug 2023 12:36:21 +0330 Subject: [PATCH 1/5] feat: mlflow added to requirements file. --- inference_server/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/inference_server/requirements.txt b/inference_server/requirements.txt index c10ece7..1196d1e 100644 --- a/inference_server/requirements.txt +++ b/inference_server/requirements.txt @@ -4,3 +4,4 @@ torch --index-url https://download.pytorch.org/whl/cpu torchvision --index-url https://download.pytorch.org/whl/cpu transformers qdrant_client +mlflow From 8c6c8724ab83c237c301462d9036d7e5163480ca Mon Sep 17 00:00:00 2001 From: Arman Aminian Date: Fri, 11 Aug 2023 12:53:21 +0330 Subject: [PATCH 2/5] feat: loading best model name from mlflow added --- inference_server/main.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/inference_server/main.py b/inference_server/main.py index d1bfb07..97da02a 100644 --- a/inference_server/main.py +++ b/inference_server/main.py @@ -6,12 +6,25 @@ from transformers import AutoModel, AutoTokenizer from fastapi.openapi.utils import get_openapi from fastapi.responses import JSONResponse +from mlflow.tracking import MlflowClient +from mlflow.entities import ViewType client = QdrantClient("https://qdrant-mlsd-video-search.darkube.app", port=443) app = FastAPI() +MLFLOW_TRACKING_URI = "https://mlflow-mlsd-video-search.darkube.app/" +client = MlflowClient(tracking_uri=MLFLOW_TRACKING_URI) +runs = client.search_runs( + experiment_ids='2', + filter_string="metrics.acc_at_10 >0.2", + run_view_type=ViewType.ACTIVE_ONLY, + max_results=5, + order_by=["metrics.acc_at_10 DESC"] +) +TEXT_ENCODER_MODEL = runs[0].data.tags['text_model'] + text_encoder = AutoModel.from_pretrained(os.environ['TEXT_ENCODER_MODEL']) text_tokenizer = AutoTokenizer.from_pretrained(os.environ['TEXT_ENCODER_MODEL']) From ec94aa034bc92bcf3d394ef315550e077b556369 Mon Sep 17 00:00:00 2001 From: Arman Aminian Date: Fri, 11 Aug 2023 12:54:01 +0330 Subject: [PATCH 3/5] feat: loading experiment id from mlflow added --- inference_server/main.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/inference_server/main.py b/inference_server/main.py index 97da02a..1254637 100644 --- a/inference_server/main.py +++ b/inference_server/main.py @@ -9,15 +9,16 @@ from mlflow.tracking import MlflowClient from mlflow.entities import ViewType - client = QdrantClient("https://qdrant-mlsd-video-search.darkube.app", port=443) app = FastAPI() MLFLOW_TRACKING_URI = "https://mlflow-mlsd-video-search.darkube.app/" client = MlflowClient(tracking_uri=MLFLOW_TRACKING_URI) +experiments = client.search_experiments() +exp_id = list(filter(lambda e: e.name == 'clip-farsi', experiments))[0].experiment_id runs = client.search_runs( - experiment_ids='2', + experiment_ids=exp_id, filter_string="metrics.acc_at_10 >0.2", run_view_type=ViewType.ACTIVE_ONLY, max_results=5, @@ -31,8 +32,9 @@ @app.get("/{video_name}/") async def query( - video_name: str = Path(..., title="Video Name", description="Name of the video or 'ALL' to search in all videos"), - search_entry: str = Query(..., title="Search Entry", description="The search entry for text embedding"), + video_name: str = Path(..., title="Video Name", + description="Name of the video or 'ALL' to search in all videos"), + search_entry: str = Query(..., title="Search Entry", description="The search entry for text embedding"), ): """ Query for video frames based on the provided text search entry. @@ -98,4 +100,3 @@ async def get_open_api_endpoint(): @app.get("/docs", include_in_schema=False) async def get_documentation(): return JSONResponse(content=app.openapi()) - From d5f3a664f3ffdfc55f14b5c99b97c61274aa0389 Mon Sep 17 00:00:00 2001 From: Arman Aminian Date: Fri, 11 Aug 2023 12:55:25 +0330 Subject: [PATCH 4/5] feat: loading models using mlflow registry --- inference_server/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/inference_server/main.py b/inference_server/main.py index 1254637..e9e13a4 100644 --- a/inference_server/main.py +++ b/inference_server/main.py @@ -26,8 +26,8 @@ ) TEXT_ENCODER_MODEL = runs[0].data.tags['text_model'] -text_encoder = AutoModel.from_pretrained(os.environ['TEXT_ENCODER_MODEL']) -text_tokenizer = AutoTokenizer.from_pretrained(os.environ['TEXT_ENCODER_MODEL']) +text_encoder = AutoModel.from_pretrained(TEXT_ENCODER_MODEL) +text_tokenizer = AutoTokenizer.from_pretrained(TEXT_ENCODER_MODEL) @app.get("/{video_name}/") From f1cb88bdd1f8e0794909890256ea3a923d8d13bc Mon Sep 17 00:00:00 2001 From: Arman Aminian Date: Fri, 11 Aug 2023 12:59:34 +0330 Subject: [PATCH 5/5] feat: loading vision model from mlflow added. --- video_database/video_to_db.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/video_database/video_to_db.py b/video_database/video_to_db.py index 8a1ef18..456a0c4 100644 --- a/video_database/video_to_db.py +++ b/video_database/video_to_db.py @@ -9,6 +9,8 @@ import torchvision.transforms as transforms from qdrant_client import QdrantClient from qdrant_client.models import Record +from mlflow.tracking import MlflowClient +from mlflow.entities import ViewType def image_to_string(image): @@ -25,7 +27,20 @@ def image_to_string(image): fps = int(video.get(cv2.CAP_PROP_FPS)) frame_interval = fps * 5 # capture a frame every 5 seconds - image_encoder = CLIPVisionModel.from_pretrained('arman-aminian/clip-farsi-vision').eval() + MLFLOW_TRACKING_URI = "https://mlflow-mlsd-video-search.darkube.app/" + client = MlflowClient(tracking_uri=MLFLOW_TRACKING_URI) + experiments = client.search_experiments() + exp_id = list(filter(lambda e: e.name == 'clip-farsi', experiments))[0].experiment_id + runs = client.search_runs( + experiment_ids=exp_id, + filter_string="metrics.acc_at_10 >0.2", + run_view_type=ViewType.ACTIVE_ONLY, + max_results=5, + order_by=["metrics.acc_at_10 DESC"] + ) + VISION_ENCODER_MODEL = runs[0].data.tags['vision_model'] + + image_encoder = CLIPVisionModel.from_pretrained(VISION_ENCODER_MODEL).eval() insert_data = [] currentframe = 0