This project implements an advanced machine learning pipeline for automated ECG heartbeat classification, capable of detecting 5 types of cardiac conditions with clinical-grade accuracy. The system processes raw ECG signals and classifies heartbeats in <50 ms.
π Try the demo here
Metric | CONVXGB | XGB+feat | CNN+MLP |
---|---|---|---|
Overall Accuracy | 98.51% | 98.48% | 98.48% |
Precision | 98.49% | 98.45% | 98.56% |
Recall | 98.51% | 98.48% | 98.48% |
F1-Score (weighted) | 98.48% | 98.46% | 98.51% |
F1-Score (macro) | 92.87% | 92.13% | 91.51% |
-
Data preprocessing and feature extraction from raw ECG signals.
-
Tuning and training of various ML models using tensorflow.
-
Model evaluation using appropriate metrics for multiclass classification.
-
Experiment tracking using MLflow.
-
Notebook for interactive experiments and visualization here
-
DVC with Dagshub s3 bucket for data versioning and keeping track of our models.
-
Docker + FastAPI to serve an easy-to-use interactive API.
-
Continuous Integration (CI) using Github Actions.
There are three ways to run HeartWaveML:
If you only need the API, simply pull the Docker image (<600MB):
docker pull josegm61/heartwaveml:latest
docker run -p 8000:8000 josegm61/heartwaveml:latest
The API will be running on http://localhost:8000
You can open web/index.html
in your browser to interact with it. You can also see every endpoint at the Swagger UI
If you want to use the trained models and datasets:
dvc pull
pip install -r requirements.txt
This will fetch the models and datasets tracked with DVC and install dependencies (you probably will need a Dagshub account)
If you prefer to generate the dataset and train the models yourself:
pip install -r requirements.txt
python -m src.data.generate_data
python -m src.tuning.tune_convxgb.py
You can tune or train the model you want by changing src.tuning.tune_convxgb.py
and use the api with
python -m src.api
This is the recommended option if you want to use this repo as a template to train your own models and try other combinations
HeartWaveML/
βββ .dvc/ # DVC control files
βββ .github/workflows/main.yml # CI pipeline with GitHub Actions
βββ assets/ # Photos and videos
βββ data/ # Datasets (tracked in DVC)
βββ src/ # Source code
β βββ data/
β β βββ download_dataset.py # Script to download dataset
β β βββ generate_data.py # Script to generate data
β βββ saved_models/ # Trained models (tracked in DVC)
β βββ training/ # Training logic
β βββ tuning/ # Hyperparameter tuning
β βββ api.py # API to serve the model
β βββ evaluate.py # Model evaluation
β βββ predict.py # Run predictions on new data
β βββ preprocessing.py # Data preprocessing functions
β βββ utils.py # Helper functions
βββ web/
β βββ index.html # Web interface
βββ .dockerignore # Ignore files in Docker builds
βββ .gitignore # Ignore files in git
βββ Dockerfile # Docker image definition
βββ dvc.lock # Exact DVC state for data/pipelines
βββ dvc.yaml # DVC pipeline definitions
βββ LICENSE # Project license
βββ README.md # Main documentation
βββ requirements_api.txt # API dependencies
βββ requirements.txt # Core dependencies
This model provides a scalable solution for cardiac monitoring, combining clinical-grade reliability with unparalleled speed.
-
High-Accuracy Screening: 98.5% accuracy ensures reliable detection of 5 types of cardiac conditions, a rate comparable to human experts.
-
Real-Time Analysis: With an average inference time of under 50 ms per heartbeat, the system enables real-time, continuous monitoring, and the rapid processing of massive datasets.
-
Augments Professional Expertise: By automating the initial screening process, the system frees up healthcare professionals to focus their expertise on complex cases and direct patient care.
We use the MIT-BIH Arrhythmia Database, a widely used benchmark dataset for ECG signal classification. The dataset contains 48 half-hour recordings of two-lead ambulatory ECG signals sampled at 360 Hz. Each recording is annotated with beat labels, indicating the type of each heartbeat according to standard conventions.
Each ECG segment is resampled or cropped to 187 samples, then scaled and filtered. The process of filtering and scaling is a must to improve our models performance:
There are lots of heartbeats types:
So we map them into 5 classes:
class_mapping = {
'N': 0, 'Β·': 0, 'L': 0, 'R': 0, 'e': 0, 'j': 0, # Normal beat
'A': 1, 'a': 1, 'J': 1, 'S': 1, # Supraventricular ectopic beat
'V': 2, 'E': 2, # Ventricular ectopic beat
'F': 3, # Fusion beat
'/': 4, 'f': 4, 'x': 4, 'Q': 4, '|': 4, '~': 4 # Unknown beat
}
And we end up with this distribution:
Class | Count |
---|---|
0 | 90608 |
1 | 2781 |
2 | 7235 |
3 | 802 |
4 | 8981 |
We also applied SMOTE to fix the extreme class imbalance oversampling classes 1 and 3 to 5000 samples.
sampling_strategy_dict = {
3: 5000, 1: 5000
}
smote = SMOTE(sampling_strategy=sampling_strategy_dict, random_state=42, k_neighbors=5)
And then we split the data into train, validation and test. To do some tests, we created various datasets:
Dataset | Description |
---|---|
base | Scaled and filtered signal |
cnn | Features extracted by CNN |
feat | Signal + Engineered features |
feat_only | Engineered features |
Moody GB, Mark RG. The impact of the MIT-BIH Arrhythmia Database. IEEE Eng in Med and Biol 20(3):45-50 (May-June 2001). (PMID: 11446209)
Goldberger, A., Amaral, L., Glass, L., Hausdorff, J., Ivanov, P. C., Mark, R., ... & Stanley, H. E. (2000). PhysioBank, PhysioToolkit, and PhysioNet: Components of a new research resource for complex physiologic signals. Circulation [Online]. 101 (23), pp. e215βe220. RRID:SCR_007345.
Contributions are welcome! Feel free to fork the repository and submit a pull request with your improvements. For any questions, suggestions, or feedback, please donβt hesitate to contact me at [email protected]. Your advice and collaboration are greatly appreciated!