Skip to content

google/grain

Folders and files

NameName
Last commit message
Last commit date
Apr 8, 2025
Apr 23, 2025
Apr 25, 2025
Dec 10, 2024
Dec 10, 2024
Dec 10, 2024
Feb 27, 2025
Aug 5, 2022
Aug 5, 2022
Apr 22, 2025
Apr 22, 2025
Apr 22, 2025
Feb 5, 2024
Mar 18, 2025
Mar 18, 2025
Mar 18, 2025
Mar 18, 2025

Grain - Feeding JAX Models

Continuous integration PyPI version

Installation | Quickstart | Reference docs

Grain is a Python library for reading and processing data for training and evaluating JAX models. It is flexible, fast and deterministic.

Grain allows to define data processing steps in a simple declarative way:

import grain

dataset = (
    grain.MapDataset.source([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    .shuffle(seed=42)  # Shuffles elements globally.
    .map(lambda x: x+1)  # Maps each element.
    .batch(batch_size=2)  # Batches consecutive elements.
)

for batch in dataset:
  # Training step.

Grain is designed to work with JAX models but it does not require JAX to run and can be used with other frameworks as well.

Installation

Grain is available on PyPI and can be installed with pip install grain.

Supported platforms

Grain does not directly use GPU or TPU in its transformations, the processing within Grain will be done on the CPU by default.

Linux Mac Windows
x86_64 yes no no
aarch64 yes yes n/a

Quickstart

Existing users

Grain is used by MaxText, kauldron and multiple internal Google projects.