-
Notifications
You must be signed in to change notification settings - Fork 4
Towards an Automatic GPU Batching Interface (AKA GPU go brrr) #127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
THargreaves
wants to merge
45
commits into
main
Choose a base branch
from
th/batching
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Contributor
|
SSMProblems.jl/SSMProblems documentation for PR #127 is available at: |
Removes parametric batched types and custom Cholesky/PDMat batched types and replaces them with a generic system based on StructArrays.
6 tasks
This PR brings the discrete SSM up to speed with the linear Gaussian SSM, implementing: Forward-backward smoothing Two-filter smoothing RB ancestor weight Unit tests for conditionally-discrete SSMs with RBPF
Removes parametric batched types and custom Cholesky/PDMat batched types and replaces them with a generic system based on StructArrays.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Update 2025-01-07
See original PR below.
The latest commits make a massive improvement to the interface to improve the way that linear algebra wrappers (e.g. Adjoint, LowerTriangular, PDMat) work with batching.
Previous Approach
The previous approach let an
BatcheCuMatrixrepresent any element type using its type parameter.This was really problematic, and let to bugs where the batched version of
A + B'was reading the raw data ofBand silently returningA + Binstead.It also required hardcode batched types for Cholesky and PDMat which was unsustainable.
New Approach
We now only allow
BatchedCuMatrixto represent a vanilla, unwrapped CuMatrix (now with its memory layout specified so the type is concrete).Wrapping is now entirely handled by
StructArrays. This works for both one argument types (Adjoint), one argument with additional type parametersLowerTriangular(T, AbstractMatrix{T})(AbstractMatrix{T}) and multi argument + additional types, e.g.MvNormal`.Broadcasting is all performed automatically by asking the Julia compiler what type we'd get if we ran the constructor with the eltypes of our batched inputs. As far as I can tell, this is a pretty robust approach.
A demo of this wrapping behaviour can be found in `research/batching/wrapping_demo.jl
Current Limitations
PDMat.dimis not a stored field, soP.dimfails on batched PDMat. We just need the broadcast to check if the field is "real"trmm,syrkand some SharedCuMatrix versionsAdjoint(LowerTriangular(...))doesn't dispatch to efficient BLAS yetStructArrayfor broadcasting. This can be fixed by defining our own version ofStructArrayThese are all easy fixes.
Original PR Description
Overview
This PR introduces infrastructure for automatically generating batched GPU versions of linear algebraic functions, similar to JAX's VMAP. This allows us to take our existing CPU implementations of filters and batch them on the GPU with no duplicate code or rewriting.
The interface supports batched and single arrays which can be contained in arbitrary structs using StructArrays, and most importantly supports nested functions, even if those function type are restricted to
AbstractMatrixs. This was something that we were unable to achieve with previous iterations.Currently, we only wrap MAGMA functions but it should be possible to combine this interface with the fused kernel generation work.
Not only will this be useful for us internally, but it will be important for users defining GPU accelerated particle filters. Further, I imagine a consistent and powerful batched GPU interface would be greatly valued by the wider Julia community.
This demo is still pretty shaky and I imagine things will break quickly if modify the demo too much. I'm hoping to talk to Joseph to get some advice on tightening it up.
Core Design
The batching system is built around two abstractions for GPU arrays:
BatchedCuMatrix,BatchedCuVector): Store N independent matrices/vectors as a single contiguous 3D/2DCuArray. Each batch element can have different values.SharedCuMatrix,SharedCuVector): Wrap a single GPU matrix/vector that is reused across all batch elements. This avoids memory duplication when parameters (like transition matrices) are constant across the batch.Automatic Broadcasting via IR Transformation
The key concept of this interface, is that whenever you try to broadcast over a batched CUDA array, this should be intercepted (to avoid slow looping over batch elements) and the function should instead be called over the entire batch. This requires the inner function calls to be converted to broadcasts recursively. This continues until we hit a broadcasted linear algebra call which we know how to handle (by passing to an appropriate CUDA wrapper). Roughly this looks like:
Importantly, this means that that you only need to define broadcasted versions of primitive linear algebra operations—the rest happens automatically. This is similar to burden that StaticArrays puts on the user/developer and is much more scalable than having to define custom GPU versions for every high-level function.
What you define manually (in
operations.jl):broadcasted(::typeof(*), A::BatchedCuMatrix, B::BatchedCuMatrix)→ calls MAGMA's batched GEMMbroadcasted(::typeof(*), A::BatchedCuMatrix, x::BatchedCuVector)→ calls MAGMA's batched GEMVbroadcasted(::typeof(+), A::BatchedCuMatrix, B::BatchedCuMatrix)→ elementwise additionbroadcasted(::Type{PDMat}, A::BatchedCuMatrix)→ batched Cholesky factorizationbroadcasted(::typeof(/), A::BatchedCuMatrix, S::BatchedPDMat)→ batched triangular solvesWhat happens automatically (in
broadcasting.jl):When you broadcast a function like
kalman_predict.(states, Ref(params)), the system uses IRTools to inspect the function's intermediate representation and transforms every operation into abroadcastedcall. This means:Inner function calls propagate: If
kalman_stepcallskalman_predictwhich calls matrix multiply, each level gets transformed. You don't need to manually define a batched version of every function—only the leaf operations.Struct construction works automatically: When the transformed code hits a constructor like
MvNormal(μ̂, Σ̂), it detects that the arguments are batched arrays and wraps them in aStructArray{MvNormal}instead. This means the return type ofkalman_predict.(...)is aStructArraywhere.μreturns aBatchedCuVectorand.Σreturns aBatchedPDMat.Tuple/getfield pass through: Operations like
params[1]orstate.μare recognized as structural and passed through without transformation, correctly extracting components fromRef-wrapped parameter tuples orStructArraystates.The result is that you write completely standard scalar code:
And broadcast it over GPU arrays with no modification:
The IR transformation walks through the function, sees
*and+andX_A_Xt, and rewrites each to call the correspondingbroadcastedmethod that dispatches to MAGMA.MAGMA Backend
The actual linear algebra is performed via MAGMA's batched routines, which are much more optimised than those from CUBLAS (though do not support fusion). Key operations include batched GEMM, GEMV, Cholesky factorization (POTRF), triangular solves (TRSM), and symmetric rank-k updates (SYRK). The docs are here.
Example Usage
See
research/batching/batching_demo.jlfor a complete example running Kalman filter steps on 1000 parallel 64-dimensional Gaussians. The demo shows:Important
You'll need to use my branch of MAGMA which has been updated for recent CUDA.jl versions
File Structure
types.jlCore type definitions:
BatchedCuMatrix,BatchedCuVector: Arrays where each batch element has different dataSharedCuMatrix,SharedCuVector: Arrays where data is shared across all batch elementsBatchedCholesky,BatchedPDMat: Batched versions of factorization typesbatch_size,inner_size,is_shared,unwrap_data,trans_flagwrappers.jlLow-level MAGMA bindings:
gemm_batched!: Batched matrix-matrix multiplygemv_batched!: Batched matrix-vector multiplypotrf_batched!: Batched Cholesky factorizationpotrs_batched!: Batched Cholesky solvetrsm_batched!: Batched triangular solvetrmm_batched!: Batched triangular matrix multiplysyrk_batched!: Batched symmetric rank-k update*_smallsq) for dimensions < 32operations.jlHigh-level broadcasted operations:
*,+,-PDMatconstruction via batched CholeskyPDMat:\,/X_A_Xtoptimized with TRMM+SYRK forBatchedPDMatI - K*Hbroadcasting.jlIR transformation machinery:
BatchedStyle: Custom broadcast style for dispatchingtransform_to_batched: IR pass that rewrites operations tobroadcastedcallswrap_if_batched: AutomaticStructArraywrapping for struct constructorsBATCHED_FUNC_CACHE: Memoization of generated batched functionsBroadcast.materialize: Entry point that triggers IR transformationRough Edges
The automatic construction ofStructArrays is a bit clunky and often picks a type that is too abstractI'm not very happy with the way that adjoints and other wrapper types are currently handled