diff --git a/tests/pipeline/classification/test_bank.py b/tests/pipeline/classification/test_bank.py index cb398093..db77bcdf 100644 --- a/tests/pipeline/classification/test_bank.py +++ b/tests/pipeline/classification/test_bank.py @@ -9,7 +9,7 @@ def test_classify_simple_case(self): actual = classify(text, domain="bank") # Convert numpy strings to regular strings for comparison actual = [str(label) for label in actual] - expected = ["CUSTOMER_SUPPORT"] + expected = ["CARD"] self.assertEqual(expected, actual) def test_classify_simple_case_2(self): @@ -18,7 +18,7 @@ def test_classify_simple_case_2(self): # Convert numpy strings to regular strings for comparison actual = [str(label) for label in actual] # Updated expectation based on new model output - expected = ["CUSTOMER_SUPPORT"] + expected = ["INTEREST_RATE"] self.assertEqual(expected, actual) def test_classify_simple_case_3(self): @@ -27,5 +27,5 @@ def test_classify_simple_case_3(self): # Convert numpy strings to regular strings for comparison actual = [str(label) for label in actual] # Updated expectation based on new model output - expected = ["TRADEMARK"] + expected = ["DISCOUNT"] self.assertEqual(expected, actual) diff --git a/underthesea/pipeline/classification/bank/__init__.py b/underthesea/pipeline/classification/bank/__init__.py index 0cc58998..8ba1bee4 100644 --- a/underthesea/pipeline/classification/bank/__init__.py +++ b/underthesea/pipeline/classification/bank/__init__.py @@ -14,38 +14,53 @@ classifier = None -def classify(X): +def _load_classifier(): global classifier - if not classifier: # Download and load UTS2017_Bank model from Hugging Face model_path = hf_hub_download( repo_id="undertheseanlp/sonar_core_1", - filename="uts2017_bank_classifier_20250927_161733.joblib", + filename="uts2017_bank_classifier_20250928_060819.joblib", ) classifier = joblib.load(model_path) + return classifier + + +def classify(X): + classifier = _load_classifier() - # Make prediction and convert to plain string - prediction = classifier.predict([X])[0] + # Use predict_text function for prediction + prediction, _, _ = predict_text(classifier, X) # Return as list to maintain compatibility with existing API return [str(prediction)] def classify_with_confidence(X): - global classifier + classifier = _load_classifier() - if not classifier: - # Download and load UTS2017_Bank model from Hugging Face - model_path = hf_hub_download( - repo_id="undertheseanlp/sonar_core_1", - filename="uts2017_bank_classifier_20250927_161733.joblib", - ) - classifier = joblib.load(model_path) + # Use predict_text function for prediction + prediction, confidence, _ = predict_text(classifier, X) - # Make prediction with probabilities and convert to plain string - prediction = classifier.predict([X])[0] + # Get full probabilities for backward compatibility probabilities = classifier.predict_proba([X])[0] - confidence = float(max(probabilities)) return {"category": str(prediction), "confidence": confidence, "probabilities": probabilities} + + +def predict_text(model, text): + probabilities = model.predict_proba([text])[0] + + # Get top 3 predictions sorted by probability + top_indices = probabilities.argsort()[-3:][::-1] + top_predictions = [] + for idx in top_indices: + category = model.classes_[idx] + prob = probabilities[idx] + top_predictions.append((category, prob)) + + # The prediction should be the top category + prediction = top_predictions[0][0] + confidence = top_predictions[0][1] + + return prediction, confidence, top_predictions