Skip to content

Conversation

@THargreaves
Copy link
Collaborator

@THargreaves THargreaves commented Jan 6, 2026

Update 2025-01-07

See original PR below.

The latest commits make a massive improvement to the interface to improve the way that linear algebra wrappers (e.g. Adjoint, LowerTriangular, PDMat) work with batching.

Previous Approach

The previous approach let an BatcheCuMatrix represent any element type using its type parameter.

struct BatchedCuMatrix{T, Inner} <: AbstractVector{Inner}
    data::CuArray{T,3}
end

# Had to specify inner type explicitly
BatchedCuMatrix{Float32, Adjoint{Float32, CuMatrix{Float32}}}(data)

This was really problematic, and let to bugs where the batched version of A + B' was reading the raw data of B and silently returning A + B instead.

It also required hardcode batched types for Cholesky and PDMat which was unsustainable.

New Approach

We now only allow BatchedCuMatrix to represent a vanilla, unwrapped CuMatrix (now with its memory layout specified so the type is concrete).

struct BatchedCuMatrix{T,M} <: AbstractVector{CuArray{T,2,M}}
    data::CuArray{T,3,M}
end

struct BatchedCuVector{T,M} <: AbstractVector{CuArray{T,1,M}}
    data::CuArray{T,2,M}
end

Wrapping is now entirely handled by StructArrays. This works for both one argument types (Adjoint), one argument with additional type parameters LowerTriangular(T, AbstractMatrix{T})(AbstractMatrix{T}) and multi argument + additional types, e.g. MvNormal`.

Broadcasting is all performed automatically by asking the Julia compiler what type we'd get if we ran the constructor with the eltypes of our batched inputs. As far as I can tell, this is a pretty robust approach.

function broadcasted(::Type{W}, args::Vararg{BatchedOrShared}) where {W}
    element_types = Tuple{map(eltype, args)...}
    ElType = Core.Compiler.return_type(W, element_types)
    field_names = fieldnames(ElType)
    nt = NamedTuple{field_names}(args)
    return StructArray{ElType}(nt)
end

A demo of this wrapping behaviour can be found in `research/batching/wrapping_demo.jl

Current Limitations

  1. Computed properties: PDMat.dim is not a stored field, so P.dim fails on batched PDMat. We just need the broadcast to check if the field is "real"
  2. Printing: StructArrays prints by materialising the batch elements. This doesn't always work, e.g. PDMat(CuArray) requires scalar indexing so errors. This can be handled by creating something similar to StructArrays but specialised for GPU batching.
  3. There are a few missing operations for simplicity such as trmm, syrk and some SharedCuMatrix versions
  4. Nested GEMM wrappers: Adjoint(LowerTriangular(...)) doesn't dispatch to efficient BLAS yet
  5. We're doing a bit of type piracy with StructArray for broadcasting. This can be fixed by defining our own version of StructArray

These are all easy fixes.


Original PR Description

Overview

This PR introduces infrastructure for automatically generating batched GPU versions of linear algebraic functions, similar to JAX's VMAP. This allows us to take our existing CPU implementations of filters and batch them on the GPU with no duplicate code or rewriting.

The interface supports batched and single arrays which can be contained in arbitrary structs using StructArrays, and most importantly supports nested functions, even if those function type are restricted to AbstractMatrixs. This was something that we were unable to achieve with previous iterations.

Currently, we only wrap MAGMA functions but it should be possible to combine this interface with the fused kernel generation work.

Not only will this be useful for us internally, but it will be important for users defining GPU accelerated particle filters. Further, I imagine a consistent and powerful batched GPU interface would be greatly valued by the wider Julia community.

This demo is still pretty shaky and I imagine things will break quickly if modify the demo too much. I'm hoping to talk to Joseph to get some advice on tightening it up.

Core Design

The batching system is built around two abstractions for GPU arrays:

  • Batched types (BatchedCuMatrix, BatchedCuVector): Store N independent matrices/vectors as a single contiguous 3D/2D CuArray. Each batch element can have different values.
  • Shared types (SharedCuMatrix, SharedCuVector): Wrap a single GPU matrix/vector that is reused across all batch elements. This avoids memory duplication when parameters (like transition matrices) are constant across the batch.

Automatic Broadcasting via IR Transformation

The key concept of this interface, is that whenever you try to broadcast over a batched CUDA array, this should be intercepted (to avoid slow looping over batch elements) and the function should instead be called over the entire batch. This requires the inner function calls to be converted to broadcasts recursively. This continues until we hit a broadcasted linear algebra call which we know how to handle (by passing to an appropriate CUDA wrapper). Roughly this looks like:

# You write standard scalar code:
function kalman_predict(state, params)
    μ̂ = params.A * state.μ + params.b
    Σ̂ = X_A_Xt(state.Σ, params.A) + params.Q
    return MvNormal(μ̂, Σ̂)
end

# When you call:
new_states = kalman_predict.(gpu_states, Ref(gpu_params))

# Julia's broadcast would normally try to iterate over gpu_states and call
# kalman_predict on each element. Instead, the IR transformation intercepts
# this and rewrites the function body to:
function kalman_predict_batched(state, params)
    μ̂ = broadcasted(*, params.A, state.μ)    # → batched GEMV
    μ̂ = broadcasted(+, μ̂, params.b)          # → elementwise add
    Σ̂ = broadcasted(X_A_Xt, state.Σ, params.A)  # → batched GEMM chain
    Σ̂ = broadcasted(+, Σ̂, params.Q)          # → elementwise add
    return wrap_as_structarray(MvNormal, μ̂, Σ̂)  # → StructArray{MvNormal}
end

