forked from NVIDIA/NVFlare
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into flare-2241-f3-stream-rewrite-for-main
- Loading branch information
Showing
39 changed files
with
1,202 additions
and
221 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,34 @@ | ||
name: Deploy to GitHub Pages | ||
name: Deploy to gh-pages | ||
|
||
on: | ||
# Trigger the workflow every time you push to the `main` branch | ||
# Using a different branch name? Replace `main` with your branch’s name | ||
# Trigger the workflow if any web/** files are modified | ||
push: | ||
branches: [ main ] | ||
# Allows you to run this workflow manually from the Actions tab on GitHub. | ||
branches: | ||
- "main" | ||
- "2.5" | ||
paths: | ||
- 'web/**' | ||
workflow_dispatch: | ||
|
||
env: | ||
site_path: ./web | ||
version_path: / | ||
|
||
# Allow this job to clone the repo and create a page deployment | ||
permissions: | ||
contents: read | ||
contents: write | ||
pages: write | ||
id-token: write | ||
|
||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Checkout your repository using git | ||
- name: Update version_path for non-main branches | ||
if: ${{ github.ref_type == 'branch' && github.ref_name != 'main'}} | ||
run: echo version_path=/version/${{ github.ref_name }}/ >> $GITHUB_ENV | ||
|
||
- name: Checkout your repository | ||
uses: actions/checkout@v4 | ||
|
||
- name: Setup Node | ||
|
@@ -31,28 +38,20 @@ jobs: | |
cache: npm | ||
cache-dependency-path: "${{ env.site_path }}/package-lock.json" | ||
|
||
- name: Install | ||
shell: "bash" | ||
working-directory: ${{ env.site_path }} | ||
- name: Install dependencies | ||
run: npm install | ||
|
||
- name: Build | ||
shell: "bash" | ||
working-directory: ${{ env.site_path }} | ||
|
||
- name: Build project | ||
run: npm run build | ||
env: | ||
PUBLIC_GH_BRANCH: ${{ github.ref_name }} | ||
working-directory: ${{ env.site_path }} | ||
|
||
- name: Upload Pages Artifact | ||
uses: actions/upload-pages-artifact@v3 | ||
- name: Deploy | ||
uses: JamesIves/github-pages-[email protected] | ||
with: | ||
path: "${{ env.site_path }}/dist/" | ||
|
||
deploy: | ||
needs: build | ||
runs-on: ubuntu-latest | ||
environment: | ||
name: github-pages | ||
url: ${{ steps.deployment.outputs.page_url }} | ||
steps: | ||
- name: Deploy to GitHub Pages | ||
id: deployment | ||
uses: actions/deploy-pages@v4 | ||
branch: gh-pages | ||
folder: ${{ env.site_path }}/dist | ||
target-folder: ${{ env.version_path }} | ||
clean-exclude: version |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,2 @@ | ||
nvflare~=2.5.0rc | ||
openmined.psi==1.1.1 | ||
openmined-psi==2.0.5 | ||
pandas |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Federated Retrieval-Augmented Generation (RAG) | ||
The examples in this directory illustrate how to use [NVIDIA FLARE](https://nvidia.github.io/NVFlare) for RAG tasks, including: | ||
- federated embedding model training |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Embedding Model Tuning via SentenceTransformers Trainer | ||
This example shows how to use [NVIDIA FLARE](https://nvidia.github.io/NVFlare) for embedding tuning tasks, a critical component of Retrieval-Augmented Generation (RAG). | ||
|
||
It illustrates how to adapt a local training script with [SentenceTransformers](https://github.com/UKPLab/sentence-transformers) trainer to NVFlare. | ||
|
||
## Introduction | ||
[SentenceTransformers](https://sbert.net/) is a widely used framework for computing dense vector representations for texts. | ||
The models are based on transformer, achieving state-of-the-art performance in various tasks. | ||
|
||
One major application is to embed the text in vector space for later clustering and/or retrieval using similarity metrics. | ||
|
||
This example illustrates a supervised fine-tuning (SFT) scheme for an embedding model with various training datasets. | ||
|
||
## Setup | ||
Please make sure you set up virtual environment following [example root readme](../../../README.md). | ||
Install additional requirements (if you already have a specific version of nvflare installed in your environment, you may want to remove nvflare in the requirements to avoid reinstalling nvflare): | ||
``` | ||
python3 -m pip install -r requirements.txt | ||
``` | ||
Models and data will be loaded directly from Huggingface, so no need to download them manually. | ||
|
||
## Centralized Training | ||
### Single-session training | ||
Centralized trainings, as the baseline for comparison with FL results, are done with the following command: | ||
``` | ||
bash train_single_session.sh | ||
``` | ||
|
||
### Adaptation Step 1: iterative training | ||
To adapt the centralized training script to federated application, under `launch_once = true` setting, we first need to "break" the single call to `trainer.train()` into iterative calls, one for each round of training. | ||
For this purpose, we provided `utils/train_iterative.py` as an example, which is a modified version of `utils/train_single_session.py`. | ||
|
||
In the iterative training script, the `trainer.train()` call is replaced by a `for` loop, and the training epochs are split into six rounds, `unit_train_epochs = 0.25` epoch per round, in total `0.25 * 6 = 1.5` epochs, same as single session setting. | ||
|
||
The first round is trained with `trainer.train()`, then from the second round, | ||
we call `trainer.train(resume_from_checkpoint=True)` with `args.num_train_epochs` incremented by `unit_train_epochs` to continue training from the last checkpoint. | ||
|
||
To run iterative training, we use the following command: | ||
``` | ||
bash train_iterative.sh | ||
``` | ||
|
||
The training loss curves are shown below, single session and iterative scripts align with each other. | ||
|
||
![iter_single](./figs/iter_single.png) | ||
|
||
### Adaptation Step 2: federated with NVFlare | ||
Once we have the iterative training script ready with "starting model" loading capability, it can be easily adapted to a NVFlare trainer by using [Client API](../../../hello-world/ml-to-fl/pt/README.md). | ||
|
||
The major code modifications are for receiving the global model, set it as the starting point for each round's training, and returning the trained model after each local training round. | ||
|
||
## Federated Training | ||
We can use the Python JobAPI to create and run the federated training job. | ||
``` | ||
python3 train_fed.py | ||
``` | ||
|
||
## Results | ||
Below are the evaluation results on two test datasets - [stsb](https://huggingface.co/datasets/sentence-transformers/stsb) with embedding similarity evaluation, and [NLI](https://huggingface.co/datasets/sentence-transformers/all-nli) with triplet accuracy evaluation. The candidate models are: | ||
- NLI: single site training using NLI data | ||
- Squad: single site training using Squad data | ||
- Quora: single site training using Quora data | ||
- All: centralized training using the combined data (see `utils/train_single_session.py`) | ||
- Federated: three sites federated learning, each site contains its own data of NLI, Squad or Quora | ||
|
||
We listed two similarity metrics for each of the two testing datasets: | ||
```commandline | ||
bash eval_all.sh | ||
``` | ||
|
||
TrainData | STSB_pearson_cos | STSB_spearman_euc | NLI_cos_acc | NLI_euc_acc | ||
--- |------------------|-------------------|-------------| --- | ||
NLI | 0.7586 | 0.7895 | 0.8033 | 0.8045 | ||
Squad | 0.8206 | 0.8154 | 0.8051 | 0.8042 | ||
Quora | 0.8161 | 0.8121 | 0.7891 | 0.7854 | ||
All | 0.8497 | 0.8523 | 0.8426 | 0.8384 | ||
Federated | 0.8443 | 0.8367 | 0.8261 | 0.8249 | ||
|
||
As shown, the federated training results are better than individual site's, and can be close to the centralized training results, demonstrating the effectiveness of NVFlare in embedding model tuning tasks. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
for dataset_name in nli squad quora all | ||
do | ||
echo "Evaluation on model ${dataset_name}" | ||
python utils/eval_model.py --model_path /tmp/embed/cen/models_single/mpnet-base-${dataset_name}/final | ||
done | ||
|
||
echo "Evaluation on model federated" | ||
python utils/eval_model.py --model_path /tmp/embed/nvflare/workspace_api/site-1/models/mpnet-base-nli/global |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
nvflare~=2.5.0 | ||
torch | ||
datasets | ||
scikit-learn | ||
tensorboard | ||
transformers | ||
sentence-transformers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
from sentence_transformers import SentenceTransformer | ||
|
||
|
||
class SenTransModel(torch.nn.Module): | ||
def __init__(self, model_name): | ||
super(SenTransModel, self).__init__() | ||
self.model = SentenceTransformer(model_name) | ||
|
||
def forward(self, input_id): | ||
output = self.model(input_ids=input_id, return_dict=False) | ||
return output |
Oops, something went wrong.