Skip to content

Commit 3e8ea4d

Browse files
authored
GH-731: Update output format of model sonar_core_1 (#815)
1 parent 31470ac commit 3e8ea4d

File tree

3 files changed

+22
-35
lines changed

3 files changed

+22
-35
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ dependencies = [
1414
"flake8>=7.3.0",
1515
'Click>=6.0',
1616
'python-crfsuite>=0.9.6',
17-
'nltk==3.8',
17+
'nltk>=3.8',
1818
'tqdm',
1919
'requests',
2020
'joblib',
21-
'scikit-learn==1.6.1',
21+
'scikit-learn>=1.6.1',
2222
'PyYAML',
2323
'underthesea_core==1.0.5'
2424
]

tests/pipeline/classification/test_sonar_core_1.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,30 @@ def test_classify_null_cases(self):
1212

1313
def test_classify_simple_case(self):
1414
text = u"HLV ngoại đòi gần tỷ mỗi tháng dẫn dắt tuyển Việt Nam"
15-
actual = classify(text)[0]
15+
actual = classify(text)
1616
expected = "the_thao"
1717
self.assertEqual(actual, expected)
1818

1919
def test_classify_sports(self):
2020
text = u"Việt Nam giành chiến thắng 3-0 trước Thái Lan trong trận bán kết"
21-
actual = classify(text)[0]
21+
actual = classify(text)
2222
expected = "the_thao"
2323
self.assertEqual(actual, expected)
2424

2525
def test_classify_technology(self):
2626
text = u"Apple ra mắt iPhone mới với nhiều tính năng đột phá"
27-
actual = classify(text)[0]
27+
actual = classify(text)
2828
expected = "vi_tinh"
2929
self.assertEqual(actual, expected)
3030

3131
def test_classify_health(self):
3232
text = u"Phát hiện vaccine mới chống lại virus corona"
33-
actual = classify(text)[0]
33+
actual = classify(text)
3434
expected = "suc_khoe"
3535
self.assertEqual(actual, expected)
3636

3737
def test_classify_business(self):
3838
text = u"Thị trường chứng khoán tăng điểm mạnh trong phiên sáng nay"
39-
actual = classify(text)[0]
39+
actual = classify(text)
4040
expected = "kinh_doanh"
4141
self.assertEqual(actual, expected)

underthesea/pipeline/classification/sonar_core_1/__init__.py

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
import os
22
import sys
33
import urllib.request
4+
import warnings
45
import zipfile
56
from os.path import dirname
67

78
import joblib
89

10+
# Suppress scikit-learn version warnings
11+
warnings.filterwarnings('ignore', category=UserWarning, module='sklearn')
12+
913
sys.path.insert(0, dirname(dirname(__file__)))
1014
classifier = None
1115

1216

1317
def _ensure_model_exists():
14-
"""Get model from latest run or download from release"""
18+
"""Download and extract sonar_core_1 model if not exists"""
1519
model_dir = os.path.expanduser("~/.underthesea/models")
1620
model_file = os.path.join(model_dir, "sonar_core_1.pkl")
1721
labels_file = os.path.join(model_dir, "sonar_core_1_labels.txt")
@@ -20,26 +24,6 @@ def _ensure_model_exists():
2024
if os.path.exists(model_file) and os.path.exists(labels_file):
2125
return model_file, labels_file
2226

23-
# Try to get from latest local run first
24-
runs_dir = os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "extensions", "labs", "classify_ml", "sonar_core_1", "runs")
25-
if os.path.exists(runs_dir):
26-
import glob
27-
run_dirs = glob.glob(os.path.join(runs_dir, "[0-9]*_[0-9]*"))
28-
if run_dirs:
29-
latest_run = sorted(run_dirs)[-1]
30-
latest_model = os.path.join(latest_run, "models", "model.pkl")
31-
latest_labels = os.path.join(latest_run, "models", "labels.txt")
32-
33-
if os.path.exists(latest_model) and os.path.exists(latest_labels):
34-
print(f"Using model from latest local run: {latest_run}")
35-
os.makedirs(model_dir, exist_ok=True)
36-
37-
# Copy from latest run
38-
import shutil
39-
shutil.copy2(latest_model, model_file)
40-
shutil.copy2(latest_labels, labels_file)
41-
return model_file, labels_file
42-
4327
print("Downloading Sonar Core 1 model...")
4428

4529
# Create directories
@@ -75,7 +59,7 @@ def _ensure_model_exists():
7559

7660
def _load_labels(labels_file):
7761
"""Load label mapping from file"""
78-
with open(labels_file, 'r', encoding='utf-8') as f:
62+
with open(labels_file, encoding='utf-8') as f:
7963
labels = [line.strip() for line in f.readlines()]
8064
return labels
8165

@@ -87,7 +71,7 @@ def classify(text):
8771
text (str): Vietnamese text to classify
8872
8973
Returns:
90-
list: List containing the predicted category (for compatibility with underthesea API)
74+
str: Predicted category
9175
"""
9276
global classifier
9377

@@ -96,9 +80,9 @@ def classify(text):
9680
classifier = joblib.load(model_file)
9781
classifier.labels = _load_labels(labels_file)
9882

99-
# Make prediction
83+
# Make prediction and convert to plain string
10084
prediction = classifier.predict([text])[0]
101-
return [prediction]
85+
return str(prediction)
10286

10387

10488
def classify_with_confidence(text):
@@ -121,13 +105,16 @@ def classify_with_confidence(text):
121105
prediction = classifier.predict([text])[0]
122106
probabilities = classifier.predict_proba([text])[0]
123107

124-
# Get top 3 predictions with probabilities
108+
# Get top 3 predictions with probabilities, convert to plain strings
125109
classes = classifier.classes_
126110
prob_dict = dict(zip(classes, probabilities))
127111
top_predictions = sorted(prob_dict.items(), key=lambda x: x[1], reverse=True)[:3]
128112

113+
# Convert numpy strings to plain strings
114+
top_predictions = [(str(label), float(prob)) for label, prob in top_predictions]
115+
129116
return {
130-
'prediction': prediction,
131-
'confidence': top_predictions[0][1],
117+
'prediction': str(prediction),
118+
'confidence': float(top_predictions[0][1]),
132119
'top_3': top_predictions
133120
}

0 commit comments

Comments
 (0)