Skip to content

End-to-end MLOps pipeline for automated ECG heartbeat classification with TensorFlow. Includes data versioning, CI, and a Docker+FastAPI deployment.

License

Notifications You must be signed in to change notification settings

JoseGarciaMayen/HeartWaveML

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

4 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

πŸ«€ HeartWaveML - an automatic ECG Heartbeat Classification

Python TensorFlow License Kaggle Blog

HeartWaveML

Project Overview

This project implements an advanced machine learning pipeline for automated ECG heartbeat classification, capable of detecting 5 types of cardiac conditions with clinical-grade accuracy. The system processes raw ECG signals and classifies heartbeats in <50 ms.

Live Demo

πŸ‘‰ Try the demo here

API GIF

Model results

Metric CONVXGB XGB+feat CNN+MLP
Overall Accuracy 98.51% 98.48% 98.48%
Precision 98.49% 98.45% 98.56%
Recall 98.51% 98.48% 98.48%
F1-Score (weighted) 98.48% 98.46% 98.51%
F1-Score (macro) 92.87% 92.13% 91.51%

Features

  • Data preprocessing and feature extraction from raw ECG signals.

  • Tuning and training of various ML models using tensorflow.

  • Model evaluation using appropriate metrics for multiclass classification.

  • Experiment tracking using MLflow.

  • Notebook for interactive experiments and visualization here

  • DVC with Dagshub s3 bucket for data versioning and keeping track of our models.

  • Docker + FastAPI to serve an easy-to-use interactive API.

  • Continuous Integration (CI) using Github Actions.

Quick Start

There are three ways to run HeartWaveML:

1️⃣ Run only the API (via Docker)

If you only need the API, simply pull the Docker image (<600MB):

docker pull josegm61/heartwaveml:latest
docker run -p 8000:8000 josegm61/heartwaveml:latest

The API will be running on http://localhost:8000 You can open web/index.html in your browser to interact with it. You can also see every endpoint at the Swagger UI

2️⃣ Use pretrained models and datasets (via DVC)

If you want to use the trained models and datasets:

dvc pull
pip install -r requirements.txt

This will fetch the models and datasets tracked with DVC and install dependencies (you probably will need a Dagshub account)

3️⃣ Train models from scratch

If you prefer to generate the dataset and train the models yourself:

pip install -r requirements.txt
python -m src.data.generate_data
python -m src.tuning.tune_convxgb.py

You can tune or train the model you want by changing src.tuning.tune_convxgb.py and use the api with

python -m src.api

This is the recommended option if you want to use this repo as a template to train your own models and try other combinations

Model Design

Model Architecture

Project Structure

HeartWaveML/
β”œβ”€β”€ .dvc/                         # DVC control files
β”œβ”€β”€ .github/workflows/main.yml    # CI pipeline with GitHub Actions
β”œβ”€β”€ assets/                       # Photos and videos
β”œβ”€β”€ data/                         # Datasets (tracked in DVC)   
β”œβ”€β”€ src/                          # Source code
β”‚   β”œβ”€β”€ data/
β”‚   β”‚   β”œβ”€β”€ download_dataset.py   # Script to download dataset
β”‚   β”‚   └── generate_data.py      # Script to generate data
β”‚   β”œβ”€β”€ saved_models/             # Trained models (tracked in DVC)   
β”‚   β”œβ”€β”€ training/                 # Training logic
β”‚   β”œβ”€β”€ tuning/                   # Hyperparameter tuning
β”‚   β”œβ”€β”€ api.py                    # API to serve the model
β”‚   β”œβ”€β”€ evaluate.py               # Model evaluation
β”‚   β”œβ”€β”€ predict.py                # Run predictions on new data
β”‚   β”œβ”€β”€ preprocessing.py          # Data preprocessing functions
β”‚   └── utils.py                  # Helper functions
β”œβ”€β”€ web/
β”‚   └── index.html                # Web interface
β”œβ”€β”€ .dockerignore                 # Ignore files in Docker builds
β”œβ”€β”€ .gitignore                    # Ignore files in git
β”œβ”€β”€ Dockerfile                    # Docker image definition
β”œβ”€β”€ dvc.lock                      # Exact DVC state for data/pipelines
β”œβ”€β”€ dvc.yaml                      # DVC pipeline definitions
β”œβ”€β”€ LICENSE                       # Project license
β”œβ”€β”€ README.md                     # Main documentation
β”œβ”€β”€ requirements_api.txt          # API dependencies
└── requirements.txt              # Core dependencies


Clinical Impact

This model provides a scalable solution for cardiac monitoring, combining clinical-grade reliability with unparalleled speed.

  • High-Accuracy Screening: 98.5% accuracy ensures reliable detection of 5 types of cardiac conditions, a rate comparable to human experts.

  • Real-Time Analysis: With an average inference time of under 50 ms per heartbeat, the system enables real-time, continuous monitoring, and the rapid processing of massive datasets.

  • Augments Professional Expertise: By automating the initial screening process, the system frees up healthcare professionals to focus their expertise on complex cases and direct patient care.

Dataset

We use the MIT-BIH Arrhythmia Database, a widely used benchmark dataset for ECG signal classification. The dataset contains 48 half-hour recordings of two-lead ambulatory ECG signals sampled at 360 Hz. Each recording is annotated with beat labels, indicating the type of each heartbeat according to standard conventions.

Each ECG segment is resampled or cropped to 187 samples, then scaled and filtered. The process of filtering and scaling is a must to improve our models performance:

Signal

There are lots of heartbeats types:

Type Distribution

So we map them into 5 classes:

class_mapping = {
    'N': 0, 'Β·': 0, 'L': 0, 'R': 0, 'e': 0, 'j': 0,           # Normal beat
    'A': 1, 'a': 1, 'J': 1, 'S': 1,                           # Supraventricular ectopic beat
    'V': 2, 'E': 2,                                           # Ventricular ectopic beat
    'F': 3,                                                   # Fusion beat
    '/': 4, 'f': 4, 'x': 4, 'Q': 4, '|': 4, '~': 4            # Unknown beat
}

And we end up with this distribution:

Class Count
0 90608
1 2781
2 7235
3 802
4 8981

We also applied SMOTE to fix the extreme class imbalance oversampling classes 1 and 3 to 5000 samples.

sampling_strategy_dict = {
    3: 5000, 1: 5000
    }

    smote = SMOTE(sampling_strategy=sampling_strategy_dict, random_state=42, k_neighbors=5)

And then we split the data into train, validation and test. To do some tests, we created various datasets:

Dataset Description
base Scaled and filtered signal
cnn Features extracted by CNN
feat Signal + Engineered features
feat_only Engineered features

Citation

Moody GB, Mark RG. The impact of the MIT-BIH Arrhythmia Database. IEEE Eng in Med and Biol 20(3):45-50 (May-June 2001). (PMID: 11446209)
Goldberger, A., Amaral, L., Glass, L., Hausdorff, J., Ivanov, P. C., Mark, R., ... & Stanley, H. E. (2000). PhysioBank, PhysioToolkit, and PhysioNet: Components of a new research resource for complex physiologic signals. Circulation [Online]. 101 (23), pp. e215–e220. RRID:SCR_007345.

Contributing

Contributions are welcome! Feel free to fork the repository and submit a pull request with your improvements. For any questions, suggestions, or feedback, please don’t hesitate to contact me at [email protected]. Your advice and collaboration are greatly appreciated!

About

End-to-end MLOps pipeline for automated ECG heartbeat classification with TensorFlow. Includes data versioning, CI, and a Docker+FastAPI deployment.

Topics

Resources

License

Stars

Watchers

Forks