Skip to content

A from-scratch implementation of Multiclass Logistic Regression inspired by the Sorting Hat. Features in-depth data analysis using performance metrics and visualizations to assess feature relevance and model behavior.

Notifications You must be signed in to change notification settings

Tsunghao-C/LogisticRegress-Hat

 
 

Repository files navigation

LogisticRegress-Hat

This project aims to recreate Hogwarts' Sorting Hat using Machine Learning. We analyze a dataset containing various student features (including name, birthdate, dominant hand, and academic scores) to implement and train a logistic regression algorithm. The trained model will be able to sort new students into one of the four Hogwarts houses (Gryffindor, Hufflepuff, Ravenclaw, and Slytherin), just like the magical Sorting Hat from the Harry Potter series.

Summary:

  • Part 1: Exploratory analysis of a dataset (descriptive statistics)
  • Part 2: Visualizations (histograms, scatter plots, pair plots)
  • Part 3: Multiclass logistic regression (one-vs-all) to classify Hogwarts houses

Repository structure

./
├── 01_data-analysis/
│   └── describe.py                # Descriptive stats printed as a table
├── 02_data-visualisation/
│   ├── 02_01_histogram/
│   │   └── histogram.py           # Histograms per feature/house
│   ├── 02_02_scatter-plot/
│   │   ├── scatter_plot_v1.py     # Scatter plot (v1)
│   │   └── scatter_plot_v2.py     # Scatter plot (v2) with pearson's correlation
│   │   └── *_featext.py		   # Scatter plots (v1/v2) with feature extractions
│   └── 02_03_pair-plot/
│       └── pair_plot.py           # Pair plot of numerical features
│       └── *_featext.py           # Pair plot of all features (with feature extractions)
├── 03_logistic-regression/
│   ├── train.py                   # Training (one-vs-all), saves weights
│   └── predict.py                 # Predictions from weights.npy → houses.csv
├── datasets/
│   ├── dataset_train.csv
│   └── dataset_test.csv
├── dataset_truth.csv              # Ground-truth labels for evaluation (test set)
├── pyproject.toml                 # Dependencies and Ruff config
├── evaluate.py                    # Evaluates houses.csv against dataset_truth.csv
└── uv.lock                        # Version lock (uv)

Requirements

  • Python >= 3.12
  • Package manager: uv (recommended) or pip

Installation

Option A — uv (recommended)

# 1) Install uv if needed (macOS/Linux)
curl -LsSf https://astral.sh/uv/install.sh | sh

# 2) Create a virtual env and install dependencies
uv venv
uv sync

Option B — pip

python -m venv .venv
source .venv/bin/activate  # macOS/Linux
pip install -r <(python - <<'PY'
import tomllib, sys
data=tomllib.loads(open('pyproject.toml','rb').read())
deps=data['project']['dependencies']
print('\n'.join(deps))
PY
)

Usage

All commands below assume you run them from the repository root.


Part 1 - Descriptive statistics

The describe.py script provides descriptive statistics for each numerical feature in the dataset. It calculates and displays a summary table containing:

  • Count of non-null values
  • Mean (average)
  • Standard deviation
  • Minimum value
  • 25th, 50th (median), and 75th percentiles
  • Maximum value
  • Range (max - min)
  • Variance
  • Interquartile range (IQR)

To better understand the underlying statistics, all metrics are implemented from scratch rather than using library functions. The implementation can be found in 01_data-analysis/describe.py.

uv run python ./01_data-analysis/describe.py --dataset ./datasets/dataset_train.csv
# or with venv activated
python ./01_data-analysis/describe.py --dataset ./datasets/dataset_train.csv

Part 2 - Data visualization

This section provides scripts to explore feature distributions and relationships, colored by Hogwarts house.

  • 02_01_histogram/histogram.py
uv run python ./02_data-visualization/02_01_histogram/histogram.py --dataset ./datasets/dataset_train.csv
  • Builds histograms for each numerical feature, split by house.

  • Helps compare distributions and spot skewness/outliers across classes.

  • Analysis: The features 'Arithmancy' and 'Care for the Magical Creatures' show homogeneous distributions across houses, suggesting they may not be informative for classification purposes.

  • 02_02_scatter-plot/

