Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EnsembleFunction and friends. #4025

Draft
wants to merge 14 commits into
base: master
Choose a base branch
from

Conversation

JHopeCollins
Copy link
Member

@JHopeCollins JHopeCollins commented Feb 10, 2025

Description

This PR introduces this abstraction to Firedrake as an EnsembleFunction representing the distributed mixed function, defined over a EnsembleFunctionSpace, and an EnsembleCofunction and EnsembleDualFunction representing the dual objects.

Why?

Methods that use the Ensemble will often involve something that looks like a mixed finite element function, where the subcomponents are distributed over the ensemble members. We need to treat the subcomponents as firedrake.Functions within each ensemble member, but also to have some collective semantics over the entire ensemble. See for example the AllAtOnceFunction in asQ which represents a timeseries over an ensemble.

Note that, whereas firedrake.Function does not allow nested mixed-ness, a particular component of EnsembleFunction can be mixed - this is essential for things like a timeseries of a mixed function, e.g. shallow water equations.

Demo

At the bottom is a small script using the new objects to solve the heat equation with backward euler, parallel in time, with a block Jacobi preconditioner. It shows setting some typical usage including

  1. setting up the ensemble objects
  2. accessing data via local Function subcomponents
  3. accessing data via the global petsc Vec
  4. setting up global petsc objects
  5. ensemble communication

EnsembleFunctionSpace

  • An EnsembleFunction is defined from an EnsembleFunctionSpace, which just collects an Ensemble with a list of firedrake.FunctionSpace for each component on the local ensemble member. These are accessible through EnsembleFunctionSpace.local_spaces.

  • The EnsembleFunctionSpace also has some handy utilities, e.g. a dual method, equality comparison, the number of dofs on the local rank, on the local ensemble member, and on the whole global function space.

  • Currently the global list of subcomponents is not available, i.e. you cannot programmatically check from rank i what FunctionSpace is in EnsembleFunctionSpace.local_spaces on rank j, but I'm open to arguments for including this.

EnsembleFunction

  • It's important to point out that this does not have any UFL symbolic information, i.e. it is data structure abstraction not a symbolic one. The subcomponents on the local ensemble member can be accessed as firedrake.Functions, with the usual UFL symbolic information, via EnsembleFunction.subfunctions.

  • Equivalents to a few firedrake.Function arithmetic operations are implemented, e.g. assign, zero, and addition/multiplication overloads.

  • A PETSc.Vec defined over the Ensemble.global_comm for the data in all subcomponents on all ranks is accessed via context managers vec, vec_ro, and vec_wo.

Adjoint

Currently no EnsembleFunction operations are taped, so it cannot be used for adjoint problems in the same way firedrake.Function can be.
However, it does has the bare minimum OverloadedType functionality implemented to be usable as a Control type, if you take responsibility for evaluating the tape (e.g. with EnsembleReducedFunctional, or the 4DVar reduced functional in another PR).
For example:

  • _ad_dot - collective over the ensemble so it will DTRT with pyadjoint.taylor_test and we should be able to remove the xfail markers here.
  • _ad_to_list and _ad_assign_numpy so it will work with the pyadjoint interface to scipy optimize that relies on ReducedFunctionalNumPy.
  • _ad_{to,from}_petsc so it can be used with TAO.

I'm planning on integrating EnsembleFunction with EnsembleReducedFunctional in a subsequent PR.

Implementation notes.

To build PETSc solvers over an Ensemble, what we need is a mechanism for glueing two types of operation over the same data together.

  1. Providing Firedrake with firedrake.Function objects on each ensemble member to evaluate local finite element operations.
  2. Providing PETSc with Vec objects on the global comm which represent the global solution, to pass to KSP, SNES, etc.

