Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMPLIFY Mega-PR #640

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"./sub-packages/bionemo-geneformer/src",
"./sub-packages/bionemo-llm/src",
"./sub-packages/bionemo-testing/src",
"./sub-packages/bionemo-amplify/src",
"./sub-packages/bionemo-example_model/src",
"./3rdparty/NeMo",
"./3rdparty/Megatron-LM"
Expand Down
10 changes: 7 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ EOF
# Reinstall TE to avoid debugpy bug in vscode: https://nvbugspro.nvidia.com/bug/5078830
# Pull the latest TE version from https://github.com/NVIDIA/TransformerEngine/releases
# Use the version that matches the pytorch base container.
ARG TE_TAG=v1.13
ARG TE_TAG=2215fa5c7557b66034068816020f9f611019e457
RUN NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi \
pip --disable-pip-version-check --no-cache-dir install \
git+https://github.com/NVIDIA/TransformerEngine.git@${TE_TAG}
Expand All @@ -48,10 +48,13 @@ RUN NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi \
RUN CAUSAL_CONV1D_FORCE_BUILD=TRUE pip --disable-pip-version-check --no-cache-dir install \
git+https://github.com/Dao-AILab/[email protected]

# Mamba dependancy installation
# Mamba dependency installation
RUN pip --disable-pip-version-check --no-cache-dir install \
git+https://github.com/state-spaces/[email protected]

ARG XFORMER_ENGINE_TAG=v0.0.29.post1
RUN pip install -v -U git+https://github.com/facebookresearch/xformers.git@${XFORMER_ENGINE_TAG}#egg=xformers

RUN pip install hatchling # needed to install nemo-run
ARG NEMU_RUN_TAG=34259bd3e752fef94045a9a019e4aaf62bd11ce2
RUN pip install nemo_run@git+https://github.com/NVIDIA/NeMo-Run.git@${NEMU_RUN_TAG}
Expand Down Expand Up @@ -100,7 +103,7 @@ COPY ./sub-packages /workspace/bionemo2/sub-packages
RUN --mount=type=bind,source=./.git,target=./.git \
--mount=type=bind,source=./requirements-test.txt,target=/requirements-test.txt \
--mount=type=bind,source=./requirements-cve.txt,target=/requirements-cve.txt \
--mount=type=cache,target=/root/.cache <<EOF
<<EOF
set -eo pipefail

