Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support modelscope models and datasets #1481

Merged
merged 13 commits into from
Jan 7, 2025
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ For **advanced installation instructions** or if you see weird errors during ins
- Go to our official [Documentation](https://docs.unsloth.ai) for saving to GGUF, checkpointing, evaluation and more!
- We support Huggingface's TRL, Trainer, Seq2SeqTrainer or even Pytorch code!
- We're in 🤗Hugging Face's official docs! Check out the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!
- If you want to download models from the ModelScope community, please use an environment variable: `UNSLOTH_USE_MODELSCOPE=1`, and install the modelscope library by: `pip install modelscope -U`.

> unsloth_cli.py also supports `UNSLOTH_USE_MODELSCOPE=1` to download models and datasets. please remember to use the model and dataset id in the ModelScope community.

```python
from unsloth import FastLanguageModel
Expand Down
12 changes: 10 additions & 2 deletions unsloth-cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@
"""

import argparse
import os


def run(args):
import torch
from unsloth import FastLanguageModel
from datasets import load_dataset
from transformers.utils import strtobool
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
Expand Down Expand Up @@ -86,8 +89,13 @@ def formatting_prompts_func(examples):
texts.append(text)
return {"text": texts}

# Load and format dataset
dataset = load_dataset(args.dataset, split="train")
use_modelscope = strtobool(os.environ.get('UNSLOTH_USE_MODELSCOPE', 'False'))
if use_modelscope:
from modelscope import MsDataset
dataset = MsDataset.load(args.dataset, split="train")
else:
# Load and format dataset
dataset = load_dataset(args.dataset, split="train")
dataset = dataset.map(formatting_prompts_func, batched=True)
print("Data is formatted and ready!")

Expand Down
19 changes: 19 additions & 0 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@
pass
from huggingface_hub import HfFileSystem

# [TODO] Move USE_MODELSCOPE to utils
USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1"
if USE_MODELSCOPE:
import importlib
if importlib.util.find_spec("modelscope") is None:
raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`')
pass
pass

# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
from unsloth_zoo.utils import Version
transformers_version = Version(transformers_version)
Expand Down Expand Up @@ -87,6 +96,11 @@ def from_pretrained(
old_model_name = model_name
model_name = get_model_name(model_name, load_in_4bit)

if USE_MODELSCOPE and not os.path.exists(model_name):
from modelscope import snapshot_download
model_name = snapshot_download(model_name)
pass

# First check if it's a normal model via AutoConfig
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
was_disabled = are_progress_bars_disabled()
Expand Down Expand Up @@ -366,6 +380,11 @@ def from_pretrained(
old_model_name = model_name
model_name = get_model_name(model_name, load_in_4bit)

if USE_MODELSCOPE and not os.path.exists(model_name):
from modelscope import snapshot_download
model_name = snapshot_download(model_name)
pass

# First check if it's a normal model via AutoConfig
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
was_disabled = are_progress_bars_disabled()
Expand Down