JAXScape is a minimal JAX library for connectivity analysis at scales. It provide key utilities to build your own connectivity analysis workflow, including
- differentiable raster to graph and graph to raster mappings
- differentiable graph distance metrics
- moving window utilities
JAXScape leverages JAX's capabilities to accelerate distance computations on CPUs/GPUs/TPUs, while ensuring differentiability of all implemented classes and methods for awesome sensitivity analysis and optimization.
pip install git+https://github.com/vboussange/jaxscape.git
Let's define our graph.
import jax.numpy as jnp
from jaxscape.gridgraph import GridGraph
import numpy as np
# loading jax array representing permeability
permeability = jnp.array(np.loadtxt("permeability.csv", delimiter=","))
# we discard pixels with permeability equal to 0
activities = permeability > 0
plt.imshow(permeability, cmap="gray")
plt.axis("off")
grid = GridGraph(activities=activities, vertex_weights=permeability)
Let's calculate some distances on the grid graph. We will specifically calculate and project the distance of all pixels to the top left pixel
from jaxscape.resistance_distance import ResistanceDistance
from jaxscape.lcp_distance import LCPDistance
from jaxscape.rsp_distance import RSPDistance
# Calculating distances of all pixels to top left pixel
source = grid.coord_to_active_vertex_index([0], [0])
distances = {
"LCP distance": LCPDistance(),
"Resistance distance": ResistanceDistance(),
"RSP distance": RSPDistance(theta=0.01, cost=lambda x: 1 / x)
}
fig, axs = plt.subplots(1, 3, figsize=(10, 4))
for ax, (title, distance) in zip(axs, distances.items()):
dist_to_node = distance(grid, source)
cbar = ax.imshow(grid.node_values_to_array(dist_to_node.ravel()), cmap="magma")
ax.axis("off")
ax.set_title(title)
fig.colorbar(cbar, ax=ax, shrink=0.2)
fig.suptitle("Distance to top left pixel")
plt.tight_layout()
plt.show()
But what's really cool about jaxscape is that you can autodiff through thoses distances! Here we calculate the gradient of the average path length of the graph w.r.t pixel permeability
# we need to provide the number of active vertices, for jit compilation
@eqx.filter_jit
def average_path_length(permeability, activities, nb_active, distance):
grid = GridGraph(activities=activities,
vertex_weights=permeability,
nb_active=nb_active)
dist = distance(grid)
return dist.sum() / nb_active**2
grad_connectivity = jax.grad(average_path_length)
nb_active = int(activities.sum())
distance = LCPDistance()
average_path_length(permeability, activities, nb_active, distance)
sensitivities = grad_connectivity(permeability, activities, nb_active, distance)
plt.figure()
cbar = plt.imshow(sensitivities, cmap = "magma")
plt.title("Gradient of APL w.r.t pixel's permeability")
plt.colorbar(cbar)
For a more advanced example with windowed sensitivity analysis and dispatch on multiple GPUs, see benchmark/moving_window_*.py
-
GridGraph
with differentiable adjacency matrix method
-
Least-cost path
- Bellman-Ford (one-to-all)
- Floyd-Warshall (all-to-all)
- Differentiable Djikstra or A* (see implementation here)
-
Resistance distance
- all-to-all calculation with dense solver (
pinv
, resulting in full distance matrix materialization) - [-] advanced mode with direct solvers (laplacian factorization, cannot scale to large landscape)
- Must rely on lineax, with wrapper over specialized solver for sparse systems:
- UMFPACK and CHOLMOD (see implementation here where scipy.spsolve is wrapped in JAX and vjp has been implemented - could also work with CHOLMOD) 🏃♀️
jax.experimental.sparse.linalg.spsolve
- Must rely on lineax, with wrapper over specialized solver for sparse systems:
- advanced mode with indirect solvers (no laplacian factorization, requires preconditioning)
- GMRES/CG with preconditioners for Krylov-based solvers
- See AlgebraicMultigrid.jl or PyAMG
- See also lineax issues
- all-to-all calculation with dense solver (
-
Randomized shortest path distance (REF)
- all-to-all calculation (distance matrix materialization)
- [-] all-to-few calculation
- Should be based on direct or inderict solvers, similarly to ResistanceDistance
- one-to-one calculation
- Moving window generator
- Differentiable connected component algorithm (see here)
- An external call to scipy/cusparse connected component libraries could do for our purposes (but not support of differentiation)
- see jax doc, doc on pure callbacks a cool concrete example here
- scaling with number of nodes, CPU/GPU
- Moving window tests
- benchmark against CircuitScape and ConScape (Julia based)
jaxscape
is distributed under the terms of the MIT license.
- gdistance
- ConScape
- Circuitscape
- graphhab
- conefor