Skip to content

Classifying images of flowers into 17 categories using EfficientNet-B0 and PyTorch.

License

Notifications You must be signed in to change notification settings

Behradsadeghi/flower-classification-efficientnet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

18 Commits
 
 
 
 
 
 
 
 

Repository files navigation

🌸 Flower Classification using EfficientNet-B0

This project demonstrates a deep learning pipeline to classify images of flowers into 17 different categories using a pre-trained EfficientNet-B0 model. It utilizes PyTorch and TIMM (PyTorch Image Models) for training, validation, and testing, implementing mixed precision training to optimize performance. The dataset used is the 17 Flower Classes dataset.

📚 Table of Contents

🌟 Introduction

The goal of this project is to classify flowers into 17 different classes using deep learning. We leverage transfer learning with the EfficientNet-B0 model pre-trained on ImageNet and fine-tune it on our flower dataset. The model training and evaluation are conducted using mixed precision training to reduce memory usage and improve training speed.

✨ Features

  • Transfer Learning: Utilizes the EfficientNet-B0 model pre-trained on ImageNet.
  • Mixed Precision Training: Reduces memory consumption and accelerates training.
  • Dataset Augmentation: Resizing and normalization of images to improve model generalization.
  • Gradient Accumulation: Handles larger batch sizes without exceeding GPU memory.
  • Training, Validation, and Test Phases: Clearly separated phases for training, validation, and testing to monitor performance.

⚙️ Installation

Prerequisites

Ensure you have the following installed:

  • Python 3.7+
  • PyTorch
  • TIMM (PyTorch Image Models)
  • Other dependencies: numpy, PIL, matplotlib, torchvision, tqdm, gc

📂 Dataset

The dataset contains images of 17 different flower classes. It is divided into training, validation, and test sets. You can download the dataset from Kaggle using the following link: 17 Flower Classes Dataset.

  • Train Directory: Contains the images used for training the model.
  • Validation: A portion (20%) of the training dataset is set aside for validation to monitor model performance and avoid overfitting.
  • Test Directory: Contains the images used for testing the model.

🛠️ Training

The model is trained using a PyTorch script that includes functions for:

  • Loading and Transforming Data: Resize to 299x299, normalization using ImageNet mean and std.
  • Training Loop: Implements gradient accumulation for better memory management.
  • Validation: A separate validation set, comprising 20% of the training data, is used to evaluate the model after each epoch to monitor validation accuracy and loss.
  • Saving Best Model: The model with the best validation accuracy is saved.

📊 Results

  • The best model weights based on validation accuracy are saved during training.
  • After training, the model was evaluated on the test set to determine its classification accuracy.
  • Test Accuracy: The model achieved an accuracy of 96.47% on the test set.

📝 License

This project is licensed under the MIT License. See the LICENSE file for more details.

Releases

No releases published

Packages

No packages published