uv run python ./02_data-visualization/02_02_scatter-plot/scatter_plot_v2.py --dataset ./datasets/dataset_train.csv
  • scatter_plot_v1.py: Basic scatter plots showing the relationship between two selected features, colored by house.

  • scatter_plot_v2.py: Same as v1, with Pearson's correlation annotated to quantify linear relationships.

  • *_featext.py: Versions of v1/v2 that include features extracted through tokenizing for 'First Name', 'Last Name' (tokenizing) 'Birthday'(datetime formatting) and 'Best Hand'(binary).

  • Analysis: The features 'Astronomy' and 'Defense Against The Dark Arts' show a strong negative correlation, suggesting that students who excel in one subject tend to perform poorly in the other.

  • 02_03_pair-plot/

uv run python ./02_data-visualization/02_03_pair-plot/pair_plot_featext.py --dataset ./datasets/dataset_train.csv
  • pair_plot.py: Pair plot (grid of scatter plots + diagonals) for numerical features, colored by house.
  • *_featext.py: Pair plot using the extended feature set (with feature extractions).
  • Analysis: The pair plot confirms that the features 'Arithmancy' and 'Care for the Magical Creatures' are not informative for classification purposes. The scatter plots involving these features show horizontal clustering patterns, indicating that the Y-axis values (these features) do not contribute meaningful information for distinguishing between houses. This validates their exclusion from the classification model.
  • Feature Exctraction Analysis: With the feature extraction version of the script, we can observe that the non-numerical features (First Name, Last Name, Birthday, Best Hand) show no significant patterns or correlations with house assignments. The scatter plots and pair plots reveal random distributions without clear clustering, indicating these features provide minimal predictive value for the classification model. We will therefore exclude them as well.

Part 3.1 - Train logistic regression

Trains 4 one-vs-all models (Gryffindor, Slytherin, Ravenclaw, Hufflepuff), reports accuracy, and saves weights to weights.npy.

uv run python ./03_logistic-regression/train.py --dataset ./datasets/dataset_train.csv
# or
python ./03_logistic-regression/train.py --dataset ./datasets/dataset_train.csv

Key outputs:

  • weights.npy: learned weight matrix (one model per house)
  • Logs: periodic loss, early stopping, final metrics

Part 3.2 - Predict houses on the test set

Loads weights.npy, applies the same preprocessing, and writes houses.csv.

uv run python ./03_logistic-regression/predict.py --dataset ./datasets/dataset_test.csv
# or
python ./03_logistic-regression/predict.py --dataset ./datasets/dataset_test.csv

Key outputs:

  • houses.csv: CSV file with a Hogwarts House column and an index named Index

Part 3.3 - Evaluate predictions

Evaluate the houses.csv predictions against ground truth labels using dataset_truth.csv.

Requirements:

  • Ensure houses.csv and dataset_truth.csv are in the same directory as evaluate.py (project root).

Run:

uv run python ./evaluate.py
# or
python ./evaluate.py

Output:

  • Prints the accuracy score on the test set (e.g., Your score on test set: 0.XXX).
  • Friendly feedback message depending on the achieved score.
%> python ./evaluate.py
Your score on test set: 0.990
Good job! Mc Gonagall congratulates you.

Preprocessing notes

  • Irrelevant columns removed: the first 6 identity columns, Arithmancy, Care of Magical Creatures
  • Normalization: per-feature z-score
  • Training: rows with NaN are dropped
  • Prediction: NaNs are imputed using the feature median

Code quality

This project uses Ruff for linting/formatting (see pyproject.toml).

uv run ruff check .
uv run ruff format .

Troubleshooting

  • ModuleNotFoundError: No module named 'pandas' → make sure you installed dependencies (uv sync or pip install) and activated the virtual environment.
  • Python version issues → use Python 3.12 as specified in pyproject.toml.
  • Missing output files (weights.npy, houses.csv) → check your working directory and write permissions.

Data

Place training/test CSV files in datasets/. Default script names are dataset_train.csv and dataset_test.csv (override with --dataset). Place the ground-truth file dataset_truth.csv at the repository root to use evaluate.py.

About

A from-scratch implementation of Multiclass Logistic Regression inspired by the Sorting Hat. Features in-depth data analysis using performance metrics and visualizations to assess feature relevance and model behavior.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%