Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/pipeline/classification/test_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def test_classify_simple_case(self):
actual = classify(text, domain="bank")
# Convert numpy strings to regular strings for comparison
actual = [str(label) for label in actual]
expected = ["CUSTOMER_SUPPORT"]
expected = ["CARD"]
self.assertEqual(expected, actual)

def test_classify_simple_case_2(self):
Expand All @@ -18,7 +18,7 @@ def test_classify_simple_case_2(self):
# Convert numpy strings to regular strings for comparison
actual = [str(label) for label in actual]
# Updated expectation based on new model output
expected = ["CUSTOMER_SUPPORT"]
expected = ["INTEREST_RATE"]
self.assertEqual(expected, actual)

def test_classify_simple_case_3(self):
Expand All @@ -27,5 +27,5 @@ def test_classify_simple_case_3(self):
# Convert numpy strings to regular strings for comparison
actual = [str(label) for label in actual]
# Updated expectation based on new model output
expected = ["TRADEMARK"]
expected = ["DISCOUNT"]
self.assertEqual(expected, actual)
47 changes: 31 additions & 16 deletions underthesea/pipeline/classification/bank/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,53 @@
classifier = None


def classify(X):
def _load_classifier():
global classifier

if not classifier:
# Download and load UTS2017_Bank model from Hugging Face
model_path = hf_hub_download(
repo_id="undertheseanlp/sonar_core_1",
filename="uts2017_bank_classifier_20250927_161733.joblib",
filename="uts2017_bank_classifier_20250928_060819.joblib",
)
classifier = joblib.load(model_path)
return classifier


def classify(X):
classifier = _load_classifier()

# Make prediction and convert to plain string
prediction = classifier.predict([X])[0]
# Use predict_text function for prediction
prediction, _, _ = predict_text(classifier, X)

# Return as list to maintain compatibility with existing API
return [str(prediction)]


def classify_with_confidence(X):
global classifier
classifier = _load_classifier()

if not classifier:
# Download and load UTS2017_Bank model from Hugging Face
model_path = hf_hub_download(
repo_id="undertheseanlp/sonar_core_1",
filename="uts2017_bank_classifier_20250927_161733.joblib",
)
classifier = joblib.load(model_path)
# Use predict_text function for prediction
prediction, confidence, _ = predict_text(classifier, X)

# Make prediction with probabilities and convert to plain string
prediction = classifier.predict([X])[0]
# Get full probabilities for backward compatibility
probabilities = classifier.predict_proba([X])[0]
confidence = float(max(probabilities))

return {"category": str(prediction), "confidence": confidence, "probabilities": probabilities}


def predict_text(model, text):
probabilities = model.predict_proba([text])[0]

# Get top 3 predictions sorted by probability
top_indices = probabilities.argsort()[-3:][::-1]
top_predictions = []
for idx in top_indices:
category = model.classes_[idx]
prob = probabilities[idx]
top_predictions.append((category, prob))

# The prediction should be the top category
prediction = top_predictions[0][0]
confidence = top_predictions[0][1]

return prediction, confidence, top_predictions