11import os
22import sys
33import urllib .request
4+ import warnings
45import zipfile
56from os .path import dirname
67
78import joblib
89
10+ # Suppress scikit-learn version warnings
11+ warnings .filterwarnings ('ignore' , category = UserWarning , module = 'sklearn' )
12+
913sys .path .insert (0 , dirname (dirname (__file__ )))
1014classifier = None
1115
1216
1317def _ensure_model_exists ():
14- """Get model from latest run or download from release """
18+ """Download and extract sonar_core_1 model if not exists """
1519 model_dir = os .path .expanduser ("~/.underthesea/models" )
1620 model_file = os .path .join (model_dir , "sonar_core_1.pkl" )
1721 labels_file = os .path .join (model_dir , "sonar_core_1_labels.txt" )
@@ -20,26 +24,6 @@ def _ensure_model_exists():
2024 if os .path .exists (model_file ) and os .path .exists (labels_file ):
2125 return model_file , labels_file
2226
23- # Try to get from latest local run first
24- runs_dir = os .path .join (os .path .dirname (__file__ ), ".." , ".." , ".." , ".." , "extensions" , "labs" , "classify_ml" , "sonar_core_1" , "runs" )
25- if os .path .exists (runs_dir ):
26- import glob
27- run_dirs = glob .glob (os .path .join (runs_dir , "[0-9]*_[0-9]*" ))
28- if run_dirs :
29- latest_run = sorted (run_dirs )[- 1 ]
30- latest_model = os .path .join (latest_run , "models" , "model.pkl" )
31- latest_labels = os .path .join (latest_run , "models" , "labels.txt" )
32-
33- if os .path .exists (latest_model ) and os .path .exists (latest_labels ):
34- print (f"Using model from latest local run: { latest_run } " )
35- os .makedirs (model_dir , exist_ok = True )
36-
37- # Copy from latest run
38- import shutil
39- shutil .copy2 (latest_model , model_file )
40- shutil .copy2 (latest_labels , labels_file )
41- return model_file , labels_file
42-
4327 print ("Downloading Sonar Core 1 model..." )
4428
4529 # Create directories
@@ -75,7 +59,7 @@ def _ensure_model_exists():
7559
7660def _load_labels (labels_file ):
7761 """Load label mapping from file"""
78- with open (labels_file , 'r' , encoding = 'utf-8' ) as f :
62+ with open (labels_file , encoding = 'utf-8' ) as f :
7963 labels = [line .strip () for line in f .readlines ()]
8064 return labels
8165
@@ -87,7 +71,7 @@ def classify(text):
8771 text (str): Vietnamese text to classify
8872
8973 Returns:
90- list: List containing the predicted category (for compatibility with underthesea API)
74+ str: Predicted category
9175 """
9276 global classifier
9377
@@ -96,9 +80,9 @@ def classify(text):
9680 classifier = joblib .load (model_file )
9781 classifier .labels = _load_labels (labels_file )
9882
99- # Make prediction
83+ # Make prediction and convert to plain string
10084 prediction = classifier .predict ([text ])[0 ]
101- return [ prediction ]
85+ return str ( prediction )
10286
10387
10488def classify_with_confidence (text ):
@@ -121,13 +105,16 @@ def classify_with_confidence(text):
121105 prediction = classifier .predict ([text ])[0 ]
122106 probabilities = classifier .predict_proba ([text ])[0 ]
123107
124- # Get top 3 predictions with probabilities
108+ # Get top 3 predictions with probabilities, convert to plain strings
125109 classes = classifier .classes_
126110 prob_dict = dict (zip (classes , probabilities ))
127111 top_predictions = sorted (prob_dict .items (), key = lambda x : x [1 ], reverse = True )[:3 ]
128112
113+ # Convert numpy strings to plain strings
114+ top_predictions = [(str (label ), float (prob )) for label , prob in top_predictions ]
115+
129116 return {
130- 'prediction' : prediction ,
131- 'confidence' : top_predictions [0 ][1 ],
117+ 'prediction' : str ( prediction ) ,
118+ 'confidence' : float ( top_predictions [0 ][1 ]) ,
132119 'top_3' : top_predictions
133120 }
0 commit comments