-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add AMPLIFY huggingface conversion utility
Signed-off-by: Peter St. John <[email protected]>
- Loading branch information
Showing
8 changed files
with
463 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,7 +38,7 @@ EOF | |
# Reinstall TE to avoid debugpy bug in vscode: https://nvbugspro.nvidia.com/bug/5078830 | ||
# Pull the latest TE version from https://github.com/NVIDIA/TransformerEngine/releases | ||
# Use the version that matches the pytorch base container. | ||
ARG TE_TAG=v1.13 | ||
ARG TE_TAG=2215fa5c7557b66034068816020f9f611019e457 | ||
RUN NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi \ | ||
pip --disable-pip-version-check --no-cache-dir install \ | ||
git+https://github.com/NVIDIA/TransformerEngine.git@${TE_TAG} | ||
|
@@ -48,10 +48,13 @@ RUN NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi \ | |
RUN CAUSAL_CONV1D_FORCE_BUILD=TRUE pip --disable-pip-version-check --no-cache-dir install \ | ||
git+https://github.com/Dao-AILab/[email protected] | ||
|
||
# Mamba dependancy installation | ||
# Mamba dependency installation | ||
RUN pip --disable-pip-version-check --no-cache-dir install \ | ||
git+https://github.com/state-spaces/[email protected] | ||
|
||
ARG XFORMER_ENGINE_TAG=v0.0.29.post1 | ||
RUN pip install -v -U git+https://github.com/facebookresearch/xformers.git@${XFORMER_ENGINE_TAG}#egg=xformers | ||
|
||
RUN pip install hatchling # needed to install nemo-run | ||
ARG NEMU_RUN_TAG=34259bd3e752fef94045a9a019e4aaf62bd11ce2 | ||
RUN pip install nemo_run@git+https://github.com/NVIDIA/NeMo-Run.git@${NEMU_RUN_TAG} | ||
|
@@ -100,7 +103,7 @@ COPY ./sub-packages /workspace/bionemo2/sub-packages | |
RUN --mount=type=bind,source=./.git,target=./.git \ | ||
--mount=type=bind,source=./requirements-test.txt,target=/requirements-test.txt \ | ||
--mount=type=bind,source=./requirements-cve.txt,target=/requirements-cve.txt \ | ||
--mount=type=cache,target=/root/.cache <<EOF | ||
<<EOF | ||
set -eo pipefail | ||
|
||
uv pip install maturin --no-build-isolation | ||
|
@@ -114,6 +117,7 @@ uv pip install --no-build-isolation \ | |
rm -rf ./3rdparty | ||
rm -rf /tmp/* | ||
rm -rf ./sub-packages/bionemo-noodles/target | ||
rm -rf /root/.cache | ||
EOF | ||
|
||
# In the devcontainer image, we just copy over the finished `dist-packages` folder from the build image back into the | ||
|
163 changes: 163 additions & 0 deletions
163
sub-packages/bionemo-amplify/src/bionemo/amplify/convert.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: LicenseRef-Apache2 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from pathlib import Path | ||
|
||
import torch | ||
from nemo.lightning import io, teardown | ||
from nemo.lightning.pytorch.utils import dtype_from_hf | ||
from transformers import AutoConfig as HFAutoConfig | ||
from transformers import AutoModel | ||
|
||
from bionemo.amplify.model import AMPLIFYConfig | ||
from bionemo.amplify.tokenizer import BioNeMoAMPLIFYTokenizer | ||
from bionemo.llm.lightning import BionemoLightningModule | ||
from bionemo.llm.model.biobert.lightning import biobert_lightning_module | ||
|
||
|
||
@io.model_importer(BionemoLightningModule, "hf") | ||
class HFAMPLIFYImporter(io.ModelConnector[AutoModel, BionemoLightningModule]): | ||
"""Converts a Hugging Face ESM-2 model to a NeMo ESM-2 model.""" | ||
|
||
def init(self) -> BionemoLightningModule: | ||
"""Initialize the converted model.""" | ||
return biobert_lightning_module(self.config, tokenizer=self.tokenizer) | ||
|
||
def apply(self, output_path: Path) -> Path: | ||
"""Applies the transformation.""" | ||
source = AutoModel.from_pretrained(str(self), trust_remote_code=True, torch_dtype="auto") | ||
target = self.init() | ||
trainer = self.nemo_setup(target) | ||
self.convert_state(source, target) | ||
self.nemo_save(output_path, trainer) | ||
teardown(trainer, target) | ||
return output_path | ||
|
||
def convert_state(self, source, target): | ||
"""Converting HF state dict to NeMo state dict.""" | ||
mapping = { | ||
"encoder.weight": "embedding.word_embeddings.weight", | ||
"transformer_encoder.*.wo.weight": "encoder.layers.*.self_attention.linear_proj.weight", | ||
"transformer_encoder.*.ffn.w12.weight": "encoder.layers.*.mlp.linear_fc1.weight", | ||
"transformer_encoder.*.ffn.w3.weight": "encoder.layers.*.mlp.linear_fc2.weight", | ||
"transformer_encoder.*.attention_norm.weight": "encoder.layers.*.self_attention.linear_qkv.layer_norm_weight", | ||
"transformer_encoder.*.ffn_norm.weight": "encoder.layers.*.mlp.linear_fc1.layer_norm_weight", | ||
"layer_norm_2.weight": "encoder.final_layernorm.weight", | ||
"decoder.weight": "output_layer.weight", | ||
"decoder.bias": "output_layer.bias", | ||
} | ||
|
||
# lm_head.bias | ||
return io.apply_transforms( | ||
source, | ||
target, | ||
mapping=mapping, | ||
transforms=[_import_qkv_weight], | ||
# transforms=[_pad_embeddings, _pad_bias, _import_qkv_weight], | ||
) | ||
|
||
@property | ||
def tokenizer(self) -> BioNeMoAMPLIFYTokenizer: | ||
"""We just have the one tokenizer for ESM-2.""" | ||
return BioNeMoAMPLIFYTokenizer() | ||
|
||
@property | ||
def config(self) -> AMPLIFYConfig: | ||
"""Returns the transformed ESM-2 config given the model tag.""" | ||
source = HFAutoConfig.from_pretrained(str(self), trust_remote_code=True) | ||
output = AMPLIFYConfig( | ||
num_layers=source.num_hidden_layers, | ||
hidden_size=source.hidden_size, | ||
ffn_hidden_size=source.intermediate_size, | ||
position_embedding_type="rope", | ||
num_attention_heads=source.num_attention_heads, | ||
seq_length=source.max_length, | ||
fp16=(dtype_from_hf(source) == torch.float16), | ||
bf16=(dtype_from_hf(source) == torch.bfloat16), | ||
params_dtype=dtype_from_hf(source), | ||
) | ||
|
||
return output | ||
|
||
|
||
@io.state_transform( | ||
source_key="esm.embeddings.word_embeddings.weight", | ||
target_key="embedding.word_embeddings.weight", | ||
) | ||
def _pad_embeddings(ctx: io.TransformCTX, source_embed): | ||
"""Pad the embedding layer to the new input dimension.""" | ||
nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by | ||
hf_embedding_dimension = source_embed.size(0) | ||
num_padding_rows = nemo_embedding_dimension - hf_embedding_dimension | ||
padding_rows = torch.zeros(num_padding_rows, source_embed.size(1)) | ||
return torch.cat((source_embed, padding_rows), dim=0) | ||
|
||
|
||
@io.state_transform( | ||
source_key="lm_head.bias", | ||
target_key="output_layer.bias", | ||
) | ||
def _pad_bias(ctx: io.TransformCTX, source_bias): | ||
"""Pad the embedding layer to the new input dimension.""" | ||
nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by | ||
hf_embedding_dimension = source_bias.size(0) | ||
output_bias = torch.zeros(nemo_embedding_dimension, dtype=source_bias.dtype, device=source_bias.device) | ||
output_bias[:hf_embedding_dimension] = source_bias | ||
return output_bias | ||
|
||
|
||
@io.state_transform( | ||
source_key=( | ||
"transformer_encoder.*.q.weight", | ||
"transformer_encoder.*.k.weight", | ||
"transformer_encoder.*.v.weight", | ||
), | ||
target_key="encoder.layers.*.self_attention.linear_qkv.weight", | ||
) | ||
def _import_qkv_weight(ctx: io.TransformCTX, query, key, value): | ||
"""Pad the embedding layer to the new input dimension.""" | ||
concat_weights = torch.cat((query, key, value), dim=0) | ||
input_shape = concat_weights.size() | ||
np = ctx.target.config.num_attention_heads | ||
# transpose weights | ||
# [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads] | ||
# --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads] | ||
concat_weights = concat_weights.view(3, np, -1, query.size()[-1]) | ||
concat_weights = concat_weights.transpose(0, 1).contiguous() | ||
concat_weights = concat_weights.view(*input_shape) | ||
return concat_weights | ||
|
||
|
||
@io.state_transform( | ||
source_key=( | ||
"esm.encoder.layer.*.attention.self.query.bias", | ||
"esm.encoder.layer.*.attention.self.key.bias", | ||
"esm.encoder.layer.*.attention.self.value.bias", | ||
), | ||
target_key="encoder.layers.*.self_attention.linear_qkv.bias", | ||
) | ||
def _import_qkv_bias(ctx: io.TransformCTX, query, key, value): | ||
"""Pad the embedding layer to the new input dimension.""" | ||
concat_biases = torch.cat((query, key, value), dim=0) | ||
input_shape = concat_biases.size() | ||
np = ctx.target.config.num_attention_heads | ||
# transpose biases | ||
# [num_splits_model_parallel * attention head size * #attention heads] | ||
# --> [attention head size * num_splits_model_parallel * #attention heads] | ||
concat_biases = concat_biases.view(3, np, -1) | ||
concat_biases = concat_biases.transpose(0, 1).contiguous() | ||
concat_biases = concat_biases.view(*input_shape) | ||
return concat_biases |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
179 changes: 179 additions & 0 deletions
179
sub-packages/bionemo-amplify/tests/bionemo/amplify/test_convert.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: LicenseRef-Apache2 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import gc | ||
from pathlib import Path | ||
|
||
import torch | ||
from megatron.core.transformer.module import Float16Module | ||
from nemo.lightning import io | ||
from transformers import AutoModel | ||
|
||
from bionemo.amplify.convert import HFAMPLIFYImporter # noqa: F401 | ||
from bionemo.amplify.model import AMPLIFYConfig | ||
from bionemo.amplify.tokenizer import BioNeMoAMPLIFYTokenizer | ||
from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype | ||
from bionemo.esm2.testing.compare import assert_cosine_similarity, get_input_tensors | ||
from bionemo.llm.model.biobert.lightning import biobert_lightning_module | ||
from bionemo.testing import megatron_parallel_state_utils | ||
|
||
|
||
def assert_amplify_equivalence( | ||
ckpt_path: str, | ||
model_tag: str, | ||
precision: PrecisionTypes = "fp32", | ||
rtol: float | None = None, | ||
atol: float | None = None, | ||
) -> None: | ||
tokenizer = BioNeMoAMPLIFYTokenizer() | ||
|
||
input_ids, attention_mask = get_input_tensors(tokenizer) | ||
hf_logits, hf_hidden_state, hf_attn_inputs, hf_attn_outputs = load_and_evaluate_hf_amplify( | ||
model_tag, precision, input_ids, attention_mask | ||
) | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
nemo_logits, nemo_hidden_state, nemo_attn_inputs, nemo_attn_outputs = load_and_evaluate_nemo_amplify( | ||
tokenizer, | ||
ckpt_path, | ||
precision, | ||
input_ids, | ||
attention_mask, | ||
) | ||
|
||
# Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These | ||
# should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences. | ||
# We don't care about the padding tokens, so we only compare the non-padding tokens. | ||
assert_cosine_similarity(nemo_attn_inputs[0].transpose(0, 1), hf_attn_inputs[0], attention_mask, msg="Attn inputs") | ||
assert_cosine_similarity( | ||
nemo_attn_outputs[0].transpose(0, 1), hf_attn_outputs[0], attention_mask, msg="Attn inputs" | ||
) | ||
|
||
assert_cosine_similarity(nemo_hidden_state, hf_hidden_state, attention_mask, rtol, atol) | ||
assert_cosine_similarity(nemo_logits, hf_logits, attention_mask, rtol, atol) | ||
|
||
|
||
def load_and_evaluate_hf_amplify( | ||
model_tag: str, precision: PrecisionTypes, input_ids: torch.Tensor, attention_mask: torch.Tensor | ||
) -> tuple[torch.Tensor, ...]: | ||
"""Load a HuggingFace model and evaluate it on the given inputs. | ||
Args: | ||
model_tag: The HuggingFace model tag for the model to compare against. | ||
precision: The precision type to use for the comparison. | ||
input_ids: The input IDs tensor to evaluate. | ||
attention_mask: The attention mask tensor to evaluate. | ||
Returns: | ||
A tuple of the logits and hidden states tensors calculated by the HuggingFace model, respectively. | ||
""" | ||
hf_model = AutoModel.from_pretrained( | ||
model_tag, | ||
torch_dtype=get_autocast_dtype(precision), | ||
trust_remote_code=True, | ||
) | ||
|
||
def hook_fn(module, inputs, outputs): | ||
hook_fn.inputs = inputs | ||
hook_fn.outputs = outputs | ||
|
||
hook_fn.inputs = None | ||
hook_fn.outputs = None | ||
|
||
hf_model.transformer_encoder[0].register_forward_hook(hook_fn) | ||
# hf_model.transformer_encoder[0].ffn.register_forward_hook(hook_fn) | ||
|
||
hf_model = hf_model.to("cuda").eval() | ||
hf_output_all = hf_model(input_ids, attention_mask.float(), output_hidden_states=True) | ||
hf_hidden_state = hf_output_all.hidden_states[-1] | ||
return hf_output_all.logits, hf_hidden_state, hook_fn.inputs, hook_fn.outputs | ||
|
||
|
||
def load_and_evaluate_nemo_amplify( | ||
tokenizer: BioNeMoAMPLIFYTokenizer, | ||
ckpt_path: Path | str, | ||
precision: PrecisionTypes, | ||
input_ids: torch.Tensor, | ||
attention_mask: torch.Tensor, | ||
) -> tuple[torch.Tensor, ...]: | ||
"""Load a AMPLIFY NeMo2 model checkpoint and evaluate it on the input tensors. | ||
It would be great to make this more ergonomic, i.e., how to create a model from a checkpoint and evaluate it. | ||
Args: | ||
tokenizer: Not sure why we need to pass a tokenizer to `configure_model`. | ||
ckpt_path: Path to the newly created NeMo2 converted checkpoint. | ||
precision: Precision type to use for the model. | ||
input_ids: Input tokens | ||
attention_mask: Input attention mask | ||
Returns: | ||
The logits and hidden states from the model. | ||
""" | ||
|
||
dtype = get_autocast_dtype(precision) | ||
nemo_config = AMPLIFYConfig( | ||
initial_ckpt_path=str(ckpt_path), | ||
include_embeddings=True, | ||
include_hiddens=True, | ||
params_dtype=dtype, | ||
pipeline_dtype=dtype, | ||
autocast_dtype=dtype, | ||
bf16=dtype is torch.bfloat16, | ||
fp16=dtype is torch.float16, | ||
) | ||
|
||
nemo_model = nemo_config.configure_model(tokenizer).to("cuda").eval() | ||
|
||
if dtype is torch.float16 or dtype is torch.bfloat16: | ||
nemo_model = Float16Module(nemo_config, nemo_model) | ||
|
||
def hook_fn(module, inputs, outputs): | ||
hook_fn.inputs = inputs | ||
hook_fn.outputs = outputs | ||
|
||
hook_fn.inputs = None | ||
hook_fn.outputs = None | ||
|
||
nemo_model.encoder.layers[0].self_attention.register_forward_hook(hook_fn) | ||
# nemo_model.encoder.layers[0].mlp.register_forward_hook(hook_fn) | ||
|
||
nemo_output = nemo_model(input_ids, attention_mask) | ||
nemo_logits = nemo_output["token_logits"].transpose(0, 1).contiguous()[..., : tokenizer.vocab_size] | ||
nemo_hidden_state = nemo_output["hidden_states"] | ||
return nemo_logits, nemo_hidden_state, hook_fn.inputs, hook_fn.outputs | ||
|
||
|
||
def test_convert_amplify_120M_smoke(tmp_path): | ||
model_tag = "chandar-lab/AMPLIFY_120M" | ||
module = biobert_lightning_module(config=AMPLIFYConfig()) | ||
io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") | ||
|
||
|
||
def test_convert_amplify_120M(tmp_path): | ||
model_tag = "chandar-lab/AMPLIFY_120M" | ||
module = biobert_lightning_module(config=AMPLIFYConfig()) | ||
io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") | ||
with megatron_parallel_state_utils.distributed_model_parallel_state(): | ||
assert_amplify_equivalence(tmp_path / "nemo_checkpoint", model_tag) | ||
|
||
|
||
def test_convert_amplify_350M(tmp_path): | ||
model_tag = "chandar-lab/AMPLIFY_350M" | ||
module = biobert_lightning_module(config=AMPLIFYConfig()) | ||
io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") | ||
with megatron_parallel_state_utils.distributed_model_parallel_state(): | ||
assert_amplify_equivalence(tmp_path / "nemo_checkpoint", model_tag) |
Oops, something went wrong.