Skip to content

Commit

Permalink
Initial Phaser port
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikandreasseitz committed Nov 17, 2023
1 parent dbb49cb commit 48dd04c
Show file tree
Hide file tree
Showing 18 changed files with 1,071 additions and 2 deletions.
58 changes: 58 additions & 0 deletions examples/analog/gradients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# This example shows how to calculate gradients
from __future__ import annotations

import jax
import jax.numpy as jnp
from jax import random
from phaser import simulate
from phaser.models import RydbergHamiltonian
from phaser.utils import init_state

key = random.PRNGKey(42)

# Initializing Hamiltonian
n_qubits = 15
dt, N = 1e-3, 3000
laser_params = (1.0, 2.0)
U = jnp.triu(random.normal(key, (n_qubits, n_qubits)) ** 2)
in_state = init_state(n_qubits)


def laser(laser_params, t):
(w_rabi, w_detune) = laser_params
return {
"rabi": 20.0 * jnp.cos(2 * jnp.pi * w_rabi * t),
"detune": 15.0 * jnp.cos(2 * jnp.pi * w_detune * t),
}


hamiltonian = RydbergHamiltonian(n_qubits, U)
hamiltonian_params = hamiltonian.init(
key,
in_state,
laser(laser_params, 0),
)


# We take the gradient of some random state w.r.t the laser params and interaction_matrix
def forward(laser_params, hamiltonian_params):
out_state = simulate(
hamiltonian,
hamiltonian_params,
laser,
laser_params,
N,
dt,
in_state,
)
return (jnp.abs(out_state) ** 2).flatten()[-1]


# Getting the gradient fn w.r.t. both the pulse and interaction matrix and printing the grads
# Note that we jit (compile) the function so the timing here includes compiling
# but this only needs to happen once
grad_fn = jax.jit(jax.grad(forward, argnums=[0, 1]))
laser_grads, interaction_grads = grad_fn(laser_params, hamiltonian_params)

print(f"Gradients w.r.t laser params: \n {laser_grads}")
print(f"Gradients w.r.t interaction matrix: \n {interaction_grads}")
88 changes: 88 additions & 0 deletions examples/analog/introduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# This example shows how to build a model hamiltonian and simulate it.
from __future__ import annotations

from time import time

import flax.linen as nn
import jax.numpy as jnp
import numpy as np
from chex import Array
from jax import random
from phaser.hamiltonians import Interaction, Number, Pauli_x
from phaser.propagators import second_order_trotter
from phaser.simulate import simulate
from phaser.utils import init_state, kron_sum

key = random.PRNGKey(42)


class RydbergHamiltonian(nn.Module):
n_qubits: int
U: Array

def setup(self):
# Rabi terms
H_rabi = [Pauli_x((idx,), None) for idx in np.arange(self.n_qubits)]

# Detuning terms
H_detune = [Number((idx,), None) for idx in np.arange(self.n_qubits)]

# Interaction term
# We don't want to learn U here so it's just a matrix
self.U_params = self.U[np.triu_indices_from(self.U, k=1)]
H_interact = [Interaction(idx, None) for idx in zip(*np.triu_indices_from(self.U, k=1))]

# Joining all terms
self.H = H_rabi + H_detune + H_interact

def __call__(self, state, weights):
weights = jnp.concatenate([weights["rabi"] / 2, -weights["detune"], self.U_params])
return kron_sum(self.H, state, weights)

def evolve(self, state: Array, dt: float, weights: dict):
# Getting weights into same shape
weights = jnp.concatenate([weights["rabi"] / 2, -weights["detune"], self.U_params])
return second_order_trotter(self.H, state, dt, weights)


# Initializing Hamiltonian
n_qubits = 15
dt, N = 1e-3, 3000
laser_params = (1.0, 2.0)
U = jnp.triu(random.normal(key, (n_qubits, n_qubits)) ** 2)
in_state = init_state(n_qubits)


# We call it laser here but it's just a function which takes in 1) some parameters and 2) the time of the simulation
# and returns the parameter values of the hamiltonian. So it's really just a way to simulate time dependent hamiltonians.
def laser(laser_params, t):
(w_rabi, w_detune) = laser_params
return {
"rabi": jnp.full((n_qubits,), 20.0 * jnp.cos(2 * jnp.pi * w_rabi * t)),
"detune": jnp.full((n_qubits,), 15.0 * jnp.cos(2 * jnp.pi * w_detune * t)),
}


hamiltonian = RydbergHamiltonian(n_qubits, U)
hamiltonian_params = hamiltonian.init(
key,
in_state,
laser(laser_params, 0),
)


# Timing
start = time()
_ = simulate(
hamiltonian,
hamiltonian_params,
laser,
laser_params,
N,
dt,
in_state,
).block_until_ready()
stop = time()

print(f"Simulation time for {n_qubits} qubits and {N} steps: {stop - start}s")
print("Note that for clarity we didn't jit the final function, so compilation time is included.")
121 changes: 121 additions & 0 deletions examples/analog/making_efficient_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# This shows how to build an efficient model using diagonalization
from __future__ import annotations

from time import time

import flax.linen as nn
import jax.numpy as jnp
import numpy as np
from chex import Array
from jax import random
from phaser.diagonal import diagonal_onebody_hamiltonian, diagonal_twobody_hamiltonian
from phaser.hamiltonians import HamiltonianTerm, Pauli_x, n
from phaser.propagators import second_order_trotter
from phaser.simulate import simulate
from phaser.utils import init_state, kron_sum

key = random.PRNGKey(42)


# Defining diagonal detuning
def diagonal_detune_H(idx, weights):
return diagonal_onebody_hamiltonian(n, weights, idx)


def diagonal_detune_expm(idx, weights):
return jnp.exp(-1j * diagonal_detune_H(idx, weights))


DiagonalDetune = HamiltonianTerm.create(diagonal_detune_H, diagonal_detune_expm)


# Interaction
def diagonal_interaction_H(idx, weights):
return diagonal_twobody_hamiltonian((n, n), weights, idx)


def diagonal_interaction_expm(idx, weights):
return jnp.exp(-1j * diagonal_interaction_H(idx, weights))


DiagonalInteraction = HamiltonianTerm.create(diagonal_interaction_H, diagonal_interaction_expm)


def generate_interaction(U):
U_params = jnp.stack(U[np.triu_indices_from(U, k=1)])
idx = tuple(zip(*np.triu_indices_from(U, k=1)))

return DiagonalInteraction(idx, lambda key: U_params)


class DiagonalRydbergHamiltonian(nn.Module):
n_qubits: int
U: Array

def setup(self):
# Rabi terms
H_rabi = [Pauli_x((idx,), None) for idx in range(self.n_qubits)]

# Detuning
H_detune = DiagonalDetune(range(self.n_qubits), None)

# Interaction term
H_interact = generate_interaction(self.U)

# Joining all terms
self.H = [*H_rabi, H_detune, H_interact]

def __call__(self, state, weights):
return kron_sum(self.H, state, self.parse_weights(weights))

def evolve(self, state: Array, dt: float, weights: dict):
return second_order_trotter(self.H, state, dt, self.parse_weights(weights))

def parse_weights(self, weights):
# Parse the weights from tuple to correct shape and values
return [
*jnp.full((self.n_qubits,), weights["rabi"] / 2),
jnp.full((self.n_qubits,), -weights["detune"]),
None,
]


if __name__ == "__main__":
# Initializing Hamiltonian
n_qubits = 20
dt, N = 1e-3, 3000
laser_params = (1.0, 2.0)
U = jnp.triu(random.normal(key, (n_qubits, n_qubits)) ** 2)
in_state = init_state(n_qubits)

def laser(laser_params, t):
(w_rabi, w_detune) = laser_params
return {
"rabi": 20.0 * jnp.cos(2 * jnp.pi * w_rabi * t),
"detune": 15.0 * jnp.cos(2 * jnp.pi * w_detune * t),
}

hamiltonian = DiagonalRydbergHamiltonian(n_qubits, U)
hamiltonian_params = hamiltonian.init(
key,
in_state,
laser(laser_params, 0),
)

# Timing
start = time()
_ = simulate(
hamiltonian,
hamiltonian_params,
laser,
laser_params,
N,
dt,
in_state,
).block_until_ready()
stop = time()

print(f"Simulation time for {n_qubits} qubits and {N} steps: {stop - start}s")
print(
"Note that for clarity we didn't jit the final function, so compilation time is included."
)
3 changes: 3 additions & 0 deletions horqrux/phaser/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from __future__ import annotations

from .simulate import simulate
44 changes: 44 additions & 0 deletions horqrux/phaser/diagonal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from functools import reduce
from itertools import chain

import jax.numpy as jnp
from chex import Array

from .utils import diagonal_kronecker, kron_AI, kron_IA


def diagonal_onebody_hamiltonian(Hi: Array, weights: Array, idx: list[int]) -> Array:
# Generates diagonal of diagonal onebody hamiltonian terms.
# Not pretty but it works...
def diagonal_Hi(diagonal: Array, idx: int) -> Array:
return kron_IA(kron_AI(diagonal, 2 ** (n_qubits - idx - 1)), 2**idx)

n_qubits = max(idx) + 1 # +1 cause of index
Hi_diag = jnp.diag(Hi)
return reduce(
lambda state, x: state + x[0] * diagonal_Hi(Hi_diag, x[1]),
zip(weights, idx),
jnp.zeros(2**n_qubits),
)


def diagonal_twobody_hamiltonian(
HiHj: tuple[Array, Array], weights: Array, idx: list[tuple[int, int]]
) -> Array:
# Generates diagonal of diagonal two-body hamiltonian terms.
# Not pretty but it works...
def diagonal_Hi(diagonal: list[Array], idx_ij: tuple[int, int]) -> Array:
idx_i, idx_j = idx_ij
left = kron_IA(diagonal[0], 2 ** (idx_i))
right = kron_IA(kron_AI(diagonal[1], 2 ** (n_qubits - idx_j - 1)), 2 ** (idx_j - idx_i - 1))
return diagonal_kronecker(left, right)

n_qubits = max(list(chain(*idx))) + 1 # +1 cause of index
HiHj_diag = [jnp.diag(H) for H in HiHj]
return reduce(
lambda state, x: state + x[0] * diagonal_Hi(HiHj_diag, x[1]),
zip(weights, idx),
jnp.zeros(2**n_qubits),
)
Loading

0 comments on commit 48dd04c

Please sign in to comment.