This is shown in the diagram below for an EnsembleFunction with 8 subcomponents, distributed equally over four ensemble members.

  • Internally all subcomponents on each ensemble member are represented as a (flattened) firedrake.MixedFunctionSpace (here).
  • The EnsembleFunction creates a firedrake.Function (called _full_local_function) from this mixed function space to store all local subcomponents. These are the dark blue boxes in the diagram, each with two subcomponents in light blue.
    (The API is independent of the choice to use MixedFunctionSpace internally).
  • The EnsembleFunction.subfunctions is a tuple of firedrake.Functions that view the relevant components of _full_local_function.subfunctions (see here). These are the light blue boxes on the right of the diagram, and will be the main way users interact with the data.
    For any elements of EnsembleFunction.subfunctions that are themselves mixed, we need to construct a new MixedDat to view multiple components of _full_local_function.subfunctions.
  • The Vec accessed via the EnsembleFunction.vec context managers is created over the Ensemble.global_comm, and is shown in dark green on the left of the diagram. This Vec is created as a view over the data in the internal mixed firedrake.Function by passing _full_local_function.dat.vec.array to Vec.createWithArray (here) (shown as the light green boxes).
    This means that for the data in EnsembleFunction.vec to be valid (the large dark green box), we need the data in _full_local_function.dat.vec to be valid (the smaller light green box attached to the MixedFunction). To ensure this data is valid, we nest the _full_local_function.dat.vec context managers inside the EnsembleFunction.vec context managers here.
    One thing to be careful of, is that if the data in _full_local_function.dat.vec has been modified, (i.e. if the user has done anything with EnsembleFunction.subfunctions!) then the data in the global EnsembleFunction.vec will have changed, but without EnsembleFunction.vec knowing. This means we have to manually increase the state counter for EnsembleFunction.vec and EnsembleFunction.vec_ro.

asQ_datalayout_updated_fitted

import firedrake as fd
from firedrake.petsc import PETSc
from numpy.random import seed, random_sample
seed(6)

# Solve the all-at-once system for the heat equation with backwards Euler.
# If A1 = M + K and A0 = -M, then the lower-triangular and bidiagonal
# all-at-once Jacobian is:
# A1  0  0  0
# A0 A1  0  0
#  0 A0 A1  0
#  0  0 A0 A1
# The preconditioner is block diagonal using the A1 blocks

ensemble = fd.Ensemble(fd.COMM_WORLD, 1)
ensemble_rank = ensemble.ensemble_rank
ensemble_size = ensemble.ensemble_size

nx = 16
mesh = fd.UnitIntervalMesh(nx, comm=ensemble.comm)

dx = 1/nx
cfl = 2
dt = cfl*dx**2

V = fd.FunctionSpace(mesh, "CG", 1)

u = fd.TrialFunction(V)
v = fd.TestFunction(V)

# mass and stiffness matrices for heat equation
M = u*v*fd.dx
K = dt*fd.inner(fd.grad(u), fd.grad(v))*fd.dx

# backward euler timestepping
A1 = M + K
A0 = -M

# how many timesteps on each ensemble rank
time_partition = [1, 2, 1, 5, 3]

# function space for the time series
local_spaces = [V for _ in range(time_partition[ensemble_rank])]
W = fd.EnsembleFunctionSpace(local_spaces, ensemble)


# python mat context for the Jacobian
class AllAtOnceMat:
    def __init__(self):
        self.x = fd.EnsembleFunction(W)
        self.y = fd.EnsembleCofunction(W.dual())

        self.xhalo = fd.Function(V)
        self.yhalo = fd.Cofunction(V.dual())

    def mult(self, A, x, y):
        # copy input into ensemble function
        with self.x.vec_wo() as xvec:
            x.copy(xvec)

        # receive the time-halos from the previous ensemble rank
        self.update_halos()
        self.y.zero()
        self.yhalo.zero()

        # contribution of diagonal blocks A1
        for xsub, ysub in zip(self.x.subfunctions,
                              self.y.subfunctions):
            fd.assemble(fd.action(A1, xsub),
                        tensor=ysub)

        # contribution of local sub-diagonal blocks A0
        for xsub, ysub in zip(self.x.subfunctions[:-1],
                              self.y.subfunctions[1:]):
            fd.assemble(fd.action(A0, xsub),
                        tensor=self.yhalo)
            ysub += self.yhalo

        # contribution of halo sub-diagonal blocks A0
        if ensemble_rank != 0:
            fd.assemble(fd.action(A0, self.xhalo),
                        tensor=self.yhalo)
            ysub = self.y.subfunctions[0]
            ysub += self.yhalo

        # copy result back out
        with self.y.vec_ro() as yvec:
            yvec.copy(y)

    def update_halos(self):
        # halo swap is a right shift
        src = (ensemble_rank - 1) % ensemble_size
        dst = (ensemble_rank + 1) % ensemble_size
        frecv = self.xhalo
        fsend = self.x.subfunctions[-1]

        ensemble.sendrecv(fsend=fsend, dest=dst, sendtag=dst,
                          frecv=frecv, source=src, recvtag=ensemble_rank)