# The broadcasted methods dispatch to MAGMA's batched BLAS, so the entire
# batch of 1000 Kalman predicts runs as a few GPU kernel calls rather than
# 1000 separate operations.

Importantly, this means that that you only need to define broadcasted versions of primitive linear algebra operations—the rest happens automatically. This is similar to burden that StaticArrays puts on the user/developer and is much more scalable than having to define custom GPU versions for every high-level function.

What you define manually (in operations.jl):

  • broadcasted(::typeof(*), A::BatchedCuMatrix, B::BatchedCuMatrix) → calls MAGMA's batched GEMM
  • broadcasted(::typeof(*), A::BatchedCuMatrix, x::BatchedCuVector) → calls MAGMA's batched GEMV
  • broadcasted(::typeof(+), A::BatchedCuMatrix, B::BatchedCuMatrix) → elementwise addition
  • broadcasted(::Type{PDMat}, A::BatchedCuMatrix) → batched Cholesky factorization
  • broadcasted(::typeof(/), A::BatchedCuMatrix, S::BatchedPDMat) → batched triangular solves

What happens automatically (in broadcasting.jl):

When you broadcast a function like kalman_predict.(states, Ref(params)), the system uses IRTools to inspect the function's intermediate representation and transforms every operation into a broadcasted call. This means:

  1. Inner function calls propagate: If kalman_step calls kalman_predict which calls matrix multiply, each level gets transformed. You don't need to manually define a batched version of every function—only the leaf operations.

  2. Struct construction works automatically: When the transformed code hits a constructor like MvNormal(μ̂, Σ̂), it detects that the arguments are batched arrays and wraps them in a StructArray{MvNormal} instead. This means the return type of kalman_predict.(...) is a StructArray where returns a BatchedCuVector and returns a BatchedPDMat.

  3. Tuple/getfield pass through: Operations like params[1] or state.μ are recognized as structural and passed through without transformation, correctly extracting components from Ref-wrapped parameter tuples or StructArray states.

The result is that you write completely standard scalar code:

function kalman_predict(state, params)
    μ̂ = params.A * state.μ + params.b
    Σ̂ = X_A_Xt(state.Σ, params.A) + params.Q
    return MvNormal(μ̂, Σ̂)
end

And broadcast it over GPU arrays with no modification:

new_states = kalman_predict.(gpu_states, Ref(gpu_params))

The IR transformation walks through the function, sees * and + and X_A_Xt, and rewrites each to call the corresponding broadcasted method that dispatches to MAGMA.

MAGMA Backend

The actual linear algebra is performed via MAGMA's batched routines, which are much more optimised than those from CUBLAS (though do not support fusion). Key operations include batched GEMM, GEMV, Cholesky factorization (POTRF), triangular solves (TRSM), and symmetric rank-k updates (SYRK). The docs are here.

Example Usage

See research/batching/batching_demo.jl for a complete example running Kalman filter steps on 1000 parallel 64-dimensional Gaussians. The demo shows:

  • Mixed shared/batched parameters (H and c shared, R batched)
  • Predict, update, and full step operations
  • Validation against CPU reference implementation

Important

You'll need to use my branch of MAGMA which has been updated for recent CUDA.jl versions

File Structure

types.jl

Core type definitions:

  • BatchedCuMatrix, BatchedCuVector: Arrays where each batch element has different data
  • SharedCuMatrix, SharedCuVector: Arrays where data is shared across all batch elements
  • BatchedCholesky, BatchedPDMat: Batched versions of factorization types
  • Helper functions: batch_size, inner_size, is_shared, unwrap_data, trans_flag
  • Pointer array creation for MAGMA calls

wrappers.jl

Low-level MAGMA bindings:

  • gemm_batched!: Batched matrix-matrix multiply
  • gemv_batched!: Batched matrix-vector multiply
  • potrf_batched!: Batched Cholesky factorization
  • potrs_batched!: Batched Cholesky solve
  • trsm_batched!: Batched triangular solve
  • trmm_batched!: Batched triangular matrix multiply
  • syrk_batched!: Batched symmetric rank-k update
  • Small-square optimized variants (*_smallsq) for dimensions < 32

operations.jl

High-level broadcasted operations:

  • Matrix/vector arithmetic: *, +, -
  • Adjoint and transpose handling
  • PDMat construction via batched Cholesky
  • Linear solves with PDMat: \, /
  • Quadratic forms: X_A_Xt optimized with TRMM+SYRK for BatchedPDMat
  • Identity minus matrix kernel for I - K*H

broadcasting.jl

IR transformation machinery:

  • BatchedStyle: Custom broadcast style for dispatching
  • transform_to_batched: IR pass that rewrites operations to broadcasted calls
  • wrap_if_batched: Automatic StructArray wrapping for struct constructors
  • BATCHED_FUNC_CACHE: Memoization of generated batched functions
  • Broadcast.materialize: Entry point that triggers IR transformation

Rough Edges

  • [FIXED] The automatic construction of StructArrays is a bit clunky and often picks a type that is too abstract
  • Error messages are often unclear. If a broadcast is missing, the code may try to broadcast deeper and deeper into the Julia code until you're batching things like getindex.
  • There are likely many edge cases of non-linear algebraic functions that don't work
  • [FIXED] I'm not very happy with the way that adjoints and other wrapper types are currently handled

@github-actions
Copy link
Contributor

github-actions bot commented Jan 6, 2026

SSMProblems.jl/SSMProblems documentation for PR #127 is available at:
https://TuringLang.github.io/SSMProblems.jl/SSMProblems/previews/PR127/

Removes parametric batched types and custom Cholesky/PDMat batched types and replaces them with a generic system based on StructArrays.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants