-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
dbb49cb
commit 48dd04c
Showing
18 changed files
with
1,071 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# This example shows how to calculate gradients | ||
from __future__ import annotations | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jax import random | ||
from phaser import simulate | ||
from phaser.models import RydbergHamiltonian | ||
from phaser.utils import init_state | ||
|
||
key = random.PRNGKey(42) | ||
|
||
# Initializing Hamiltonian | ||
n_qubits = 15 | ||
dt, N = 1e-3, 3000 | ||
laser_params = (1.0, 2.0) | ||
U = jnp.triu(random.normal(key, (n_qubits, n_qubits)) ** 2) | ||
in_state = init_state(n_qubits) | ||
|
||
|
||
def laser(laser_params, t): | ||
(w_rabi, w_detune) = laser_params | ||
return { | ||
"rabi": 20.0 * jnp.cos(2 * jnp.pi * w_rabi * t), | ||
"detune": 15.0 * jnp.cos(2 * jnp.pi * w_detune * t), | ||
} | ||
|
||
|
||
hamiltonian = RydbergHamiltonian(n_qubits, U) | ||
hamiltonian_params = hamiltonian.init( | ||
key, | ||
in_state, | ||
laser(laser_params, 0), | ||
) | ||
|
||
|
||
# We take the gradient of some random state w.r.t the laser params and interaction_matrix | ||
def forward(laser_params, hamiltonian_params): | ||
out_state = simulate( | ||
hamiltonian, | ||
hamiltonian_params, | ||
laser, | ||
laser_params, | ||
N, | ||
dt, | ||
in_state, | ||
) | ||
return (jnp.abs(out_state) ** 2).flatten()[-1] | ||
|
||
|
||
# Getting the gradient fn w.r.t. both the pulse and interaction matrix and printing the grads | ||
# Note that we jit (compile) the function so the timing here includes compiling | ||
# but this only needs to happen once | ||
grad_fn = jax.jit(jax.grad(forward, argnums=[0, 1])) | ||
laser_grads, interaction_grads = grad_fn(laser_params, hamiltonian_params) | ||
|
||
print(f"Gradients w.r.t laser params: \n {laser_grads}") | ||
print(f"Gradients w.r.t interaction matrix: \n {interaction_grads}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# This example shows how to build a model hamiltonian and simulate it. | ||
from __future__ import annotations | ||
|
||
from time import time | ||
|
||
import flax.linen as nn | ||
import jax.numpy as jnp | ||
import numpy as np | ||
from chex import Array | ||
from jax import random | ||
from phaser.hamiltonians import Interaction, Number, Pauli_x | ||
from phaser.propagators import second_order_trotter | ||
from phaser.simulate import simulate | ||
from phaser.utils import init_state, kron_sum | ||
|
||
key = random.PRNGKey(42) | ||
|
||
|
||
class RydbergHamiltonian(nn.Module): | ||
n_qubits: int | ||
U: Array | ||
|
||
def setup(self): | ||
# Rabi terms | ||
H_rabi = [Pauli_x((idx,), None) for idx in np.arange(self.n_qubits)] | ||
|
||
# Detuning terms | ||
H_detune = [Number((idx,), None) for idx in np.arange(self.n_qubits)] | ||
|
||
# Interaction term | ||
# We don't want to learn U here so it's just a matrix | ||
self.U_params = self.U[np.triu_indices_from(self.U, k=1)] | ||
H_interact = [Interaction(idx, None) for idx in zip(*np.triu_indices_from(self.U, k=1))] | ||
|
||
# Joining all terms | ||
self.H = H_rabi + H_detune + H_interact | ||
|
||
def __call__(self, state, weights): | ||
weights = jnp.concatenate([weights["rabi"] / 2, -weights["detune"], self.U_params]) | ||
return kron_sum(self.H, state, weights) | ||
|
||
def evolve(self, state: Array, dt: float, weights: dict): | ||
# Getting weights into same shape | ||
weights = jnp.concatenate([weights["rabi"] / 2, -weights["detune"], self.U_params]) | ||
return second_order_trotter(self.H, state, dt, weights) | ||
|
||
|
||
# Initializing Hamiltonian | ||
n_qubits = 15 | ||
dt, N = 1e-3, 3000 | ||
laser_params = (1.0, 2.0) | ||
U = jnp.triu(random.normal(key, (n_qubits, n_qubits)) ** 2) | ||
in_state = init_state(n_qubits) | ||
|
||
|
||
# We call it laser here but it's just a function which takes in 1) some parameters and 2) the time of the simulation | ||
# and returns the parameter values of the hamiltonian. So it's really just a way to simulate time dependent hamiltonians. | ||
def laser(laser_params, t): | ||
(w_rabi, w_detune) = laser_params | ||
return { | ||
"rabi": jnp.full((n_qubits,), 20.0 * jnp.cos(2 * jnp.pi * w_rabi * t)), | ||
"detune": jnp.full((n_qubits,), 15.0 * jnp.cos(2 * jnp.pi * w_detune * t)), | ||
} | ||
|
||
|
||
hamiltonian = RydbergHamiltonian(n_qubits, U) | ||
hamiltonian_params = hamiltonian.init( | ||
key, | ||
in_state, | ||
laser(laser_params, 0), | ||
) | ||
|
||
|
||
# Timing | ||
start = time() | ||
_ = simulate( | ||
hamiltonian, | ||
hamiltonian_params, | ||
laser, | ||
laser_params, | ||
N, | ||
dt, | ||
in_state, | ||
).block_until_ready() | ||
stop = time() | ||
|
||
print(f"Simulation time for {n_qubits} qubits and {N} steps: {stop - start}s") | ||
print("Note that for clarity we didn't jit the final function, so compilation time is included.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# This shows how to build an efficient model using diagonalization | ||
from __future__ import annotations | ||
|
||
from time import time | ||
|
||
import flax.linen as nn | ||
import jax.numpy as jnp | ||
import numpy as np | ||
from chex import Array | ||
from jax import random | ||
from phaser.diagonal import diagonal_onebody_hamiltonian, diagonal_twobody_hamiltonian | ||
from phaser.hamiltonians import HamiltonianTerm, Pauli_x, n | ||
from phaser.propagators import second_order_trotter | ||
from phaser.simulate import simulate | ||
from phaser.utils import init_state, kron_sum | ||
|
||
key = random.PRNGKey(42) | ||
|
||
|
||
# Defining diagonal detuning | ||
def diagonal_detune_H(idx, weights): | ||
return diagonal_onebody_hamiltonian(n, weights, idx) | ||
|
||
|
||
def diagonal_detune_expm(idx, weights): | ||
return jnp.exp(-1j * diagonal_detune_H(idx, weights)) | ||
|
||
|
||
DiagonalDetune = HamiltonianTerm.create(diagonal_detune_H, diagonal_detune_expm) | ||
|
||
|
||
# Interaction | ||
def diagonal_interaction_H(idx, weights): | ||
return diagonal_twobody_hamiltonian((n, n), weights, idx) | ||
|
||
|
||
def diagonal_interaction_expm(idx, weights): | ||
return jnp.exp(-1j * diagonal_interaction_H(idx, weights)) | ||
|
||
|
||
DiagonalInteraction = HamiltonianTerm.create(diagonal_interaction_H, diagonal_interaction_expm) | ||
|
||
|
||
def generate_interaction(U): | ||
U_params = jnp.stack(U[np.triu_indices_from(U, k=1)]) | ||
idx = tuple(zip(*np.triu_indices_from(U, k=1))) | ||
|
||
return DiagonalInteraction(idx, lambda key: U_params) | ||
|
||
|
||
class DiagonalRydbergHamiltonian(nn.Module): | ||
n_qubits: int | ||
U: Array | ||
|
||
def setup(self): | ||
# Rabi terms | ||
H_rabi = [Pauli_x((idx,), None) for idx in range(self.n_qubits)] | ||
|
||
# Detuning | ||
H_detune = DiagonalDetune(range(self.n_qubits), None) | ||
|
||
# Interaction term | ||
H_interact = generate_interaction(self.U) | ||
|
||
# Joining all terms | ||
self.H = [*H_rabi, H_detune, H_interact] | ||
|
||
def __call__(self, state, weights): | ||
return kron_sum(self.H, state, self.parse_weights(weights)) | ||
|
||
def evolve(self, state: Array, dt: float, weights: dict): | ||
return second_order_trotter(self.H, state, dt, self.parse_weights(weights)) | ||
|
||
def parse_weights(self, weights): | ||
# Parse the weights from tuple to correct shape and values | ||
return [ | ||
*jnp.full((self.n_qubits,), weights["rabi"] / 2), | ||
jnp.full((self.n_qubits,), -weights["detune"]), | ||
None, | ||
] | ||
|
||
|
||
if __name__ == "__main__": | ||
# Initializing Hamiltonian | ||
n_qubits = 20 | ||
dt, N = 1e-3, 3000 | ||
laser_params = (1.0, 2.0) | ||
U = jnp.triu(random.normal(key, (n_qubits, n_qubits)) ** 2) | ||
in_state = init_state(n_qubits) | ||
|
||
def laser(laser_params, t): | ||
(w_rabi, w_detune) = laser_params | ||
return { | ||
"rabi": 20.0 * jnp.cos(2 * jnp.pi * w_rabi * t), | ||
"detune": 15.0 * jnp.cos(2 * jnp.pi * w_detune * t), | ||
} | ||
|
||
hamiltonian = DiagonalRydbergHamiltonian(n_qubits, U) | ||
hamiltonian_params = hamiltonian.init( | ||
key, | ||
in_state, | ||
laser(laser_params, 0), | ||
) | ||
|
||
# Timing | ||
start = time() | ||
_ = simulate( | ||
hamiltonian, | ||
hamiltonian_params, | ||
laser, | ||
laser_params, | ||
N, | ||
dt, | ||
in_state, | ||
).block_until_ready() | ||
stop = time() | ||
|
||
print(f"Simulation time for {n_qubits} qubits and {N} steps: {stop - start}s") | ||
print( | ||
"Note that for clarity we didn't jit the final function, so compilation time is included." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from __future__ import annotations | ||
|
||
from .simulate import simulate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from __future__ import annotations | ||
|
||
from functools import reduce | ||
from itertools import chain | ||
|
||
import jax.numpy as jnp | ||
from chex import Array | ||
|
||
from .utils import diagonal_kronecker, kron_AI, kron_IA | ||
|
||
|
||
def diagonal_onebody_hamiltonian(Hi: Array, weights: Array, idx: list[int]) -> Array: | ||
# Generates diagonal of diagonal onebody hamiltonian terms. | ||
# Not pretty but it works... | ||
def diagonal_Hi(diagonal: Array, idx: int) -> Array: | ||
return kron_IA(kron_AI(diagonal, 2 ** (n_qubits - idx - 1)), 2**idx) | ||
|
||
n_qubits = max(idx) + 1 # +1 cause of index | ||
Hi_diag = jnp.diag(Hi) | ||
return reduce( | ||
lambda state, x: state + x[0] * diagonal_Hi(Hi_diag, x[1]), | ||
zip(weights, idx), | ||
jnp.zeros(2**n_qubits), | ||
) | ||
|
||
|
||
def diagonal_twobody_hamiltonian( | ||
HiHj: tuple[Array, Array], weights: Array, idx: list[tuple[int, int]] | ||
) -> Array: | ||
# Generates diagonal of diagonal two-body hamiltonian terms. | ||
# Not pretty but it works... | ||
def diagonal_Hi(diagonal: list[Array], idx_ij: tuple[int, int]) -> Array: | ||
idx_i, idx_j = idx_ij | ||
left = kron_IA(diagonal[0], 2 ** (idx_i)) | ||
right = kron_IA(kron_AI(diagonal[1], 2 ** (n_qubits - idx_j - 1)), 2 ** (idx_j - idx_i - 1)) | ||
return diagonal_kronecker(left, right) | ||
|
||
n_qubits = max(list(chain(*idx))) + 1 # +1 cause of index | ||
HiHj_diag = [jnp.diag(H) for H in HiHj] | ||
return reduce( | ||
lambda state, x: state + x[0] * diagonal_Hi(HiHj_diag, x[1]), | ||
zip(weights, idx), | ||
jnp.zeros(2**n_qubits), | ||
) |
Oops, something went wrong.