# python pc context for a block-diagonal preconditioner with A1
class BJacobiPC:
    def __init__(self):
        self.x = fd.EnsembleCofunction(W.dual())
        self.y = fd.EnsembleFunction(W)

        self.blocks = []
        for i in range(W.nlocal_spaces):
            global_step = i + sum(time_partition[:ensemble_rank])

            # use the subcomponents directly here, avoids additional shuffling in apply
            problem = fd.LinearVariationalProblem(
                A1, self.x.subfunctions[i],
                self.y.subfunctions[i])

            solver = fd.LinearVariationalSolver(
                problem, options_prefix=f"step_{global_step}")

            self.blocks.append(solver)

    def apply(self, pc, x, y):
        with self.x.vec_wo() as xvec:
            x.copy(xvec)

        for block in self.blocks:
            block.solve()

        with self.y.vec_ro() as yvec:
            yvec.copy(y)


# Set up the PETSc objects
sizes = (W.nlocal_rank_dofs, W.nglobal_dofs)
mat = PETSc.Mat().createPython(
    (sizes, sizes), AllAtOnceMat(),
    comm=ensemble.global_comm)
mat.setUp()

ksp = PETSc.KSP().create(comm=ensemble.global_comm)
ksp.setOptionsPrefix("")
ksp.setType("richardson")
ksp.setOperators(mat)
ksp.setFromOptions()
ksp.setUp()

ksp.pc.setType("python")
ksp.pc.setPythonContext(BJacobiPC())

# the right hand side an solution functions
u = fd.EnsembleFunction(W)
b = fd.EnsembleCofunction(W.dual())

with u.vec_ro() as sol, b.vec_wo() as rhs:
    rhs.array[:] = ensemble_rank*random_sample(rhs.array.shape)
    ksp.solve(rhs, sol)

# now the timesteps are accessible via subfunctions
if ensemble_rank == ensemble_size - 1:
    final_timestep = u.subfunctions[-1]

Copy link

github-actions bot commented Feb 10, 2025

TestsPassed ✅Skipped ⏭️Failed ❌
Firedrake complex8110 ran6567 passed1543 skipped0 failed

Copy link

github-actions bot commented Feb 10, 2025

TestsPassed ✅Skipped ⏭️Failed ❌
Firedrake real8040 ran7331 passed709 skipped0 failed

Copy link
Contributor

@connorjward connorjward left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for all the nitpicks but this is a big addition to the API so we should try and follow what we claim are our best practices. In particular I think the docstrings need tweaking as per https://github.com/firedrakeproject/firedrake/wiki/Docstrings

@@ -0,0 +1,242 @@
from firedrake.petsc import PETSc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing __all__ (though I dislike the design pattern). You can avoid this by not using a wildcard import in ensemble/__init__.py.

from .checkpointing import disk_checkpointing

from functools import wraps

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing __all__

firedrake/adjoint_utils/ensemble_function.py Outdated Show resolved Hide resolved
firedrake/adjoint_utils/ensemble_function.py Outdated Show resolved Hide resolved
firedrake/adjoint_utils/ensemble_function.py Outdated Show resolved Hide resolved
tests/firedrake/ensemble/test_ensemble_function.py Outdated Show resolved Hide resolved
tests/firedrake/ensemble/test_ensemble_function.py Outdated Show resolved Hide resolved
tests/firedrake/ensemble/test_ensemble_function.py Outdated Show resolved Hide resolved
tests/firedrake/ensemble/test_ensemble_function.py Outdated Show resolved Hide resolved
tests/firedrake/ensemble/test_ensemble_functionspace.py Outdated Show resolved Hide resolved
@JHopeCollins JHopeCollins self-assigned this Feb 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants