Skip to content

jxpress/setfit-pytorch-lightning

Repository files navigation

Setfit-PyTorch-Lightning

main_theme

🎉 We are happy to be featured in the official SetFit repository.


🤗 About SetFit

The SetFit provides a strong method of few-shot learning for text classification. With SetFit, you can create an AI with an accuracy comparable to GPT3 with as little as a few dozen data points. You can see the official paper, blog, and code of SetFit.

If you want to run SetFit instantaneously, you can access here and find some example notebooks to run SetFit.

This repository provides code that allows SetFit to run in PyTorch Lightning to facilitate parameter, experiment management and so on.

This repository is created from lightning-hydra-template


🚀 How to use this repository

step 0: create miniconda GPU environment and operation check

Create miniconda GPU environment

# [OPTIONAL] create conda environment
conda create -n myenv python=3.9
conda activate myenv
# install requirements
pip install -r requirements.txt

Operation check

Enter the following code to execute the sample code (classification of sst2).

make operation-check

or

python src/train.py ++trainer.fast_dev_run=true

step 1. Custom LightningDataModule.

Data is managed in LightningDataModule. In the sample code, training data is obtained from the sst2 dataset.

If you are not familiar with PyTorch Lightning, I recommend you to change only self.train_dataset, self.valid_dataset and self.test_dataset in __init__ of DataModule.

Parameters of Datamodule are managed in config file.

Or if you want to custom more, README of lightning-hydra-template would offer useful information.


step 2. Custom LightningModule.

Parameters that were entered into the original SetFit trainer and SetFitModel can be entered into LightnigngModule. You can manage such parameters in config file.

If you want to customize more, see here to find out how we implemented SetFit in PyTorch Lightning


step 3. Custom other options such as callback or logger.

PyTorch Lightning offers useful callbacks and logger to save a model or metrics and so on. You can manage what and how callback or logger will be called in config files.

⚠Note : if you want to use callbacks of ModelCheckpoint, use SetFitModelCheckpoint to save the model if the model head is consist of sklearn, like sample code

step 4. Execute the train

Run

python src/train.py

Or you can override experimental configtion like below

python src/train.py trainer.max_epochs=1

step 5. Load the trained model

Since SetFit model may be configured with sklearn, so please load the model as in this notebook.

🐾 others

Experiment management

For managing your experimentm you can add experimental confition to config file like this and run it like below

python src/train.py experiment=example

For more information, this might useful for you


Hyperparameter optimize

IF you want to excepuce hyperparameter optimization, just add config file like this and run like below

python src/train.py -m hparams_search=setfit_optuna

For more information, this might useful for you

😍 Welcome contributions

if you find some error or feel something, feel free to tell me by PR or Issues!! Opinions on any content are welcome!

📝 Appendix

This Implementation is based on our experience in adapting SetFit to the JX Press training template code.

JX PRESS Corporation has created and used the training template code in order to enhance team development capability and development speed.

For more information on JX's training template code, see How we at JX PRESS Corporation devise for team development of R&D that tends to become a genus and PyTorch Lightning explained by a heavy user. (Now these blogs are written in Japanese. If you want to see, please translate it into your language. We would like to translate it in English and publish it someday)

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published