diff --git a/.github/workflows/gpu_test.yaml b/.github/workflows/gpu_test.yaml index 7a664b2e29..829e9384a6 100644 --- a/.github/workflows/gpu_test.yaml +++ b/.github/workflows/gpu_test.yaml @@ -46,7 +46,7 @@ jobs: run: python -m pip install --upgrade pip - name: Install torch nightly if: ${{ matrix.torch-version == 'nightly' }} - run: python -m pip install --pre torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu121 + run: python -m pip install --pre torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu126 - name: Install torch stable if: ${{ matrix.torch-version == 'stable' }} run: python -m pip install torch torchvision torchao diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3f8e3149fc..854ee0e97c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ default_language_version: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: 6306a48f7dae5861702d573c9c247e4e9498e867 + rev: v5.0.0 hooks: - id: trailing-whitespace - id: check-ast @@ -18,7 +18,7 @@ repos: exclude: '^(.*\.svg)$' - repo: https://github.com/Lucas-C/pre-commit-hooks - rev: v1.5.4 + rev: v1.5.5 hooks: - id: insert-license files: \.py$|\.sh$ @@ -27,7 +27,7 @@ repos: - docs/license_header.txt - repo: https://github.com/pycqa/flake8 - rev: 34cbf8ef3950f43d09b85e2e45c15ae5717dc37b + rev: 7.1.1 hooks: - id: flake8 additional_dependencies: @@ -37,7 +37,7 @@ repos: args: ['--config=.flake8'] - repo: https://github.com/omnilib/ufmt - rev: v2.3.0 + rev: v2.8.0 hooks: - id: ufmt additional_dependencies: @@ -45,7 +45,7 @@ repos: - usort == 1.0.5 - repo: https://github.com/jsh9/pydoclint - rev: 94efc5f989adbea30f3534b476b2931a02c1af90 + rev: 0.5.12 hooks: - id: pydoclint args: [--config=pyproject.toml] diff --git a/README.md b/README.md index 289d433426..0d014e8d2a 100644 --- a/README.md +++ b/README.md @@ -170,7 +170,7 @@ pip install torchtune ```bash # Install PyTorch, torchvision, torchao nightlies -pip install --pre --upgrade torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124 +pip install --pre --upgrade torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu126 # full options are cpu/cu118/cu121/cu124/cu126 pip install --pre --upgrade torchtune --extra-index-url https://download.pytorch.org/whl/nightly/cpu ``` diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst index 5eb8fff358..979e57347f 100644 --- a/docs/source/api_ref_modules.rst +++ b/docs/source/api_ref_modules.rst @@ -48,10 +48,10 @@ model specific tokenizers. :toctree: generated/ :nosignatures: - tokenizers.SentencePieceBaseTokenizer - tokenizers.TikTokenBaseTokenizer - tokenizers.ModelTokenizer - tokenizers.BaseTokenizer + transforms.tokenizers.SentencePieceBaseTokenizer + transforms.tokenizers.TikTokenBaseTokenizer + transforms.tokenizers.ModelTokenizer + transforms.tokenizers.BaseTokenizer Tokenizer Utilities ------------------- @@ -61,8 +61,8 @@ These are helper methods that can be used by any tokenizer. :toctree: generated/ :nosignatures: - tokenizers.tokenize_messages_no_special_tokens - tokenizers.parse_hf_tokenizer_json + transforms.tokenizers.tokenize_messages_no_special_tokens + transforms.tokenizers.parse_hf_tokenizer_json PEFT Components diff --git a/docs/source/api_ref_rlhf.rst b/docs/source/api_ref_rlhf.rst index e08c699cb4..0042857e34 100644 --- a/docs/source/api_ref_rlhf.rst +++ b/docs/source/api_ref_rlhf.rst @@ -16,4 +16,3 @@ Components and losses for RLHF algorithms like PPO and DPO. loss.PPOLoss loss.DPOLoss loss.RSOLoss - loss.SimPOLoss diff --git a/docs/source/basics/custom_components.rst b/docs/source/basics/custom_components.rst index f252cb197e..0f742644dc 100644 --- a/docs/source/basics/custom_components.rst +++ b/docs/source/basics/custom_components.rst @@ -117,7 +117,7 @@ our models in torchtune - see :func:`~torchtune.models.llama3_2_vision.llama3_2_ # from torchtune.datasets import SFTDataset, PackedDataset from torchtune.data import InputOutputToMessages - from torchtune.modules.tokenizers import ModelTokenizer + from torchtune.modules.transforms.tokenizers import ModelTokenizer # Example builder function for a custom code instruct dataset not in torchtune, but using # different dataset building blocks from torchtune diff --git a/docs/source/basics/message_transforms.rst b/docs/source/basics/message_transforms.rst index 9f92659128..9c5f88f707 100644 --- a/docs/source/basics/message_transforms.rst +++ b/docs/source/basics/message_transforms.rst @@ -95,6 +95,7 @@ Example message transforms -------------------------- - Instruct - :class:`~torchtune.data.InputOutputToMessages` + - :class:`~torchtune.data.AlpacaToMessages` - Chat - :class:`~torchtune.data.ShareGPTToMessages` - :class:`~torchtune.data.OpenAIToMessages` diff --git a/docs/source/basics/model_transforms.rst b/docs/source/basics/model_transforms.rst index c10cb1abd8..71e7e08bd5 100644 --- a/docs/source/basics/model_transforms.rst +++ b/docs/source/basics/model_transforms.rst @@ -101,7 +101,7 @@ The following methods are required on the model transform: .. code-block:: python - from torchtune.modules.tokenizers import ModelTokenizer + from torchtune.modules.transforms.tokenizers import ModelTokenizer from torchtune.modules.transforms import Transform class MyMultimodalTransform(ModelTokenizer, Transform): diff --git a/docs/source/basics/tokenizers.rst b/docs/source/basics/tokenizers.rst index d637961c54..47be88fe0c 100644 --- a/docs/source/basics/tokenizers.rst +++ b/docs/source/basics/tokenizers.rst @@ -168,7 +168,7 @@ For example, here we change the ``"<|begin_of_text|>"`` and ``"<|end_of_text|>"` Base tokenizers --------------- -:class:`~torchtune.modules.tokenizers.BaseTokenizer` are the underlying byte-pair encoding modules that perform the actual raw string to token ID conversion and back. +:class:`~torchtune.modules.transforms.tokenizers.BaseTokenizer` are the underlying byte-pair encoding modules that perform the actual raw string to token ID conversion and back. In torchtune, they are required to implement ``encode`` and ``decode`` methods, which are called by the :ref:`model_tokenizers` to convert between raw text and token IDs. @@ -202,13 +202,13 @@ between raw text and token IDs. """ pass -If you load any :ref:`model_tokenizers`, you can see that it calls its underlying :class:`~torchtune.modules.tokenizers.BaseTokenizer` +If you load any :ref:`model_tokenizers`, you can see that it calls its underlying :class:`~torchtune.modules.transforms.tokenizers.BaseTokenizer` to do the actual encoding and decoding. .. code-block:: python from torchtune.models.mistral import mistral_tokenizer - from torchtune.modules.tokenizers import SentencePieceBaseTokenizer + from torchtune.modules.transforms.tokenizers import SentencePieceBaseTokenizer m_tokenizer = mistral_tokenizer("/tmp/Mistral-7B-v0.1/tokenizer.model") # Mistral uses SentencePiece for its underlying BPE @@ -227,7 +227,7 @@ to do the actual encoding and decoding. Model tokenizers ---------------- -:class:`~torchtune.modules.tokenizers.ModelTokenizer` are specific to a particular model. They are required to implement the ``tokenize_messages`` method, +:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` are specific to a particular model. They are required to implement the ``tokenize_messages`` method, which converts a list of Messages into a list of token IDs. .. code-block:: python @@ -259,7 +259,7 @@ is because they add all the necessary special tokens or prompt templates require .. code-block:: python from torchtune.models.mistral import mistral_tokenizer - from torchtune.modules.tokenizers import SentencePieceBaseTokenizer + from torchtune.modules.transforms.tokenizers import SentencePieceBaseTokenizer from torchtune.data import Message m_tokenizer = mistral_tokenizer("/tmp/Mistral-7B-v0.1/tokenizer.model") diff --git a/docs/source/install.rst b/docs/source/install.rst index 7b5f908da1..4a4f55fdb0 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -19,7 +19,7 @@ nightly versions with the following commands: pip install torch torchvision torchao # Or nightly install for latest features - pip install --pre torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124 + pip install --pre torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu126 # full options are cpu/cu118/cu121/cu124/cu126 Install via PyPI @@ -88,4 +88,4 @@ to the package *without* installing via ``git clone``, you can install with the If you already have PyTorch installed, torchtune will default to using that version. However, if you want to use the nightly version of PyTorch, you can append the ``--force-reinstall`` option to the above command. If you opt for this install method, you will likely need to change the "cpu" suffix in the index url to match your CUDA -version. For example, if you are running CUDA 12, your index url would be "https://download.pytorch.org/whl/nightly/cu121". +version. For example, if you are running CUDA 12, your index url would be "https://download.pytorch.org/whl/nightly/cu126". diff --git a/docs/source/recipes/dpo.rst b/docs/source/recipes/dpo.rst index 5fdb455a35..c4854ef81e 100644 --- a/docs/source/recipes/dpo.rst +++ b/docs/source/recipes/dpo.rst @@ -56,8 +56,6 @@ To use any of these, simply use the ``loss`` config entry or flag through the :r loss=torchtune.modules.loss.RSOLoss \ gamma=0.5 -.. todo (@SalmanMohammadi) point to an example repo for SimPO - For a deeper understanding of the different levers you can pull when using this recipe, see our documentation for the different PEFT training paradigms we support: diff --git a/docs/source/tutorials/e2e_flow.rst b/docs/source/tutorials/e2e_flow.rst index 21571d2e30..9f39de4b6b 100644 --- a/docs/source/tutorials/e2e_flow.rst +++ b/docs/source/tutorials/e2e_flow.rst @@ -275,18 +275,20 @@ Let's first copy over the config to our local working directory so we can make c $ tune cp generation ./custom_generation_config.yaml Copied file to custom_generation_config.yaml + $ mkdir /tmp/torchtune/llama3_2_3B/lora_single_device/out Let's modify ``custom_generation_config.yaml`` to include the following changes. Again, you only need to replace two fields: ``output_dir`` and ``checkpoint_files`` .. code-block:: yaml - output_dir: /tmp/torchtune/llama3_2_3B/lora_single_device/epoch_0 + checkpoint_dir: /tmp/torchtune/llama3_2_3B/lora_single_device/epoch_0 + output_dir: /tmp/torchtune/llama3_2_3B/lora_single_device/out # Tokenizer tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer - path: ${output_dir}/original/tokenizer.model + path: ${checkpoint_dir}/original/tokenizer.model prompt_template: null model: @@ -295,7 +297,7 @@ Let's modify ``custom_generation_config.yaml`` to include the following changes. checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: ${output_dir} + checkpoint_dir: ${checkpoint_dir} checkpoint_files: [ ft-model-00001-of-00002.safetensors, ft-model-00002-of-00002.safetensors, @@ -312,8 +314,8 @@ Let's modify ``custom_generation_config.yaml`` to include the following changes. # Generation arguments; defaults taken from gpt-fast prompt: - system: null - user: "Tell me a joke. " + system: null + user: "Tell me a joke. " max_new_tokens: 300 temperature: 0.6 # 0.8 and 0.6 are popular values to try top_k: 300 @@ -330,7 +332,7 @@ these parameters. .. code-block:: text - $ tune run generate --config ./custom_generation_config.yaml prompt="tell me a joke. " + $ tune run generate --config ./custom_generation_config.yaml prompt.user="Tell me a joke. " Tell me a joke. Here's a joke for you: What do you call a fake noodle? diff --git a/docs/source/tutorials/llama3.rst b/docs/source/tutorials/llama3.rst index 938eafae27..6ceac07e8f 100644 --- a/docs/source/tutorials/llama3.rst +++ b/docs/source/tutorials/llama3.rst @@ -230,7 +230,7 @@ Running generation with our LoRA-finetuned model, we see the following output: .. code-block:: bash tune run generate --config ./custom_generation_config.yaml \ - prompt="Hello, my name is" + prompt.user="Hello, my name is" [generate.py:122] Hello, my name is Sarah and I am a busy working mum of two young children, living in the North East of England. ... diff --git a/pyproject.toml b/pyproject.toml index 87ed1fb89e..f94732b58a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,7 @@ target-version = ["py38"] [tool.pydoclint] style = 'google' check-return-types = 'False' -exclude = 'tests/torchtune/models/(\w+)/scripts/' +exclude = 'tests/torchtune/models/(\w+)/scripts/|recipes/|torchtune/modules/_export' [tool.pytest.ini_options] addopts = ["--showlocals", "--import-mode=prepend", "--without-integration", "--without-slow-integration"] diff --git a/recipes/configs/code_llama2/7B_full_low_memory.yaml b/recipes/configs/code_llama2/7B_full_low_memory.yaml index ad941803bb..7688e83b56 100644 --- a/recipes/configs/code_llama2/7B_full_low_memory.yaml +++ b/recipes/configs/code_llama2/7B_full_low_memory.yaml @@ -64,6 +64,7 @@ optimizer: optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/code_llama2/7B_lora_single_device.yaml b/recipes/configs/code_llama2/7B_lora_single_device.yaml index 416c11fc27..fd266598c3 100644 --- a/recipes/configs/code_llama2/7B_lora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_lora_single_device.yaml @@ -72,6 +72,7 @@ lr_scheduler: num_warmup_steps: 100 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/code_llama2/7B_qlora_single_device.yaml b/recipes/configs/code_llama2/7B_qlora_single_device.yaml index 9f3f1dbe4e..54e6c9b2ff 100644 --- a/recipes/configs/code_llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_qlora_single_device.yaml @@ -71,6 +71,7 @@ lr_scheduler: num_warmup_steps: 100 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/code_llama2/evaluation.yaml b/recipes/configs/code_llama2/evaluation.yaml index 596170cb1c..51f4d48e29 100644 --- a/recipes/configs/code_llama2/evaluation.yaml +++ b/recipes/configs/code_llama2/evaluation.yaml @@ -3,6 +3,8 @@ # To launch, run the following command: # tune run eleuther_eval --config code_llama2/evaluation +output_dir: ./ # Not needed + # Model arguments model: _component_: torchtune.models.code_llama2.code_llama2_7b diff --git a/recipes/configs/gemma/2B_full.yaml b/recipes/configs/gemma/2B_full.yaml index fa692e0f0d..722885f1bb 100644 --- a/recipes/configs/gemma/2B_full.yaml +++ b/recipes/configs/gemma/2B_full.yaml @@ -57,6 +57,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/gemma/2B_lora.yaml b/recipes/configs/gemma/2B_lora.yaml index 4a79abcf99..d601f82087 100644 --- a/recipes/configs/gemma/2B_lora.yaml +++ b/recipes/configs/gemma/2B_lora.yaml @@ -69,6 +69,7 @@ batch_size: 4 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/gemma/2B_lora_single_device.yaml b/recipes/configs/gemma/2B_lora_single_device.yaml index e0b473e4ec..9be8a5b4bc 100644 --- a/recipes/configs/gemma/2B_lora_single_device.yaml +++ b/recipes/configs/gemma/2B_lora_single_device.yaml @@ -68,6 +68,7 @@ batch_size: 4 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/gemma/2B_qlora_single_device.yaml b/recipes/configs/gemma/2B_qlora_single_device.yaml index 842dc2580f..fa4fb05ad6 100644 --- a/recipes/configs/gemma/2B_qlora_single_device.yaml +++ b/recipes/configs/gemma/2B_qlora_single_device.yaml @@ -68,6 +68,7 @@ batch_size: 4 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/gemma/7B_full.yaml b/recipes/configs/gemma/7B_full.yaml index 47206ed291..f40ac5f61e 100644 --- a/recipes/configs/gemma/7B_full.yaml +++ b/recipes/configs/gemma/7B_full.yaml @@ -59,6 +59,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/gemma/7B_lora.yaml b/recipes/configs/gemma/7B_lora.yaml index 3383bae31c..cfee4c7c6a 100644 --- a/recipes/configs/gemma/7B_lora.yaml +++ b/recipes/configs/gemma/7B_lora.yaml @@ -71,6 +71,7 @@ batch_size: 4 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/gemma/7B_lora_single_device.yaml b/recipes/configs/gemma/7B_lora_single_device.yaml index e055b09bd5..5c53f71334 100644 --- a/recipes/configs/gemma/7B_lora_single_device.yaml +++ b/recipes/configs/gemma/7B_lora_single_device.yaml @@ -70,6 +70,7 @@ batch_size: 8 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/gemma/7B_qlora_single_device.yaml b/recipes/configs/gemma/7B_qlora_single_device.yaml index 01fb823b4a..debf0d34e3 100644 --- a/recipes/configs/gemma/7B_qlora_single_device.yaml +++ b/recipes/configs/gemma/7B_qlora_single_device.yaml @@ -70,6 +70,7 @@ batch_size: 4 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/gemma2/27B_full.yaml b/recipes/configs/gemma2/27B_full.yaml index 46a31b6821..b06bed3355 100644 --- a/recipes/configs/gemma2/27B_full.yaml +++ b/recipes/configs/gemma2/27B_full.yaml @@ -56,6 +56,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/gemma2/27B_lora.yaml b/recipes/configs/gemma2/27B_lora.yaml index c8b96ed55e..4ee9cdc8fd 100644 --- a/recipes/configs/gemma2/27B_lora.yaml +++ b/recipes/configs/gemma2/27B_lora.yaml @@ -68,6 +68,7 @@ batch_size: 4 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/gemma2/27B_lora_single_device.yaml b/recipes/configs/gemma2/27B_lora_single_device.yaml index 74af4c22b5..4268f9e899 100644 --- a/recipes/configs/gemma2/27B_lora_single_device.yaml +++ b/recipes/configs/gemma2/27B_lora_single_device.yaml @@ -67,6 +67,7 @@ batch_size: 2 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/gemma2/27B_qlora_single_device.yaml b/recipes/configs/gemma2/27B_qlora_single_device.yaml index 2f11ef13ab..eee85b3302 100644 --- a/recipes/configs/gemma2/27B_qlora_single_device.yaml +++ b/recipes/configs/gemma2/27B_qlora_single_device.yaml @@ -67,6 +67,7 @@ batch_size: 4 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/gemma2/2B_full.yaml b/recipes/configs/gemma2/2B_full.yaml index 42b034fa2c..f04bbf8aec 100644 --- a/recipes/configs/gemma2/2B_full.yaml +++ b/recipes/configs/gemma2/2B_full.yaml @@ -58,6 +58,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/gemma2/2B_lora.yaml b/recipes/configs/gemma2/2B_lora.yaml index 3a38fc5d0c..19d73e37a1 100644 --- a/recipes/configs/gemma2/2B_lora.yaml +++ b/recipes/configs/gemma2/2B_lora.yaml @@ -70,6 +70,7 @@ batch_size: 4 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/gemma2/2B_lora_single_device.yaml b/recipes/configs/gemma2/2B_lora_single_device.yaml index 228e1b3b33..82ee52143f 100644 --- a/recipes/configs/gemma2/2B_lora_single_device.yaml +++ b/recipes/configs/gemma2/2B_lora_single_device.yaml @@ -69,6 +69,7 @@ batch_size: 8 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/gemma2/2B_qlora_single_device.yaml b/recipes/configs/gemma2/2B_qlora_single_device.yaml index 16dd23cc51..6964901e39 100644 --- a/recipes/configs/gemma2/2B_qlora_single_device.yaml +++ b/recipes/configs/gemma2/2B_qlora_single_device.yaml @@ -69,6 +69,7 @@ batch_size: 4 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/gemma2/9B_full.yaml b/recipes/configs/gemma2/9B_full.yaml index bbb31fb268..f31afea54a 100644 --- a/recipes/configs/gemma2/9B_full.yaml +++ b/recipes/configs/gemma2/9B_full.yaml @@ -56,6 +56,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/gemma2/9B_lora.yaml b/recipes/configs/gemma2/9B_lora.yaml index 3c402433a0..596415a14b 100644 --- a/recipes/configs/gemma2/9B_lora.yaml +++ b/recipes/configs/gemma2/9B_lora.yaml @@ -68,6 +68,7 @@ batch_size: 4 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/gemma2/9B_lora_single_device.yaml b/recipes/configs/gemma2/9B_lora_single_device.yaml index 7a665c1f12..b9a25f6c2e 100644 --- a/recipes/configs/gemma2/9B_lora_single_device.yaml +++ b/recipes/configs/gemma2/9B_lora_single_device.yaml @@ -67,6 +67,7 @@ batch_size: 8 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/gemma2/9B_qlora_single_device.yaml b/recipes/configs/gemma2/9B_qlora_single_device.yaml index eff7057b63..bf323daa17 100644 --- a/recipes/configs/gemma2/9B_qlora_single_device.yaml +++ b/recipes/configs/gemma2/9B_qlora_single_device.yaml @@ -67,6 +67,7 @@ batch_size: 4 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/generation.yaml b/recipes/configs/generation.yaml index c2081a1ed7..0171310aef 100644 --- a/recipes/configs/generation.yaml +++ b/recipes/configs/generation.yaml @@ -1,4 +1,9 @@ -# Config for running the InferenceRecipe in generate.py to generate output from an LLM +# Config for running the InferenceRecipe in generate.py to generate output +# from Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --ignore-patterns "*.safetensors" --hf-token # # To launch, run the following command from root torchtune directory: # tune run generate --config generation diff --git a/recipes/configs/llama2/13B_full.yaml b/recipes/configs/llama2/13B_full.yaml index 67932bbb1b..24ca1bdf76 100644 --- a/recipes/configs/llama2/13B_full.yaml +++ b/recipes/configs/llama2/13B_full.yaml @@ -61,6 +61,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/llama2/13B_lora.yaml b/recipes/configs/llama2/13B_lora.yaml index fbd8a2141d..7fa54d8075 100644 --- a/recipes/configs/llama2/13B_lora.yaml +++ b/recipes/configs/llama2/13B_lora.yaml @@ -77,6 +77,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama2/13B_qlora_single_device.yaml b/recipes/configs/llama2/13B_qlora_single_device.yaml index 69558858bd..ab4924d636 100644 --- a/recipes/configs/llama2/13B_qlora_single_device.yaml +++ b/recipes/configs/llama2/13B_qlora_single_device.yaml @@ -72,6 +72,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama2/70B_lora.yaml b/recipes/configs/llama2/70B_lora.yaml index 7fed032ec4..bd15cb0799 100644 --- a/recipes/configs/llama2/70B_lora.yaml +++ b/recipes/configs/llama2/70B_lora.yaml @@ -62,6 +62,7 @@ loss: # Training epochs: 1 max_steps_per_epoch: null +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory gradient_accumulation_steps: 1 # Use to increase effective batch size diff --git a/recipes/configs/llama2/70B_qlora.yaml b/recipes/configs/llama2/70B_qlora.yaml index b140624fc2..51f0cfd1e1 100644 --- a/recipes/configs/llama2/70B_qlora.yaml +++ b/recipes/configs/llama2/70B_qlora.yaml @@ -72,6 +72,7 @@ fsdp: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama2/7B_full.yaml b/recipes/configs/llama2/7B_full.yaml index 40fb804035..923c382728 100644 --- a/recipes/configs/llama2/7B_full.yaml +++ b/recipes/configs/llama2/7B_full.yaml @@ -60,6 +60,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/llama2/7B_full_low_memory.yaml b/recipes/configs/llama2/7B_full_low_memory.yaml index 29d157dbf6..11411374f7 100644 --- a/recipes/configs/llama2/7B_full_low_memory.yaml +++ b/recipes/configs/llama2/7B_full_low_memory.yaml @@ -65,6 +65,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training environment diff --git a/recipes/configs/llama2/7B_lora.yaml b/recipes/configs/llama2/7B_lora.yaml index fc8ec7e346..0e7666dce2 100644 --- a/recipes/configs/llama2/7B_lora.yaml +++ b/recipes/configs/llama2/7B_lora.yaml @@ -73,6 +73,7 @@ loss: # Training epochs: 1 max_steps_per_epoch: null +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory gradient_accumulation_steps: 8 # Use to increase effective batch size diff --git a/recipes/configs/llama2/7B_lora_single_device.yaml b/recipes/configs/llama2/7B_lora_single_device.yaml index ac87053ae1..f36128bbe4 100644 --- a/recipes/configs/llama2/7B_lora_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_single_device.yaml @@ -72,6 +72,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama2/7B_qat_full.yaml b/recipes/configs/llama2/7B_qat_full.yaml index a82e3f580c..da4ee51be9 100644 --- a/recipes/configs/llama2/7B_qat_full.yaml +++ b/recipes/configs/llama2/7B_qat_full.yaml @@ -56,6 +56,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/llama2/7B_qlora.yaml b/recipes/configs/llama2/7B_qlora.yaml index 49c44c27f3..5df9e0ac92 100644 --- a/recipes/configs/llama2/7B_qlora.yaml +++ b/recipes/configs/llama2/7B_qlora.yaml @@ -77,6 +77,7 @@ fsdp: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama2/7B_qlora_single_device.yaml b/recipes/configs/llama2/7B_qlora_single_device.yaml index a925ac782b..b5b39949e4 100644 --- a/recipes/configs/llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/llama2/7B_qlora_single_device.yaml @@ -71,6 +71,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3/70B_full.yaml b/recipes/configs/llama3/70B_full.yaml index 5491ae093d..ee9c914ce8 100644 --- a/recipes/configs/llama3/70B_full.yaml +++ b/recipes/configs/llama3/70B_full.yaml @@ -69,6 +69,7 @@ enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed. fsdp_cpu_offload: True +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/llama3/70B_generation_distributed.yaml b/recipes/configs/llama3/70B_generation_distributed.yaml new file mode 100644 index 0000000000..78c77ba263 --- /dev/null +++ b/recipes/configs/llama3/70B_generation_distributed.yaml @@ -0,0 +1,50 @@ +# Config for running the InferenceRecipe in dev/generate_v2.py to generate output +# using a Llama3 70B Instruct model +# +# This config assumes that you've run the following command before launching: +# tune download meta-llama/Meta-Llama-3-70B-Instruct --output-dir /tmp/Meta-Llama-3-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token +# +# To launch, run the following command from root torchtune directory: +# tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3/70B_generation_distributed + +output_dir: ./ + +# Model arguments +model: + _component_: torchtune.models.llama3.llama3_70b + +parallelize_plan: + _component_: torchtune.models.llama3.base_llama_tp_plan + +# Transform arguments +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3-70B-Instruct/original/tokenizer.model + prompt_template: null + max_seq_len: 8192 + +# Checkpointer +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00030" + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: LLAMA3 + +# Device +device: cuda +dtype: bf16 +seed: 1234 +log_level: INFO + +# Generation arguments +prompt: + system: null + user: + text: Tell a joke. +max_new_tokens: 200 +temperature: 0.6 # 0.8 and 0.6 are popular values to try +top_k: 300 diff --git a/recipes/configs/llama3/70B_lora.yaml b/recipes/configs/llama3/70B_lora.yaml index da82f156f6..dbbcf2dceb 100644 --- a/recipes/configs/llama3/70B_lora.yaml +++ b/recipes/configs/llama3/70B_lora.yaml @@ -63,6 +63,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3/8B_dora.yaml b/recipes/configs/llama3/8B_dora.yaml index a68f5cf7ff..e1c19afd65 100644 --- a/recipes/configs/llama3/8B_dora.yaml +++ b/recipes/configs/llama3/8B_dora.yaml @@ -67,6 +67,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3/8B_dora_single_device.yaml b/recipes/configs/llama3/8B_dora_single_device.yaml index f27acc3a12..a79a0a0afa 100644 --- a/recipes/configs/llama3/8B_dora_single_device.yaml +++ b/recipes/configs/llama3/8B_dora_single_device.yaml @@ -69,6 +69,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3/8B_full.yaml b/recipes/configs/llama3/8B_full.yaml index 2723d08c90..4c23331171 100644 --- a/recipes/configs/llama3/8B_full.yaml +++ b/recipes/configs/llama3/8B_full.yaml @@ -60,6 +60,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/llama3/8B_full_single_device.yaml b/recipes/configs/llama3/8B_full_single_device.yaml index ad534c62b9..d81e512b6a 100644 --- a/recipes/configs/llama3/8B_full_single_device.yaml +++ b/recipes/configs/llama3/8B_full_single_device.yaml @@ -64,6 +64,7 @@ loss: max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training environment diff --git a/recipes/configs/llama3/8B_lora.yaml b/recipes/configs/llama3/8B_lora.yaml index 7ac7bd1942..2b140ee796 100644 --- a/recipes/configs/llama3/8B_lora.yaml +++ b/recipes/configs/llama3/8B_lora.yaml @@ -72,6 +72,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3/8B_lora_single_device.yaml b/recipes/configs/llama3/8B_lora_single_device.yaml index 8b1db9d06d..689e26d92c 100644 --- a/recipes/configs/llama3/8B_lora_single_device.yaml +++ b/recipes/configs/llama3/8B_lora_single_device.yaml @@ -71,6 +71,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3/8B_qat_full.yaml b/recipes/configs/llama3/8B_qat_full.yaml index b1d9bfad5b..a25e66b22a 100644 --- a/recipes/configs/llama3/8B_qat_full.yaml +++ b/recipes/configs/llama3/8B_qat_full.yaml @@ -60,6 +60,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/llama3/8B_qat_lora.yaml b/recipes/configs/llama3/8B_qat_lora.yaml index 5f88f175ec..3979d5667a 100644 --- a/recipes/configs/llama3/8B_qat_lora.yaml +++ b/recipes/configs/llama3/8B_qat_lora.yaml @@ -68,6 +68,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3/8B_qdora_single_device.yaml b/recipes/configs/llama3/8B_qdora_single_device.yaml index 45b515476e..678ab07c89 100644 --- a/recipes/configs/llama3/8B_qdora_single_device.yaml +++ b/recipes/configs/llama3/8B_qdora_single_device.yaml @@ -70,6 +70,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3/8B_qlora_single_device.yaml b/recipes/configs/llama3/8B_qlora_single_device.yaml index 4922ada0f0..a91bd2e339 100644 --- a/recipes/configs/llama3/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3/8B_qlora_single_device.yaml @@ -70,6 +70,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3_1/405B_qlora.yaml b/recipes/configs/llama3_1/405B_qlora.yaml index 044f47b48e..3bb2c28351 100644 --- a/recipes/configs/llama3_1/405B_qlora.yaml +++ b/recipes/configs/llama3_1/405B_qlora.yaml @@ -70,6 +70,7 @@ fsdp: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3_1/70B_full.yaml b/recipes/configs/llama3_1/70B_full.yaml index 1ecf130e1a..0c4c7fce7f 100644 --- a/recipes/configs/llama3_1/70B_full.yaml +++ b/recipes/configs/llama3_1/70B_full.yaml @@ -71,6 +71,7 @@ enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed. fsdp_cpu_offload: True +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/llama3_1/70B_generation_distributed.yaml b/recipes/configs/llama3_1/70B_generation_distributed.yaml new file mode 100644 index 0000000000..d71a94f8de --- /dev/null +++ b/recipes/configs/llama3_1/70B_generation_distributed.yaml @@ -0,0 +1,50 @@ +# Config for running the InferenceRecipe in dev/generate_v2.py to generate output +# using a Llama3.1 70B Instruct model +# +# This config assumes that you've run the following command before launching: +# tune download meta-llama/Meta-Llama-3.1-70B-Instruct --output-dir /tmp/Meta-Llama-3.1-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token +# +# To launch, run the following command from root torchtune directory: +# tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3_1/70B_generation_distributed + +output_dir: ./ + +# Model arguments +model: + _component_: torchtune.models.llama3_1.llama3_1_70b + +parallelize_plan: + _component_: torchtune.models.llama3.base_llama_tp_plan + +# Transform arguments +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-70B-Instruct/original/tokenizer.model + prompt_template: null + max_seq_len: 8192 + +# Checkpointer +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00030" + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: LLAMA3 + +# Device +device: cuda +dtype: bf16 +seed: 1234 +log_level: INFO + +# Generation arguments +prompt: + system: null + user: + text: Tell a joke. +max_new_tokens: 200 +temperature: 0.6 # 0.8 and 0.6 are popular values to try +top_k: 300 diff --git a/recipes/configs/llama3_1/70B_lora.yaml b/recipes/configs/llama3_1/70B_lora.yaml index 81ef5f6875..848960ed69 100644 --- a/recipes/configs/llama3_1/70B_lora.yaml +++ b/recipes/configs/llama3_1/70B_lora.yaml @@ -62,6 +62,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3_1/8B_full.yaml b/recipes/configs/llama3_1/8B_full.yaml index 357d20356d..0e59a0f76f 100644 --- a/recipes/configs/llama3_1/8B_full.yaml +++ b/recipes/configs/llama3_1/8B_full.yaml @@ -62,6 +62,7 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 gradient_accumulation_steps: 1 # Use to increase effective batch size diff --git a/recipes/configs/llama3_1/8B_full_single_device.yaml b/recipes/configs/llama3_1/8B_full_single_device.yaml index 1429b9cc2b..094a3547eb 100644 --- a/recipes/configs/llama3_1/8B_full_single_device.yaml +++ b/recipes/configs/llama3_1/8B_full_single_device.yaml @@ -64,6 +64,7 @@ loss: max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training environment diff --git a/recipes/configs/llama3_1/8B_lora.yaml b/recipes/configs/llama3_1/8B_lora.yaml index 7303194173..2f1dd3d067 100644 --- a/recipes/configs/llama3_1/8B_lora.yaml +++ b/recipes/configs/llama3_1/8B_lora.yaml @@ -75,6 +75,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3_1/8B_lora_single_device.yaml b/recipes/configs/llama3_1/8B_lora_single_device.yaml index 0f19750219..a6c7b070fa 100644 --- a/recipes/configs/llama3_1/8B_lora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_lora_single_device.yaml @@ -74,6 +74,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3_1/8B_qat_lora.yaml b/recipes/configs/llama3_1/8B_qat_lora.yaml index 3d7c94744e..d4450f5002 100644 --- a/recipes/configs/llama3_1/8B_qat_lora.yaml +++ b/recipes/configs/llama3_1/8B_qat_lora.yaml @@ -71,6 +71,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3_1/8B_qlora_single_device.yaml b/recipes/configs/llama3_1/8B_qlora_single_device.yaml index 3386601917..3793f7612c 100644 --- a/recipes/configs/llama3_1/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_qlora_single_device.yaml @@ -73,6 +73,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3_2/1B_full.yaml b/recipes/configs/llama3_2/1B_full.yaml index 25c7de45c1..ecf1778c2e 100644 --- a/recipes/configs/llama3_2/1B_full.yaml +++ b/recipes/configs/llama3_2/1B_full.yaml @@ -68,6 +68,7 @@ device: cuda # Memory management enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/llama3_2/1B_full_single_device.yaml b/recipes/configs/llama3_2/1B_full_single_device.yaml index e24fc56219..39d635fac1 100644 --- a/recipes/configs/llama3_2/1B_full_single_device.yaml +++ b/recipes/configs/llama3_2/1B_full_single_device.yaml @@ -60,6 +60,7 @@ loss: max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training environment diff --git a/recipes/configs/llama3_2/1B_lora.yaml b/recipes/configs/llama3_2/1B_lora.yaml index 15e14be3b1..33ec005530 100644 --- a/recipes/configs/llama3_2/1B_lora.yaml +++ b/recipes/configs/llama3_2/1B_lora.yaml @@ -71,6 +71,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3_2/1B_lora_single_device.yaml b/recipes/configs/llama3_2/1B_lora_single_device.yaml index 6e441d4711..72c03f55d9 100644 --- a/recipes/configs/llama3_2/1B_lora_single_device.yaml +++ b/recipes/configs/llama3_2/1B_lora_single_device.yaml @@ -69,6 +69,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3_2/1B_qat_lora.yaml b/recipes/configs/llama3_2/1B_qat_lora.yaml index bffc52a4ac..ce4d58b2b5 100644 --- a/recipes/configs/llama3_2/1B_qat_lora.yaml +++ b/recipes/configs/llama3_2/1B_qat_lora.yaml @@ -67,6 +67,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3_2/1B_qlora_single_device.yaml b/recipes/configs/llama3_2/1B_qlora_single_device.yaml index 99165c806f..57413ed366 100644 --- a/recipes/configs/llama3_2/1B_qlora_single_device.yaml +++ b/recipes/configs/llama3_2/1B_qlora_single_device.yaml @@ -69,6 +69,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3_2/3B_full.yaml b/recipes/configs/llama3_2/3B_full.yaml index 0703437596..4de06b4340 100644 --- a/recipes/configs/llama3_2/3B_full.yaml +++ b/recipes/configs/llama3_2/3B_full.yaml @@ -68,6 +68,7 @@ device: cuda # Memory management enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/llama3_2/3B_full_single_device.yaml b/recipes/configs/llama3_2/3B_full_single_device.yaml index 052c524019..51f046cfa2 100644 --- a/recipes/configs/llama3_2/3B_full_single_device.yaml +++ b/recipes/configs/llama3_2/3B_full_single_device.yaml @@ -62,6 +62,7 @@ loss: max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training environment diff --git a/recipes/configs/llama3_2/3B_lora.yaml b/recipes/configs/llama3_2/3B_lora.yaml index 9575df0f78..624032915b 100644 --- a/recipes/configs/llama3_2/3B_lora.yaml +++ b/recipes/configs/llama3_2/3B_lora.yaml @@ -72,6 +72,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3_2/3B_lora_single_device.yaml b/recipes/configs/llama3_2/3B_lora_single_device.yaml index 451455253a..6ac050a789 100644 --- a/recipes/configs/llama3_2/3B_lora_single_device.yaml +++ b/recipes/configs/llama3_2/3B_lora_single_device.yaml @@ -71,6 +71,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3_2/3B_qat_lora.yaml b/recipes/configs/llama3_2/3B_qat_lora.yaml index 64985de1ea..6addae108e 100644 --- a/recipes/configs/llama3_2/3B_qat_lora.yaml +++ b/recipes/configs/llama3_2/3B_qat_lora.yaml @@ -68,6 +68,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3_2/3B_qlora_single_device.yaml b/recipes/configs/llama3_2/3B_qlora_single_device.yaml index 3cc504f1d0..3de98fc8c7 100644 --- a/recipes/configs/llama3_2/3B_qlora_single_device.yaml +++ b/recipes/configs/llama3_2/3B_qlora_single_device.yaml @@ -70,6 +70,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3_2/8B_to_1B_KD_lora_single_device.yaml b/recipes/configs/llama3_2/8B_to_1B_KD_lora_single_device.yaml index 0a2dfea9f5..650bac96d2 100644 --- a/recipes/configs/llama3_2/8B_to_1B_KD_lora_single_device.yaml +++ b/recipes/configs/llama3_2/8B_to_1B_KD_lora_single_device.yaml @@ -91,6 +91,7 @@ kd_ratio: 0.5 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3_2/evaluation.yaml b/recipes/configs/llama3_2/evaluation.yaml index 44b4ebcead..0c86e97174 100644 --- a/recipes/configs/llama3_2/evaluation.yaml +++ b/recipes/configs/llama3_2/evaluation.yaml @@ -3,6 +3,8 @@ # To launch, run the following command: # tune run eleuther_eval --config llama3_2/evaluation +output_dir: ./ # Not needed + # Model Arguments model: _component_: torchtune.models.llama3_2.llama3_2_3b diff --git a/recipes/configs/llama3_2_vision/11B_generation_v2.yaml b/recipes/configs/llama3_2_vision/11B_generation_v2.yaml index c78e0e52b6..11fd14f8d1 100644 --- a/recipes/configs/llama3_2_vision/11B_generation_v2.yaml +++ b/recipes/configs/llama3_2_vision/11B_generation_v2.yaml @@ -7,7 +7,7 @@ # To launch, run the following command from root torchtune directory: # tune run dev/generate_v2 --config llama3_2_vision/generation_v2 -output_dir: ./ # Not needed +output_dir: ./ # Model arguments model: diff --git a/recipes/configs/llama3_2_vision/11B_lora_multi_dataset.yaml b/recipes/configs/llama3_2_vision/11B_lora_multi_dataset.yaml index 87afa718b2..c947e7a28d 100644 --- a/recipes/configs/llama3_2_vision/11B_lora_multi_dataset.yaml +++ b/recipes/configs/llama3_2_vision/11B_lora_multi_dataset.yaml @@ -1,20 +1,22 @@ -# Config for multi-device LoRA finetuning in lora_finetune_distributed_td.py +# Config for multi-device LoRA finetuning in lora_finetune_distributed_multi_dataset.py # using a Llama3.2 11B Vision Instruct model # # This config assumes that you've run the following command before launching: # tune download meta-llama/Llama-3.2-11B-Vision-Instruct --output-dir /tmp/Llama-3.2-11B-Vision-Instruct --ignore-patterns "original/consolidated*" # # To launch on 2 devices, run the following command from root: -# tune run --nproc_per_node 2 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td +# tune run --nproc_per_node 2 lora_finetune_distributed_multi_dataset --config llama3_2_vision/11B_lora_multi_dataset # # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training: -# tune run --nproc_per_node 2 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td checkpointer.checkpoint_dir= +# tune run --nproc_per_node 2 lora_finetune_distributed_multi_dataset --config llama3_2_vision/11B_lora_multi_dataset checkpointer.checkpoint_dir= # # This config works best when the model is being fine-tuned on 2+ GPUs. # For single device LoRA finetuning please use 11B_lora_single_device.yaml # or 11B_qlora_single_device.yaml +output_dir: /tmp/torchtune/llama3_2_vision_11B/lora_multi_dataset # /tmp may be deleted by your system. Change it to your preference. + # Model arguments model: _component_: torchtune.models.llama3_2_vision.lora_llama3_2_vision_11b @@ -44,7 +46,7 @@ checkpointer: filename_format: model-{}-of-{}.safetensors max_filename: "00005" recipe_checkpoint: null - output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/ + output_dir: ${output_dir} model_type: LLAMA3_VISION resume_from_checkpoint: False save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only. @@ -117,6 +119,6 @@ dtype: bf16 output_dir: /tmp/lora-llama3.2-vision-finetune metric_logger: _component_: torchtune.training.metric_logging.DiskLogger - log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs + log_dir: ${output_dir}/logs log_every_n_steps: 1 log_peak_memory_stats: True diff --git a/recipes/configs/llama3_3/70B_full.yaml b/recipes/configs/llama3_3/70B_full.yaml index f7ec013c15..4880a89edc 100644 --- a/recipes/configs/llama3_3/70B_full.yaml +++ b/recipes/configs/llama3_3/70B_full.yaml @@ -71,6 +71,7 @@ enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed. fsdp_cpu_offload: True +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/llama3_3/70B_generation_distributed.yaml b/recipes/configs/llama3_3/70B_generation_distributed.yaml new file mode 100644 index 0000000000..d39acf45ad --- /dev/null +++ b/recipes/configs/llama3_3/70B_generation_distributed.yaml @@ -0,0 +1,50 @@ +# Config for running the InferenceRecipe in dev/generate_v2.py to generate output +# using a Llama3.1 70B Instruct model +# +# This config assumes that you've run the following command before launching: +# tune download meta-llama/Llama-3.3-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token +# +# To launch, run the following command from root torchtune directory: +# tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3_3/70B_generation_distributed + +output_dir: ./ + +# Model arguments +model: + _component_: torchtune.models.llama3_3.llama3_3_70b + +parallelize_plan: + _component_: torchtune.models.llama3.base_llama_tp_plan + +# Transform arguments +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Llama-3.3-70B-Instruct/original/tokenizer.model + prompt_template: null + max_seq_len: 8192 + +# Checkpointer +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/ + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00030" + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: LLAMA3 + +# Device +device: cuda +dtype: bf16 +seed: 1234 +log_level: INFO + +# Generation arguments +prompt: + system: null + user: + text: Tell a joke. +max_new_tokens: 200 +temperature: 0.6 # 0.8 and 0.6 are popular values to try +top_k: 300 diff --git a/recipes/configs/llama3_3/70B_lora.yaml b/recipes/configs/llama3_3/70B_lora.yaml index 06c2924f5c..9a6d91d3cc 100644 --- a/recipes/configs/llama3_3/70B_lora.yaml +++ b/recipes/configs/llama3_3/70B_lora.yaml @@ -62,6 +62,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/llama3_3/70B_qlora.yaml b/recipes/configs/llama3_3/70B_qlora.yaml index 53c4a8c3b5..87d499fc38 100644 --- a/recipes/configs/llama3_3/70B_qlora.yaml +++ b/recipes/configs/llama3_3/70B_qlora.yaml @@ -62,6 +62,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/mistral/7B_full.yaml b/recipes/configs/mistral/7B_full.yaml index 15a6ec7b89..f646451a05 100644 --- a/recipes/configs/mistral/7B_full.yaml +++ b/recipes/configs/mistral/7B_full.yaml @@ -63,6 +63,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/mistral/7B_full_low_memory.yaml b/recipes/configs/mistral/7B_full_low_memory.yaml index 287a66dbd0..9aa1eaf86f 100644 --- a/recipes/configs/mistral/7B_full_low_memory.yaml +++ b/recipes/configs/mistral/7B_full_low_memory.yaml @@ -77,6 +77,7 @@ enable_activation_offloading: True # True reduces memory dtype: bf16 # Model compilation +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/mistral/7B_full_ppo_low_memory.yaml b/recipes/configs/mistral/7B_full_ppo_low_memory.yaml index 166fbeac1d..19e17c362a 100644 --- a/recipes/configs/mistral/7B_full_ppo_low_memory.yaml +++ b/recipes/configs/mistral/7B_full_ppo_low_memory.yaml @@ -124,10 +124,10 @@ shuffle: True device: cuda # Training arguments -batch_size: 64 +batch_size: 128 num_steps: 10000 -ppo_epochs: 2 -ppo_batch_size: 32 +ppo_epochs: 1 +ppo_batch_size: 128 gradient_accumulation_steps: 1 # Use to increase effective batch size # Memory management and performance @@ -137,13 +137,14 @@ optimizer: lr: 3e-6 optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 log_peak_memory_stats: True -enable_activation_checkpointing: True # True reduces memory +enable_activation_checkpointing: True # True reduces memory +enable_kv_cache: True # Reduced precision dtype: bf16 # batch size for forward pass during generation -forward_batch_size: 16 +forward_batch_size: 128 max_generated_tokens: 58 temperature: 0.7 top_k: null @@ -180,3 +181,27 @@ metric_logger: log_dir: ${output_dir}/logs log_every_n_steps: 1 + +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: True + with_stack: False + record_shapes: False + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 3 + num_cycles: 1 diff --git a/recipes/configs/mistral/7B_lora.yaml b/recipes/configs/mistral/7B_lora.yaml index ef3c9b0e1b..2c9b1a5453 100644 --- a/recipes/configs/mistral/7B_lora.yaml +++ b/recipes/configs/mistral/7B_lora.yaml @@ -77,6 +77,7 @@ batch_size: 4 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/mistral/7B_lora_single_device.yaml b/recipes/configs/mistral/7B_lora_single_device.yaml index c98f23c840..f553b71d91 100644 --- a/recipes/configs/mistral/7B_lora_single_device.yaml +++ b/recipes/configs/mistral/7B_lora_single_device.yaml @@ -74,6 +74,7 @@ batch_size: 4 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/mistral/7B_qlora_single_device.yaml b/recipes/configs/mistral/7B_qlora_single_device.yaml index 353ad54187..318b2b8735 100644 --- a/recipes/configs/mistral/7B_qlora_single_device.yaml +++ b/recipes/configs/mistral/7B_qlora_single_device.yaml @@ -75,6 +75,7 @@ batch_size: 4 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/phi3/mini_full.yaml b/recipes/configs/phi3/mini_full.yaml index 7dc954576d..b5b9972125 100644 --- a/recipes/configs/phi3/mini_full.yaml +++ b/recipes/configs/phi3/mini_full.yaml @@ -60,6 +60,7 @@ optimizer: lr: 5e-6 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/phi3/mini_full_low_memory.yaml b/recipes/configs/phi3/mini_full_low_memory.yaml index 8162e73c18..45505be144 100644 --- a/recipes/configs/phi3/mini_full_low_memory.yaml +++ b/recipes/configs/phi3/mini_full_low_memory.yaml @@ -62,6 +62,7 @@ optimizer: optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/phi3/mini_lora.yaml b/recipes/configs/phi3/mini_lora.yaml index 429a1c2a6d..b57887753c 100644 --- a/recipes/configs/phi3/mini_lora.yaml +++ b/recipes/configs/phi3/mini_lora.yaml @@ -71,6 +71,7 @@ lr_scheduler: num_warmup_steps: 100 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/phi3/mini_lora_single_device.yaml b/recipes/configs/phi3/mini_lora_single_device.yaml index 26e5ac457f..0977175ee0 100644 --- a/recipes/configs/phi3/mini_lora_single_device.yaml +++ b/recipes/configs/phi3/mini_lora_single_device.yaml @@ -69,6 +69,7 @@ lr_scheduler: num_warmup_steps: 100 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/phi3/mini_qlora_single_device.yaml b/recipes/configs/phi3/mini_qlora_single_device.yaml index a81e34f669..e11e72e59d 100644 --- a/recipes/configs/phi3/mini_qlora_single_device.yaml +++ b/recipes/configs/phi3/mini_qlora_single_device.yaml @@ -69,6 +69,7 @@ lr_scheduler: num_warmup_steps: 100 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training env diff --git a/recipes/configs/qwen2/0.5B_full.yaml b/recipes/configs/qwen2/0.5B_full.yaml index 093887fb59..97f9b68224 100644 --- a/recipes/configs/qwen2/0.5B_full.yaml +++ b/recipes/configs/qwen2/0.5B_full.yaml @@ -59,6 +59,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/qwen2/0.5B_full_single_device.yaml b/recipes/configs/qwen2/0.5B_full_single_device.yaml index 4f670695ca..ac563e4d83 100644 --- a/recipes/configs/qwen2/0.5B_full_single_device.yaml +++ b/recipes/configs/qwen2/0.5B_full_single_device.yaml @@ -60,6 +60,7 @@ optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_ste max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training environment diff --git a/recipes/configs/qwen2/0.5B_lora.yaml b/recipes/configs/qwen2/0.5B_lora.yaml index f4ce567afb..b27fff5dd6 100644 --- a/recipes/configs/qwen2/0.5B_lora.yaml +++ b/recipes/configs/qwen2/0.5B_lora.yaml @@ -73,6 +73,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2/0.5B_lora_single_device.yaml b/recipes/configs/qwen2/0.5B_lora_single_device.yaml index 9bd95bfedc..c8a797120d 100644 --- a/recipes/configs/qwen2/0.5B_lora_single_device.yaml +++ b/recipes/configs/qwen2/0.5B_lora_single_device.yaml @@ -71,6 +71,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2/1.5B_full.yaml b/recipes/configs/qwen2/1.5B_full.yaml index 04017db7ec..a473edbbdf 100644 --- a/recipes/configs/qwen2/1.5B_full.yaml +++ b/recipes/configs/qwen2/1.5B_full.yaml @@ -59,6 +59,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/qwen2/1.5B_full_single_device.yaml b/recipes/configs/qwen2/1.5B_full_single_device.yaml index d529629823..6a5776cad8 100644 --- a/recipes/configs/qwen2/1.5B_full_single_device.yaml +++ b/recipes/configs/qwen2/1.5B_full_single_device.yaml @@ -65,6 +65,7 @@ loss: max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training environment diff --git a/recipes/configs/qwen2/1.5B_lora.yaml b/recipes/configs/qwen2/1.5B_lora.yaml index 665e13c671..0a8ecf30cd 100644 --- a/recipes/configs/qwen2/1.5B_lora.yaml +++ b/recipes/configs/qwen2/1.5B_lora.yaml @@ -69,6 +69,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2/1.5B_lora_single_device.yaml b/recipes/configs/qwen2/1.5B_lora_single_device.yaml index 47f6afd7bd..ce62609598 100644 --- a/recipes/configs/qwen2/1.5B_lora_single_device.yaml +++ b/recipes/configs/qwen2/1.5B_lora_single_device.yaml @@ -69,6 +69,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2/1.5_to_0.5B_KD_lora_single_device.yaml b/recipes/configs/qwen2/1.5_to_0.5B_KD_lora_single_device.yaml index 385c1d453a..609ecfae22 100644 --- a/recipes/configs/qwen2/1.5_to_0.5B_KD_lora_single_device.yaml +++ b/recipes/configs/qwen2/1.5_to_0.5B_KD_lora_single_device.yaml @@ -84,6 +84,7 @@ kd_ratio: 0.5 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2/7B_full.yaml b/recipes/configs/qwen2/7B_full.yaml index ec82a0d701..2efbd2abfd 100644 --- a/recipes/configs/qwen2/7B_full.yaml +++ b/recipes/configs/qwen2/7B_full.yaml @@ -62,6 +62,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/qwen2/7B_full_single_device.yaml b/recipes/configs/qwen2/7B_full_single_device.yaml index 0b01526ba4..55e4416714 100644 --- a/recipes/configs/qwen2/7B_full_single_device.yaml +++ b/recipes/configs/qwen2/7B_full_single_device.yaml @@ -64,6 +64,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training environment diff --git a/recipes/configs/qwen2/7B_lora.yaml b/recipes/configs/qwen2/7B_lora.yaml index 1da8e0de4d..ea72e59da5 100644 --- a/recipes/configs/qwen2/7B_lora.yaml +++ b/recipes/configs/qwen2/7B_lora.yaml @@ -75,6 +75,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2/7B_lora_single_device.yaml b/recipes/configs/qwen2/7B_lora_single_device.yaml index 082be9a3fd..03466e9f51 100644 --- a/recipes/configs/qwen2/7B_lora_single_device.yaml +++ b/recipes/configs/qwen2/7B_lora_single_device.yaml @@ -73,6 +73,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2_5/0.5B_full.yaml b/recipes/configs/qwen2_5/0.5B_full.yaml index c415425d5b..e6fad21b0a 100644 --- a/recipes/configs/qwen2_5/0.5B_full.yaml +++ b/recipes/configs/qwen2_5/0.5B_full.yaml @@ -63,6 +63,7 @@ device: cuda enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory dtype: bf16 +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2_5/0.5B_full_single_device.yaml b/recipes/configs/qwen2_5/0.5B_full_single_device.yaml index 2ac3a79f00..19e153be61 100644 --- a/recipes/configs/qwen2_5/0.5B_full_single_device.yaml +++ b/recipes/configs/qwen2_5/0.5B_full_single_device.yaml @@ -63,6 +63,7 @@ device: cuda enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory dtype: bf16 +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2_5/0.5B_lora.yaml b/recipes/configs/qwen2_5/0.5B_lora.yaml index 704aa7ca80..a0b4483860 100644 --- a/recipes/configs/qwen2_5/0.5B_lora.yaml +++ b/recipes/configs/qwen2_5/0.5B_lora.yaml @@ -71,6 +71,7 @@ device: cuda enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory dtype: bf16 +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2_5/0.5B_lora_single_device.yaml b/recipes/configs/qwen2_5/0.5B_lora_single_device.yaml index 20ceb6536d..5558175628 100644 --- a/recipes/configs/qwen2_5/0.5B_lora_single_device.yaml +++ b/recipes/configs/qwen2_5/0.5B_lora_single_device.yaml @@ -71,6 +71,7 @@ device: cuda enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory dtype: bf16 +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2_5/1.5B_full.yaml b/recipes/configs/qwen2_5/1.5B_full.yaml index 431c1b519a..4c10559344 100644 --- a/recipes/configs/qwen2_5/1.5B_full.yaml +++ b/recipes/configs/qwen2_5/1.5B_full.yaml @@ -63,6 +63,7 @@ device: cuda enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory dtype: bf16 +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2_5/1.5B_full_single_device.yaml b/recipes/configs/qwen2_5/1.5B_full_single_device.yaml index d48176616d..f083026c1b 100644 --- a/recipes/configs/qwen2_5/1.5B_full_single_device.yaml +++ b/recipes/configs/qwen2_5/1.5B_full_single_device.yaml @@ -66,6 +66,7 @@ device: cuda enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory dtype: bf16 +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2_5/1.5B_lora.yaml b/recipes/configs/qwen2_5/1.5B_lora.yaml index 84d9e2c9bd..a9f9043ae2 100644 --- a/recipes/configs/qwen2_5/1.5B_lora.yaml +++ b/recipes/configs/qwen2_5/1.5B_lora.yaml @@ -70,6 +70,7 @@ device: cuda enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory dtype: bf16 +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2_5/1.5B_lora_single_device.yaml b/recipes/configs/qwen2_5/1.5B_lora_single_device.yaml index 579c39bfec..a7a612f966 100644 --- a/recipes/configs/qwen2_5/1.5B_lora_single_device.yaml +++ b/recipes/configs/qwen2_5/1.5B_lora_single_device.yaml @@ -70,6 +70,7 @@ device: cuda enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory dtype: bf16 +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2_5/14B_lora_single_device.yaml b/recipes/configs/qwen2_5/14B_lora_single_device.yaml index e918b8de09..f90171d4d1 100644 --- a/recipes/configs/qwen2_5/14B_lora_single_device.yaml +++ b/recipes/configs/qwen2_5/14B_lora_single_device.yaml @@ -70,6 +70,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2_5/32B_lora.yaml b/recipes/configs/qwen2_5/32B_lora.yaml index 1633d59c3f..8f9b1ecb79 100644 --- a/recipes/configs/qwen2_5/32B_lora.yaml +++ b/recipes/configs/qwen2_5/32B_lora.yaml @@ -68,6 +68,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2_5/3B_full.yaml b/recipes/configs/qwen2_5/3B_full.yaml index 217769ad8c..d2eaabbf77 100644 --- a/recipes/configs/qwen2_5/3B_full.yaml +++ b/recipes/configs/qwen2_5/3B_full.yaml @@ -60,6 +60,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/qwen2_5/3B_full_single_device.yaml b/recipes/configs/qwen2_5/3B_full_single_device.yaml index 38b1645817..72f2fdfb16 100644 --- a/recipes/configs/qwen2_5/3B_full_single_device.yaml +++ b/recipes/configs/qwen2_5/3B_full_single_device.yaml @@ -62,6 +62,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training environment diff --git a/recipes/configs/qwen2_5/3B_lora.yaml b/recipes/configs/qwen2_5/3B_lora.yaml index 152c8da204..d31612851a 100644 --- a/recipes/configs/qwen2_5/3B_lora.yaml +++ b/recipes/configs/qwen2_5/3B_lora.yaml @@ -71,6 +71,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2_5/3B_lora_single_device.yaml b/recipes/configs/qwen2_5/3B_lora_single_device.yaml index 98ed48f06f..56ab513be5 100644 --- a/recipes/configs/qwen2_5/3B_lora_single_device.yaml +++ b/recipes/configs/qwen2_5/3B_lora_single_device.yaml @@ -70,6 +70,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2_5/72B_lora.yaml b/recipes/configs/qwen2_5/72B_lora.yaml index 6eabf7eca9..f6adf197bb 100644 --- a/recipes/configs/qwen2_5/72B_lora.yaml +++ b/recipes/configs/qwen2_5/72B_lora.yaml @@ -68,6 +68,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2_5/7B_full.yaml b/recipes/configs/qwen2_5/7B_full.yaml index d071687103..dcbc217fc5 100644 --- a/recipes/configs/qwen2_5/7B_full.yaml +++ b/recipes/configs/qwen2_5/7B_full.yaml @@ -62,6 +62,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/configs/qwen2_5/7B_full_single_device.yaml b/recipes/configs/qwen2_5/7B_full_single_device.yaml index e6ebbcb8cf..d317e33478 100644 --- a/recipes/configs/qwen2_5/7B_full_single_device.yaml +++ b/recipes/configs/qwen2_5/7B_full_single_device.yaml @@ -64,6 +64,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Training environment diff --git a/recipes/configs/qwen2_5/7B_lora.yaml b/recipes/configs/qwen2_5/7B_lora.yaml index f78c522e8a..716897b4f9 100644 --- a/recipes/configs/qwen2_5/7B_lora.yaml +++ b/recipes/configs/qwen2_5/7B_lora.yaml @@ -74,6 +74,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/configs/qwen2_5/7B_lora_single_device.yaml b/recipes/configs/qwen2_5/7B_lora_single_device.yaml index 3accf271d3..ed564076fe 100644 --- a/recipes/configs/qwen2_5/7B_lora_single_device.yaml +++ b/recipes/configs/qwen2_5/7B_lora_single_device.yaml @@ -73,6 +73,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size +clip_grad_norm: null compile: False # torch.compile the model + loss, True increases speed + decreases memory # Logging diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml index 0253bf82e2..1e1491a05c 100644 --- a/recipes/dev/7B_full_early_exit.yaml +++ b/recipes/dev/7B_full_early_exit.yaml @@ -77,6 +77,7 @@ loss: _component_: torch.nn.CrossEntropyLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase virtual batch size +clip_grad_norm: null compile: False # pytorch compile, set to true for better perf/memory optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 5abc674356..7d8808d90d 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -653,7 +653,7 @@ def _setup_data( for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) - packed = False + packed = getattr(ds, "packed", False) else: ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) @@ -870,6 +870,7 @@ def train(self) -> None: and curr_epoch == 0 and self.profiler_profile_memory and idx == self.profiler_wait_steps + self.profiler_warmup_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history() @@ -951,7 +952,7 @@ def train(self) -> None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), max_norm=float(self._clip_grad_norm), - ) + ).full_tensor() self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) @@ -1019,6 +1020,7 @@ def train(self) -> None: == self.profiler_wait_steps + self.profiler_warmup_steps + self.profiler_active_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history(enabled=None) diff --git a/recipes/dev/generate_v2.py b/recipes/dev/generate_v2.py index 3ce95a9fdf..66329e8367 100644 --- a/recipes/dev/generate_v2.py +++ b/recipes/dev/generate_v2.py @@ -39,18 +39,22 @@ def __call__(self, prompt: Dict[str, Any]) -> List[Message]: # Iterate through roles and add content for role, content in prompt.items(): - if isinstance(content, str): + if content is None: + continue + elif isinstance(content, str): new_content = [{"type": "text", "content": content}] - else: - assert ( - "image" in content.keys() - ), "Multiple entries per role expect an image key" + elif "image" in content.keys(): image_loc = content["image"] image = load_image(image_loc) new_content = [ {"type": "image", "content": image}, {"type": "text", "content": content["text"]}, ] + else: + assert ( + "text" in content.keys() + ), "Multiple entries per role expect at least a text key" + new_content = [{"type": "text", "content": content["text"]}] messages.append(Message(role=role, content=new_content)) # Finally, add an empty assistant message to kick-start generation @@ -109,11 +113,13 @@ def log_metrics(self, total_time: int, tokens_per_second: float) -> None: f"Time for inference: {total_time:.02f} sec total, {tokens_per_second:.02f} tokens/sec" ) self._logger.info( - f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s" - ) - self._logger.info( - f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB" + f"Bandwidth achieved: {model_size * tokens_per_second / (1024**3):.02f} GiB/s" ) + if self._device.type != "cpu": + torch_device = utils.get_torch_device_namespace() + self._logger.info( + f"Max memory allocated: {torch_device.max_memory_allocated() / (1024**3):.02f} GiB" + ) @torch.inference_mode() def generate(self, cfg: DictConfig): diff --git a/recipes/dev/generate_v2_distributed.py b/recipes/dev/generate_v2_distributed.py new file mode 100644 index 0000000000..48a147bd15 --- /dev/null +++ b/recipes/dev/generate_v2_distributed.py @@ -0,0 +1,267 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import itertools +import sys +import time +from typing import Any, Dict, List + +import torch +import torch.distributed as dist +from omegaconf import DictConfig, OmegaConf +from torch.distributed.tensor.parallel import parallelize_module + +from torchtune import config, training, utils +from torchtune.data import load_image, Message, padded_collate_tiled_images_and_mask + +from torchtune.generation import sample + +from torchtune.modules.transforms import Transform + + +class SingleTurnYAMLToMessages(Transform): + """ + Converts a single turn conversation in YAML format to a list of messages. + + Expects the YAML to look like: + system: You are a helpful AI assistant. + user: What is the capital of France? + + or if it includes an image: + system: You are a helpful AI assistant. + user: + image: url or path_to_image + text: Describe the image in detail. + """ + + def __call__(self, prompt: Dict[str, Any]) -> List[Message]: + messages = [] + + # Iterate through roles and add content + for role, content in prompt.items(): + if content is None: + continue + elif isinstance(content, str): + new_content = [{"type": "text", "content": content}] + elif "image" in content.keys(): + image_loc = content["image"] + image = load_image(image_loc) + new_content = [ + {"type": "image", "content": image}, + {"type": "text", "content": content["text"]}, + ] + else: + assert ( + "text" in content.keys() + ), "Multiple entries per role expect at least a text key" + new_content = [{"type": "text", "content": content["text"]}] + messages.append(Message(role=role, content=new_content)) + + # Finally, add an empty assistant message to kick-start generation + messages.append(Message(role="assistant", content="")) + return messages + + +class InferenceRecipe: + """ + Recipe for generating tokens from a dense Transformer-based LLM. + This works for text-only generation and image-text generation. + + Supports distributed inference using Tensor Paralellism(TP) for + large models that don't fit on a single GPU. For more information + on TP, see: https://pytorch.org/docs/stable/distributed.tensor.parallel.html. + + This *does not* currently support the following features: + - torch.compile + - quantization through torchao + - batch generation + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device) + self._logger = utils.get_logger(cfg.log_level) + # Set up distributed env + dist.init_process_group(backend="nccl") + _, rank = utils.get_world_size_and_rank() + self._is_rank_zero = rank == 0 + training.set_seed(seed=cfg.seed) + + def setup(self, cfg: DictConfig) -> None: + """Setup the model and transforms.""" + # Load checkpointer and state_dict + _checkpointer = config.instantiate(cfg.checkpointer) + _ckpt_dict = _checkpointer.load_checkpoint() + + # Instantiate model on meta device + with training.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg.model) + + # Set up tensor parallel device mesh + tp_degree = dist.get_world_size() # Using all GPUs for TP + tp_mesh_shape = (tp_degree,) + tp_device_mesh = dist.init_device_mesh("cuda", tp_mesh_shape) + + # Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor paralell + training.prepare_mha_for_tp(model, tp_device_mesh) + parallelize_module( + model, + tp_device_mesh, + parallelize_plan=config.instantiate(cfg.parallelize_plan), + ) + + with training.set_default_dtype(self._dtype), self._device: + for m in model.modules(): + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + + # This method will convert the full model state dict into a sharded state + # dict and load into the model + training.load_from_full_model_state_dict( + model=model, + full_sd=_ckpt_dict[training.MODEL_KEY], + device=self._device, + strict=True, + cpu_offload=False, + ) + + self.model = model + if self._is_rank_zero: + self._logger.info( + f"Model was initialized with precision {self._dtype} and TP degree {tp_degree}." + ) + + # Instantiate transforms + self.model_transform = config.instantiate(cfg.tokenizer) + self.to_messages = SingleTurnYAMLToMessages() + + def log_metrics(self, total_time: int, tokens_per_second: float) -> None: + """Logs the following metrics: total time for inference, tokens/sec, + bandwidth achieved, and max memory allocated. + + Feel free to modify this function to log additional metrics. + """ + model_size = sum( + [ + p.numel() * p.dtype.itemsize + for p in itertools.chain(self.model.parameters(), self.model.buffers()) + ] + ) + self._logger.info( + f"Time for inference: {total_time:.02f} sec total, {tokens_per_second:.02f} tokens/sec" + ) + self._logger.info( + f"Bandwidth achieved: {model_size * tokens_per_second / (1024**3):.02f} GiB/s" + ) + if self._device.type != "cpu": + torch_device = utils.get_torch_device_namespace() + self._logger.info( + f"Max memory allocated: {torch_device.max_memory_allocated() / (1024**3):.02f} GiB" + ) + + @torch.inference_mode() + def generate(self, cfg: DictConfig): + """The main entry point for generating tokens from a prompt.""" + # 1. Convert input to messages + messages = self.to_messages(OmegaConf.to_container(cfg.prompt)) + is_multimodal_input = any([m.contains_media for m in messages]) + + # 2. Apply model transform + model_inputs = self.model_transform({"messages": messages}, inference=True) + seq_len = len(model_inputs["tokens"]) + total_response_length = seq_len + cfg.max_new_tokens + + # 3. Setup KV cache + with self._device: + self.model.setup_caches( + batch_size=1, + dtype=self._dtype, + encoder_max_seq_len=( + self.model_transform.image_seq_len if is_multimodal_input else None + ), + decoder_max_seq_len=total_response_length, + ) + + # 4. Pre-allocate causal mask and input_pos + causal_mask = torch.tril( + torch.ones( + size=(total_response_length, total_response_length), + dtype=torch.bool, + device=self._device, + ) + ) + input_pos = torch.arange(total_response_length) + + # 5. Collate to batch size of 1 and tensor-ify + batch = {} + if is_multimodal_input: + batch = padded_collate_tiled_images_and_mask( + [model_inputs], + pad_direction="left", + pad_max_images=1, + pad_max_tiles=self.model_transform.max_num_tiles, + ) + batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] + prompt = batch.pop("tokens").to(self._device) + else: + prompt = torch.tensor( + model_inputs["tokens"], device=self._device + ).unsqueeze(0) + batch["mask"] = causal_mask[None, :seq_len] + batch["input_pos"] = input_pos[None, :seq_len] + utils.batch_to_device(batch, self._device) + + # 6. Prefill step + generated_tokens = [] + t0 = time.perf_counter() + logits = self.model(prompt, **batch)[:, -1] + token = sample(logits, temperature=cfg.temperature, top_k=cfg.top_k) + generated_tokens.append(token.item()) + + if is_multimodal_input: + # Don't need image info b/c we only support 1 image and it's been + # processed by the model now + batch.pop("encoder_input") + batch["encoder_mask"] = batch["encoder_mask"][:, -1:] + + # 7. Continue generating + for i in range(cfg.max_new_tokens): + + # Update position and mask for incremental decoding + batch["input_pos"] = input_pos[None, seq_len] + batch["mask"] = causal_mask[None, seq_len, None, :] + + if token.item() in self.model_transform.stop_tokens: + break + + logits = self.model(token, **batch)[:, -1] + token = sample(logits, temperature=cfg.temperature, top_k=cfg.top_k) + generated_tokens.append(token.item()) + seq_len += 1 + + t = time.perf_counter() - t0 + + # 8. Translate tokens back to text + decoded = self.model_transform.decode(generated_tokens) + if self._is_rank_zero: + self._logger.info(f"\n\n{decoded}\n") + + # 9. Log metrics + tokens_per_second = len(generated_tokens) / t + if self._is_rank_zero: + self.log_metrics(total_time=t, tokens_per_second=tokens_per_second) + + +@config.parse +def main(cfg: DictConfig) -> None: + config.log_config(recipe_name="InferenceRecipe", cfg=cfg) + recipe = InferenceRecipe(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.generate(cfg=cfg) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/recipes/eleuther_eval.py b/recipes/eleuther_eval.py index 5693e899ad..fd1ac8f6e1 100644 --- a/recipes/eleuther_eval.py +++ b/recipes/eleuther_eval.py @@ -31,8 +31,8 @@ from torchtune.modules import TransformerDecoder from torchtune.modules.common_utils import local_kv_cache from torchtune.modules.model_fusion import DeepFusionModel -from torchtune.modules.tokenizers import ModelTokenizer from torchtune.modules.transforms import Transform +from torchtune.modules.transforms.tokenizers import ModelTokenizer from torchtune.recipe_interfaces import EvalRecipeInterface from torchtune.training import FullModelTorchTuneCheckpointer @@ -547,9 +547,11 @@ def evaluate(self) -> None: # Log metrics self.logger.info(f"Eval completed in {t1:.02f} seconds.") - self.logger.info( - f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB" - ) + if self.device.type != "cpu": + torch_device = utils.get_torch_device_namespace() + self.logger.info( + f"Max memory allocated: {torch_device.max_memory_allocated() / 1e9:.02f} GB" + ) formatted_output = make_table(output) self.logger.info(f"\n\n{formatted_output}\n") diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 745ef64eb4..34ad48e938 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -646,7 +646,7 @@ def _setup_data( for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) - packed = False + packed = getattr(ds, "packed", False) else: ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) @@ -723,6 +723,7 @@ def train(self) -> None: and curr_epoch == 0 and self.profiler_profile_memory and idx == self.profiler_wait_steps + self.profiler_warmup_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history() @@ -766,7 +767,9 @@ def train(self) -> None: if self._optimizer_in_bwd: torch.distributed.all_reduce(num_tokens) torch.distributed.all_reduce(running_loss) - current_loss = current_loss / num_tokens + + # We multiply by world_size to undo FSDP2 gradient normalization. + current_loss = current_loss * (world_size / num_tokens) current_loss.backward() @@ -778,12 +781,13 @@ def train(self) -> None: # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + # We multiply by world_size to undo FSDP2 gradient normalization. + training.scale_grads(self._model, world_size / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), max_norm=float(self._clip_grad_norm), - ) + ).full_tensor() self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) @@ -843,6 +847,7 @@ def train(self) -> None: == self.profiler_wait_steps + self.profiler_warmup_steps + self.profiler_active_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history(enabled=None) diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 946e970206..0c53666dad 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -131,9 +131,9 @@ def __init__(self, cfg: DictConfig) -> None: self._log_every_n_steps = cfg.get("log_every_n_steps", 1) self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) - if self._log_peak_memory_stats and self._device.type != "cuda": + if self._log_peak_memory_stats and self._device.type == "cpu": log.info( - "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False." + "log_peak_memory_stats was set to True, however, training uses cpu. Setting log_peak_memory_stats=False." ) self._log_peak_memory_stats = False @@ -558,7 +558,7 @@ def _setup_data( for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) - packed = False + packed = getattr(ds, "packed", False) else: ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) @@ -685,9 +685,9 @@ def train(self) -> None: curr_epoch == 0 and self.profiler_profile_memory and idx == self.profiler_wait_steps + self.profiler_warmup_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history() - utils.batch_to_device(batch, self._device) # Calculate the number of unmasked tokens in the current batch @@ -766,6 +766,7 @@ def train(self) -> None: == self.profiler_wait_steps + self.profiler_warmup_steps + self.profiler_active_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history(enabled=None) diff --git a/recipes/generate.py b/recipes/generate.py index 56723b04bd..b577e0b66d 100644 --- a/recipes/generate.py +++ b/recipes/generate.py @@ -187,7 +187,11 @@ def generate(self, cfg: DictConfig): f"Time for inference: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec" ) logger.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") - logger.info(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + if self._device.type != "cpu": + torch_device = utils.get_torch_device_namespace() + logger.info( + f"Memory used: {torch_device.max_memory_allocated() / 1e9:.02f} GB" + ) @config.parse diff --git a/recipes/knowledge_distillation_distributed.py b/recipes/knowledge_distillation_distributed.py index a11adabb97..4e5165eb3b 100644 --- a/recipes/knowledge_distillation_distributed.py +++ b/recipes/knowledge_distillation_distributed.py @@ -652,7 +652,7 @@ def _setup_data( for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) - packed = False + packed = getattr(ds, "packed", False) else: ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) @@ -769,7 +769,6 @@ def save_checkpoint(self, epoch: int) -> None: def _loss_step( self, batch: Dict[str, torch.Tensor] ) -> (torch.Tensor, torch.Tensor): - # Both are shape [b, s] tokens, labels = batch["tokens"], batch["labels"] @@ -847,6 +846,7 @@ def train(self) -> None: and curr_epoch == 0 and self.profiler_profile_memory and idx == self.profiler_wait_steps + self.profiler_warmup_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history() @@ -875,7 +875,8 @@ def train(self) -> None: torch.distributed.all_reduce(running_class_loss) torch.distributed.all_reduce(running_kd_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + # We multiply by world_size to undo FSDP2 gradient normalization. + training.scale_grads(self._model, world_size / num_tokens) class_loss_to_log = running_class_loss.item() / num_tokens kd_loss_to_log = running_kd_loss.item() / num_tokens self._optimizer.step() diff --git a/recipes/knowledge_distillation_single_device.py b/recipes/knowledge_distillation_single_device.py index 71d850d791..1571ef1f44 100644 --- a/recipes/knowledge_distillation_single_device.py +++ b/recipes/knowledge_distillation_single_device.py @@ -120,9 +120,9 @@ def __init__(self, cfg: DictConfig) -> None: self._log_every_n_steps = cfg.get("log_every_n_steps", 1) self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) - if self._log_peak_memory_stats and self._device.type != "cuda": + if self._log_peak_memory_stats and self._device.type == "cpu": log.info( - "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False." + "log_peak_memory_stats was set to True, however, training uses cpu. Setting log_peak_memory_stats=False." ) self._log_peak_memory_stats = False @@ -223,6 +223,10 @@ def setup(self, cfg: DictConfig) -> None: self._metric_logger.log_config(cfg) self._compile = cfg.compile + if cfg.device == "npu" and cfg.compile: + raise ValueError( + "NPU does not support model compilation. Please set `compile: False` in the config." + ) checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) teacher_checkpoint_dict = self.load_teacher_checkpoint( cfg_checkpointer=cfg.teacher_checkpointer @@ -447,7 +451,7 @@ def _setup_model( log.info(f"Student model is initialized with precision {self._dtype}.") - if self._device.type == "cuda": + if self._device.type != "cpu": log.info("Memory stats initializing student model:") memory_stats = training.get_memory_stats(device=self._device) training.log_memory_stats( @@ -476,7 +480,7 @@ def _setup_teacher_model( ) log.info(f"Teacher model is initialized with precision {self._dtype}.") - if self._device.type == "cuda": + if self._device.type != "cpu": memory_stats = training.get_memory_stats(device=self._device) training.log_memory_stats( memory_stats, message="Memory stats after teacher model init:" @@ -527,7 +531,7 @@ def _setup_data( for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) - packed = False + packed = getattr(ds, "packed", False) else: ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) @@ -698,6 +702,7 @@ def train(self) -> None: curr_epoch == 0 and self.profiler_profile_memory and idx == self.profiler_wait_steps + self.profiler_warmup_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history() @@ -753,7 +758,7 @@ def train(self) -> None: "tokens_per_second_per_gpu": num_tokens / time_per_step, } if ( - self._device.type == "cuda" + self._device.type != "cpu" and self._log_peak_memory_stats ): log_dict.update( @@ -780,6 +785,7 @@ def train(self) -> None: == self.profiler_wait_steps + self.profiler_warmup_steps + self.profiler_active_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history(enabled=None) diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py index 9b5dc6fb1a..c493b65602 100644 --- a/recipes/lora_dpo_single_device.py +++ b/recipes/lora_dpo_single_device.py @@ -98,9 +98,9 @@ def __init__(self, cfg: DictConfig) -> None: self._log_every_n_steps = cfg.get("log_every_n_steps", 1) self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) - if self._log_peak_memory_stats and self._device.type != "cuda": + if self._log_peak_memory_stats and self._device.type == "cpu": log.info( - "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False." + "log_peak_memory_stats was set to True, however, training uses cpu. Setting log_peak_memory_stats=False." ) self._log_peak_memory_stats = False @@ -327,7 +327,7 @@ def _setup_model( # Compile model, if enabled. if compile_model: training.compile_model(model) - if self._device == torch.device("cuda"): + if self._device.type != "cpu": memory_stats = training.get_memory_stats(device=self._device) training.log_memory_stats(memory_stats) return model diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index ac05e2060a..d5304e496e 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -591,7 +591,7 @@ def _setup_data( for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) - packed = False + packed = getattr(ds, "packed", False) else: ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) @@ -776,6 +776,7 @@ def train(self) -> None: and curr_epoch == 0 and self.profiler_profile_memory and idx == self.profiler_wait_steps + self.profiler_warmup_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history() @@ -822,12 +823,13 @@ def train(self) -> None: # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + # We multiply by world_size to undo FSDP2 gradient normalization. + training.scale_grads(self._model, world_size / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), max_norm=float(self._clip_grad_norm), - ) + ).full_tensor() self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) self._lr_scheduler.step() @@ -879,6 +881,7 @@ def train(self) -> None: == self.profiler_wait_steps + self.profiler_warmup_steps + self.profiler_active_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history(enabled=None) diff --git a/recipes/lora_finetune_distributed_multi_dataset.py b/recipes/lora_finetune_distributed_multi_dataset.py index 30ece70347..7d0d442c6c 100644 --- a/recipes/lora_finetune_distributed_multi_dataset.py +++ b/recipes/lora_finetune_distributed_multi_dataset.py @@ -805,6 +805,7 @@ def train(self) -> None: and curr_epoch == 0 and self.profiler_profile_memory and idx == self.profiler_wait_steps + self.profiler_warmup_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history() @@ -851,12 +852,13 @@ def train(self) -> None: # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + # We multiply by world_size to undo FSDP2 gradient normalization. + training.scale_grads(self._model, world_size / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), max_norm=float(self._clip_grad_norm), - ) + ).full_tensor() self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) self._lr_scheduler.step() @@ -908,6 +910,7 @@ def train(self) -> None: == self.profiler_wait_steps + self.profiler_warmup_steps + self.profiler_active_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history(enabled=None) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 9a3f3eacfb..5cf0a0f969 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -136,9 +136,9 @@ def __init__(self, cfg: DictConfig) -> None: self._log_every_n_steps = cfg.get("log_every_n_steps", 1) self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) - if self._log_peak_memory_stats and self._device.type != "cuda": + if self._log_peak_memory_stats and self._device.type == "cpu": log.info( - "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False." + "log_peak_memory_stats was set to True, however, training uses cpu. Setting log_peak_memory_stats=False." ) self._log_peak_memory_stats = False @@ -528,7 +528,7 @@ def _setup_data( for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) - packed = False + packed = getattr(ds, "packed", False) else: ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) @@ -688,6 +688,7 @@ def train(self) -> None: curr_epoch == 0 and self.profiler_profile_memory and idx == self.profiler_wait_steps + self.profiler_warmup_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history() @@ -735,7 +736,7 @@ def train(self) -> None: "tokens_per_second_per_gpu": num_tokens / time_per_step, } if ( - self._device.type == "cuda" + self._device.type != "cpu" and self._log_peak_memory_stats ): log_dict.update( @@ -761,6 +762,7 @@ def train(self) -> None: == self.profiler_wait_steps + self.profiler_warmup_steps + self.profiler_active_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history(enabled=None) diff --git a/recipes/ppo_full_finetune_single_device.py b/recipes/ppo_full_finetune_single_device.py index cb6357c3dc..c89521ccfc 100644 --- a/recipes/ppo_full_finetune_single_device.py +++ b/recipes/ppo_full_finetune_single_device.py @@ -4,12 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib import math -import os import sys +import time from functools import partial from itertools import chain -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from warnings import warn import torch @@ -20,11 +21,16 @@ from torchtune import config, generation, modules, rlhf, training, utils from torchtune.data import padded_collate from torchtune.datasets import ConcatDataset +from torchtune.modules import local_kv_cache from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.rlhf import PPOStats, Trajectory +from torchtune.training import DummyProfiler, PROFILER_KEY from tqdm import tqdm log = utils.get_logger("DEBUG") +# enabling compile results in slightly more recompiles than the default cache limit (8) +# so we set a higher limit here +torch._dynamo.config.cache_size_limit = 16 class PPOFullFinetuneRecipeSingleDevice(FTRecipeInterface): @@ -32,8 +38,8 @@ class PPOFullFinetuneRecipeSingleDevice(FTRecipeInterface): Full finetuning recipe for RLHF with PPO for dense transformer-based LLMs such as LLama2. This recipe is optimized for single GPU training. Training on CPU is not supported. - This implementation is based on `Learning to summarize from human feedback ). Features: - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` @@ -175,8 +181,9 @@ def setup(self, cfg: DictConfig) -> None: # ``_setup_model`` handles initialization and loading the state dict. This method # should be called before ``_setup_optimizer`` since transforming the optimizer # state dict requires the model - self._model_compile = cfg.compile + self.compile = cfg.compile self._optimizer_in_bwd = cfg.optimizer_in_bwd + ( self._policy_model, self._value_model, @@ -186,7 +193,7 @@ def setup(self, cfg: DictConfig) -> None: cfg_model=cfg.policy_model, cfg_reward_value_model=cfg.reward_and_value_model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, - compile_model=self._model_compile, + compile_model=self.compile, policy_state_dict=policy_model_checkpoint_dict[training.MODEL_KEY], ref_policy_state_dict=ref_policy_state_dict[training.MODEL_KEY], value_model_state_dict=value_model_checkpoint_dict[training.MODEL_KEY], @@ -213,7 +220,7 @@ def setup(self, cfg: DictConfig) -> None: log.info("Loss is initialized.") # sampler and dataloader depends on the tokenizer and should be set - # setup afterit is initialized + # setup after it is initialized self._sampler, self._dataloader = self._setup_data( cfg_dataset=cfg.dataset, shuffle=cfg.shuffle, @@ -223,6 +230,21 @@ def setup(self, cfg: DictConfig) -> None: self._setup_training_parameters(cfg) self._setup_training_hyperparameters(cfg) + # setup a context manager for enabling KV-cacheing during + # trajectory generation if enabled in the config + self.cache_ctx_manager = lambda enable_kv_cache: ( + local_kv_cache( + self._policy_model, + batch_size=self._forward_batch_size, + dtype=self._dtype, + decoder_max_seq_len=self._tokenizer.max_seq_len + + self._max_generated_tokens, + device=self._device, + ) + if enable_kv_cache + else contextlib.nullcontext() + ) + if self._resume_from_checkpoint: self._update_recipe_state(policy_model_checkpoint_dict) @@ -233,6 +255,77 @@ def setup(self, cfg: DictConfig) -> None: * (self.batch_size // self._ppo_batch_size) ) + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + + return profiler + def _setup_training_hyperparameters(self, cfg) -> None: """ Sets up the training hyperparameters for the recipe. This includes the GAE hyperparameters, @@ -295,6 +388,7 @@ def _setup_training_parameters(self, cfg: DictConfig) -> None: self._ppo_backward_batch_size = ( cfg.ppo_batch_size // self._gradient_accumulation_steps ) + self.enable_kv_cache = cfg.enable_kv_cache if self.batch_size % self._forward_batch_size != 0: raise ValueError( @@ -423,6 +517,12 @@ def _setup_models( reward_model = config.instantiate(cfg_reward_value_model) value_model = config.instantiate(cfg_reward_value_model) + if compile_model: + training.compile_model(policy_model) + training.compile_model(ref_policy_model) + training.compile_model(value_model) + training.compile_model(reward_model) + if enable_activation_checkpointing: training.set_activation_checkpointing( policy_model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} @@ -450,6 +550,7 @@ def _setup_models( value_model.load_state_dict(value_model_state_dict) # Validate models were loaded in with the expected dtype. + training.validate_expected_param_dtype( value_model.named_parameters(), dtype=self._dtype ) @@ -490,16 +591,6 @@ def _setup_models( for p in ref_policy_model.parameters(): p.requires_grad = False - # Compile model, if enabled. - if compile_model: - backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") - log.info("Compiling models with torch.compile...") - - policy_model.compile(backend=backend) - reward_model.compile(backend=backend) - ref_policy_model.compile(backend=backend) - value_model.compile(backend=backend) - if self._device.type == "cuda": memory_stats = training.get_memory_stats(device=self._device) training.log_memory_stats(memory_stats) @@ -585,7 +676,6 @@ def _setup_data( dataset=ds, sampler=sampler, batch_size=batch_size, - # dropping last avoids shape issues with compile + flex attention drop_last=True, collate_fn=partial( padded_collate, @@ -688,19 +778,19 @@ def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory: Trajectory: An instance of :class:`~torchtune.rlhf.Trajectory` comprising the current trajectory. """ - batch_size, context_length = input_ids.shape # step 1: generate responses, and logits corresponding to the responses using the current policy - query_responses, logits = generation.generate( - model=self._policy_model, - prompt=input_ids, - max_generated_tokens=self._max_generated_tokens, - temperature=self._temperature, - top_k=self._top_k, - pad_id=self._tokenizer.pad_id, - rng=self._rng, - ) - + with self.cache_ctx_manager(self.enable_kv_cache): + query_responses, logits = generation.generate( + model=self._policy_model, + prompt=input_ids, + max_generated_tokens=self._max_generated_tokens, + temperature=self._temperature, + top_k=self._top_k, + pad_id=self._tokenizer.pad_id, + rng=self._rng, + ) + _, context_length = input_ids.shape responses = query_responses[:, context_length:].clone() query_response_padding_masks = query_responses != self._tokenizer.pad_id @@ -715,7 +805,6 @@ def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory: del query_response_padding_masks # step 2. estimate logprobs of the responses using the current policy - logits = logits[:, context_length - 1 :] logprobs = rlhf.logits_to_logprobs(logits, responses, self._temperature) del logits @@ -751,7 +840,9 @@ def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory: # step 5.1 the scores from the reward model are the logits for the last non-padding token in # each (query, truncated-response) pair seq_lens = training.get_unmasked_sequence_lengths(response_padding_masks) - scores = scores[torch.arange(batch_size), seq_lens + context_length].squeeze(-1) + scores = scores.gather(1, (seq_lens + context_length)[:, None, None]).squeeze( + (-1, -2) + ) # step 5.2 if configured, apply any penalties for sequences without EOS tokens # or shorter than a certain length @@ -775,11 +866,9 @@ def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory: seq_lens, ) value_padding_masks = response_padding_masks.clone() - value_padding_masks[ - torch.arange(batch_size, device=value_padding_masks.device), - value_seq_idxs, - ] = False - + value_padding_masks = value_padding_masks.scatter_( + 1, value_seq_idxs.unsqueeze(-1), False + ) values[value_padding_masks] = 0.0 return Trajectory( @@ -798,8 +887,8 @@ def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory: def generate_trajectory_batched(self, input_ids: torch.Tensor) -> Trajectory: """ - Generates a ``self.batch_size`` batch of trajectories using `self._forward_batch_size` batch sizes. - See ``generate_trajectory`` for more details. + Generates a self.batch_size batch of trajectories using self._forward_batch_size batch sizes. + See generate_trajectory for more details. Args: input_ids (torch.Tensor): tensor of input token IDs with shape [b, seq_length] @@ -814,6 +903,7 @@ def generate_trajectory_batched(self, input_ids: torch.Tensor) -> Trajectory: batch_input_ids = input_ids[ batch_start : batch_start + self._forward_batch_size ] + trajectories.append(self.generate_trajectory(batch_input_ids)) return Trajectory(*map(torch.cat, zip(*trajectories))) @@ -821,7 +911,7 @@ def train(self) -> None: """ The core training loop.""" - if self._model_compile: + if self.compile: log.info( "NOTE: torch.compile is enabled and model is compiled in first forward." "Expect a relatively slow first iteration." @@ -831,25 +921,34 @@ def train(self) -> None: self._optimizer.zero_grad() training_completed = False + self._profiler.start() pbar = tqdm(total=self._total_steps, initial=self._steps_run) for curr_epoch in range(self._epochs_run, self._total_epochs): # Update the sampler to ensure data is correctly shuffled across epochs # in case shuffle is True self._sampler.set_epoch(curr_epoch) - for _, batch in enumerate(self._dataloader): + for idx, batch in enumerate(self._dataloader): + + # Start tracking CUDA memory for active steps for just the first epoch + if ( + curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + and self._device.type == "cuda" + ): + torch.cuda.memory._record_memory_history() + batch = batch["tokens"].to(self._device) _, context_length = batch.shape + num_tokens = batch.numel() - # step 1. generate the trajectory using: - # - the current policy (pi_theta) - # - the current value function (V_phi) - # - the reference frozen policy model (pi_theta_0) + # step 1. generate the trajectory + t0_traj = time.perf_counter() trajectory = self.generate_trajectory_batched(batch) + traj_time = time.perf_counter() - t0_traj - # step 2. get the rewards for the current trajectory. these are based on: - # - the divergence between the current policy and the reference policy - # - the scores from the reward model + # step 2. get the rewards for the current trajectory rewards, kl, kl_rewards = rlhf.get_rewards_ppo( trajectory.scores, trajectory.logprobs, @@ -867,7 +966,8 @@ def train(self) -> None: masks=~trajectory.response_padding_masks, ) - # step 4. optimise using the PPO objective over multiple epochs + # # step 4. optimise using the PPO objective over multiple epochs + t0_ppo = time.perf_counter() ppo_stats: List[PPOStats] = [] for _ in range(self._ppo_epochs): batch_idxs = torch.randperm(self.batch_size, device=self._device) @@ -893,7 +993,7 @@ def train(self) -> None: ) ) batch_ppo_stats.append( - self._ppo_step( + self.ppo_step( batch_trajectory, advantages[backward_batch_idxs], returns[backward_batch_idxs], @@ -909,6 +1009,7 @@ def train(self) -> None: self._optimizer.zero_grad(set_to_none=True) self.global_step += 1 + ppo_time = time.perf_counter() - t0_ppo # step 5. profit self._steps_run += 1 @@ -918,11 +1019,29 @@ def train(self) -> None: PPOStats(*map(torch.stack, zip(*ppo_stats))), kl, kl_rewards, + num_tokens / traj_time, + num_tokens / ppo_time, ) self.cleanup_after_step( trajectory, ppo_stats, advantages, returns, kl, kl_rewards ) pbar.update(1) + + # Stop tracking CUDA memory now that active steps are complete + if ( + curr_epoch == 0 + and self.profiler_profile_memory + and idx + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + and self._device.type == "cuda" + ): + torch.cuda.memory._record_memory_history(enabled=None) + + # Step the profiler + self._profiler.step() + if self._steps_run == self._total_steps: training_completed = True break @@ -934,9 +1053,12 @@ def train(self) -> None: curr_epoch, is_intermediate_checkpoint=not training_completed ) if training_completed: + self._profiler.stop() return - def _ppo_step( + self._profiler.stop() + + def ppo_step( self, trajectory: Trajectory, advantages: torch.Tensor, @@ -1023,6 +1145,8 @@ def log_metrics( ppo_stats: PPOStats, kl: torch.Tensor, kl_rewards: torch.Tensor, + tokens_per_second_trajectory: torch.Tensor, + tokens_per_second_loss: torch.Tensor, ) -> None: """ Log metrics and statistics for the current step to the metric logger. @@ -1040,6 +1164,8 @@ def log_metrics( "ratios": ppo_stats.ratios.mean(), "approx_policy_kl": ppo_stats.approx_policy_kls.mean(), "response_lengths": trajectory.seq_lens.float().mean(), + "tokens_per_second_per_gpu_trajectory": tokens_per_second_trajectory, + "tokens_per_second_per_gpu_ppo": tokens_per_second_loss, } if self._device.type == "cuda" and self._log_peak_memory_stats: log_dict.update(training.get_memory_stats(device=self._device)) diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index eaa2974579..f9ba25ca34 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -606,7 +606,7 @@ def _setup_data( for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) - packed = False + packed = getattr(ds, "packed", False) else: ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) @@ -773,6 +773,7 @@ def train(self) -> None: and curr_epoch == 0 and self.profiler_profile_memory and idx == self.profiler_wait_steps + self.profiler_warmup_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history() @@ -837,7 +838,9 @@ def train(self) -> None: if self._optimizer_in_bwd: torch.distributed.all_reduce(num_tokens) torch.distributed.all_reduce(running_loss) - current_loss = current_loss / num_tokens + + # We multiply by world_size to undo FSDP2 gradient normalization. + current_loss = current_loss * (world_size / num_tokens) current_loss.backward() @@ -849,12 +852,13 @@ def train(self) -> None: # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + # We multiply by world_size to undo FSDP2 gradient normalization. + training.scale_grads(self._model, world_size / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), max_norm=float(self._clip_grad_norm), - ) + ).full_tensor() self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) @@ -910,6 +914,7 @@ def train(self) -> None: == self.profiler_wait_steps + self.profiler_warmup_steps + self.profiler_active_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history(enabled=None) diff --git a/recipes/qat_lora_finetune_distributed.py b/recipes/qat_lora_finetune_distributed.py index b9080de77d..074113b216 100644 --- a/recipes/qat_lora_finetune_distributed.py +++ b/recipes/qat_lora_finetune_distributed.py @@ -633,7 +633,7 @@ def _setup_data( for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) - packed = False + packed = getattr(ds, "packed", False) else: ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) @@ -820,6 +820,7 @@ def train(self) -> None: and curr_epoch == 0 and self.profiler_profile_memory and idx == self.profiler_wait_steps + self.profiler_warmup_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history() @@ -866,12 +867,13 @@ def train(self) -> None: # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + # We multiply by world_size to undo FSDP2 gradient normalization. + training.scale_grads(self._model, world_size / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), max_norm=float(self._clip_grad_norm), - ) + ).full_tensor() self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) self._lr_scheduler.step() @@ -923,6 +925,7 @@ def train(self) -> None: == self.profiler_wait_steps + self.profiler_warmup_steps + self.profiler_active_steps + and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history(enabled=None) diff --git a/recipes/quantize.py b/recipes/quantize.py index bb28d45b87..f53abf182f 100644 --- a/recipes/quantize.py +++ b/recipes/quantize.py @@ -92,7 +92,11 @@ def quantize(self, cfg: DictConfig): self._model = self._quantizer.quantize(self._model) t = time.perf_counter() - t0 logger.info(f"Time for quantization: {t:.02f} sec") - logger.info(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + if self._device.type != "cpu": + torch_device = utils.get_torch_device_namespace() + logger.info( + f"Memory used: {torch_device.max_memory_allocated() / 1e9:.02f} GB" + ) def save_checkpoint(self, cfg: DictConfig): ckpt_dict = self._model.state_dict() diff --git a/tests/recipes/dev/test_generate_v2.py b/tests/recipes/dev/test_generate_v2.py index be3f995f58..2159e46b70 100644 --- a/tests/recipes/dev/test_generate_v2.py +++ b/tests/recipes/dev/test_generate_v2.py @@ -55,9 +55,7 @@ def test_llama2_generate_results(self, caplog, monkeypatch, tmpdir): # this is gibberish b/c the model is random weights, but it's # the expected value for what we currently have in V2 # this test should catch any changes to the generate recipe that affect output - expected_output = ( - "Country maior Connection Kohćutsójcustomulas Sometimes Security" - ) + expected_output = "Pietroместkap щotimes rivers cache НиtringindexPathNAME" logs = caplog.text assert expected_output in logs diff --git a/tests/test_utils.py b/tests/test_utils.py index 6497539869..ca28029710 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -20,8 +20,8 @@ import torch from torch import nn from torchtune.data import Message, PromptTemplate, truncate -from torchtune.modules.tokenizers import ModelTokenizer from torchtune.modules.transforms import Transform +from torchtune.modules.transforms.tokenizers import ModelTokenizer skip_if_cuda_not_available = unittest.skipIf( not torch.cuda.is_available(), "CUDA is not available" diff --git a/tests/torchtune/config/test_config_utils.py b/tests/torchtune/config/test_config_utils.py index bfce087dbf..b3a2baf063 100644 --- a/tests/torchtune/config/test_config_utils.py +++ b/tests/torchtune/config/test_config_utils.py @@ -28,6 +28,8 @@ }, "d": 4, "f": 8, + "g": "foo", + "h": "${g}/bar", } @@ -50,7 +52,9 @@ def test_get_component_from_path(self): ): _ = _get_component_from_path("torchtune.models.dummy") - @mock.patch("torchtune.config._parse.OmegaConf.load", return_value=_CONFIG) + @mock.patch( + "torchtune.config._parse.OmegaConf.load", return_value=OmegaConf.create(_CONFIG) + ) def test_merge_yaml_and_cli_args(self, mock_load): parser = TuneRecipeArgumentParser("test parser") yaml_args, cli_args = parser.parse_known_args( @@ -63,6 +67,7 @@ def test_merge_yaml_and_cli_args(self, mock_load): "d=6", # Test overriding a flat param "e=7", # Test adding a new param "~f", # Test removing a param + "g=bazz", # Test interpolation happens after override ] ) conf = _merge_yaml_and_cli_args(yaml_args, cli_args) @@ -75,6 +80,7 @@ def test_merge_yaml_and_cli_args(self, mock_load): assert conf.d == 6, f"d == {conf.d}, not 6 as set in overrides." assert conf.e == 7, f"e == {conf.e}, not 7 as set in overrides." assert "f" not in conf, f"f == {conf.f}, not removed as set in overrides." + assert conf.h == "bazz/bar", f"h == {conf.h}, not bazz/bar as set in overrides." mock_load.assert_called_once() yaml_args, cli_args = parser.parse_known_args( @@ -185,5 +191,5 @@ def test_remove_key_by_dotpath(self): # Test removing non-existent param fails cfg = copy.deepcopy(_CONFIG) - with pytest.raises(KeyError, match="'g'"): - _remove_key_by_dotpath(cfg, "g") + with pytest.raises(KeyError, match="'i'"): + _remove_key_by_dotpath(cfg, "i") diff --git a/tests/torchtune/config/test_parse.py b/tests/torchtune/config/test_parse.py index c4e278acaf..e396b10864 100644 --- a/tests/torchtune/config/test_parse.py +++ b/tests/torchtune/config/test_parse.py @@ -13,7 +13,7 @@ from torchtune.config._parse import TuneRecipeArgumentParser -_CONFIG = {"a": 1, "b": 2} +_CONFIG = {"a": 1, "b": 2, "c": "foo", "d": "${c}/bar"} class TestParse: @@ -41,7 +41,9 @@ def parser(self): parser = TuneRecipeArgumentParser("Test parser") return parser - @patch("torchtune.config._parse.OmegaConf.load", return_value=_CONFIG) + @patch( + "torchtune.config._parse.OmegaConf.load", return_value=OmegaConf.create(_CONFIG) + ) def test_parse_known_args(self, mock_load, parser): """ Test that the parser can load a config and override parameters provided on CLI. @@ -65,3 +67,11 @@ def test_parse_known_args(self, mock_load, parser): _ = parser.parse_known_args( ["--config", "test.yaml", "--b", "3"], ) + + # Test that parsing does not prematurely interpolate variables. + config_args, cli_args = parser.parse_known_args( + ["--config", "test.yaml", "c=bazz"] + ) + assert ( + config_args.d == "${c}/bar" + ), f"d == {config_args.d} not ${{c}}/bar as set in config." diff --git a/tests/torchtune/datasets/test_concat_dataset.py b/tests/torchtune/datasets/test_concat_dataset.py index 352d20e372..ffe1d4acbd 100644 --- a/tests/torchtune/datasets/test_concat_dataset.py +++ b/tests/torchtune/datasets/test_concat_dataset.py @@ -80,7 +80,7 @@ def test_invalid_index_type(self, datasets): with pytest.raises(TypeError): multi_dataset["invalid_type"] # Non-integer index - def test_packed_dataset(self, torch_datasets): + def test_single_packed_dataset(self, torch_datasets): torch_datasets[0] = PackedDataset( torch_datasets[0], max_seq_len=25, @@ -90,3 +90,33 @@ def test_packed_dataset(self, torch_datasets): with pytest.raises(ValueError): concated_dataset = ConcatDataset(torch_datasets) + + def test_all_packed_datasets(self, torch_datasets): + for i in range(len(torch_datasets)): + torch_datasets[i] = PackedDataset( + torch_datasets[i], + max_seq_len=2000, + max_packs=16, + split_across_pack=True, + ) + concated_dataset = ConcatDataset(torch_datasets) + assert concated_dataset.packed + + # 2k tokens per pack + # 1st ds has 4k tokens, 2nd ds has 8k tokens, 3rd ds has 15k tokens + # 4th ds has 16k tokens, 5th ds has 23k tokens, 6th ds has 42k tokens + + assert concated_dataset[0]["seq_lens"][0] == 4 + # 2nd packed ds starts at idx 2 + assert concated_dataset[2]["seq_lens"][0] == 8 + # 3rd packed ds starts at idx 6 + assert concated_dataset[6]["seq_lens"][0] == 15 + # 4th packed ds starts at idx 14 + assert concated_dataset[14]["seq_lens"][0] == 16 + # 5th packed ds starts at idx 22 + assert concated_dataset[22]["seq_lens"][0] == 23 + # 6th packed ds starts at idx 34 + assert concated_dataset[34]["seq_lens"][0] == 42 + + # Total length is 2 + 4 + 8 + 8 + 12 + 16 (because of max_packs) = 50 + assert len(concated_dataset) == 50 diff --git a/tests/torchtune/generation/test_generation.py b/tests/torchtune/generation/test_generation.py index 4efd1e3acd..b740e7afbf 100644 --- a/tests/torchtune/generation/test_generation.py +++ b/tests/torchtune/generation/test_generation.py @@ -245,7 +245,7 @@ def test_reproducibility(self, request, model1, model2, prompt_tokens): top_k = 100 torch.manual_seed(42) - outputs_first, _ = generate( + outputs_first, logits_first = generate( model=model1, prompt=prompt_tokens, max_generated_tokens=10, @@ -254,17 +254,15 @@ def test_reproducibility(self, request, model1, model2, prompt_tokens): ) torch.manual_seed(42) - outputs_second, _ = generate( + outputs_second, logits_second = generate( model=model2, prompt=prompt_tokens, max_generated_tokens=10, temperature=temperature, top_k=top_k, ) - - # slicing for the last 18 tokens - this is the whole sequence for unpadded inputs - # and excludes the first two tokens for padded inputs, which are padding tokens assert torch.equal(outputs_first, outputs_second) + torch.testing.assert_close(logits_first, logits_second) @pytest.mark.parametrize( "model1", @@ -303,7 +301,7 @@ def test_reproducibility_batched(self, request, model1, model2, prompt1, prompt2 top_k = 100 torch.manual_seed(42) - outputs_first, _ = generate( + outputs_first, logits_first = generate( model=model1, prompt=prompt1, max_generated_tokens=10, @@ -312,7 +310,7 @@ def test_reproducibility_batched(self, request, model1, model2, prompt1, prompt2 ) torch.manual_seed(42) - outputs_second, _ = generate( + outputs_second, logits_second = generate( model=model2, prompt=prompt2, max_generated_tokens=10, @@ -323,6 +321,8 @@ def test_reproducibility_batched(self, request, model1, model2, prompt1, prompt2 # slicing for the last 18 tokens - this is the whole sequence for unpadded inputs # and excludes the first two tokens for padded inputs, which are padding tokens assert torch.equal(outputs_first[:, -18:], outputs_second[:, -18:]) + # logits are only ever returned for the generated tokens, so no slicing needed + torch.testing.assert_close(logits_first, logits_second, atol=1e-4, rtol=1e-6) @pytest.mark.parametrize( "model", @@ -343,7 +343,8 @@ def test_stop_tokens_batched(self, request, model, prompt, expected_tokens_batch top_k = 100 # This is the first token generated by the model - # so it should stop immediately + # so it should stop immediately resulting in only a single + # token being generated stop_tokens = [3987, 3958, 3989] torch.manual_seed(42) @@ -465,7 +466,6 @@ def test_stop_tokens_batched_uneven_stopping_left_padded( [0, 0, 2, 3, 4, 5, 6, 7, 8, 9, 3989, 0, 0], ] ) - assert torch.equal(outputs, expected_output) diff --git a/tests/torchtune/models/qwen2_5/test_tokenizer.py b/tests/torchtune/models/qwen2_5/test_tokenizer.py index 332fef2b92..8649d6d45e 100644 --- a/tests/torchtune/models/qwen2_5/test_tokenizer.py +++ b/tests/torchtune/models/qwen2_5/test_tokenizer.py @@ -24,50 +24,15 @@ def test_tokenize_messages(self): Message(role="user", content="Give me a short introduction to LLMs."), Message(role="assistant", content=""), ] + + # fmt: off expected_tokens = [ - 151644, - 82, - 88, - 479, - 94, - 56, - 119, - 230, - 98, - 374, - 494, - 1318, - 249, - 13, - 151645, - 94, - 151644, - 273, - 105, - 94, - 38, - 229, - 362, - 98, - 1695, - 310, - 1305, - 165, - 128, - 432, - 43, - 44, - 82, - 13, - 151645, - 94, - 151644, - 397, - 251, - 249, - 94, + 151644, 82, 88, 479, 94, 56, 119, 230, 98, 374, 494, 1318, 249, 13, 151645, 94, 151644, 273, 105, 94, + 38, 229, 362, 98, 1695, 310, 1305, 165, 128, 432, 43, 44, 82, 13, 151645, 94, 151644, 397, 251, 249, 94, 151643, - ] + ] # noqa + # fmt: on + expected_formatted_messages = ( "<|im_start|>system\n" "You are a helpful assistant.<|im_end|>\n" @@ -92,75 +57,15 @@ def test_tool_call(self): Message(role="ipython", content="test response"), Message(role="assistant", content=""), ] + # fmt: off expected_tokens = [ - 151644, - 82, - 88, - 479, - 94, - 64, - 151645, - 94, - 151644, - 273, - 105, - 94, - 65, - 151645, - 94, - 151644, - 397, - 251, - 249, - 94, - 151657, - 94, - 83, - 269, - 107, - 330, - 94, - 151658, - 151645, - 94, - 151644, - 273, - 105, - 94, - 27, - 83, - 1364, - 62, - 237, - 79, - 102, - 182, - 29, - 94, - 83, - 269, - 706, - 102, - 182, - 94, - 1932, - 83, - 1364, - 62, - 237, - 79, - 102, - 182, - 29, - 151645, - 94, - 151644, - 397, - 251, - 249, - 94, - 151643, - ] + 151644, 82, 88, 479, 94, 64, 151645, 94, 151644, 273, 105, 94, 65, 151645, 94, 151644, 397, 251, 249, + 94, 151657, 94, 83, 269, 107, 330, 94, 151658, 151645, 94, 151644, 273, 105, 94, 27, 83, 1364, + 62, 237, 79, 102, 182, 29, 94, 83, 269, 706, 102, 182, 94, 1932, 83, 1364, 62, 237, 79, 102, + 182, 29, 151645, 94, 151644, 397, 251, 249, 94, 151643, + ] # noqa + # fmt: on + expected_formatted_messages = ( "<|im_start|>system\n" "a<|im_end|>\n" diff --git a/tests/torchtune/models/t5/test_t5_encoder.py b/tests/torchtune/models/t5/test_t5_encoder.py index dbb8dbb472..32f306664b 100644 --- a/tests/torchtune/models/t5/test_t5_encoder.py +++ b/tests/torchtune/models/t5/test_t5_encoder.py @@ -51,24 +51,24 @@ def test_forward(self, model, inputs): expected = torch.tensor( [ [ - [0.3670, 0.2938], - [0.3692, 0.2921], - [0.3611, 0.2984], - [0.4207, 0.2437], - [0.3447, 0.3106], - [0.3383, 0.3150], - [0.3727, 0.2892], - [0.3996, 0.2653], + [0.1940, 0.5625], + [0.1893, 0.5681], + [0.2020, 0.5522], + [0.2547, 0.4681], + [0.1769, 0.5822], + [0.2737, 0.4281], + [0.2828, 0.4066], + [0.2841, 0.4033], ], [ - [0.3855, 0.2783], - [0.2627, 0.3581], - [0.3601, 0.2992], - [0.3473, 0.3087], - [0.3549, 0.3032], - [0.2871, 0.3459], - [0.2753, 0.3520], - [0.2285, 0.3728], + [0.1796, 0.5792], + [0.2020, 0.5523], + [0.2209, 0.5258], + [0.2802, 0.4128], + [0.2923, 0.3817], + [0.2677, 0.4414], + [0.2458, 0.4847], + [0.1923, 0.5645], ], ] ) diff --git a/tests/torchtune/modules/_export/test_export_position_embeddings.py b/tests/torchtune/modules/_export/test_export_position_embeddings.py index 6907ca3edd..3beb23e7ef 100644 --- a/tests/torchtune/modules/_export/test_export_position_embeddings.py +++ b/tests/torchtune/modules/_export/test_export_position_embeddings.py @@ -161,7 +161,6 @@ def test_tiled_token_positional_embedding_aoti(self): with tempfile.TemporaryDirectory() as tmpdir: path = torch._inductor.aoti_compile_and_package( tpe_ep, - (self.x, self.aspect_ratio), package_path=os.path.join(tmpdir, "tpe.pt2"), ) tpe_aoti = load_package(path) diff --git a/tests/torchtune/modules/loss/test_kd_losses.py b/tests/torchtune/modules/loss/test_kd_losses.py index ddfdd4012c..e820e841ac 100644 --- a/tests/torchtune/modules/loss/test_kd_losses.py +++ b/tests/torchtune/modules/loss/test_kd_losses.py @@ -24,7 +24,7 @@ def random(): class TestForwardKLWithChunkedOutputLoss: - def test_forward_kl_loss(self): + def setup_forward_kl_loss(self, ignore_all_tokens: bool = False): # Create a sample input and label ignore_index = -100 batch_size = 3 @@ -40,7 +40,10 @@ def test_forward_kl_loss(self): # add random ignore index to random tokens in the label random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens)) - labels[random_indices < num_tokens // 5] = ignore_index + if ignore_all_tokens: + labels[:] = ignore_index + else: + labels[random_indices < num_tokens // 5] = ignore_index # chunked FKL chunked_fkl_loss = ForwardKLWithChunkedOutputLoss( @@ -58,6 +61,29 @@ def test_forward_kl_loss(self): teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1)) labels = labels.reshape(-1) standard_loss = fkl_loss(logits, teacher_logits, labels) + return chunked_loss, standard_loss + + def test_forward_kl_loss(self): + + chunked_loss, standard_loss = self.setup_forward_kl_loss( + ignore_all_tokens=False + ) + + # Assert + assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2) + + def test_forward_kl_loss_zero_masks(self): + + # set all masks to zero + chunked_loss, standard_loss = self.setup_forward_kl_loss(ignore_all_tokens=True) + + # Assert + assert_expected( + chunked_loss, + torch.tensor(0.0, device=chunked_loss.device), + rtol=1e-2, + atol=1e-2, + ) # Assert assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2) diff --git a/tests/torchtune/modules/low_precision/test_nf4_linear.py b/tests/torchtune/modules/low_precision/test_nf4_linear.py index 65f51338a4..dc266212b5 100644 --- a/tests/torchtune/modules/low_precision/test_nf4_linear.py +++ b/tests/torchtune/modules/low_precision/test_nf4_linear.py @@ -4,12 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -try: - import bitsandbytes as bnb - - bnb_installed = True -except ImportError: - bnb_installed = False import pytest import torch from torchao.dtypes.nf4tensor import NF4Tensor @@ -22,19 +16,6 @@ def random(): set_seed(31) -def _build_bnb_linear(input_weight): - """ - Builds a bnb.nn.LinearNF4 from a given input weight - """ - param = bnb.nn.Params4bit(input_weight, requires_grad=False, quant_type="nf4") - bnb_linear = bnb.nn.LinearNF4( - input_weight.size(0), input_weight.size(1), bias=False - ) - bnb_linear.weight = param - bnb_linear.cuda() - return bnb_linear - - class TestNF4Linear: """ Class for testing our NF4Linear implementation. @@ -88,7 +69,6 @@ def test_backward_dtype(self, dtype): assert inp.grad is not None and inp.grad.dtype == dtype assert nf4_linear.weight.grad is None - @pytest.mark.skipif(not bnb_installed, reason="bitsandbytes is not installed") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) def test_nf4_reconstruction_vs_bnb(self, dtype): @@ -96,10 +76,22 @@ def test_nf4_reconstruction_vs_bnb(self, dtype): Ensures a BNB NF4 linear and our FrozenNF4Linear have low error when reconstructing the respective original weights. """ + try: + import bitsandbytes as bnb + except ImportError: + pytest.skip("bitsandbytes is not installed") + return + dim = 512 nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype) orig_weight = nf4_linear.weight.get_original_weight().clone().detach() - bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight) + + param = bnb.nn.Params4bit(orig_weight, requires_grad=False, quant_type="nf4") + bnb_nf4_linear = bnb.nn.LinearNF4( + orig_weight.size(0), orig_weight.size(1), bias=False + ) + bnb_nf4_linear.weight = param + bnb_nf4_linear.cuda() # From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65 bnb_reconstruction = bnb_nf4_linear( @@ -110,7 +102,6 @@ def test_nf4_reconstruction_vs_bnb(self, dtype): bnb_reconstruction.T, nf4_linear.weight.get_original_weight(), 1e-2 ) - @pytest.mark.skipif(not bnb_installed, reason="bitsandbytes is not installed") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) def test_nf4_bnb_linear(self, dtype): @@ -118,10 +109,23 @@ def test_nf4_bnb_linear(self, dtype): This test ensures that nf4_linear is "no worse" than BNB by ensuring the error compared to a bf16 linear is not more than BNB's implementation. """ + try: + import bitsandbytes as bnb + except ImportError: + pytest.skip("bitsandbytes is not installed") + return + dim = 512 nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype) orig_weight = nf4_linear.weight.get_original_weight().clone().detach() - bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight) + + param = bnb.nn.Params4bit(orig_weight, requires_grad=False, quant_type="nf4") + bnb_nf4_linear = bnb.nn.LinearNF4( + orig_weight.size(0), orig_weight.size(1), bias=False + ) + bnb_nf4_linear.weight = param + bnb_nf4_linear.cuda() + bf16_linear = torch.nn.Linear(dim, dim, device="cuda", dtype=dtype) inp = torch.randn(2, 512, dtype=dtype, device="cuda") diff --git a/tests/torchtune/modules/model_fusion/test_early_fusion.py b/tests/torchtune/modules/model_fusion/test_early_fusion.py index d7ff407289..0d97594e11 100644 --- a/tests/torchtune/modules/model_fusion/test_early_fusion.py +++ b/tests/torchtune/modules/model_fusion/test_early_fusion.py @@ -334,3 +334,17 @@ def test_state_dict_hooks(self, fused_model, state_dict): actual = fused_model.state_dict() expected = state_dict assert_expected(actual, expected) + + def test_sequential_state_dict_hooks(self, fused_model, state_dict): + """ + Test that state dict hooks work when EarlyFusion is wrapped in a larger model + """ + sequential_model = nn.Sequential(fused_model, nn.Linear(10, 10, bias=False)) + linear_weight = torch.randn(10, 10) + sequential_state_dict = {f"0.{k}": v for k, v in state_dict.items()} + sequential_state_dict.update({"1.weight": linear_weight}) + sequential_model.load_state_dict(sequential_state_dict) + actual = sequential_model.state_dict() + expected = sequential_state_dict + expected.update({"1.weight": linear_weight}) + assert_expected(actual, expected) diff --git a/tests/torchtune/modules/peft/test_dora.py b/tests/torchtune/modules/peft/test_dora.py index b364b70992..7a9df10a54 100644 --- a/tests/torchtune/modules/peft/test_dora.py +++ b/tests/torchtune/modules/peft/test_dora.py @@ -9,9 +9,12 @@ import pytest import torch + +import torch.distributed from tests.test_utils import fixed_init_model, gpu_test from torch import nn from torch.distributed._composable.fsdp import fully_shard +from torch.distributed._tensor import DTensor, Replicate from torch.testing._internal.common_fsdp import FSDPTest from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4 from torchtune import training @@ -402,4 +405,10 @@ def _test_dora_distributed_init(self, load_dora_weights): ) expected_magnitude = torch.linalg.norm(weight, axis=1).to(device=device) actual_magnitude = getattr(ffn, layer).magnitude.full_tensor() + # to explicit replicate the tensor before comparing with DTensor + if isinstance(expected_magnitude, DTensor): + device_mesh = torch.distributed.init_device_mesh("cuda", (2,)) + actual_magnitude = DTensor.from_local( + actual_magnitude, device_mesh=device_mesh, placements=[Replicate()] + ) torch.testing.assert_close(expected_magnitude, actual_magnitude) diff --git a/tests/torchtune/modules/test_position_embeddings.py b/tests/torchtune/modules/test_position_embeddings.py index 282fce085c..af9dead941 100644 --- a/tests/torchtune/modules/test_position_embeddings.py +++ b/tests/torchtune/modules/test_position_embeddings.py @@ -136,20 +136,21 @@ def test_rope_init_meta_device(self, input_params): class TestVisionRotaryPositionEmbedding: - EXPECTED_X_OUT_MEAN = tensor(0.0789793) - EXPECTED_X_OUT_SUM = tensor(25.2733822) - EXPECTED_X_OUT_MAX = tensor(3.1225626) + EXPECTED_X_OUT_MEAN = tensor(-0.00903320) + EXPECTED_X_OUT_SUM = tensor(-29.48437119) + EXPECTED_X_OUT_MAX = tensor(4.07074356) @pytest.fixture def input_params(self): bsz = 2 + max_num_tiles = 3 num_heads = 8 embed_dim = 32 head_dim = embed_dim // num_heads - seq_len = 5 patch_size = 4 tile_size = 16 - return bsz, num_heads, head_dim, seq_len, patch_size, tile_size + seq_len = ((tile_size // patch_size) ** 2 + 1) * max_num_tiles + return bsz, num_heads, head_dim, seq_len, max_num_tiles, patch_size, tile_size @pytest.fixture def input(self, input_params) -> tensor: @@ -158,9 +159,12 @@ def input(self, input_params) -> tensor: @pytest.fixture def rope(self, input_params): - _, _, head_dim, _, patch_size, tile_size = input_params + _, _, head_dim, _, max_num_tiles, patch_size, tile_size = input_params return VisionRotaryPositionalEmbeddings( - patch_size=patch_size, tile_size=tile_size, dim=head_dim // 2 + patch_size=patch_size, + tile_size=tile_size, + max_num_tiles=max_num_tiles, + dim=head_dim // 2, ) @mps_ignored_test() @@ -175,63 +179,20 @@ def test_forward(self, input, rope) -> None: # check shapes assert_expected(x_out.shape, input.shape) - @mps_ignored_test() - def test_forward_with_curr_pos(self, input, rope) -> None: - ( - _, - seq_len, - _, - _, - ) = input.shape - x_out = rope(input, input_pos=torch.arange(seq_len)) - - # these values should be exactly the same as test_forward - # since in this case input_pos covers the entire input - # sequence. This tests that input_pos works as expected i.e. - # extracts the embeddings for the relevant positions - assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4) - assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM) - assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX) - - # check shapes - assert_expected(x_out.shape, input.shape) - - @mps_ignored_test() - def test_forward_with_packed_pos(self, input, rope) -> None: - """ - Use input_pos to indicate positions of each token relative to its sequence - when sample is packed. - """ - ( - bsz, - seq_len, - _, - _, - ) = input.shape - x_out = rope( - input, input_pos=torch.arange(seq_len).unsqueeze(0).expand(bsz, seq_len) - ) - - # these values should be exactly the same as test_forward - # AND test_forward_with_current_pos. In this case input_pos - # covers the entire batch dim and is defined for each sample separately. - # This tests that input_pos works as expected i.e. - # extracts the embeddings for the relevant positions for each sample - assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4) - assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM) - assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX) - - # check shapes - assert_expected(x_out.shape, input.shape) - def test_rope_init_meta_device(self, input_params): - _, _, head_dim, _, patch_size, tile_size = input_params + _, _, head_dim, _, max_num_tiles, patch_size, tile_size = input_params rope_on_device = VisionRotaryPositionalEmbeddings( - dim=head_dim, patch_size=patch_size, tile_size=tile_size + dim=head_dim, + patch_size=patch_size, + max_num_tiles=max_num_tiles, + tile_size=tile_size, ) with torch.device("meta"): meta_rope = VisionRotaryPositionalEmbeddings( - dim=head_dim, patch_size=patch_size, tile_size=tile_size + dim=head_dim, + patch_size=patch_size, + tile_size=tile_size, + max_num_tiles=max_num_tiles, ) meta_rope.rope_init() diff --git a/tests/torchtune/modules/tokenizers/test_sentencepiece.py b/tests/torchtune/modules/transforms/tokenizers/test_sentencepiece.py similarity index 97% rename from tests/torchtune/modules/tokenizers/test_sentencepiece.py rename to tests/torchtune/modules/transforms/tokenizers/test_sentencepiece.py index d11c1b9c52..217f0bf2d8 100644 --- a/tests/torchtune/modules/tokenizers/test_sentencepiece.py +++ b/tests/torchtune/modules/transforms/tokenizers/test_sentencepiece.py @@ -7,7 +7,7 @@ import pytest from tests.common import ASSETS -from torchtune.modules.tokenizers import SentencePieceBaseTokenizer +from torchtune.modules.transforms.tokenizers import SentencePieceBaseTokenizer class TestSentencePieceBaseTokenizer: diff --git a/tests/torchtune/modules/tokenizers/test_tiktoken.py b/tests/torchtune/modules/transforms/tokenizers/test_tiktoken.py similarity index 98% rename from tests/torchtune/modules/tokenizers/test_tiktoken.py rename to tests/torchtune/modules/transforms/tokenizers/test_tiktoken.py index e7e69f62d3..5d3608d4bd 100644 --- a/tests/torchtune/modules/tokenizers/test_tiktoken.py +++ b/tests/torchtune/modules/transforms/tokenizers/test_tiktoken.py @@ -8,7 +8,7 @@ from tests.common import ASSETS from torchtune.models.llama3._tokenizer import CL100K_PATTERN -from torchtune.modules.tokenizers import TikTokenBaseTokenizer +from torchtune.modules.transforms.tokenizers import TikTokenBaseTokenizer class TestTikTokenBaseTokenizer: diff --git a/tests/torchtune/modules/tokenizers/test_utils.py b/tests/torchtune/modules/transforms/tokenizers/test_utils.py similarity index 94% rename from tests/torchtune/modules/tokenizers/test_utils.py rename to tests/torchtune/modules/transforms/tokenizers/test_utils.py index 2c49d82a5a..e3a11e6f36 100644 --- a/tests/torchtune/modules/tokenizers/test_utils.py +++ b/tests/torchtune/modules/transforms/tokenizers/test_utils.py @@ -9,7 +9,7 @@ from tests.test_utils import DummyTokenizer from torchtune.data import Message -from torchtune.modules.tokenizers import tokenize_messages_no_special_tokens +from torchtune.modules.transforms.tokenizers import tokenize_messages_no_special_tokens class TestTokenizerUtils: diff --git a/tests/torchtune/rlhf/loss/test_dpo_loss.py b/tests/torchtune/rlhf/loss/test_dpo_loss.py index 6c3e2dd4e0..ab1bfefa4c 100644 --- a/tests/torchtune/rlhf/loss/test_dpo_loss.py +++ b/tests/torchtune/rlhf/loss/test_dpo_loss.py @@ -6,7 +6,7 @@ import pytest import torch -from torchtune.rlhf.loss import DPOLoss, RSOLoss, SimPOLoss +from torchtune.rlhf.loss import DPOLoss, RSOLoss @pytest.fixture(autouse=True) @@ -28,14 +28,6 @@ def rso_loss(self): gamma=0.1, ) - @pytest.fixture - def simpo_loss(self): - return SimPOLoss( - beta=2.0, - gamma=0.5, - label_smoothing=0.0, - ) - @pytest.fixture def loss_inputs(self): """ @@ -102,24 +94,3 @@ def test_rso_loss(self, rso_loss, loss_inputs): losses, *_ = rso_loss(*loss_inputs) torch.testing.assert_close(losses, expected_losses, atol=1e-4, rtol=1e-5) - - def test_simpo_loss(self, simpo_loss, loss_inputs): - """ - here's the maths (see `loss_inputs`): - ratios = torch.tensor([-0.4, 20.0, 20.0]) - gamma_logratios = 0.25 - - logits is ratios - gamma_logratios - - logits = torch.tensor([-0.65, 19.75, 19.75]) - scaled_logits = beta * logits = torch.tensor([-1.3, 39.5, 39.5]) - - since label_smoothing is zero, loss is NLL with temperature scaled logits - """ - policy_chosen_logprobs, policy_rejected_logprobs, *_ = loss_inputs - exp_scaled_logits = torch.exp(torch.tensor([1.3, -39.5, -39.5])) - - expected_losses = -(1 / (1 + exp_scaled_logits)).log() - losses, *_ = simpo_loss(policy_chosen_logprobs, policy_rejected_logprobs) - - torch.testing.assert_close(losses, expected_losses, atol=1e-4, rtol=1e-5) diff --git a/tests/torchtune/rlhf/test_rewards.py b/tests/torchtune/rlhf/test_rewards.py index 4284d1d63d..ecc5422cf5 100644 --- a/tests/torchtune/rlhf/test_rewards.py +++ b/tests/torchtune/rlhf/test_rewards.py @@ -9,7 +9,7 @@ class TestGetRewards: - def test_get_rewards(self): + def test_get_rewards_ppo(self): scores = torch.tensor([1.0, 2.0, 3.0]) logprobs = torch.tensor( [ @@ -25,7 +25,7 @@ def test_get_rewards(self): [0.9, 1.0, 1.1], ] ) - kl_controller_value = 0.5 + kl_coeff = 0.5 # expected kl is logprobs - ref_logprobs expected_kl = torch.tensor( @@ -36,7 +36,7 @@ def test_get_rewards(self): ] ) - # expected kl_rewards is -kl_controller_value * kl + # expected kl_rewards is -kl_coeff * kl expected_kl_rewards = torch.tensor( [ [0.05, 0.05, 0.05], @@ -55,7 +55,22 @@ def test_get_rewards(self): ) rewards, kl, kl_rewards = rlhf.get_rewards_ppo( - scores, logprobs, ref_logprobs, kl_controller_value + scores, logprobs, ref_logprobs, kl_coeff + ) + + torch.testing.assert_close(kl, expected_kl, rtol=1e-4, atol=1e-4) + torch.testing.assert_close( + kl_rewards, expected_kl_rewards, rtol=1e-4, atol=1e-4 + ) + torch.testing.assert_close(rewards, expected_rewards, rtol=1e-4, atol=1e-4) + + # add a test to ensure valid_score_idxs works as expected + rewards, kl, kl_rewards = rlhf.get_rewards_ppo( + scores, + logprobs, + ref_logprobs, + kl_coeff, + valid_score_idxs=torch.tensor([2, 2, 2]), ) torch.testing.assert_close(kl, expected_kl, rtol=1e-4, atol=1e-4) @@ -137,7 +152,7 @@ def test_masked_var(self): mask = torch.tensor([True, True, True, False, False]) expected_var = torch.tensor(1.0) - output = rlhf.masked_var(x, mask) + output = rlhf.masked_var(x - rlhf.masked_mean(x, mask), mask) torch.testing.assert_close(output, expected_var, rtol=1e-4, atol=1e-4) diff --git a/tests/torchtune/training/test_distributed.py b/tests/torchtune/training/test_distributed.py index 2e5e16da9a..3fe2dd340d 100644 --- a/tests/torchtune/training/test_distributed.py +++ b/tests/torchtune/training/test_distributed.py @@ -10,20 +10,24 @@ import pytest import torch +import torch.distributed as dist import torch.nn as nn from packaging import version from tests.test_utils import gpu_test from torch.distributed import launcher - from torch.distributed._composable.fsdp import fully_shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( CheckpointWrapper, ) +from torch.testing._internal.common_distributed import MultiProcessTestCase from torch.testing._internal.common_fsdp import FSDPTest, MLP from torchao.dtypes.nf4tensor import NF4Tensor from torchtune import modules, training from torchtune.models.llama2._component_builders import lora_llama2 -from torchtune.modules import TransformerSelfAttentionLayer +from torchtune.models.llama3_1._component_builders import llama3_mlp +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE +from torchtune.modules import RMSNorm, TransformerSelfAttentionLayer +from torchtune.modules.attention import MultiHeadAttention from torchtune.modules.peft import ( DoRALinear, get_adapter_params, @@ -379,3 +383,57 @@ def _broadcast_full_state_dict(self, full_sd): result.append(None) torch.distributed.broadcast_object_list(result, src=0) return result[0] + + +class TestTensorParalell(MultiProcessTestCase): + @property + def world_size(self) -> int: + return 2 + + @gpu_test(gpu_count=2) + def test_prepare_mha_for_tp(self) -> None: + """Test tensor parallelism preparation for multi-head attention.""" + # Create a device mesh for tensor parallelism + mesh = dist.init_device_mesh("cuda", mesh_shape=(2,)) + + # Parameters for TransformerSelfAttentionLayer + embed_dim = 64 + hidden_dim = 64 + num_heads = 4 + num_kv_heads = 4 + max_seq_len = 128 + rope_base = 500000 + head_dim = embed_dim // num_heads + rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + self_attn = MultiHeadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + max_seq_len=max_seq_len, + attn_dropout=0.0, + ) + mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim) + decoder_layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=1e-5), + mlp_norm=RMSNorm(dim=embed_dim, eps=1e-5), + ) + + orig_num_heads = self_attn.num_heads + orig_num_kv_heads = self_attn.num_kv_heads + orig_embed_dim = self_attn.embed_dim + + # Apply tensor parallelism preparation + decoder_layer = training.prepare_mha_for_tp(decoder_layer, mesh) + + # Verify that parameters were scaled correctly + assert decoder_layer.attn.num_heads == orig_num_heads // 2 + assert decoder_layer.attn.num_kv_heads == orig_num_kv_heads // 2 + assert decoder_layer.attn.embed_dim == orig_embed_dim // 2 diff --git a/tests/torchtune/training/test_pooling.py b/tests/torchtune/training/test_pooling.py index bb1204b6bf..9f2bd956d7 100644 --- a/tests/torchtune/training/test_pooling.py +++ b/tests/torchtune/training/test_pooling.py @@ -7,7 +7,7 @@ from torchtune.training.pooling import get_unmasked_sequence_lengths -class TestGetLastUnmaskedTokenIdx: +class TestGetUnmaskedSeqenceLengths: def test_get_last_unmasked_token_idx_multi_batch(self): """ Tests that the last non-padding tokens are correctly selected for a multi-batch input. diff --git a/tests/torchtune/training/test_profiler.py b/tests/torchtune/training/test_profiler.py index 58d4c4a164..b66ee1bf05 100644 --- a/tests/torchtune/training/test_profiler.py +++ b/tests/torchtune/training/test_profiler.py @@ -39,6 +39,7 @@ def profiler_cfg(): enabled: True cpu: True cuda: True + xpu: True profile_memory: False with_stack: False record_shapes: True @@ -92,6 +93,7 @@ def reference_profiler_basic(): activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, + torch.profiler.ProfilerActivity.XPU, ], schedule=torch.profiler.schedule(wait=3, warmup=1, active=1, repeat=0), profile_memory=False, @@ -107,6 +109,7 @@ def reference_profiler_full(): activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, + torch.profiler.ProfilerActivity.XPU, ], schedule=torch.profiler.schedule(wait=3, warmup=1, active=1, repeat=0), profile_memory=True, @@ -194,10 +197,12 @@ def test_default_activities(profiler_cfg): # Test setup automatically adds CPU + CUDA tracing if neither CPU nor CUDA is specified cfg.pop("cpu") cfg.pop("cuda") + cfg.pop("xpu") profiler, updated_cfg = _setup_profiler(cfg) assert profiler.activities == DEFAULT_PROFILER_ACTIVITIES assert updated_cfg.cpu is True assert updated_cfg.cuda is True + assert updated_cfg.xpu is True def test_default_output_dir(profiler_cfg): diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 4b504315a9..1c41519712 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -433,6 +433,25 @@ class Recipe: ], supports_distributed=False, ), + Recipe( + name="dev/generate_v2_distributed", + file_path="dev/generate_v2_distributed.py", + configs=[ + Config( + name="llama3/70B_generation_distributed", + file_path="llama3/70B_generation_distributed.yaml", + ), + Config( + name="llama3_1/70B_generation_distributed", + file_path="llama3_1/70B_generation_distributed.yaml", + ), + Config( + name="llama3_3/70B_generation_distributed", + file_path="llama3_3/70B_generation_distributed.yaml", + ), + ], + supports_distributed=True, + ), Recipe( name="dev/early_exit_finetune_distributed", file_path="dev/early_exit_finetune_distributed.py", diff --git a/torchtune/config/_parse.py b/torchtune/config/_parse.py index 5a8e762333..0a29d3be22 100644 --- a/torchtune/config/_parse.py +++ b/torchtune/config/_parse.py @@ -57,7 +57,7 @@ def parse_known_args(self, *args, **kwargs) -> Tuple[Namespace, List[str]]: config = OmegaConf.load(namespace.config) assert "config" not in config, "Cannot use 'config' within a config file" - self.set_defaults(**config) + self.set_defaults(**OmegaConf.to_container(config, resolve=False)) namespace, unknown_args = super().parse_known_args(*args, **kwargs) del namespace.config diff --git a/torchtune/data/_collate.py b/torchtune/data/_collate.py index 5157f4a7fa..410ad49376 100644 --- a/torchtune/data/_collate.py +++ b/torchtune/data/_collate.py @@ -81,13 +81,14 @@ def padded_collate( padding values. Returns: - torch.Tensor: The padded tensor of input ids with shape [batch_size, max_seq_len]. + torch.Tensor: The padded tensor of input ids with shape ``[batch_size, max_seq_len]``. Raises: - ValueError: if ``pad_direction`` is not one of "left" or "right". - ValueError: if ``keys_to_pad`` is empty, or is not a list, or is not a subset of keys in the batch. - ValueError: if ``padding_idx`` is provided as a dictionary, but the keys are not identical to - ``keys_to_pad``. + ValueError: + If ``pad_direction`` is not one of "left" or "right", **or** + if ``keys_to_pad`` is empty, or is not a list, **or** + if ``keys_to_pad`` is not a subset of keys in the batch, **or** + if ``padding_idx`` is provided as a dictionary, but the keys are not identical to ``keys_to_pad`` Example: >>> a = [1, 2, 3] @@ -149,9 +150,9 @@ def padded_collate( output_dict[k] = pad_fn( [torch.tensor(x[k]) for x in batch], batch_first=True, - padding_value=padding_idx[k] - if isinstance(padding_idx, dict) - else padding_idx, + padding_value=( + padding_idx[k] if isinstance(padding_idx, dict) else padding_idx + ), ) return output_dict @@ -274,8 +275,9 @@ def padded_collate_tiled_images_and_mask( - aspect_ratio: Tensor of shape (bsz, max_num_images, 2) Raises: - ValueError: if ``pad_direction`` is not one of "left" or "right". - ValueError: if pad_max_tiles is set to a value less than the largest number of tiles in an image. + ValueError: + If ``pad_direction`` is not one of "left" or "right", **or** + if pad_max_tiles is set to a value less than the largest number of tiles in an image. Example: >>> image_id = 1 diff --git a/torchtune/data/_messages.py b/torchtune/data/_messages.py index bbd3ae5981..170970e5c5 100644 --- a/torchtune/data/_messages.py +++ b/torchtune/data/_messages.py @@ -22,9 +22,10 @@ class Message: """ This class represents individual messages in a fine-tuning dataset. It supports - text-only content, text with interleaved images, and tool calls. The :class:`~torchtune.modules.tokenizers.ModelTokenizer` - will tokenize the content of the message using ``tokenize_messages`` and attach - the appropriate special tokens based on the flags set in this class. + text-only content, text with interleaved images, and tool calls. The + :class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` will tokenize + the content of the message using ``tokenize_messages`` and attach the appropriate + special tokens based on the flags set in this class. Args: role (Role): role of the message writer. Can be "system" for system prompts, @@ -168,9 +169,10 @@ class InputOutputToMessages(Transform): on a remote url. For text-only, leave as None. Default is None. Raises: - ValueError: If ``column_map`` is provided and ``input`` not in ``column_map``, or - ``output`` not in ``column_map``. - ValueError: If ``image_dir`` is provided but ``image`` not in ``column_map``. + ValueError: + If ``column_map`` is provided and ``input`` not in ``column_map``, or + ``output`` not in ``column_map``, **or** + if ``image_dir`` is provided but ``image`` not in ``column_map``. """ def __init__( diff --git a/torchtune/data/_utils.py b/torchtune/data/_utils.py index 812d1617a1..6a266cebbe 100644 --- a/torchtune/data/_utils.py +++ b/torchtune/data/_utils.py @@ -57,8 +57,9 @@ def load_image(image_loc: Union[Path, str]) -> "PIL.Image.Image": to start with "http" or "https" e.g. "https://www.wikipedia.org/en/bird.jpg". Raises: - ValueError: If the image cannot be loaded from remote source. - ValueError: If the image cannot be opened as a :class:`~PIL.Image.Image`. + ValueError: + If the image cannot be loaded from remote source, **or** + if the image cannot be opened as a :class:`~PIL.Image.Image`. Examples: >>> # Load from remote source diff --git a/torchtune/datasets/_alpaca.py b/torchtune/datasets/_alpaca.py index a881c149b0..c7795c8f28 100644 --- a/torchtune/datasets/_alpaca.py +++ b/torchtune/datasets/_alpaca.py @@ -12,7 +12,7 @@ from torchtune.datasets._packed import PackedDataset from torchtune.datasets._sft import SFTDataset -from torchtune.modules.tokenizers import ModelTokenizer +from torchtune.modules.transforms.tokenizers import ModelTokenizer def alpaca_dataset( diff --git a/torchtune/datasets/_chat.py b/torchtune/datasets/_chat.py index f126fb3979..1e3962e14b 100644 --- a/torchtune/datasets/_chat.py +++ b/torchtune/datasets/_chat.py @@ -9,7 +9,7 @@ from torchtune.data._messages import OpenAIToMessages, ShareGPTToMessages from torchtune.datasets._packed import PackedDataset from torchtune.datasets._sft import SFTDataset -from torchtune.modules.tokenizers import ModelTokenizer +from torchtune.modules.transforms.tokenizers import ModelTokenizer def chat_dataset( diff --git a/torchtune/datasets/_cnn_dailymail.py b/torchtune/datasets/_cnn_dailymail.py index d3c3af1f93..3995d46b22 100644 --- a/torchtune/datasets/_cnn_dailymail.py +++ b/torchtune/datasets/_cnn_dailymail.py @@ -8,7 +8,7 @@ from torchtune.datasets._text_completion import TextCompletionDataset -from torchtune.modules.tokenizers import ModelTokenizer +from torchtune.modules.transforms.tokenizers import ModelTokenizer def cnn_dailymail_articles_dataset( diff --git a/torchtune/datasets/_concat.py b/torchtune/datasets/_concat.py index bf85bf0939..b6c5093d06 100644 --- a/torchtune/datasets/_concat.py +++ b/torchtune/datasets/_concat.py @@ -67,12 +67,12 @@ class ConcatDataset(Dataset): def __init__(self, datasets: List[Dataset]): self._datasets: List[Dataset] = datasets - for dataset in self._datasets: - if isinstance(dataset, PackedDataset): - raise ValueError( - "ConcatDataset can't process instances of PackedDataset." - ) - + is_packed = [isinstance(dataset, PackedDataset) for dataset in datasets] + if any(is_packed) and not all(is_packed): + raise ValueError( + "ConcatDataset can't process a mix of packed and non-packed datasets." + ) + self.packed = all(is_packed) self._len: int = sum(len(dataset) for dataset in datasets) self._indexes: List[Tuple[int, int, int]] = [] diff --git a/torchtune/datasets/_grammar.py b/torchtune/datasets/_grammar.py index 9e9d700ea6..02970cedef 100644 --- a/torchtune/datasets/_grammar.py +++ b/torchtune/datasets/_grammar.py @@ -10,7 +10,7 @@ from torchtune.data import InputOutputToMessages from torchtune.datasets._packed import PackedDataset from torchtune.datasets._sft import SFTDataset -from torchtune.modules.tokenizers import ModelTokenizer +from torchtune.modules.transforms.tokenizers import ModelTokenizer def grammar_dataset( diff --git a/torchtune/datasets/_hh_rlhf_helpful.py b/torchtune/datasets/_hh_rlhf_helpful.py index e466a8a4fd..8eea7e1a46 100644 --- a/torchtune/datasets/_hh_rlhf_helpful.py +++ b/torchtune/datasets/_hh_rlhf_helpful.py @@ -8,7 +8,7 @@ from torchtune.data import ChosenRejectedToMessages from torchtune.datasets._preference import PreferenceDataset -from torchtune.modules.tokenizers import ModelTokenizer +from torchtune.modules.transforms.tokenizers import ModelTokenizer def hh_rlhf_helpful_dataset( diff --git a/torchtune/datasets/_instruct.py b/torchtune/datasets/_instruct.py index 0dfa46146d..20168aac1d 100644 --- a/torchtune/datasets/_instruct.py +++ b/torchtune/datasets/_instruct.py @@ -9,7 +9,7 @@ from torchtune.data import InputOutputToMessages from torchtune.datasets._packed import PackedDataset from torchtune.datasets._sft import SFTDataset -from torchtune.modules.tokenizers import ModelTokenizer +from torchtune.modules.transforms.tokenizers import ModelTokenizer def instruct_dataset( diff --git a/torchtune/datasets/_preference.py b/torchtune/datasets/_preference.py index dea4eec852..c9615fe93c 100644 --- a/torchtune/datasets/_preference.py +++ b/torchtune/datasets/_preference.py @@ -11,10 +11,10 @@ from torch.utils.data import Dataset from torchtune.data import ChosenRejectedToMessages, CROSS_ENTROPY_IGNORE_IDX - -from torchtune.modules.tokenizers import ModelTokenizer from torchtune.modules.transforms import Transform +from torchtune.modules.transforms.tokenizers import ModelTokenizer + class PreferenceDataset(Dataset): """ @@ -84,7 +84,7 @@ class requires the dataset to have "chosen" and "rejected" model responses. Thes of messages are stored in the ``"chosen"`` and ``"rejected"`` keys. tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. Since PreferenceDataset only supports text data, it requires a - :class:`~torchtune.modules.tokenizers.ModelTokenizer` instead of the ``model_transform`` in + :class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` instead of the ``model_transform`` in :class:`~torchtune.datasets.SFTDataset`. filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See the Hugging Face `docs `_ for more diff --git a/torchtune/datasets/_samsum.py b/torchtune/datasets/_samsum.py index 905911d736..bd7f7dd8eb 100644 --- a/torchtune/datasets/_samsum.py +++ b/torchtune/datasets/_samsum.py @@ -10,7 +10,7 @@ from torchtune.data import InputOutputToMessages from torchtune.datasets._packed import PackedDataset from torchtune.datasets._sft import SFTDataset -from torchtune.modules.tokenizers import ModelTokenizer +from torchtune.modules.transforms.tokenizers import ModelTokenizer def samsum_dataset( diff --git a/torchtune/datasets/_sft.py b/torchtune/datasets/_sft.py index 9ee11244b6..0d1461dd0d 100644 --- a/torchtune/datasets/_sft.py +++ b/torchtune/datasets/_sft.py @@ -69,11 +69,13 @@ class SFTDataset(Dataset): multimodal datasets requires processing the images in a way specific to the vision encoder being used by the model and is agnostic to the specific dataset. - Tokenization is handled by the ``model_transform``. All :class:`~torchtune.modules.tokenizers.ModelTokenizer` - can be treated as a ``model_transform`` since it uses the model-specific tokenizer to - transform the list of messages outputted from the ``message_transform`` into tokens - used by the model for training. Text-only datasets will simply pass the :class:`~torchtune.modules.tokenizers.ModelTokenizer` - into ``model_transform``. Tokenizers handle prompt templating, if configured. + Tokenization is handled by the ``model_transform``. All + :class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` can be treated as + a ``model_transform`` since it uses the model-specific tokenizer to transform the + list of messages outputted from the ``message_transform`` into tokens used by the + model for training. Text-only datasets will simply pass the + :class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` into ``model_transform``. + Tokenizers handle prompt templating, if configured. Args: source (str): path to dataset repository on Hugging Face. For local datasets, diff --git a/torchtune/datasets/_slimorca.py b/torchtune/datasets/_slimorca.py index 126b6b92e4..2701b2d717 100644 --- a/torchtune/datasets/_slimorca.py +++ b/torchtune/datasets/_slimorca.py @@ -10,7 +10,7 @@ from torchtune.datasets._packed import PackedDataset from torchtune.datasets._sft import SFTDataset -from torchtune.modules.tokenizers import ModelTokenizer +from torchtune.modules.transforms.tokenizers import ModelTokenizer def slimorca_dataset( diff --git a/torchtune/datasets/_stack_exchange_paired.py b/torchtune/datasets/_stack_exchange_paired.py index 09eda929fe..a111d415d2 100644 --- a/torchtune/datasets/_stack_exchange_paired.py +++ b/torchtune/datasets/_stack_exchange_paired.py @@ -8,8 +8,8 @@ from torchtune.data import Message from torchtune.datasets._preference import PreferenceDataset -from torchtune.modules.tokenizers import ModelTokenizer from torchtune.modules.transforms import Transform +from torchtune.modules.transforms.tokenizers import ModelTokenizer class StackExchangePairedToMessages(Transform): diff --git a/torchtune/datasets/_text_completion.py b/torchtune/datasets/_text_completion.py index 5b5cc94299..342c6aa816 100644 --- a/torchtune/datasets/_text_completion.py +++ b/torchtune/datasets/_text_completion.py @@ -10,7 +10,7 @@ from torch.utils.data import Dataset from torchtune.data._utils import truncate from torchtune.datasets._packed import PackedDataset -from torchtune.modules.tokenizers import ModelTokenizer +from torchtune.modules.transforms.tokenizers import ModelTokenizer class TextCompletionDataset(Dataset): diff --git a/torchtune/datasets/_wikitext.py b/torchtune/datasets/_wikitext.py index 01111a25c6..4f9ada6741 100644 --- a/torchtune/datasets/_wikitext.py +++ b/torchtune/datasets/_wikitext.py @@ -13,7 +13,7 @@ TextCompletionDataset, ) -from torchtune.modules.tokenizers import ModelTokenizer +from torchtune.modules.transforms.tokenizers import ModelTokenizer def wikitext_dataset( diff --git a/torchtune/generation/_generation.py b/torchtune/generation/_generation.py index b73dd186a8..76d4acb743 100644 --- a/torchtune/generation/_generation.py +++ b/torchtune/generation/_generation.py @@ -94,15 +94,15 @@ def generate_next_token( - tokens (torch.Tensor): tensor with the generated tokens, with shape [bsz x 1]. - logits (torch.Tensor): tensor with the logits associated with the generated tokens, - with shape [bsz x seq_length x vocab_size]. + with shape [bsz x 1 x vocab_size]. """ # model produces logits in [bsz, seq_length, vocab_size] # we want to take the last token's logits as the input to the next model call - logits = model(x, input_pos=input_pos, mask=mask) + logits = model(x, input_pos=input_pos, mask=mask)[:, -1] return ( - sample(logits[:, -1].clone(), temperature=temperature, top_k=top_k, q=q), - logits, + sample(logits.clone(), temperature=temperature, top_k=top_k, q=q), + logits.unsqueeze(1), ) @@ -189,7 +189,7 @@ def get_position_ids_from_padding_mask( return ((padding_mask.cumsum(-1) - 1) * padding_mask).to(torch.int) -@torch.inference_mode() +@torch.no_grad() def generate( model: TransformerDecoder, prompt: torch.Tensor, @@ -241,7 +241,7 @@ def generate( with shape ``[bsz x seq_len + num_generated_tokens]`` where ``num_generated_tokens`` may be less than ``max_generated_tokens`` if ``stop_tokens`` are provided. - logits (torch.Tensor): tensor with the logits associated with the generated tokens, - with shape ``[bsz x seq_len + num_generated_tokens x vocab_size]``. + with shape ``[bsz x num_generated_tokens x vocab_size]``. """ prompt = prompt.view(1, -1) if prompt.ndim == 1 else prompt @@ -355,8 +355,8 @@ def generate( # if incremental decoding is enabled, we can use the current position # otherwise, we take the whole sequence up to the current position if incremental_decoding: - curr_input_pos = input_pos[:, curr_pos] - curr_masks = masks[:, curr_pos, None, :] + curr_input_pos = input_pos[:, curr_pos].contiguous() + curr_masks = masks[:, curr_pos, None, :].contiguous() else: tokens = generated_tokens.clone() curr_input_pos = input_pos[:, : curr_pos + 1] @@ -377,11 +377,8 @@ def generate( q=q, ) generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) + generated_logits = torch.cat([generated_logits, logits], dim=1) curr_pos += 1 - if incremental_decoding: - generated_logits = torch.cat([generated_logits, logits], dim=1) - else: - generated_logits = logits if stop_tokens is not None: stop_token_reached = update_stop_tokens_tracker( @@ -393,6 +390,6 @@ def generate( # mask out generated tokens in seqs that already hit a stop token if stop_tokens is not None: generated_tokens *= stop_token_mask - generated_logits *= stop_token_mask[:, :-1, None] + generated_logits *= stop_token_mask[:, -generated_logits.shape[1] :, None] return generated_tokens, generated_logits diff --git a/torchtune/models/clip/_component_builders.py b/torchtune/models/clip/_component_builders.py index c8d19aae41..edbb31ad32 100644 --- a/torchtune/models/clip/_component_builders.py +++ b/torchtune/models/clip/_component_builders.py @@ -24,7 +24,7 @@ VisionRotaryPositionalEmbeddings, ) from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook -from torchtune.modules.peft import LORA_ATTN_MODULES, DoRALinear, LoRALinear +from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear from torchtune.modules.vision_transformer import CLSProjection, VisionTransformer @@ -43,6 +43,7 @@ def clip_vision_encoder( max_num_tiles: int = 4, in_channels: int = 3, append_cls_token: bool = False, + use_tile_pos_embed: bool = True, ) -> VisionTransformer: """ Builds the vision encoder associated with the clip model. This includes: @@ -78,6 +79,8 @@ def clip_vision_encoder( in_channels (int): The number of image input channels. append_cls_token (bool): If True, adds CLS token embedding to the end of the sequence in the vision transformer. Default is False, which adds CLS token to the beginning of the sequence. + use_tile_pos_embed (bool): If True, use pre-tile, post-tile, and tiled token positional embeddings, if max_num_tiles > 1. + If False, only use standard token positional embeddings. Returns: A `VisionTransformer` object. @@ -89,10 +92,6 @@ def clip_vision_encoder( raise ValueError( f"embed_dim must be divisible by num_heads, got {embed_dim} and {num_heads}" ) - if use_rope and max_num_tiles != 1: - raise ValueError( - f"2D RoPE is only supported for max_num_tiles = 1, got {max_num_tiles}" - ) head_dim = embed_dim // num_heads @@ -105,6 +104,7 @@ def clip_vision_encoder( VisionRotaryPositionalEmbeddings( patch_size=patch_size, tile_size=tile_size, + max_num_tiles=max_num_tiles, dim=head_dim // 2, base=10_000, append_cls_token=append_cls_token, @@ -143,13 +143,7 @@ def clip_vision_encoder( ) # position embeddings - if max_num_tiles == 1: - pre_tile_pos_embed = None - post_tile_pos_embed = None - token_pos_embedding = TokenPositionalEmbedding( - embed_dim=embed_dim, patch_size=patch_size, tile_size=tile_size - ) - else: + if use_tile_pos_embed and max_num_tiles > 1: pre_tile_pos_embed = TilePositionalEmbedding( max_num_tiles=max_num_tiles, embed_dim=embed_dim ) @@ -162,6 +156,12 @@ def clip_vision_encoder( patch_size=patch_size, tile_size=tile_size, ) + else: + pre_tile_pos_embed = None + post_tile_pos_embed = None + token_pos_embedding = TokenPositionalEmbedding( + embed_dim=embed_dim, patch_size=patch_size, tile_size=tile_size + ) return VisionTransformer( num_layers=num_layers, diff --git a/torchtune/models/clip/_position_embeddings.py b/torchtune/models/clip/_position_embeddings.py index 09a98862e1..50488a5380 100644 --- a/torchtune/models/clip/_position_embeddings.py +++ b/torchtune/models/clip/_position_embeddings.py @@ -126,12 +126,13 @@ def _load_state_dict_hook( **kwargs (Dict[str, Any]): Additional keyword arguments. Raises: - ValueError: if loaded local or global embedding n_tokens_per_tile is not derived - from a squared grid. - ValueError: if after interpolation, the shape of the loaded local embedding - is not compatible with the current embedding. - ValueError: if after interpolation, the shape of the loaded global embedding - is not compatible with the current embedding. + ValueError: + If loaded local or global embedding n_tokens_per_tile is not derived + from a squared grid, **or** + if after interpolation, the shape of the loaded local embedding + is not compatible with the current embedding, **or** + if after interpolation, the shape of the loaded global embedding + is not compatible with the current embedding. """ # process local_token_positional_embedding @@ -530,9 +531,10 @@ def _load_state_dict_hook( **kwargs (Dict[str, Any]): Additional keyword arguments. Raises: - ValueError: if the shape of the loaded embedding is not compatible with the current embedding. - ValueError: if max_num_tiles_x, max_num_tiles_y are not equal. - ValueError: if after interpolation, the shape of the loaded embedding is not compatible with the current embedding. + ValueError: + If the shape of the loaded embedding is not compatible with the current embedding, **or** + if ``max_num_tiles_x``, ``max_num_tiles_y`` are not equal, **or** + if after interpolation, the shape of the loaded embedding is not compatible with the current embedding. """ embedding = state_dict.get(prefix + "embedding") diff --git a/torchtune/models/clip/_tokenizer.py b/torchtune/models/clip/_tokenizer.py index 69fed32c72..cdab2c9c05 100644 --- a/torchtune/models/clip/_tokenizer.py +++ b/torchtune/models/clip/_tokenizer.py @@ -7,7 +7,7 @@ import regex as re -from torchtune.modules.tokenizers._utils import BaseTokenizer +from torchtune.modules.transforms.tokenizers._utils import BaseTokenizer WORD_BOUNDARY = "" diff --git a/torchtune/models/gemma/_component_builders.py b/torchtune/models/gemma/_component_builders.py index ba5b666c98..0f02e6111a 100644 --- a/torchtune/models/gemma/_component_builders.py +++ b/torchtune/models/gemma/_component_builders.py @@ -76,33 +76,37 @@ def gemma( TransformerDecoder: Instantiation of gemma model. """ rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) - self_att = MultiHeadAttention( - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), - k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), - v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), - output_proj=nn.Linear(num_heads * head_dim, embed_dim, bias=False), - pos_embeddings=rope, - kv_cache=None, - max_seq_len=max_seq_len, - attn_dropout=attn_dropout, - ) - mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim) - layer = TransformerSelfAttentionLayer( - attn=self_att, - mlp=mlp, - sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), - mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), - ) + + layers = nn.ModuleList() + for _ in range(num_layers): + self_att = MultiHeadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(num_heads * head_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + layer = TransformerSelfAttentionLayer( + attn=self_att, + mlp=mlp, + sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), + mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), + ) + layers.append(layer) + tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim) output_proj = TiedLinear(tok_embeddings) model = TransformerDecoder( tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, + layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, output=output_proj, @@ -186,47 +190,50 @@ def lora_gemma( TransformerDecoder: Instantiation of Gemma model with LoRA applied to a subset of the attention projections in each layer. """ - self_attn = lora_gemma_self_attention( - lora_modules=lora_attn_modules, - embed_dim=embed_dim, - head_dim=head_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - max_seq_len=max_seq_len, - attn_dropout=attn_dropout, - rope_base=rope_base, - lora_rank=lora_rank, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - use_dora=use_dora, - quantize_base=quantize_base, - ) - - if apply_lora_to_mlp: - mlp = lora_gemma_mlp( - dim=embed_dim, - hidden_dim=intermediate_dim, + layers = nn.ModuleList() + for _ in range(num_layers): + self_attn = lora_gemma_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + head_dim=head_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + rope_base=rope_base, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, ) - else: - mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) - layer = TransformerSelfAttentionLayer( - attn=self_attn, - mlp=mlp, - sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), - mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), - ) + if apply_lora_to_mlp: + mlp = lora_gemma_mlp( + dim=embed_dim, + hidden_dim=intermediate_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + else: + mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) + + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), + mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), + ) + layers.append(layer) + tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim) output_proj = TiedLinear(tok_embeddings) model = TransformerDecoder( tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, + layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, output=output_proj, diff --git a/torchtune/models/gemma/_tokenizer.py b/torchtune/models/gemma/_tokenizer.py index e5eb89e230..dc5d2eadf8 100644 --- a/torchtune/models/gemma/_tokenizer.py +++ b/torchtune/models/gemma/_tokenizer.py @@ -7,12 +7,12 @@ from typing import Any, List, Mapping, Optional, Tuple from torchtune.data import Message, PromptTemplate -from torchtune.modules.tokenizers import ( +from torchtune.modules.transforms import Transform +from torchtune.modules.transforms.tokenizers import ( ModelTokenizer, SentencePieceBaseTokenizer, tokenize_messages_no_special_tokens, ) -from torchtune.modules.transforms import Transform WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"] diff --git a/torchtune/models/gemma2/_attention.py b/torchtune/models/gemma2/_attention.py index 1cd3bfdc12..0ed71b34d6 100644 --- a/torchtune/models/gemma2/_attention.py +++ b/torchtune/models/gemma2/_attention.py @@ -50,10 +50,11 @@ class Gemma2Attention(nn.Module): softcapping (Optional[float]): capping value used for soft caping, if None no capping is performed query_pre_attn_scalar (Optional[int]): value used for pre attention normalisation, if None head_dim is used instead Raises: - ValueError: If ``num_heads % num_kv_heads != 0`` - ValueError: If ``embed_dim % num_heads != 0`` - ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` - ValueError: if q_norm is defined without k_norm or vice versa + ValueError: + If ``num_heads % num_kv_heads != 0``, **or** + if ``embed_dim % num_heads != 0``, **or** + if ``attn_dropout < 0`` or ``attn_dropout > 1``, **or** + if ``q_norm`` is defined without k_norm or vice versa """ def __init__( @@ -156,7 +157,11 @@ def setup_cache( self.cache_enabled = True def reset_cache(self): - """Reset the key value caches.""" + """Reset the key value caches. + + Raises: + RuntimeError: if key value caches are not already setup. + """ if self.kv_cache is None: raise RuntimeError( "Key value caches are not setup. Call ``setup_caches()`` first." @@ -196,6 +201,7 @@ def forward( If none, assume the index of the token is its position id. Default is None. Raises: + NotImplementedError: If ``mask`` is provided, but mask is not an instance of ``torch.Tensor``. ValueError: If no ``y`` input and ``kv_cache`` is not enabled. Returns: diff --git a/torchtune/models/gemma2/_component_builders.py b/torchtune/models/gemma2/_component_builders.py index 0ddef36857..f276dc1bed 100644 --- a/torchtune/models/gemma2/_component_builders.py +++ b/torchtune/models/gemma2/_component_builders.py @@ -4,9 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torch import nn import torch -from typing import List +from torch import nn from torchtune.modules.common_utils import _register_reparametrize_state_dict_hooks from typing import List, Optional @@ -116,7 +115,6 @@ def gemma2( rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) layers = torch.nn.ModuleList() - for layer_idx in range(num_layers): mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim) @@ -149,6 +147,7 @@ def gemma2( mlp_scale=GemmaRMSNorm(embed_dim, eps=norm_eps), ) layers.append(layer) + tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim) output_proj = TiedLinear(tok_embeddings) model = TransformerDecoder( @@ -231,8 +230,7 @@ def lora_gemma2( tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim) output_proj = TiedLinear(tok_embeddings) - layers = torch.nn.ModuleList() - + layers = nn.ModuleList() for layer_idx in range(num_layers): if apply_lora_to_mlp: mlp = lora_gemma_mlp( diff --git a/torchtune/models/llama2/_component_builders.py b/torchtune/models/llama2/_component_builders.py index 12f04c93c8..e74fbad5b1 100644 --- a/torchtune/models/llama2/_component_builders.py +++ b/torchtune/models/llama2/_component_builders.py @@ -81,40 +81,43 @@ def llama2( """ head_dim = embed_dim // num_heads num_kv_heads = num_kv_heads if num_kv_heads else num_heads - - rope = RotaryPositionalEmbeddings( - dim=head_dim, max_seq_len=max_seq_len, base=rope_base - ) - self_attn = MultiHeadAttention( - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), - k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), - v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), - output_proj=nn.Linear(embed_dim, embed_dim, bias=False), - pos_embeddings=rope, - kv_cache=None, - max_seq_len=max_seq_len, - attn_dropout=attn_dropout, - ) hidden_dim = ( intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) ) - mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) - layer = TransformerSelfAttentionLayer( - attn=self_attn, - mlp=mlp, - sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + + rope = RotaryPositionalEmbeddings( + dim=head_dim, max_seq_len=max_seq_len, base=rope_base ) + layers = nn.ModuleList() + for _ in range(num_layers): + self_attn = MultiHeadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + layers.append(layer) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) output_proj = nn.Linear(embed_dim, vocab_size, bias=False) return TransformerDecoder( tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, + layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=head_dim, @@ -212,45 +215,48 @@ def lora_llama2( a subset of the attention projections in each layer. """ - - self_attn = lora_llama2_self_attention( - lora_modules=lora_attn_modules, - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - max_seq_len=max_seq_len, - attn_dropout=attn_dropout, - lora_rank=lora_rank, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - use_dora=use_dora, - quantize_base=quantize_base, - ) - hidden_dim = ( intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) ) - if apply_lora_to_mlp: - mlp = lora_llama2_mlp( - dim=embed_dim, - hidden_dim=hidden_dim, + + layers = nn.ModuleList() + for _ in range(num_layers): + self_attn = lora_llama2_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, lora_rank=lora_rank, lora_alpha=lora_alpha, - quantize_base=quantize_base, - use_dora=use_dora, lora_dropout=lora_dropout, - ) - else: - mlp = llama2_mlp( - dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base + use_dora=use_dora, + quantize_base=quantize_base, ) - layer = TransformerSelfAttentionLayer( - attn=self_attn, - mlp=mlp, - sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - ) + if apply_lora_to_mlp: + mlp = lora_llama2_mlp( + dim=embed_dim, + hidden_dim=hidden_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + quantize_base=quantize_base, + use_dora=use_dora, + lora_dropout=lora_dropout, + ) + else: + mlp = llama2_mlp( + dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base + ) + + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + layers.append(layer) tok_embeddings = nn.Embedding(vocab_size, embed_dim) @@ -269,8 +275,7 @@ def lora_llama2( ) model = TransformerDecoder( tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, + layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=(embed_dim // num_heads), @@ -511,38 +516,42 @@ def llama2_classifier( """ head_dim = embed_dim // num_heads num_kv_heads = num_kv_heads if num_kv_heads else num_heads - - rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) - self_attn = MultiHeadAttention( - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), - k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), - v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), - output_proj=nn.Linear(embed_dim, embed_dim, bias=False), - pos_embeddings=rope, - kv_cache=None, - max_seq_len=max_seq_len, - attn_dropout=attn_dropout, - ) hidden_dim = ( intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) ) - mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) - layer = TransformerSelfAttentionLayer( - attn=self_attn, - mlp=mlp, - sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - ) + + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + + layers = nn.ModuleList() + for _ in range(num_layers): + self_attn = MultiHeadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + layers.append(layer) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) output_proj = nn.Linear(embed_dim, num_classes, bias=False) return TransformerDecoder( tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, + layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=head_dim, @@ -616,43 +625,46 @@ def lora_llama2_classifier( a subset of the attention projections in each layer. """ - - self_attn = lora_llama2_self_attention( - lora_modules=lora_attn_modules, - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - max_seq_len=max_seq_len, - attn_dropout=attn_dropout, - lora_rank=lora_rank, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - use_dora=use_dora, - quantize_base=quantize_base, - ) - hidden_dim = ( intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) ) - if apply_lora_to_mlp: - mlp = lora_llama2_mlp( - dim=embed_dim, - hidden_dim=hidden_dim, + + layers = nn.ModuleList() + for _ in range(num_layers): + self_attn = lora_llama2_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, lora_rank=lora_rank, lora_alpha=lora_alpha, - quantize_base=quantize_base, - use_dora=use_dora, lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, ) - else: - mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) - layer = TransformerSelfAttentionLayer( - attn=self_attn, - mlp=mlp, - sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - ) + if apply_lora_to_mlp: + mlp = lora_llama2_mlp( + dim=embed_dim, + hidden_dim=hidden_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + quantize_base=quantize_base, + use_dora=use_dora, + lora_dropout=lora_dropout, + ) + else: + mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) + + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + layers.append(layer) tok_embeddings = nn.Embedding(vocab_size, embed_dim) @@ -671,8 +683,7 @@ def lora_llama2_classifier( ) model = TransformerDecoder( tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, + layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=(embed_dim // num_heads), diff --git a/torchtune/models/llama2/_tokenizer.py b/torchtune/models/llama2/_tokenizer.py index 078494c531..4e2ab6a40c 100644 --- a/torchtune/models/llama2/_tokenizer.py +++ b/torchtune/models/llama2/_tokenizer.py @@ -8,12 +8,12 @@ from torchtune.data import Message, PromptTemplate from torchtune.models.llama2._prompt_template import Llama2ChatTemplate -from torchtune.modules.tokenizers import ( +from torchtune.modules.transforms import Transform +from torchtune.modules.transforms.tokenizers import ( ModelTokenizer, SentencePieceBaseTokenizer, tokenize_messages_no_special_tokens, ) -from torchtune.modules.transforms import Transform WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"] diff --git a/torchtune/models/llama3/__init__.py b/torchtune/models/llama3/__init__.py index 90de8c286f..5cf4e6b616 100644 --- a/torchtune/models/llama3/__init__.py +++ b/torchtune/models/llama3/__init__.py @@ -15,6 +15,7 @@ qlora_llama3_70b, qlora_llama3_8b, ) +from ._parallelism import base_llama_tp_plan from ._tokenizer import Llama3Tokenizer __all__ = [ @@ -28,4 +29,5 @@ "lora_llama3_70b", "qlora_llama3_8b", "qlora_llama3_70b", + "base_llama_tp_plan", ] diff --git a/torchtune/models/llama3/_component_builders.py b/torchtune/models/llama3/_component_builders.py index ca3c1c34bc..0ba7a97257 100644 --- a/torchtune/models/llama3/_component_builders.py +++ b/torchtune/models/llama3/_component_builders.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from functools import partial from typing import List, Optional from torch import nn @@ -83,38 +82,42 @@ def llama3( """ head_dim = embed_dim // num_heads num_kv_heads = num_kv_heads if num_kv_heads else num_heads - rope = RotaryPositionalEmbeddings( - dim=head_dim, max_seq_len=max_seq_len, base=rope_base - ) - self_attn = MultiHeadAttention( - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), - k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), - v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), - output_proj=nn.Linear(embed_dim, embed_dim, bias=False), - pos_embeddings=rope, - max_seq_len=max_seq_len, - attn_dropout=attn_dropout, - ) hidden_dim = ( intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) ) - mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim) - layer = TransformerSelfAttentionLayer( - attn=self_attn, - mlp=mlp, - sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + rope = RotaryPositionalEmbeddings( + dim=head_dim, max_seq_len=max_seq_len, base=rope_base ) + + layers = nn.ModuleList() + for _ in range(num_layers): + self_attn = MultiHeadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim) + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + layers.append(layer) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) output_proj = nn.Linear(embed_dim, vocab_size, bias=False) return TransformerDecoder( tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, + layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=head_dim, @@ -213,46 +216,49 @@ def lora_llama3( a subset of the attention projections in each layer. """ - - self_attn = lora_llama3_self_attention( - lora_modules=lora_attn_modules, - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - max_seq_len=max_seq_len, - attn_dropout=attn_dropout, - rope_base=rope_base, - lora_rank=lora_rank, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - quantize_base=quantize_base, - use_dora=use_dora, - ) - hidden_dim = ( intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) ) - if apply_lora_to_mlp: - mlp = lora_llama3_mlp( - dim=embed_dim, - hidden_dim=hidden_dim, + + layers = nn.ModuleList() + for _ in range(num_layers): + self_attn = lora_llama3_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + rope_base=rope_base, lora_rank=lora_rank, lora_alpha=lora_alpha, - quantize_base=quantize_base, lora_dropout=lora_dropout, + quantize_base=quantize_base, use_dora=use_dora, ) - else: - mlp = llama3_mlp( - dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base - ) - layer = TransformerSelfAttentionLayer( - attn=self_attn, - mlp=mlp, - sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - ) + if apply_lora_to_mlp: + mlp = lora_llama3_mlp( + dim=embed_dim, + hidden_dim=hidden_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + quantize_base=quantize_base, + lora_dropout=lora_dropout, + use_dora=use_dora, + ) + else: + mlp = llama3_mlp( + dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base + ) + + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + layers.append(layer) tok_embeddings = nn.Embedding(vocab_size, embed_dim) @@ -271,8 +277,7 @@ def lora_llama3( ) model = TransformerDecoder( tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, + layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=(embed_dim // num_heads), diff --git a/torchtune/models/llama3/_model_builders.py b/torchtune/models/llama3/_model_builders.py index 0ddca90189..6c13e37cff 100644 --- a/torchtune/models/llama3/_model_builders.py +++ b/torchtune/models/llama3/_model_builders.py @@ -13,7 +13,7 @@ from torchtune.modules import TransformerDecoder from torchtune.modules.peft import LORA_ATTN_MODULES -from torchtune.modules.tokenizers import parse_hf_tokenizer_json +from torchtune.modules.transforms.tokenizers import parse_hf_tokenizer_json """ diff --git a/torchtune/models/llama3/_parallelism.py b/torchtune/models/llama3/_parallelism.py new file mode 100644 index 0000000000..6046f0e83c --- /dev/null +++ b/torchtune/models/llama3/_parallelism.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +from torch.distributed._tensor import Replicate +from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel +from torch.distributed.tensor.parallel.style import ParallelStyle + + +# Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models +BASE_LLAMA_TP_PLAN = { + "tok_embeddings": RowwiseParallel(input_layouts=Replicate()), + "output": ColwiseParallel(output_layouts=Replicate()), + "layers.*.attn.q_proj": ColwiseParallel(), + "layers.*.attn.k_proj": ColwiseParallel(), + "layers.*.attn.v_proj": ColwiseParallel(), + "layers.*.attn.output_proj": RowwiseParallel(), + "layers.*.mlp.w1": ColwiseParallel(), + "layers.*.mlp.w2": RowwiseParallel(), + "layers.*.mlp.w3": ColwiseParallel(), +} + + +def base_llama_tp_plan() -> Dict[str, ParallelStyle]: + """ + Helper function to get the base tensor parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models + + Returns: + Dict[str, Any]: The tensor parallel plan for Llama3 model. + """ + return BASE_LLAMA_TP_PLAN diff --git a/torchtune/models/llama3/_tokenizer.py b/torchtune/models/llama3/_tokenizer.py index 50ea0a7581..012aa9f584 100644 --- a/torchtune/models/llama3/_tokenizer.py +++ b/torchtune/models/llama3/_tokenizer.py @@ -8,8 +8,11 @@ from typing import Any, Dict, List, Mapping, Optional, Tuple from torchtune.data import Message, PromptTemplate, truncate -from torchtune.modules.tokenizers import ModelTokenizer, TikTokenBaseTokenizer from torchtune.modules.transforms import Transform +from torchtune.modules.transforms.tokenizers import ( + ModelTokenizer, + TikTokenBaseTokenizer, +) CL100K_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # noqa diff --git a/torchtune/models/llama3_1/_component_builders.py b/torchtune/models/llama3_1/_component_builders.py index ba9e791dc1..3fc2431c12 100644 --- a/torchtune/models/llama3_1/_component_builders.py +++ b/torchtune/models/llama3_1/_component_builders.py @@ -83,7 +83,8 @@ def llama3_1( head_dim = embed_dim // num_heads num_kv_heads = num_kv_heads if num_kv_heads else num_heads rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base, scale_factor=scale_factor) - layers = [] + + layers = nn.ModuleList() for _ in range(num_layers): self_attn = MultiHeadAttention( embed_dim=embed_dim, @@ -107,7 +108,6 @@ def llama3_1( mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), ) layers.append(layer) - layers = nn.ModuleList(layers) tok_embeddings = nn.Embedding(vocab_size, embed_dim) output_proj = nn.Linear(embed_dim, vocab_size, bias=False) @@ -206,7 +206,8 @@ def lora_llama3_1( hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) head_dim = embed_dim // num_heads rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base, scale_factor=scale_factor) - layers = [] + + layers = nn.ModuleList() for _ in range(num_layers): self_attn = lora_llama3_attention( lora_modules=lora_attn_modules, @@ -244,7 +245,7 @@ def lora_llama3_1( mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), ) layers.append(layer) - layers = nn.ModuleList(layers) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) # TODO: quantize_base is not applied to final output_proj currently. diff --git a/torchtune/models/llama3_1/_model_builders.py b/torchtune/models/llama3_1/_model_builders.py index b6439b2eb2..f48ce580f5 100644 --- a/torchtune/models/llama3_1/_model_builders.py +++ b/torchtune/models/llama3_1/_model_builders.py @@ -73,7 +73,7 @@ def llama3_1_405b() -> TransformerDecoder: num_heads=128, num_kv_heads=8, embed_dim=16384, - max_seq_len=8192, + max_seq_len=131072, intermediate_dim=53248, attn_dropout=0.0, norm_eps=1e-5, @@ -236,7 +236,7 @@ def lora_llama3_1_405b( num_heads=128, num_kv_heads=8, embed_dim=16384, - max_seq_len=8192, + max_seq_len=131072, intermediate_dim=53248, attn_dropout=0.0, norm_eps=1e-5, diff --git a/torchtune/models/llama3_2/_component_builders.py b/torchtune/models/llama3_2/_component_builders.py index 37c1f0df95..1817a10a20 100644 --- a/torchtune/models/llama3_2/_component_builders.py +++ b/torchtune/models/llama3_2/_component_builders.py @@ -85,7 +85,8 @@ def llama3_2( head_dim = embed_dim // num_heads num_kv_heads = num_kv_heads if num_kv_heads else num_heads rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base, scale_factor=scale_factor) - layers = [] + + layers = nn.ModuleList() for _ in range(num_layers): self_attn = MultiHeadAttention( embed_dim=embed_dim, @@ -109,7 +110,6 @@ def llama3_2( mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), ) layers.append(layer) - layers = nn.ModuleList(layers) tok_embeddings = nn.Embedding(vocab_size, embed_dim) output_proj = TiedLinear(tok_embeddings) @@ -207,7 +207,8 @@ def lora_llama3_2( hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) head_dim = embed_dim // num_heads rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base, scale_factor=scale_factor) - layers = [] + + layers = nn.ModuleList() for _ in range(num_layers): self_attn = lora_llama3_2_self_attention( lora_modules=lora_attn_modules, @@ -245,7 +246,7 @@ def lora_llama3_2( mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), ) layers.append(layer) - layers = nn.ModuleList(layers) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) if apply_lora_to_output: diff --git a/torchtune/models/llama3_2_vision/_component_builders.py b/torchtune/models/llama3_2_vision/_component_builders.py index b83648eff0..dd48e6b337 100644 --- a/torchtune/models/llama3_2_vision/_component_builders.py +++ b/torchtune/models/llama3_2_vision/_component_builders.py @@ -179,9 +179,9 @@ def llama3_2_vision_decoder( head_dim = embed_dim // num_heads num_kv_heads = num_kv_heads if num_kv_heads else num_heads hidden_dim = intermediate_dim or scale_hidden_dim_for_mlp(embed_dim) - layers = [] - rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + + layers = nn.ModuleList() for idx in range(1, num_layers + 1): # Self attention layers for text decoder @@ -543,9 +543,9 @@ def lora_llama3_2_vision_decoder( head_dim = embed_dim // num_heads num_kv_heads = num_kv_heads if num_kv_heads else num_heads hidden_dim = intermediate_dim or scale_hidden_dim_for_mlp(embed_dim) - layers = [] - rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + + layers = nn.ModuleList() for idx in range(1, num_layers + 1): # Self attention layers for text decoder diff --git a/torchtune/models/llama3_2_vision/_model_builders.py b/torchtune/models/llama3_2_vision/_model_builders.py index f9da10095a..4f035f92c5 100644 --- a/torchtune/models/llama3_2_vision/_model_builders.py +++ b/torchtune/models/llama3_2_vision/_model_builders.py @@ -20,7 +20,6 @@ from torchtune.models.llama3_2_vision._transform import Llama3VisionTransform from torchtune.modules.model_fusion import DeepFusionModel from torchtune.modules.peft import LORA_ATTN_MODULES -from torchtune.modules.tokenizers import parse_hf_tokenizer_json def llama3_2_vision_transform( diff --git a/torchtune/models/llama3_2_vision/_transform.py b/torchtune/models/llama3_2_vision/_transform.py index eaf627d027..534ed4ab1c 100644 --- a/torchtune/models/llama3_2_vision/_transform.py +++ b/torchtune/models/llama3_2_vision/_transform.py @@ -10,8 +10,8 @@ from torchtune.models.clip import CLIPImageTransform from torchtune.models.llama3 import llama3_tokenizer -from torchtune.modules.tokenizers import ModelTokenizer from torchtune.modules.transforms import Transform, VisionCrossAttentionMask +from torchtune.modules.transforms.tokenizers import ModelTokenizer class Llama3VisionTransform(ModelTokenizer, Transform): diff --git a/torchtune/models/mistral/_component_builders.py b/torchtune/models/mistral/_component_builders.py index e848b116d5..73c3cce7ac 100644 --- a/torchtune/models/mistral/_component_builders.py +++ b/torchtune/models/mistral/_component_builders.py @@ -81,33 +81,36 @@ def mistral( rope = RotaryPositionalEmbeddings( dim=head_dim, max_seq_len=max_seq_len, base=rope_base ) - self_attn = MultiHeadAttention( - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), - k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), - v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), - output_proj=nn.Linear(embed_dim, embed_dim, bias=False), - pos_embeddings=rope, - kv_cache=None, - max_seq_len=max_seq_len, - attn_dropout=attn_dropout, - ) - mlp = mistral_mlp(dim=embed_dim, hidden_dim=intermediate_dim) - layer = TransformerSelfAttentionLayer( - attn=self_attn, - mlp=mlp, - sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - ) + layers = nn.ModuleList() + for _ in range(num_layers): + self_attn = MultiHeadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + mlp = mistral_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + layers.append(layer) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) output_proj = nn.Linear(embed_dim, vocab_size, bias=False) return TransformerDecoder( tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, + layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=head_dim, @@ -201,43 +204,45 @@ def lora_mistral( a subset of the attention projections in each layer. """ - - self_attn = lora_mistral_self_attention( - lora_modules=lora_attn_modules, - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - max_seq_len=max_seq_len, - attn_dropout=attn_dropout, - rope_base=rope_base, - lora_rank=lora_rank, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - use_dora=use_dora, - quantize_base=quantize_base, - ) - - if apply_lora_to_mlp: - mlp = lora_mistral_mlp( - dim=embed_dim, - hidden_dim=intermediate_dim, + layers = nn.ModuleList() + for _ in range(num_layers): + self_attn = lora_mistral_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + rope_base=rope_base, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, ) - else: - mlp = mistral_mlp( - dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base - ) - layer = TransformerSelfAttentionLayer( - attn=self_attn, - mlp=mlp, - sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - ) + if apply_lora_to_mlp: + mlp = lora_mistral_mlp( + dim=embed_dim, + hidden_dim=intermediate_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + else: + mlp = mistral_mlp( + dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base + ) + + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + layers.append(layer) tok_embeddings = nn.Embedding(vocab_size, embed_dim) @@ -250,8 +255,7 @@ def lora_mistral( ) model = TransformerDecoder( tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, + layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=(embed_dim // num_heads), @@ -499,33 +503,36 @@ def mistral_classifier( rope = RotaryPositionalEmbeddings( dim=head_dim, max_seq_len=max_seq_len, base=rope_base ) - self_attn = MultiHeadAttention( - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), - k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), - v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), - output_proj=nn.Linear(embed_dim, embed_dim, bias=False), - pos_embeddings=rope, - kv_cache=None, - max_seq_len=max_seq_len, - attn_dropout=attn_dropout, - ) - mlp = mistral_mlp(dim=embed_dim, hidden_dim=intermediate_dim) - layer = TransformerSelfAttentionLayer( - attn=self_attn, - mlp=mlp, - sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - ) + layers = nn.ModuleList() + for _ in range(num_layers): + self_attn = MultiHeadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + mlp = mistral_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + layers.append(layer) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) output_proj = nn.Linear(embed_dim, num_classes, bias=False) return TransformerDecoder( tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, + layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=head_dim, @@ -600,41 +607,43 @@ def lora_mistral_classifier( a subset of the attention projections in each layer. """ - - self_attn = lora_mistral_self_attention( - lora_modules=lora_attn_modules, - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - max_seq_len=max_seq_len, - attn_dropout=attn_dropout, - rope_base=rope_base, - lora_rank=lora_rank, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - use_dora=use_dora, - quantize_base=quantize_base, - ) - - if apply_lora_to_mlp: - mlp = lora_mistral_mlp( - dim=embed_dim, - hidden_dim=intermediate_dim, + layers = nn.ModuleList() + for _ in range(num_layers): + self_attn = lora_mistral_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + rope_base=rope_base, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, ) - else: - mlp = mistral_mlp(dim=embed_dim, hidden_dim=intermediate_dim) - layer = TransformerSelfAttentionLayer( - attn=self_attn, - mlp=mlp, - sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - ) + if apply_lora_to_mlp: + mlp = lora_mistral_mlp( + dim=embed_dim, + hidden_dim=intermediate_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + else: + mlp = mistral_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + layers.append(layer) tok_embeddings = nn.Embedding(vocab_size, embed_dim) @@ -653,8 +662,7 @@ def lora_mistral_classifier( ) model = TransformerDecoder( tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, + layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=(embed_dim // num_heads), diff --git a/torchtune/models/mistral/_tokenizer.py b/torchtune/models/mistral/_tokenizer.py index c3bbc8a4a7..49617220c3 100644 --- a/torchtune/models/mistral/_tokenizer.py +++ b/torchtune/models/mistral/_tokenizer.py @@ -8,12 +8,12 @@ from torchtune.data import Message, PromptTemplate from torchtune.models.mistral._prompt_template import MistralChatTemplate -from torchtune.modules.tokenizers import ( +from torchtune.modules.transforms import Transform +from torchtune.modules.transforms.tokenizers import ( ModelTokenizer, SentencePieceBaseTokenizer, tokenize_messages_no_special_tokens, ) -from torchtune.modules.transforms import Transform WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"] diff --git a/torchtune/models/phi3/_component_builders.py b/torchtune/models/phi3/_component_builders.py index 49c0f3ca84..e961825d79 100644 --- a/torchtune/models/phi3/_component_builders.py +++ b/torchtune/models/phi3/_component_builders.py @@ -70,33 +70,35 @@ def phi3( num_kv_heads = num_kv_heads if num_kv_heads else num_heads rope = Phi3RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) - self_attn = MultiHeadAttention( - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), - k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), - v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), - output_proj=nn.Linear(embed_dim, embed_dim, bias=False), - pos_embeddings=rope, - kv_cache=None, - max_seq_len=max_seq_len, - attn_dropout=attn_dropout, - ) - mlp = phi3_mlp(dim=embed_dim, hidden_dim=intermediate_dim) - layer = TransformerSelfAttentionLayer( - attn=self_attn, - mlp=mlp, - sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - ) + layers = nn.ModuleList() + for _ in range(num_layers): + self_attn = MultiHeadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + mlp = phi3_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + layers.append(layer) tok_embeddings = nn.Embedding(vocab_size, embed_dim) output_proj = nn.Linear(embed_dim, vocab_size, bias=False) return TransformerDecoder( tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, + layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=head_dim, @@ -183,41 +185,43 @@ def lora_phi3( a subset of the attention projections in each layer. """ - - self_attn = lora_phi3_self_attention( - lora_modules=lora_attn_modules, - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - max_seq_len=max_seq_len, - attn_dropout=attn_dropout, - rope_base=rope_base, - lora_rank=lora_rank, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - use_dora=use_dora, - quantize_base=quantize_base, - ) - - if apply_lora_to_mlp: - mlp = lora_phi3_mlp( - dim=embed_dim, - hidden_dim=intermediate_dim, + layers = nn.ModuleList() + for _ in range(num_layers): + self_attn = lora_phi3_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + rope_base=rope_base, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, ) - else: - mlp = phi3_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) - - layer = TransformerSelfAttentionLayer( - attn=self_attn, - mlp=mlp, - sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - ) + + if apply_lora_to_mlp: + mlp = lora_phi3_mlp( + dim=embed_dim, + hidden_dim=intermediate_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + else: + mlp = phi3_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) + + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + layers.append(layer) tok_embeddings = nn.Embedding(vocab_size, embed_dim) @@ -230,8 +234,7 @@ def lora_phi3( ) model = TransformerDecoder( tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, + layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=(embed_dim // num_heads), diff --git a/torchtune/models/phi3/_model_builders.py b/torchtune/models/phi3/_model_builders.py index 91d42623d7..e1275df783 100644 --- a/torchtune/models/phi3/_model_builders.py +++ b/torchtune/models/phi3/_model_builders.py @@ -6,7 +6,7 @@ from torchtune.modules import TransformerDecoder from torchtune.modules.peft import LORA_ATTN_MODULES from functools import partial -from torchtune.modules.tokenizers import parse_hf_tokenizer_json +from torchtune.modules.transforms.tokenizers import parse_hf_tokenizer_json from torchtune.data._prompt_templates import _TemplateType from torchtune.data._prompt_templates import _get_prompt_template diff --git a/torchtune/models/phi3/_tokenizer.py b/torchtune/models/phi3/_tokenizer.py index b48b1d93a3..44f66b5934 100644 --- a/torchtune/models/phi3/_tokenizer.py +++ b/torchtune/models/phi3/_tokenizer.py @@ -9,8 +9,11 @@ from torchtune.data._messages import Message from torchtune.data._prompt_templates import PromptTemplate from torchtune.data._utils import truncate -from torchtune.modules.tokenizers import ModelTokenizer, SentencePieceBaseTokenizer from torchtune.modules.transforms import Transform +from torchtune.modules.transforms.tokenizers import ( + ModelTokenizer, + SentencePieceBaseTokenizer, +) PHI3_SPECIAL_TOKENS = { "<|endoftext|>": 32000, @@ -157,6 +160,7 @@ def tokenize_messages( Raises: ValueError: If the role is not "user", "assistant", or "system". + RuntimeError: If ``message["type"] != "text``. Returns: Tuple[List[int], List[bool]]: The tokenized messages diff --git a/torchtune/models/qwen2/_component_builders.py b/torchtune/models/qwen2/_component_builders.py index 716fe337ad..c24683cfea 100644 --- a/torchtune/models/qwen2/_component_builders.py +++ b/torchtune/models/qwen2/_component_builders.py @@ -82,27 +82,32 @@ def qwen2( num_kv_heads = num_kv_heads if num_kv_heads else num_heads rope = Qwen2RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) - self_attn = MultiHeadAttention( - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=True), - k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=True), - v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=True), - output_proj=nn.Linear(embed_dim, embed_dim, bias=False), - pos_embeddings=rope, - kv_cache=None, - max_seq_len=max_seq_len, - attn_dropout=attn_dropout, - ) - mlp = qwen2_mlp(dim=embed_dim, hidden_dim=intermediate_dim) - layer = TransformerSelfAttentionLayer( - attn=self_attn, - mlp=mlp, - sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - ) + + layers = nn.ModuleList() + for _ in range(num_layers): + self_attn = MultiHeadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=True), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=True), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=True), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + mlp = qwen2_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + layers.append(layer) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) if tie_word_embeddings: output_proj = TiedLinear(tok_embeddings) @@ -110,8 +115,7 @@ def qwen2( output_proj = nn.Linear(embed_dim, vocab_size, bias=False) return TransformerDecoder( tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, + layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=head_dim, @@ -199,41 +203,43 @@ def lora_qwen2( ValueError: if ``apply_lora_to_output`` and ``tie_word_embeddings``. """ - - self_attn = lora_qwen2_self_attention( - lora_modules=lora_attn_modules, - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - max_seq_len=max_seq_len, - attn_dropout=attn_dropout, - rope_base=rope_base, - lora_rank=lora_rank, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - use_dora=use_dora, - quantize_base=quantize_base, - ) - - if apply_lora_to_mlp: - mlp = lora_qwen2_mlp( - dim=embed_dim, - hidden_dim=intermediate_dim, + layers = nn.ModuleList() + for _ in range(num_layers): + self_attn = lora_qwen2_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + rope_base=rope_base, lora_rank=lora_rank, lora_alpha=lora_alpha, - quantize_base=quantize_base, - use_dora=use_dora, lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, ) - else: - mlp = qwen2_mlp(dim=embed_dim, hidden_dim=intermediate_dim) - layer = TransformerSelfAttentionLayer( - attn=self_attn, - mlp=mlp, - sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), - ) + if apply_lora_to_mlp: + mlp = lora_qwen2_mlp( + dim=embed_dim, + hidden_dim=intermediate_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + quantize_base=quantize_base, + use_dora=use_dora, + lora_dropout=lora_dropout, + ) + else: + mlp = qwen2_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + layers.append(layer) tok_embeddings = nn.Embedding(vocab_size, embed_dim) @@ -254,8 +260,7 @@ def lora_qwen2( ) model = TransformerDecoder( tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, + layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=(embed_dim // num_heads), diff --git a/torchtune/models/qwen2/_model_builders.py b/torchtune/models/qwen2/_model_builders.py index 2a0ee06f83..f1ca5b8506 100644 --- a/torchtune/models/qwen2/_model_builders.py +++ b/torchtune/models/qwen2/_model_builders.py @@ -11,7 +11,7 @@ from torchtune.models.qwen2._tokenizer import QWEN2_SPECIAL_TOKENS, Qwen2Tokenizer from torchtune.modules import TransformerDecoder from torchtune.modules.peft import LORA_ATTN_MODULES -from torchtune.modules.tokenizers import parse_hf_tokenizer_json +from torchtune.modules.transforms.tokenizers import parse_hf_tokenizer_json """ Model builders build specific instantiations using component builders. For example diff --git a/torchtune/models/qwen2/_tokenizer.py b/torchtune/models/qwen2/_tokenizer.py index 0e4ee6bd35..dd6d038003 100644 --- a/torchtune/models/qwen2/_tokenizer.py +++ b/torchtune/models/qwen2/_tokenizer.py @@ -11,7 +11,7 @@ import regex as re from torchtune.data import ChatMLTemplate, Message, PromptTemplate, truncate -from torchtune.modules.tokenizers import ModelTokenizer +from torchtune.modules.transforms.tokenizers import ModelTokenizer PRETOKENIZE_REGEX = ( r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|" diff --git a/torchtune/models/qwen2_5/_model_builders.py b/torchtune/models/qwen2_5/_model_builders.py index 7d39802375..716ae48329 100644 --- a/torchtune/models/qwen2_5/_model_builders.py +++ b/torchtune/models/qwen2_5/_model_builders.py @@ -11,7 +11,7 @@ from torchtune.models.qwen2_5._tokenizer import QWEN2_5_SPECIAL_TOKENS, Qwen2_5Tokenizer from torchtune.modules import TransformerDecoder from torchtune.modules.peft import LORA_ATTN_MODULES -from torchtune.modules.tokenizers import parse_hf_tokenizer_json +from torchtune.modules.transforms.tokenizers import parse_hf_tokenizer_json """ Model builders build specific instantiations using component builders. For example diff --git a/torchtune/models/t5/_component_builders.py b/torchtune/models/t5/_component_builders.py index 4867b5036f..4e093ea42c 100644 --- a/torchtune/models/t5/_component_builders.py +++ b/torchtune/models/t5/_component_builders.py @@ -51,37 +51,39 @@ def t5_encoder( """ token_embedding = nn.Embedding(vocab_size, embed_dim) - attn = T5EncoderSelfAttention( - embed_dim=embed_dim, - num_heads=num_heads, - head_dim=head_dim, - q_proj=nn.Linear(embed_dim, embed_dim, bias=False), - k_proj=nn.Linear(embed_dim, embed_dim, bias=False), - v_proj=nn.Linear(embed_dim, embed_dim, bias=False), - output_proj=nn.Linear(embed_dim, embed_dim, bias=False), - ) + layers = nn.ModuleList() + for _ in range(num_layers): + attn = T5EncoderSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, embed_dim, bias=False), + k_proj=nn.Linear(embed_dim, embed_dim, bias=False), + v_proj=nn.Linear(embed_dim, embed_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + ) - mlp = FeedForward( - gate_proj=nn.Linear(embed_dim, mlp_dim, bias=False), - down_proj=nn.Linear(mlp_dim, embed_dim, bias=False), - up_proj=nn.Linear(embed_dim, mlp_dim, bias=False), - activation=nn.GELU(), - ) + mlp = FeedForward( + gate_proj=nn.Linear(embed_dim, mlp_dim, bias=False), + down_proj=nn.Linear(mlp_dim, embed_dim, bias=False), + up_proj=nn.Linear(embed_dim, mlp_dim, bias=False), + activation=nn.GELU(), + ) - layer = T5EncoderLayer( - attn=attn, - mlp=mlp, - sa_norm=RMSNorm(embed_dim, eps=norm_eps), - mlp_norm=RMSNorm(embed_dim, eps=norm_eps), - ) + layer = T5EncoderLayer( + attn=attn, + mlp=mlp, + sa_norm=RMSNorm(embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(embed_dim, eps=norm_eps), + ) + layers.append(layer) final_norm = RMSNorm(embed_dim, eps=norm_eps) return T5Encoder( token_embedding=token_embedding, - layer=layer, + layers=layers, final_norm=final_norm, - num_layers=num_layers, num_heads=num_heads, rel_pos_num_buckets=rel_pos_num_buckets, rel_pos_max_dist=rel_pos_max_dist, diff --git a/torchtune/models/t5/_encoder.py b/torchtune/models/t5/_encoder.py index 7828e9ecc5..71dd0c5a8d 100644 --- a/torchtune/models/t5/_encoder.py +++ b/torchtune/models/t5/_encoder.py @@ -3,14 +3,15 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import copy + import math +from typing import List, Optional, Union import torch import torch.nn.functional as F from torch import nn, Tensor - from torchtune.modules import MultiHeadAttention +from torchtune.modules.transformer import _get_clones class T5Encoder(nn.Module): @@ -21,9 +22,8 @@ class T5Encoder(nn.Module): Args: token_embedding (nn.Embedding): PyTorch embedding layer to place tokens in an embedding space. - layer (nn.Module): A single encoder layer. + layers (Union[nn.Module, List[nn.Module], nn.ModuleList]): A single encoder layer. final_norm (nn.Module): Module that applies normalization to the output of the encoder - num_layers (int): Number of encoder layers. num_heads (int): The number of attention heads. rel_pos_num_buckets (int): Number of discrete buckets to divide the relative positions into. See: :class:`~torchtune.models.t5._encoder.T5EncoderRelativePositionBias` @@ -31,23 +31,29 @@ class T5Encoder(nn.Module): Distances beyond this are grouped into the last bucket. See: :class:`~torchtune.models.t5._encoder.T5EncoderRelativePositionBias` max_seq_len (int): The maximum sequence length (context length) of the model. + num_layers (Optional[int]): Number of encoder layers, only define when layers is not a list. + + Raises: + AssertionError: + If ``num_layers`` is set and layer is a list, **or** + ``num_layers`` is not set and layer is an ``nn.Module``. + """ def __init__( self, *, token_embedding: nn.Embedding, - layer: nn.Module, + layers: Union[nn.Module, List[nn.Module], nn.ModuleList], final_norm: nn.Module, - num_layers: int, num_heads: int, rel_pos_num_buckets: int, rel_pos_max_dist: int, max_seq_len: int, - ): + num_layers: Optional[int] = None, + ) -> None: super().__init__() self.token_embedding = token_embedding - self.layers = nn.ModuleList([copy.deepcopy(layer) for i in range(num_layers)]) self.final_norm = final_norm self.max_seq_len = max_seq_len self.relative_position_bias = T5EncoderRelativePositionBias( @@ -57,6 +63,18 @@ def __init__( max_seq_len=max_seq_len, ) + self.layers = None + if isinstance(layers, nn.ModuleList): + self.layers = layers + elif isinstance(layers, list): + self.layers = nn.ModuleList(layers) + else: + if not isinstance(layers, nn.Module): + raise AssertionError("num_layers is defined, layers must be a module") + if num_layers is None: + raise AssertionError("num_layers is not defined, layers must be a list") + self.layers = _get_clones(layers, num_layers) + def forward(self, tokens: Tensor) -> Tensor: """ Args: @@ -147,8 +165,9 @@ class T5EncoderSelfAttention(nn.Module): output_proj (nn.Module): Projection layer for output. Raises: - ValueError: If ``num_heads % num_kv_heads != 0`` - ValueError: If ``embed_dim // num_heads != head_dim`` + ValueError: + If ``embed_dim % num_heads != 0``, **or** + if ``embed_dim // num_heads != head_dim`` """ def __init__( diff --git a/torchtune/models/t5/_tokenizer.py b/torchtune/models/t5/_tokenizer.py index f89dff00f6..e4fa9c539e 100644 --- a/torchtune/models/t5/_tokenizer.py +++ b/torchtune/models/t5/_tokenizer.py @@ -5,7 +5,9 @@ # LICENSE file in the root directory of this source tree. from typing import Any, Dict, List -from torchtune.modules.tokenizers._sentencepiece import SentencePieceBaseTokenizer +from torchtune.modules.transforms.tokenizers._sentencepiece import ( + SentencePieceBaseTokenizer, +) class T5Tokenizer(SentencePieceBaseTokenizer): diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index f4cfa142f7..ff6faccb5d 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -73,10 +73,11 @@ class MultiHeadAttention(nn.Module): Default value is 0.0. Raises: - ValueError: If ``num_heads % num_kv_heads != 0`` - ValueError: If ``embed_dim % num_heads != 0`` - ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` - ValueError: if q_norm is defined without k_norm or vice versa + ValueError: + If ``num_heads % num_kv_heads != 0``, **or** + if ``embed_dim % num_heads != 0``, **or** + if ``attn_dropout < 0`` or ``attn_dropout > 1``, **or** + if q_norm is defined without k_norm or vice versa """ def __init__( diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index e96491c22a..3d72e87adc 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -84,9 +84,13 @@ def update( Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively. Raises: - AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length. ValueError: if the batch size of the new key (or value) tensor is greater than the batch size used during cache setup. + + Note: + This function will raise an ``AssertionError`` if the sequence length of ``k_val`` + is longer than the maximum cache sequence length. + """ bsz, _, seq_len, _ = k_val.shape if bsz > self.k_cache.shape[0]: @@ -109,6 +113,6 @@ def update( # this allows us to track the current position in the cache # after the last update in a compile-friendly way without any dynamism # e.g. relying on an int size tracker, or re-creating cache_pos every time - self.cache_pos += seq_len + self.cache_pos.add_(seq_len) return k_out, v_out diff --git a/torchtune/modules/loss/kd_losses.py b/torchtune/modules/loss/kd_losses.py index 3c8c9d1153..c3d48af8f0 100644 --- a/torchtune/modules/loss/kd_losses.py +++ b/torchtune/modules/loss/kd_losses.py @@ -54,7 +54,9 @@ def forward( mask = (labels != self.ignore_index).int() if not normalize: return -torch.sum(x * mask.view(-1), dim=0) - if torch.sum(mask.view(-1), dim=0) == 0: + + sum_masks = torch.sum(mask.view(-1), dim=0) + if sum_masks == 0: return torch.tensor(0.0, device=x.device) return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0) @@ -245,6 +247,10 @@ def forward( student_chunk, teacher_chunk, label_chunk, normalize=False ) + sum_masks = torch.sum(mask.view(-1), dim=0) + if sum_masks == 0: + return torch.tensor(0.0, device=student_logits[0].device) + return total_fkl_loss / torch.sum(mask.view(-1), dim=0) diff --git a/torchtune/modules/model_fusion/_early_fusion.py b/torchtune/modules/model_fusion/_early_fusion.py index d20b2d119f..7b41e896ad 100644 --- a/torchtune/modules/model_fusion/_early_fusion.py +++ b/torchtune/modules/model_fusion/_early_fusion.py @@ -130,23 +130,28 @@ def __init__( set_trainable_params(self, trainable_params) @staticmethod - def _state_dict_hook(module, state_dict, *args, **kwargs): + def _state_dict_hook(module, state_dict, prefix, *args, **kwargs): """ Keep tok_embeddings inside of decoder state_dict [!Note] This update changes the order of the OrderedDict """ for n, p in module.tok_embeddings.named_parameters(): - state_dict[f"decoder.tok_embeddings.{n}"] = p - del state_dict[f"tok_embeddings.{n}"] + orig_key = f"{prefix}tok_embeddings.{n}" + if orig_key in state_dict: + # preserve the original tensor with its requires_grad state + state_dict[f"{prefix}decoder.tok_embeddings.{n}"] = state_dict[orig_key] + del state_dict[orig_key] @staticmethod - def _load_state_dict_hook(module, state_dict, *args, **kwargs): + def _load_state_dict_hook(module, state_dict, prefix, *args, **kwargs): """Undo the change from _state_dict_hook""" old_keys = list(state_dict.keys()) for key in old_keys: - if key.startswith("decoder.tok_embeddings"): - state_dict[key[len("decoder.") :]] = state_dict[key] + if "decoder.tok_embeddings" in key: + state_dict[prefix + key[len("decoder.") + len(prefix) :]] = state_dict[ + key + ] del state_dict[key] def set_num_output_chunks(self, num_output_chunks: int) -> None: @@ -185,7 +190,9 @@ def reset_caches(self): def _decoder_embed(self, tokens) -> Tuple[torch.Tensor, torch.Tensor]: """Embed the text-only tokens with the decoder's tok_embeddings""" - encoder_token_ids = torch.tensor(list(self.encoder_tokens.values())) + encoder_token_ids = torch.tensor( + list(self.encoder_tokens.values()), device=tokens.device + ) # [bsz, seq_len], True indicates the token is not an encoder special token is_text = ~torch.isin(tokens, encoder_token_ids) text_tokens = torch.masked_select(tokens, is_text) diff --git a/torchtune/modules/peft/_utils.py b/torchtune/modules/peft/_utils.py index 1d0f1047b6..6d8b4c8b13 100644 --- a/torchtune/modules/peft/_utils.py +++ b/torchtune/modules/peft/_utils.py @@ -9,7 +9,6 @@ import torch from torch import nn -from torchtune.utils._logging import deprecated # Modules from MultiHeadAttention that LoRA can be applied to LORA_ATTN_MODULES = Literal["q_proj", "k_proj", "v_proj", "output_proj"] @@ -285,10 +284,11 @@ def validate_missing_and_unexpected_for_lora( None Raises: - AssertionError: if base_missing contains any base model keys. - AssertionError: if base_unexpected is nonempty. - AssertionError: if lora_missing contains any LoRA keys. - AssertionError: if lora_unexpected is nonempty. + AssertionError: + If base_missing contains any base model keys, **or** + if base_unexpected is nonempty, **or** + if lora_missing contains any LoRA keys, **or** + if lora_unexpected is nonempty. """ lora_modules = get_lora_module_names( lora_attn_modules, apply_lora_to_mlp, apply_lora_to_output @@ -312,17 +312,3 @@ def validate_missing_and_unexpected_for_lora( raise AssertionError(f"Missing LoRA key {k} from adapter state dict") if lora_unexpected: raise AssertionError("Unexpected key loading adapter") - - -@deprecated( - msg="load_dora_magnitudes will be deprecated in 0.6.0. Please use DoRALinear.initialize_dora_magnitude instead." -) -def load_dora_magnitudes(model: nn.Module) -> None: - """ - For DoRA magnitude we use setattr to move from meta device - """ - dora_parents = { - n: p for n, p in model.named_modules() if hasattr(p, "adapter_params") - } - sd = {f"{n}.magnitude": p.magnitude for n, p in dora_parents.items()} - model.load_state_dict(sd, strict=False, assign=True) diff --git a/torchtune/modules/position_embeddings.py b/torchtune/modules/position_embeddings.py index 5f07772d82..197295a328 100644 --- a/torchtune/modules/position_embeddings.py +++ b/torchtune/modules/position_embeddings.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional +from typing import Any, Optional import torch from torch import nn @@ -127,7 +127,9 @@ class VisionRotaryPositionalEmbeddings(nn.Module): This class implements two-dimensional Rotary Positional Embeddings (RoPE) for images based on the axial frequency 2D RoPE described in https://arxiv.org/pdf/2403.13298. - The position embedding is simply applied to the x-axis and y-axis separately. + The position embedding is simply applied to the x-axis and y-axis separately, encoding + the x and y position of each patch within every tile.. The embedding is applied to each + tile identically. Note: This module assumes the CLS token embedding is appended at the end of the sequence. @@ -136,6 +138,8 @@ class VisionRotaryPositionalEmbeddings(nn.Module): E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches. tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, the size of the full input image. In this case, the function will consider your image as a single tile. + max_num_tiles (int): The maximum number of tiles in the image. This is used to unfold the input sequence + length into sequence length per tile so RoPE can be applied to each tile separately. dim (int): Embedding dimension. Unlike :class:`~torchtune.modules.RotaryPositionalEmbeddings`, this is usually set to the dim of each head in the attention module divided by 2, computed as ``embed_dim // num_heads // 2``. The divide by 2 accounts for x and y positions. @@ -149,12 +153,14 @@ def __init__( self, patch_size: int, tile_size: int, + max_num_tiles: int, dim: int, base: int = 10_000, append_cls_token: bool = True, ) -> None: super().__init__() self.patch_grid_size = tile_size // patch_size + self.max_num_tiles = max_num_tiles self.dim = dim self.base = base self.append_cls_token = append_cls_token @@ -209,46 +215,46 @@ def build_rope_cache(self) -> None: self.register_buffer("cache", cache, persistent=False) def forward( - self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None + self, + x: torch.Tensor, + **kwargs: Any, ) -> torch.Tensor: """ Args: - x (torch.Tensor): input tensor with shape - ``[b, s, n_h, h_d]`` - input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids - of each token. During training, this is used to indicate the positions - of each token relative to its sample when packed, shape [b, s]. - During inference, this indicates the position of the current token. - If none, assume the index of the token is its position id. Default is None. + x (torch.Tensor): input tensor with shape ``[b, s, n_h, h_d]`` + **kwargs (Any): additional keyword arguments. This is kept to match the forward signature of + :class:`~torchtune.modules.RotaryPositionalEmbeddings`. Returns: torch.Tensor: output tensor with shape ``[b, s, n_h, h_d]`` + Raises: + ValueError: if sequence length of input tensor does not match the 2D RoPE cache's sequence length + Notation used for tensor shapes: - b: batch size - s: sequence length - n_h: num heads - h_d: head dim """ - # input tensor has shape [b, s, n_h, h_d] - seq_len = x.size(1) - - # extract the values based on whether input_pos is set or not - rope_cache = ( - self.cache[:seq_len] if input_pos is None else self.cache[input_pos] - ) + bsz, _, n_h, h_d = x.shape # reshape input; the last dimension is used for computing the output. + # Split tile dimension from the sequence dimension # Cast to float to match the reference implementation - # tensor has shape [b, s, n_h, h_d // 2, 2] - xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + # tensor has shape [b, max_num_tiles, s // max_num_tiles, n_h, h_d // 2, 2] + xshaped = x.float().reshape(bsz, self.max_num_tiles, -1, n_h, h_d // 2, 2) + seq_len = xshaped.size(2) + + if seq_len != self.cache.shape[0]: + raise ValueError( + f"Input sequence length {seq_len} does not match 2D RoPE cache sequence length {self.cache.shape[0]}." + ) # reshape the cache for broadcasting - # tensor has shape [b, s, 1, h_d // 2, 2] if packed samples, - # otherwise has shape [1, s, 1, h_d // 2, 2] - rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) + rope_cache = self.cache.view(1, 1, seq_len, 1, h_d // 2, 2) - # tensor has shape [b, s, n_h, h_d // 2, 2] + # tensor has shape [b, max_num_tiles, s // max_num_tiles, n_h, h_d // 2, 2] x_out = torch.stack( [ xshaped[..., 0] * rope_cache[..., 0] @@ -259,6 +265,6 @@ def forward( -1, ) - # tensor has shape [b, s, n_h, h_d] - x_out = x_out.flatten(3) + # Squash tile dimension back into sequence dimension - tensor has shape [b, s, n_h, h_d] + x_out = x_out.reshape(bsz, self.max_num_tiles * seq_len, n_h, h_d) return x_out.type_as(x) diff --git a/torchtune/modules/tokenizers/__init__.py b/torchtune/modules/tokenizers/__init__.py index 2fecc279ee..f10a9b3dd6 100644 --- a/torchtune/modules/tokenizers/__init__.py +++ b/torchtune/modules/tokenizers/__init__.py @@ -4,20 +4,28 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from ._sentencepiece import SentencePieceBaseTokenizer -from ._tiktoken import TikTokenBaseTokenizer -from ._utils import ( +# flake8: noqa: F401 + +# NOTE: This file is maintained for backward compatibility purposes. +# The imports below point to the new location in `torchtune.modules.transforms.tokenizers`. +# The import paths will be removed in v0.7. Please update your code to use the new path +# (torchtune.modules.transforms.tokenizers) to avoid breaking changes in future releases. + + +import warnings + +from torchtune.modules.transforms.tokenizers import ( BaseTokenizer, ModelTokenizer, parse_hf_tokenizer_json, + SentencePieceBaseTokenizer, + TikTokenBaseTokenizer, tokenize_messages_no_special_tokens, ) -__all__ = [ - "SentencePieceBaseTokenizer", - "TikTokenBaseTokenizer", - "ModelTokenizer", - "BaseTokenizer", - "tokenize_messages_no_special_tokens", - "parse_hf_tokenizer_json", -] +warnings.warn( + "The import path 'torchtune.modules.tokenizers' is deprecated and will be removed in v0.7. " + "Please update your imports to 'torchtune.modules.transforms.tokenizers'.", + DeprecationWarning, + stacklevel=2, +) diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 66ac92002f..fa4213b659 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -344,8 +344,9 @@ class TransformerDecoder(nn.Module): output_hidden_states (Optional[List[int]]): List of layers (indices) to include in the output Raises: - AssertionError: num_layers is set and layer is a list - AssertionError: num_layers is not set and layer is an nn.Module + AssertionError: + If ``num_layers`` is set and layer is a list, **or** + ``num_layers`` is not set and layer is an ``nn.Module``. Note: Arg values are checked for correctness (eg: ``attn_dropout`` belongs to [0,1]) @@ -418,12 +419,11 @@ def setup_caches( encoder_max_seq_len (Optional[int]): maximum encoder cache sequence length. decoder_max_seq_len (Optional[int]): maximum decoder cache sequence length. """ - has_encoder_layers = any( isinstance(m, TransformerCrossAttentionLayer) for m in self.modules() ) has_decoder_layers = any( - isinstance(l, TransformerSelfAttentionLayer) for l in self.layers + isinstance(m, TransformerSelfAttentionLayer) for m in self.modules() ) if has_encoder_layers: @@ -437,7 +437,6 @@ def setup_caches( self.decoder_max_cache_seq_len = decoder_max_seq_len else: self.decoder_max_cache_seq_len = self.max_seq_len - for layer in self.layers: layer.setup_caches( batch_size, @@ -519,10 +518,11 @@ def _validate_inputs( input_pos (Optional[torch.Tensor]): Input tensor position IDs. Raises: - ValueError: if seq_len of x is bigger than max_seq_len - ValueError: if the model has caches which have been setup with self-attention layers and ``mask`` is not provided. - ValueError: if the model has caches which have been setup with encoder layers and ``encoder_mask`` is not provided. - ValueError: if the model has caches which have been setup ``input_pos`` is not provided. + ValueError: + If seq_len of x is bigger than max_seq_len, **or** + if the model has caches which have been setup with self-attention layers and ``mask`` is not provided, **or** + if the model has caches which have been setup with encoder layers and ``encoder_mask`` is not provided, **or** + if the model has caches which have been setup ``input_pos`` is not provided. """ if seq_len > self.max_seq_len: diff --git a/torchtune/modules/transforms/tokenizers/__init__.py b/torchtune/modules/transforms/tokenizers/__init__.py new file mode 100644 index 0000000000..2fecc279ee --- /dev/null +++ b/torchtune/modules/transforms/tokenizers/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._sentencepiece import SentencePieceBaseTokenizer +from ._tiktoken import TikTokenBaseTokenizer +from ._utils import ( + BaseTokenizer, + ModelTokenizer, + parse_hf_tokenizer_json, + tokenize_messages_no_special_tokens, +) + +__all__ = [ + "SentencePieceBaseTokenizer", + "TikTokenBaseTokenizer", + "ModelTokenizer", + "BaseTokenizer", + "tokenize_messages_no_special_tokens", + "parse_hf_tokenizer_json", +] diff --git a/torchtune/modules/tokenizers/_sentencepiece.py b/torchtune/modules/transforms/tokenizers/_sentencepiece.py similarity index 98% rename from torchtune/modules/tokenizers/_sentencepiece.py rename to torchtune/modules/transforms/tokenizers/_sentencepiece.py index 0b22b63ee3..8d98617378 100644 --- a/torchtune/modules/tokenizers/_sentencepiece.py +++ b/torchtune/modules/transforms/tokenizers/_sentencepiece.py @@ -7,8 +7,7 @@ from typing import List, Optional from sentencepiece import SentencePieceProcessor - -from torchtune.modules.tokenizers._utils import BaseTokenizer +from torchtune.modules.transforms.tokenizers._utils import BaseTokenizer WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"] diff --git a/torchtune/modules/tokenizers/_tiktoken.py b/torchtune/modules/transforms/tokenizers/_tiktoken.py similarity index 98% rename from torchtune/modules/tokenizers/_tiktoken.py rename to torchtune/modules/transforms/tokenizers/_tiktoken.py index 077b22b0cd..64733b4634 100644 --- a/torchtune/modules/tokenizers/_tiktoken.py +++ b/torchtune/modules/transforms/tokenizers/_tiktoken.py @@ -8,7 +8,7 @@ from tiktoken import Encoding from tiktoken.load import load_tiktoken_bpe -from torchtune.modules.tokenizers._utils import BaseTokenizer +from torchtune.modules.transforms.tokenizers._utils import BaseTokenizer # Constants controlling encode logic MAX_ENCODE_CHARS = 400_000 diff --git a/torchtune/modules/tokenizers/_utils.py b/torchtune/modules/transforms/tokenizers/_utils.py similarity index 97% rename from torchtune/modules/tokenizers/_utils.py rename to torchtune/modules/transforms/tokenizers/_utils.py index b580eda1c0..ff374738c7 100644 --- a/torchtune/modules/tokenizers/_utils.py +++ b/torchtune/modules/transforms/tokenizers/_utils.py @@ -14,8 +14,8 @@ class BaseTokenizer(Protocol): """ Abstract token encoding model that implements ``encode`` and ``decode`` methods. - See :class:`~torchtune.modules.tokenizers.SentencePieceBaseTokenizer` and - :class:`~torchtune.modules.tokenizers.TikTokenBaseTokenizer` for example implementations of this protocol. + See :class:`~torchtune.modules.transforms.tokenizers.SentencePieceBaseTokenizer` and + :class:`~torchtune.modules.transforms.tokenizers.TikTokenBaseTokenizer` for example implementations of this protocol. """ def encode(self, text: str, **kwargs: Dict[str, Any]) -> List[int]: diff --git a/torchtune/modules/vision_transformer.py b/torchtune/modules/vision_transformer.py index d44d0c930f..8ba14dfe45 100644 --- a/torchtune/modules/vision_transformer.py +++ b/torchtune/modules/vision_transformer.py @@ -190,9 +190,10 @@ class VisionTransformer(nn.Module): Default is False, which adds CLS token to the beginning of the sequence. Raises: - ValueError: If `tile_size` is not greater than 0. - ValueError: If `patch_size` is not greater than 0. - ValueError: If `len(out_indices)` is greater than `num_layers`. + ValueError: + If `tile_size` is not greater than 0, **or** + if `patch_size` is not greater than 0, **or** + if `len(out_indices)` is greater than `num_layers`. """ def __init__( diff --git a/torchtune/rlhf/loss/__init__.py b/torchtune/rlhf/loss/__init__.py index 5c4b649587..4058979f4a 100644 --- a/torchtune/rlhf/loss/__init__.py +++ b/torchtune/rlhf/loss/__init__.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. -from .dpo import DPOLoss, RSOLoss, SimPOLoss +from .dpo import DPOLoss, RSOLoss from .ppo import PPOLoss -__all__ = ["DPOLoss", "RSOLoss", "SimPOLoss", "PPOLoss"] +__all__ = ["DPOLoss", "RSOLoss", "PPOLoss"] diff --git a/torchtune/rlhf/loss/dpo.py b/torchtune/rlhf/loss/dpo.py index b19e0d93ca..c09d36a261 100644 --- a/torchtune/rlhf/loss/dpo.py +++ b/torchtune/rlhf/loss/dpo.py @@ -9,7 +9,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torchtune.utils._logging import deprecated class DPOLoss(nn.Module): @@ -159,77 +158,3 @@ def forward( ) return losses, chosen_rewards, rejected_rewards - - -@deprecated(msg="SimPOLoss will be deprecated in an upcoming release.") -class SimPOLoss(nn.Module): - """ - SimPO: Simple Preference Optimization with a Reference-Free Reward: https://arxiv.org/abs/2405.14734. - Intuition from the paper: - - The effectiveness of SimPO is attributed to a key design: using the average log probability of a sequence as - the implicit reward. Additionally, we introduce a target reward margin to the Bradley-Terry objective to - encourage a larger margin between the winning and losing responses, further enhancing the algorithm's performance. - - Based on the TRL implementation: - https://github.com/huggingface/trl/blob/98ad01ddfd1e1b67ec018014b83cba40e0caea66/trl/trainer/cpo_trainer.py#L603 - - SimPO is pretty much identitcal to DPO but uses average logprobs to eliminate the need for a reference model to regularize - the policy during training. It also uses a target reward margin to guide the policy towards better responses. - This is kind of the same intuition as in :class:`~torchtune.rlhf.loss.IPOLoss`, but instead of optimizing against - a margin between the reference policy and policy models, we're optimizing against a margin between the chosen and - rejected responses. - - Args: - beta (float): Equivalent temperature scaling parameter to DPO loss, typically in the range of 2.0 to 2.5. Default is 2.0. - gamma (float): Target reward margin hyperparameter, typically we have ``gamma in (0, 1.5]``. - Default is 0.5. - label_smoothing (float): Parameter encoding uncertainty about the labels. Default is 0. - """ - - def __init__( - self, - beta: float = 2.0, - gamma: float = 0.5, - label_smoothing: float = 0.0, - ): - super().__init__() - self.beta = beta - self.gamma = gamma - self.label_smoothing = label_smoothing - - def forward( - self, - policy_chosen_logps: torch.Tensor, - policy_rejected_logps: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Compute the SimPO loss for a batch chosen and rejected average log probabilities. - - Args: - policy_chosen_logps (torch.Tensor): Average log probabilities of the policy model - for the chosen responses with shape [b,]. - policy_rejected_logps (torch.Tensor): Average log probabilities of the policy model - for the rejected responses with shape [b,]. - - Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]; A tuple of three tensors with shape [b,]: - - losses: The SimPO loss for each example in the batch. - - chosen_rewards: Rewards for the chosen responses. - - rejected_rewards: Rewards for the rejected responses. - """ - - pi_logratios = policy_chosen_logps - policy_rejected_logps - - gamma_logratios = self.gamma / self.beta - logits = pi_logratios - gamma_logratios - - losses = ( - -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - - F.logsigmoid(-self.beta * logits) * self.label_smoothing - ) - - chosen_rewards = self.beta * (policy_chosen_logps).detach() - rejected_rewards = self.beta * (policy_rejected_logps).detach() - - return losses, chosen_rewards, rejected_rewards diff --git a/torchtune/rlhf/loss/ppo.py b/torchtune/rlhf/loss/ppo.py index d4770802f7..2c3d48a0e8 100644 --- a/torchtune/rlhf/loss/ppo.py +++ b/torchtune/rlhf/loss/ppo.py @@ -82,7 +82,9 @@ def forward( policy_losses_clipped = -advantages * clipped_ratios policy_losses_unclipped = -advantages * ratios - clipfrac = (policy_losses_clipped > policy_losses_unclipped).float() + clipfrac = (policy_losses_clipped > policy_losses_unclipped).to( + pi_logprobs.dtype + ) clipfrac = ( clipfrac.mean() if padding_masks is None diff --git a/torchtune/rlhf/rewards.py b/torchtune/rlhf/rewards.py index f0e42ca58c..f5882908bc 100644 --- a/torchtune/rlhf/rewards.py +++ b/torchtune/rlhf/rewards.py @@ -76,10 +76,6 @@ def get_rewards_ppo( - response_len: model response length """ - # 1. calculate kl between logprobs and reflogprobs - # 2. calculate kl reward using adaptive scaling value - # 3. calculate total reward by summing above - # return all kl = logprobs - ref_logprobs kl_reward = -kl_coeff * kl @@ -89,9 +85,9 @@ def get_rewards_ppo( # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L153 if valid_score_idxs is not None: - total_reward[ - torch.arange(scores.shape[0], device=scores.device), valid_score_idxs - ] += scores + total_reward.scatter_add_( + 1, valid_score_idxs.unsqueeze(-1), scores.unsqueeze(-1) + ) else: total_reward[:, -1] += scores @@ -113,17 +109,17 @@ def masked_mean( Returns: torch.Tensor: The mean tensor. """ - return (x * mask).sum(dim=dim) / mask.sum(dim=dim) + return (x * mask).sum(dim=dim) / (mask.sum(dim=dim) + 1e-8) def masked_var( - x: torch.Tensor, mask: torch.Tensor, unbiased: bool = True + centered_values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True ) -> torch.Tensor: """ - Compute variance of tensor with masked values. Taken from https://github.com/huggingface/trl/blob/main/trl/core.py - + Compute variance of mean-centered tensor with masked values. Taken from https://github.com/huggingface/trl/blob/main/trl/core.py + We use ``centered_values`` to avoid repeated calls to ``masked_mean``. Args: - x (torch.Tensor): The input tensor. + centered_values (torch.Tensor): The mean-centered tensor e.g. ``x - masked_mean(x)``. mask (torch.Tensor): The bool mask tensor, where True indicates the corresponding value in ``x`` should participate in the mean calculation. unbiased (bool): Whether to use the unbiased variance. @@ -131,21 +127,10 @@ def masked_var( Returns: torch.Tensor: The variance tensor. - Raises: - ValueError: If the sum of the mask is zero. """ - mean = masked_mean(x, mask) - centered_values = x - mean var = masked_mean(centered_values.pow(2), mask) if unbiased: - mask_sum = mask.sum() - if mask_sum == 0: - raise ValueError( - "The sum of the mask is zero, which can happen when ``ppo_batch_size=1``;" - "try increase the ``ppo_batch_size`` or ``gradient_accumulation_steps``" - ) - # note that if mask_sum == 1, then there is a division by zero issue - # to avoid it you just need to use a larger minibatch_size + mask_sum = mask.sum() + 1e-8 bessel_correction = mask_sum / (mask_sum - 1) var = var * bessel_correction return var @@ -158,16 +143,16 @@ def whiten( Whiten (normalises) values, optionally with masked values. Taken from https://github.com/huggingface/trl/blob/main/trl/core.py Args: x (torch.Tensor): The input tensor. - mask (Optional[torch.Tensor]): The bool mask tensor, where True indicates the corresponding value in ``x`` - should participate in the mean calculation. Default None. - shift_mean (bool): Whether to shift normalised values by the mean. + mask (Optional[torch.Tensor]): The bool mask tensor with the same shape as ``x``, and where True indicates + the corresponding value in ``x`` should participate in the mean calculation. Default None. + shift_mean (bool): Whether to shift normalised values by the mean. Default True. Returns: torch.Tensor: The whitened tensor. """ if mask is not None: mean = masked_mean(x, mask) - var = masked_var(x, mask) if mask.any() else x.var() + var = masked_var(x - mean, mask) else: mean, var = x.mean(), x.var() whitened = (x - mean) * torch.rsqrt(var + 1e-8) @@ -228,10 +213,8 @@ def estimate_advantages( returns = advantages + values # normalize advantages across the batch of trajectories to reduce variance + advantages = whiten(advantages, mask=masks) if masks is not None: - advantages = whiten(advantages, mask=masks) advantages[~masks] = 0.0 - else: - advantages = whiten(advantages) return advantages, returns diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index 9dd31246c3..d461d84dc4 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -18,6 +18,7 @@ is_distributed, load_from_full_model_state_dict, load_from_full_optimizer_state_dict, + prepare_mha_for_tp, set_torch_num_threads, shard_model, validate_no_params_on_meta_device, @@ -74,6 +75,7 @@ __all__ = [ "get_act_offloading_ctx_manager", + "prepare_mha_for_tp", "apply_selective_activation_checkpointing", "get_dtype", "set_default_dtype", diff --git a/torchtune/training/_activation_offloading.py b/torchtune/training/_activation_offloading.py index a802ce98d8..8d32083b54 100644 --- a/torchtune/training/_activation_offloading.py +++ b/torchtune/training/_activation_offloading.py @@ -55,7 +55,6 @@ class OffloadActivations(saved_tensors_hooks): Raises: ValueError: if max_fwd_stash_size is not at least 1. - RuntimeError: if use_streams but torch installation is earlier than torch-2.5.0.dev20240907 Example: >>> with OffloadActivations(): diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 4001db768b..0e33bfd118 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -25,15 +25,17 @@ set_optimizer_state_dict, StateDictOptions, ) +from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import ShardingStrategy from torch.nn.modules.module import _IncompatibleKeys from torch.optim import Optimizer from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4 from torchtune.modules import TransformerDecoder +from torchtune.modules.attention import MultiHeadAttention +from torchtune.modules.model_fusion import DeepFusionModel from torchtune.modules.peft import get_adapter_state_dict from torchtune.utils import get_device, get_logger from torchtune.utils._logging import deprecated -from torchtune.utils._version import torch_version_ge _log: logging.Logger = get_logger() @@ -41,9 +43,11 @@ _valid_distributed_single_node_nnodes = ["1:1", "1"] torch_version = torch.__version__ -_DISTRIBUTED_STATE_DICT_API_IS_AVAILABLE = ( - "dev" not in torch_version and torch_version_ge("2.6.0") -) or ("dev" in torch_version and torch_version.split("dev")[1] >= "20241220") +# TODO: Fix issues with DSD before uncommenting. See #2313 and #2277. +# _DISTRIBUTED_STATE_DICT_API_IS_AVAILABLE = ( +# "dev" not in torch_version and torch_version_ge("2.6.0") +# ) or ("dev" in torch_version and torch_version.split("dev")[1] >= "20241220") +_DISTRIBUTED_STATE_DICT_API_IS_AVAILABLE = False def _get_sharding_strategy(strategy: str) -> ShardingStrategy: @@ -201,7 +205,7 @@ def load_from_full_model_state_dict( for param in model.parameters() ) meta_sharded_sd = model.state_dict() - # NF4Tensor is not supported in `set_model_state_dict` right now, running with the privious logic right + # NF4Tensor is not supported in `set_model_state_dict` right now, running with the previous logic right # now, would support in the future and remove the following code if _DISTRIBUTED_STATE_DICT_API_IS_AVAILABLE and not has_nf4: for param_name in full_sd.keys(): @@ -546,3 +550,64 @@ def shard_model( # Finally shard the entire model to account for any stragglers fully_shard(model, **fsdp_kwargs) + + +def prepare_mha_for_tp( + model: nn.Module, + tp_mesh: DeviceMesh, +) -> nn.Module: + """ + Utility to scale MultiHeadAttention parameters(num_heads, num_kv_heads, embed_dim) across + tensor parallel devices. Each device will handle a portion of the attention computations. + + Args: + model (nn.Module): Model whose attention parameters will be scaled by TP size. + tp_mesh (DeviceMesh): Tensor parallel device mesh. + + Returns: + nn.Module: The model with scaled MultiHeadAttention parameters. + + Raises: + ValueError: If attention heads, kv heads, or embed dimension is not divisible by TP size. + + Examples: + >>> from torchtune.modules import TransformerDecoder + >>> from torch.distributed.device_mesh import DeviceMesh + >>> model = TransformerDecoder( + num_heads=32, + num_kv_heads=32, + embed_dim=4096, + ) + >>> tp_mesh = DeviceMesh("cuda", torch.arange(2)) # 2 GPUs + >>> model = prepare_mha_for_tp(model, tp_mesh) + >>> # Now each GPU has: + >>> # num_heads = 16 (32/2) + >>> # num_kv_heads = 16 (32/2) + >>> # embed_dim = 2048 (4096/2) + """ + # Consider the case of Deep Fusion models + if isinstance(model, DeepFusionModel): + model = model.decoder + tp_size = tp_mesh.size() + for m in list(model.modules()): + if isinstance(m, MultiHeadAttention): + # Adjust attention module to use the local number of heads + if m.num_heads % tp_size != 0: + raise ValueError( + f"Number of attention heads ({m.num_heads}) must be divisible by " + f"tensor parallel size ({tp_size})." + ) + if m.num_kv_heads % tp_size != 0: + raise ValueError( + f"Number of KV heads ({m.num_kv_heads}) must be divisible by " + f"tensor parallel size ({tp_size})." + ) + if m.embed_dim % tp_size != 0: + raise ValueError( + f"Embedding dimension ({m.embed_dim}) must be divisible by " + f"tensor parallel size ({tp_size})." + ) + m.num_heads = m.num_heads // tp_size + m.num_kv_heads = m.num_kv_heads // tp_size + m.embed_dim = m.embed_dim // tp_size + return model diff --git a/torchtune/training/_profiler.py b/torchtune/training/_profiler.py index 5fe3d74b5c..4cf584c359 100644 --- a/torchtune/training/_profiler.py +++ b/torchtune/training/_profiler.py @@ -27,6 +27,7 @@ DEFAULT_PROFILER_ACTIVITIES = { torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, + torch.profiler.ProfilerActivity.XPU, } DEFAULT_SCHEDULE: dict = { @@ -111,7 +112,7 @@ def trace_handler( log.info(f"Finished dumping traces in {time.monotonic() - begin:.2f} seconds") # Memory timeline sometimes fails to export - if prof.profile_memory: + if prof.profile_memory and torch.cuda.is_available(): if rank == 0: try: prof.export_memory_timeline( @@ -185,6 +186,7 @@ def setup_torch_profiler( enabled: bool = False, cpu: bool = True, cuda: bool = True, + xpu: bool = True, profile_memory: bool = DEFAULT_TRACE_OPTS["profile_memory"], with_stack: bool = DEFAULT_TRACE_OPTS["with_stack"], record_shapes: bool = DEFAULT_TRACE_OPTS["record_shapes"], @@ -252,6 +254,7 @@ def setup_torch_profiler( enabled (bool): Enable pytorch profiler. Default is False. cpu (bool): Enable cpu profiling. Default is True. cuda (bool): Enable cuda profiling. Default is True. + xpu (bool): Enable xpu profiling. Default is True. profile_memory (bool): Profile memory usage. Default is False. with_stack (bool): Profile stack. Default is False. record_shapes (bool): Record shapes. Default is True. @@ -276,10 +279,12 @@ def setup_torch_profiler( activities.append(torch.profiler.ProfilerActivity.CPU) if cuda: activities.append(torch.profiler.ProfilerActivity.CUDA) + if xpu: + activities.append(torch.profiler.ProfilerActivity.XPU) if len(activities) == 0: _warn("No activities specified, defaulting to CPU + CUDA") activities = DEFAULT_PROFILER_ACTIVITIES - cpu = cuda = True + cpu = cuda = xpu = True # Check for schedule # 1) If no schedule is provided, set to DEFAULT_SCHEDULE @@ -372,6 +377,7 @@ def setup_torch_profiler( "output_dir": output_dir, "cpu": cpu, "cuda": cuda, + "xpu": xpu, "profile_memory": profile_memory, "with_stack": with_stack, "record_shapes": record_shapes, diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index 780a2af75f..7c8d2b0bed 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -943,7 +943,6 @@ class FullModelMetaCheckpointer(_CheckpointerInterface): Raises: ValueError: If ``checkpoint_files`` is not a list of length 1 - ValueError: If ``should_load_recipe_state`` is True but ``recipe_checkpoint`` is None """ def __init__( diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 6c9f8474e4..0366b6d2b7 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -284,8 +284,9 @@ def update_state_dict_for_classifier( if ``output.weight != model.output.weight``. Raises: - AssertionError: if ``state_dict`` does not contain ``output.weight``. - AssertionError: if ``model_named_parameters`` does not contain ``output.weight``. + AssertionError: + If ``state_dict`` does not contain ``output.weight``, **or** + if ``model_named_parameters`` does not contain ``output.weight``. """ output_weight = dict(model_named_parameters).get("output.weight", None) diff --git a/torchtune/training/metric_logging.py b/torchtune/training/metric_logging.py index 42aa1f9d72..dde1619194 100644 --- a/torchtune/training/metric_logging.py +++ b/torchtune/training/metric_logging.py @@ -23,6 +23,31 @@ log = get_logger("DEBUG") +def save_config(config: DictConfig) -> Path: + """ + Save the OmegaConf configuration to a YAML file at `{config.output_dir}/torchtune_config.yaml`. + + Args: + config (DictConfig): The OmegaConf config object to be saved. It must contain an `output_dir` attribute + specifying where the configuration file should be saved. + + Returns: + Path: The path to the saved configuration file. + + Note: + If the specified `output_dir` does not exist, it will be created. + """ + try: + output_dir = Path(config.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + output_config_fname = output_dir / "torchtune_config.yaml" + OmegaConf.save(config, output_config_fname) + return output_config_fname + except Exception as e: + log.warning(f"Error saving config.\nError: \n{e}.") + + class MetricLoggerInterface(Protocol): """Abstract metric logger.""" @@ -42,7 +67,7 @@ def log( pass def log_config(self, config: DictConfig) -> None: - """Logs the config + """Logs the config as file Args: config (DictConfig): config to log @@ -99,6 +124,9 @@ def log(self, name: str, data: Scalar, step: int) -> None: self._file.write(f"Step {step} | {name}:{data}\n") self._file.flush() + def log_config(self, config: DictConfig) -> None: + _ = save_config(config) + def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: self._file.write(f"Step {step} | ") for name, data in payload.items(): @@ -119,6 +147,9 @@ class StdoutLogger(MetricLoggerInterface): def log(self, name: str, data: Scalar, step: int) -> None: print(f"Step {step} | {name}:{data}") + def log_config(self, config: DictConfig) -> None: + _ = save_config(config) + def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: print(f"Step {step} | ", end="") for name, data in payload.items(): @@ -183,6 +214,10 @@ def __init__( # Use dir if specified, otherwise use log_dir. self.log_dir = kwargs.pop("dir", log_dir) + # create log_dir if missing + if not os.path.exists(self.log_dir): + os.makedirs(self.log_dir) + _, self.rank = get_world_size_and_rank() if self._wandb.run is None and self.rank == 0: @@ -219,23 +254,16 @@ def log_config(self, config: DictConfig) -> None: self._wandb.config.update( resolved, allow_val_change=self.config_allow_val_change ) - try: - output_config_fname = Path( - os.path.join( - config.output_dir, - "torchtune_config.yaml", - ) - ) - OmegaConf.save(config, output_config_fname) - log.info(f"Logging {output_config_fname} to W&B under Files") + # Also try to save the config as a file + output_config_fname = save_config(config) + try: self._wandb.save( output_config_fname, base_path=output_config_fname.parent ) - except Exception as e: log.warning( - f"Error saving {output_config_fname} to W&B.\nError: \n{e}." + f"Error uploading {output_config_fname} to W&B.\nError: \n{e}." "Don't worry the config will be logged the W&B workspace" ) @@ -305,6 +333,9 @@ def log(self, name: str, data: Scalar, step: int) -> None: if self._writer: self._writer.add_scalar(name, data, global_step=step, new_style=True) + def log_config(self, config: DictConfig) -> None: + _ = save_config(config) + def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: for name, data in payload.items(): self.log(name, data, step) @@ -387,13 +418,16 @@ def __init__( "Alternatively, use the ``StdoutLogger``, which can be specified by setting metric_logger_type='stdout'." ) from e + # Remove 'log_dir' from kwargs as it is not a valid argument for comet_ml.ExperimentConfig + if "log_dir" in kwargs: + del kwargs["log_dir"] + _, self.rank = get_world_size_and_rank() # Declare it early so further methods don't crash in case of # Experiment Creation failure due to mis-named configuration for # example self.experiment = None - if self.rank == 0: self.experiment = comet_ml.start( api_key=api_key, @@ -421,24 +455,13 @@ def log_config(self, config: DictConfig) -> None: self.experiment.log_parameters(resolved) # Also try to save the config as a file + output_config_fname = save_config(config) try: - self._log_config_as_file(config) + self.experiment.log_asset( + output_config_fname, file_name=output_config_fname.name + ) except Exception as e: - log.warning(f"Error saving Config to disk.\nError: \n{e}.") - return - - def _log_config_as_file(self, config: DictConfig): - output_config_fname = Path( - os.path.join( - config.checkpointer.checkpoint_dir, - "torchtune_config.yaml", - ) - ) - OmegaConf.save(config, output_config_fname) - - self.experiment.log_asset( - output_config_fname, file_name="torchtune_config.yaml" - ) + log.warning(f"Failed to upload config to Comet assets. Error: {e}") def close(self) -> None: if self.experiment is not None: diff --git a/torchtune/training/pooling.py b/torchtune/training/pooling.py index 3e0ba41507..e0fe204a5b 100644 --- a/torchtune/training/pooling.py +++ b/torchtune/training/pooling.py @@ -8,7 +8,7 @@ def get_unmasked_sequence_lengths(mask: torch.Tensor) -> torch.Tensor: """ - Returns the sequence lengths for each batch element, excluding masked tokens. + Returns the sequence lengths (0-indexed) for each batch element, excluding masked tokens. Args: mask (torch.Tensor): Boolean mask with shape [b x s], where True indicates a value to be masked out @@ -37,13 +37,6 @@ def get_unmasked_sequence_lengths(mask: torch.Tensor) -> torch.Tensor: """ # calculate per-batch-element sequence lengths by finding last valid tokens - if mask.any(): - sequence_lengths = ( - (~mask).sum(-1).sub(1).clip(0).to(mask.device, dtype=torch.long) - ) - else: - sequence_lengths = torch.full( - (mask.shape[0],), mask.shape[1] - 1, dtype=torch.long, device=mask.device - ) - - return sequence_lengths + sequence_lengths = (~mask).cumsum(dim=-1).argmax(dim=-1).to(dtype=torch.long) + + return sequence_lengths.clip(0, mask.shape[1] - 1) diff --git a/torchtune/training/precision.py b/torchtune/training/precision.py index 85a2c07e4f..d9232cf97f 100644 --- a/torchtune/training/precision.py +++ b/torchtune/training/precision.py @@ -33,7 +33,8 @@ def _set_float32_precision(precision: str = "high") -> None: Args: precision (str): The setting to determine which datatypes to use for matrix multiplication and convolution operations. """ - if not torch.cuda.is_available(): # Not relevant for non-CUDA devices + # Not relevant for non-CUDA or non-NPU devices + if not (torch.cuda.is_available() or is_npu_available): return # set precision for matrix multiplications torch.set_float32_matmul_precision(precision) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index b158d4b9a3..e5c88da04f 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. from typing import Callable, Optional -from warnings import warn from torch import nn from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear @@ -144,6 +143,7 @@ class Int4WeightOnlyQATQuantizerModuleSwap(Int4WeightOnlyQATQuantizer): "4w-qat-module-swap" ] = enable_4w_fake_quant_module_swap + # int8 dynamic activations + int4 weight class Int8DynActInt4WeightQATQuantizerModuleSwap(Int8DynActInt4WeightQATQuantizer): pass @@ -179,12 +179,7 @@ def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]: Returns: Optional[str]: The quantization mode. """ - mode = _quantizer_to_mode.get(type(quantizer), None) - if mode is not None and "module-swap" in mode: - warn( - "*QuantizerModuleSwap is deprecated. Please use the version without 'ModuleSwap' instead" - ) - return mode + return _quantizer_to_mode.get(type(quantizer), None) def _get_disable_fake_quant(quantizer_mode: str) -> Callable: diff --git a/torchtune/utils/_device.py b/torchtune/utils/_device.py index 10d5e62a05..1d6defefae 100644 --- a/torchtune/utils/_device.py +++ b/torchtune/utils/_device.py @@ -177,10 +177,11 @@ def batch_to_device(batch: dict, device: torch.device) -> None: Args: batch (dict): dict of Tensors or more nested dicts of tensors. - device (torch.device): torch device to move the tensor's too + device (torch.device): torch device to move the tensors to. Raises: - AttributeError: if batch dict contains anything other than tensors + ValueError: if batch dict contains anything other than ``torch.Tensor``. + """ for k, v in batch.items(): if isinstance(v, dict): diff --git a/torchtune/utils/_logging.py b/torchtune/utils/_logging.py index ec3912e317..40b8d229b1 100644 --- a/torchtune/utils/_logging.py +++ b/torchtune/utils/_logging.py @@ -67,6 +67,9 @@ def deprecated(msg: str = "") -> Callable[[T], T]: @lru_cache(maxsize=1) def warn(obj): + rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 + if rank != 0: + return warnings.warn( f"{obj.__name__} is deprecated and will be removed in future versions. " + msg,