uv pip install maturin --no-build-isolation
Expand All @@ -114,6 +117,7 @@ uv pip install --no-build-isolation \
rm -rf ./3rdparty
rm -rf /tmp/*
rm -rf ./sub-packages/bionemo-noodles/target
rm -rf /root/.cache
EOF

# In the devcontainer image, we just copy over the finished `dist-packages` folder from the build image back into the
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ nemo_toolkit = { workspace = true }
megatron-core = { workspace = true }
# in sub-packages/
bionemo-core = { workspace = true }
bionemo-amplify = { workspace = true }
bionemo-esm2 = { workspace = true }
bionemo-example_model = { workspace = true }
bionemo-fw = { workspace = true }
Expand Down
11 changes: 11 additions & 0 deletions sub-packages/bionemo-amplify/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bionemo-amplify

To install, execute the following:
```bash
pip install -e .
```

To run unit tests, execute:
```bash
pytest -v .
```
1 change: 1 addition & 0 deletions sub-packages/bionemo-amplify/VERSION
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0.0.1
33 changes: 33 additions & 0 deletions sub-packages/bionemo-amplify/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

[build-system]
requires = ["setuptools>=64", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "bionemo-amplify"
readme = "README.md"
description = ""
authors = [{ name = "BioNeMo Team", email = "[email protected]" }]
requires-python = ">=3.10"
license = { file = "LICENSE" }
dynamic = ["version"]
dependencies = [
# internal
'bionemo-core',
'bionemo-llm',
'bionemo-esm2',
# external
# 'xformers'
]

[tool.setuptools.packages.find]
where = ["src"]
include = ["bionemo.*"]
namespaces = true
exclude = ["test*."]

[tool.uv]
cache-keys = [{ git = true }]

[tool.setuptools.dynamic]
version = { file = "VERSION" }
14 changes: 14 additions & 0 deletions sub-packages/bionemo-amplify/src/bionemo/amplify/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
163 changes: 163 additions & 0 deletions sub-packages/bionemo-amplify/src/bionemo/amplify/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from pathlib import Path

import torch
from nemo.lightning import io, teardown
from nemo.lightning.pytorch.utils import dtype_from_hf
from transformers import AutoConfig as HFAutoConfig
from transformers import AutoModel

from bionemo.amplify.model import AMPLIFYConfig
from bionemo.amplify.tokenizer import BioNeMoAMPLIFYTokenizer
from bionemo.llm.lightning import BionemoLightningModule
from bionemo.llm.model.biobert.lightning import biobert_lightning_module


@io.model_importer(BionemoLightningModule, "hf")
class HFAMPLIFYImporter(io.ModelConnector[AutoModel, BionemoLightningModule]):
"""Converts a Hugging Face ESM-2 model to a NeMo ESM-2 model."""

def init(self) -> BionemoLightningModule:
"""Initialize the converted model."""
return biobert_lightning_module(self.config, tokenizer=self.tokenizer)

def apply(self, output_path: Path) -> Path:
"""Applies the transformation."""
source = AutoModel.from_pretrained(str(self), trust_remote_code=True, torch_dtype="auto")
target = self.init()
trainer = self.nemo_setup(target)
self.convert_state(source, target)
self.nemo_save(output_path, trainer)
teardown(trainer, target)
return output_path

def convert_state(self, source, target):
"""Converting HF state dict to NeMo state dict."""
mapping = {
"encoder.weight": "embedding.word_embeddings.weight",
"transformer_encoder.*.wo.weight": "encoder.layers.*.self_attention.linear_proj.weight",
"transformer_encoder.*.ffn.w12.weight": "encoder.layers.*.mlp.linear_fc1.weight",
"transformer_encoder.*.ffn.w3.weight": "encoder.layers.*.mlp.linear_fc2.weight",
"transformer_encoder.*.attention_norm.weight": "encoder.layers.*.self_attention.linear_qkv.layer_norm_weight",
"transformer_encoder.*.ffn_norm.weight": "encoder.layers.*.mlp.linear_fc1.layer_norm_weight",
"layer_norm_2.weight": "encoder.final_layernorm.weight",
"decoder.weight": "output_layer.weight",
"decoder.bias": "output_layer.bias",
}

# lm_head.bias
return io.apply_transforms(
source,
target,
mapping=mapping,
transforms=[_import_qkv_weight],
# transforms=[_pad_embeddings, _pad_bias, _import_qkv_weight],
)

@property
def tokenizer(self) -> BioNeMoAMPLIFYTokenizer:
"""We just have the one tokenizer for ESM-2."""
return BioNeMoAMPLIFYTokenizer()

@property
def config(self) -> AMPLIFYConfig:
"""Returns the transformed ESM-2 config given the model tag."""
source = HFAutoConfig.from_pretrained(str(self), trust_remote_code=True)
output = AMPLIFYConfig(
num_layers=source.num_hidden_layers,
hidden_size=source.hidden_size,
ffn_hidden_size=source.intermediate_size,
position_embedding_type="rope",
num_attention_heads=source.num_attention_heads,
seq_length=source.max_length,
fp16=(dtype_from_hf(source) == torch.float16),
bf16=(dtype_from_hf(source) == torch.bfloat16),
params_dtype=dtype_from_hf(source),
)

return output


@io.state_transform(
source_key="esm.embeddings.word_embeddings.weight",
target_key="embedding.word_embeddings.weight",
)
def _pad_embeddings(ctx: io.TransformCTX, source_embed):
"""Pad the embedding layer to the new input dimension."""
nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by
hf_embedding_dimension = source_embed.size(0)
num_padding_rows = nemo_embedding_dimension - hf_embedding_dimension
padding_rows = torch.zeros(num_padding_rows, source_embed.size(1))
return torch.cat((source_embed, padding_rows), dim=0)


@io.state_transform(
source_key="lm_head.bias",
target_key="output_layer.bias",
)
def _pad_bias(ctx: io.TransformCTX, source_bias):
"""Pad the embedding layer to the new input dimension."""
nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by
hf_embedding_dimension = source_bias.size(0)
output_bias = torch.zeros(nemo_embedding_dimension, dtype=source_bias.dtype, device=source_bias.device)
output_bias[:hf_embedding_dimension] = source_bias
return output_bias


@io.state_transform(
source_key=(
"transformer_encoder.*.q.weight",
"transformer_encoder.*.k.weight",
"transformer_encoder.*.v.weight",
),
target_key="encoder.layers.*.self_attention.linear_qkv.weight",
)
def _import_qkv_weight(ctx: io.TransformCTX, query, key, value):
"""Pad the embedding layer to the new input dimension."""
concat_weights = torch.cat((query, key, value), dim=0)
input_shape = concat_weights.size()
np = ctx.target.config.num_attention_heads
# transpose weights
# [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads]
# --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads]
concat_weights = concat_weights.view(3, np, -1, query.size()[-1])
concat_weights = concat_weights.transpose(0, 1).contiguous()
concat_weights = concat_weights.view(*input_shape)
return concat_weights


@io.state_transform(
source_key=(
"esm.encoder.layer.*.attention.self.query.bias",
"esm.encoder.layer.*.attention.self.key.bias",
"esm.encoder.layer.*.attention.self.value.bias",
),
target_key="encoder.layers.*.self_attention.linear_qkv.bias",
)
def _import_qkv_bias(ctx: io.TransformCTX, query, key, value):
"""Pad the embedding layer to the new input dimension."""
concat_biases = torch.cat((query, key, value), dim=0)
input_shape = concat_biases.size()
np = ctx.target.config.num_attention_heads
# transpose biases
# [num_splits_model_parallel * attention head size * #attention heads]
# --> [attention head size * num_splits_model_parallel * #attention heads]
concat_biases = concat_biases.view(3, np, -1)
concat_biases = concat_biases.transpose(0, 1).contiguous()
concat_biases = concat_biases.view(*input_shape)
return concat_biases
95 changes: 95 additions & 0 deletions sub-packages/bionemo-amplify/src/bionemo/amplify/hf_rotary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Code copied from https://huggingface.co/chandar-lab/AMPLIFY_350M/blob/main/rotary.py


from typing import Tuple

import torch


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""Precompute the frequency tensor for complex exponentials (cis) with given dimensions.

This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.

Args:
dim (int): Dimension of the frequency tensor.
end (int): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.

Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
return torch.polar(torch.ones_like(freqs), freqs) # complex64


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
"""Reshape frequency tensor for broadcasting it with another tensor.

This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.

Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.

Returns:
torch.Tensor: Reshaped frequency tensor.

Raises:
AssertionError: If the frequency tensor doesn't match the expected shape.
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)


def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary embeddings to input tensors using the given frequency tensor.

This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.

Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings.
xk (torch.Tensor): Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.

Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
Loading
Loading