Skip to content

A minimal JAX library for connectivity analysis at scales

License

Notifications You must be signed in to change notification settings

vboussange/jaxscape

Repository files navigation

tests

JAXScape Logo

JAXScape is a minimal JAX library for connectivity analysis at scales. It provide key utilities to build your own connectivity analysis workflow, including

  • differentiable raster to graph and graph to raster mappings
  • differentiable graph distance metrics
  • moving window utilities

JAXScape leverages JAX's capabilities to accelerate distance computations on CPUs/GPUs/TPUs, while ensuring differentiability of all implemented classes and methods for awesome sensitivity analysis and optimization.

Installation

pip install git+https://github.com/vboussange/jaxscape.git

Quick start

Let's define our graph.

import jax.numpy as jnp
from jaxscape.gridgraph import GridGraph
import numpy as np

# loading jax array representing permeability
permeability = jnp.array(np.loadtxt("permeability.csv", delimiter=","))

# we discard pixels with permeability equal to 0
activities = permeability > 0
plt.imshow(permeability, cmap="gray")
plt.axis("off")

grid = GridGraph(activities=activities, vertex_weights=permeability)

Let's calculate some distances on the grid graph. We will specifically calculate and project the distance of all pixels to the top left pixel

from jaxscape.resistance_distance import ResistanceDistance
from jaxscape.lcp_distance import LCPDistance
from jaxscape.rsp_distance import RSPDistance

# Calculating distances of all pixels to top left pixel
source = grid.coord_to_active_vertex_index([0], [0])

distances = {
    "LCP distance": LCPDistance(),
    "Resistance distance": ResistanceDistance(),
    "RSP distance": RSPDistance(theta=0.01, cost=lambda x: 1 / x)
}

fig, axs = plt.subplots(1, 3, figsize=(10, 4))
for ax, (title, distance) in zip(axs, distances.items()):
    dist_to_node = distance(grid, source)
    cbar = ax.imshow(grid.node_values_to_array(dist_to_node.ravel()), cmap="magma")
    ax.axis("off")
    ax.set_title(title)
    fig.colorbar(cbar, ax=ax, shrink=0.2)

fig.suptitle("Distance to top left pixel")
plt.tight_layout()
plt.show()
Distances

But what's really cool about jaxscape is that you can autodiff through thoses distances! Here we calculate the gradient of the average path length of the graph w.r.t pixel permeability

# we need to provide the number of active vertices, for jit compilation
@eqx.filter_jit
def average_path_length(permeability, activities, nb_active, distance):
    grid = GridGraph(activities=activities, 
                     vertex_weights=permeability,
                     nb_active=nb_active)
    dist = distance(grid)
    return dist.sum() / nb_active**2

grad_connectivity = jax.grad(average_path_length)
nb_active = int(activities.sum())


distance = LCPDistance()
average_path_length(permeability, activities, nb_active, distance)


sensitivities = grad_connectivity(permeability, activities, nb_active, distance)
plt.figure()
cbar = plt.imshow(sensitivities, cmap = "magma")
plt.title("Gradient of APL w.r.t pixel's permeability")
plt.colorbar(cbar)
Sensitivities

For a more advanced example with windowed sensitivity analysis and dispatch on multiple GPUs, see benchmark/moving_window_*.py

Features and roadmap 🚀

Raster to graphs

  • GridGraph with differentiable adjacency matrix method

Distances

  • Least-cost path

    • Bellman-Ford (one-to-all)
    • Floyd-Warshall (all-to-all)
    • Differentiable Djikstra or A* (see implementation here)
  • Resistance distance

    • all-to-all calculation with dense solver (pinv, resulting in full distance matrix materialization)
    • [-] advanced mode with direct solvers (laplacian factorization, cannot scale to large landscape)
      • Must rely on lineax, with wrapper over specialized solver for sparse systems:
        • UMFPACK and CHOLMOD (see implementation here where scipy.spsolve is wrapped in JAX and vjp has been implemented - could also work with CHOLMOD) 🏃‍♀️
        • jax.experimental.sparse.linalg.spsolve
    • advanced mode with indirect solvers (no laplacian factorization, requires preconditioning)
  • Randomized shortest path distance (REF)

    • all-to-all calculation (distance matrix materialization)
    • [-] all-to-few calculation
      • Should be based on direct or inderict solvers, similarly to ResistanceDistance
    • one-to-one calculation

Utilities

  • Moving window generator
  • Differentiable connected component algorithm (see here)
    • An external call to scipy/cusparse connected component libraries could do for our purposes (but not support of differentiation)
    • see jax doc, doc on pure callbacks a cool concrete example here

Benchmark

  • scaling with number of nodes, CPU/GPU
  • Moving window tests
  • benchmark against CircuitScape and ConScape (Julia based)

License

jaxscape is distributed under the terms of the MIT license.

Related packages

  • gdistance
  • ConScape
  • Circuitscape
  • graphhab
  • conefor

About

A minimal JAX library for connectivity analysis at scales

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages