Skip to content

Run pretraining of llama 7b/8b models on GPUs using a3/1vm.sh #2193

@giacomoni

Description

@giacomoni

Bug report

It looks like there is no valid workflow to build a docker image for GPUs to run the MaxText/configs/a3/llama2_7b/1vm.sh script.

I have been relying on the docker_build_dependency_image.sh and tried different combinations of input arguments:

  1. DEVICE=gpu
  2. MODE=pinned DEVICE=gpu
  3. DEVICE=gpu MODE=stable_stack BASEIMAGE=us-central1-docker.pkg.dev/deeplearning-images/jax-ai-image/gpu:jax0.5.1-cuda_dl25.02-rev1

Each of the above fails at some point, wether during the building process of the docker image or runtime.

Case 1) fails when trying to run the script, since the attention mask is of type float but transformer-engine's dot product attention expects a bool mask.

Case 2) fails when building the image, because pip can't resolve the requirements and constraints at the same time.

In Case 3) I get a RESOURCE_EXHAUSTED jax compilation error.

I also fixed the issues in 1) (patched the code) and I eventually got the script to run, until JAX compilation broke due to RESOURCE_EXHAUSTED.

I am running the script on 1 node with 8 H100, so the current configs for llama2 7b and llama3 8b should run.

I have also tried all the above not from master but the latest available tag (i.e. tpu-recipes-v0.1.4).

Any suggestion on how to get a working and stable Docker image for GPU execution?

Any suggestion

Logs/Output

No response

Environment Information

No response

Additional Context

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions