Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xfields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .beam_elements.temp_slicer import TempSlicer
from .beam_elements.electroncloud import ElectronCloud
from .beam_elements.electronlens_interpolated import ElectronLensInterpolated
from .beam_elements.waketracker import WakeTracker

from .general import _pkg_root
from .config_tools import replace_spacecharge_with_quasi_frozen
Expand Down
10 changes: 10 additions & 0 deletions xfields/beam_elements/element_with_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self,
num_slices=num_slices, # Per bunch, this is N_1 in the paper
bunch_spacing_zeta=bunch_spacing_zeta, # This is P in the paper
filling_scheme=filling_scheme,
bunch_selection=bunch_selection,
num_turns=num_turns,
circumference=circumference)

Expand Down Expand Up @@ -104,19 +105,28 @@ def _initialize_moments(
num_slices=None, # Per bunch, this is N_1 in the paper
bunch_spacing_zeta=None, # This is P in the paper
filling_scheme=None,
bunch_selection=None,
num_turns=1,
circumference=None):


if filling_scheme is not None:
i_last_bunch = np.where(filling_scheme)[0][-1]
num_periods = i_last_bunch + 1
else:
num_periods = 1

if bunch_selection is None:
num_targets = num_periods
else:
num_targets = 1+ np.max(bunch_selection)-np.min(bunch_selection)

self.moments_data = CompressedProfile(
moments=self.source_moments + ['result'],
zeta_range=zeta_range,
num_slices=num_slices,
bunch_spacing_zeta=bunch_spacing_zeta,
num_targets = num_targets,
num_periods=num_periods,
num_turns=num_turns,
circumference=circumference,
Expand Down
19 changes: 13 additions & 6 deletions xfields/beam_elements/waketracker/convolution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import time
import numpy as np
from matplotlib import pyplot as plt
from scipy.constants import e as qe

import xobjects as xo
Expand Down Expand Up @@ -64,13 +66,17 @@ def __init__(self, component, waketracker=None, _flatten=False, log_moments=None

def my_rfft(self, data, **kwargs):
if type(self._context) in (xo.ContextCpu, xo.ContextCupy):
return self._context.nplike_lib.fft.rfft(data, **kwargs)
if hasattr(self._context,'omp_num_threads') and self._context.omp_num_threads > 1:
kwargs['workers'] = self._context.omp_num_threads
return self._context.splike_lib.fft.rfft(data, **kwargs)
else:
raise NotImplementedError('Waketacker implemented only for CPU and Cupy')

def my_irfft(self, data, **kwargs):
if type(self._context) in (xo.ContextCpu, xo.ContextCupy):
return self._context.nplike_lib.fft.irfft(data, **kwargs)
if hasattr(self._context,'omp_num_threads') and self._context.omp_num_threads > 1:
kwargs['workers'] = self._context.omp_num_threads
return self._context.splike_lib.fft.irfft(data, **kwargs)
else:
raise NotImplementedError('Waketacker implemented only for CPU and Cupy')

Expand All @@ -81,13 +87,13 @@ def _initialize_conv_data(self, _flatten=False, moments_data=None, beta0=None):
self._M_aux = moments_data._M_aux
self._N_1 = moments_data._N_1
self._N_S = moments_data._N_S
self._N_T = moments_data._N_S
self._N_T = moments_data._N_T
self._BB = 1 # B in the paper
# (for now we assume that B=0 is the first bunch in time and the
# last one in zeta)
self._AA = self._BB - self._N_S
self._CC = self._AA
self._DD = self._BB
self._DD = -1*np.min(self.waketracker.slicer.bunch_selection)+1
self._CC = self._DD - self._N_T

# Build wake matrix
self.z_wake = _build_z_wake(moments_data._z_a, moments_data._z_b,
Expand All @@ -97,6 +103,7 @@ def _initialize_conv_data(self, _flatten=False, moments_data=None, beta0=None):
moments_data.dz, self._AA,
self._BB, self._CC, self._DD,
moments_data._z_P)

assert beta0 is not None
# here below I had to add float() to beta0 because when using Cupy
# context particles.beta0[0] turns out to be a 0d array. To be checked
Expand Down Expand Up @@ -246,6 +253,6 @@ def _build_z_wake(z_a, z_b, num_turns, n_aux, m_aux, circumference, dz,
z_p = 0

for ii, ll in enumerate(range(
cc - bb + 1, dd - aa)):
int(cc - bb + 1), int(dd - aa))):
z_wake[tt, ii * n_aux:(ii + 1) * n_aux] = temp_z + ll * z_p
return z_wake
52 changes: 44 additions & 8 deletions xfields/beam_elements/waketracker/waketracker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Tuple

import numpy as np

from scipy.constants import c as clight
Expand Down Expand Up @@ -51,6 +50,9 @@ def __init__(self, components,
filling_scheme=None,
bunch_selection=None,
num_turns=1,
fake_coupled_bunch_phase_x = None,
fake_coupled_bunch_phase_y = None,
beta_x = None, beta_y = None,
circumference=None,
log_moments=None,
_flatten=False,
Expand All @@ -61,11 +63,30 @@ def __init__(self, components,
self.components = components
self.pipeline_manager = None

self.fake_coupled_bunch_phases = {}
self.betas = {}
if fake_coupled_bunch_phase_x is not None:
self.fake_coupled_bunch_phases['x'] = fake_coupled_bunch_phase_x
assert beta_x is not None and beta_x > 0
self.betas['x'] = beta_x
if fake_coupled_bunch_phase_y is not None:
self.fake_coupled_bunch_phases['y'] = fake_coupled_bunch_phase_y
assert beta_y is not None and beta_y > 0
self.betas['y'] = beta_y
if self.fake_coupled_bunch_phases:
assert bunch_selection is not None and filling_scheme is not None
assert bunch_selection, "When faking a coupled bunch mode, only one bunch should be selected as ref."

all_slicer_moments = []
for cc in self.components:
assert not hasattr(cc, 'moments_data') or cc.moments_data is None
all_slicer_moments += cc.source_moments

if self.fake_coupled_bunch_phases:
for moment_name in self.fake_coupled_bunch_phases.keys():
if moment_name in all_slicer_moments:
all_slicer_moments.append('p'+moment_name)

self.all_slicer_moments = list(set(all_slicer_moments))

super().__init__(
Expand All @@ -79,18 +100,18 @@ def __init__(self, components,
num_turns=num_turns,
circumference=circumference,
with_compressed_profile=True,
_context=self.context)
_context=self._context)

self._initialize_moments(
zeta_range=zeta_range, # These are [a, b] in the paper
num_slices=num_slices, # Per bunch, this is N_1 in the paper
bunch_spacing_zeta=bunch_spacing_zeta, # This is P in the paper
filling_scheme=filling_scheme,
bunch_selection=bunch_selection,
num_turns=num_turns,
circumference=circumference)

self._flatten = _flatten
all_slicer_moments = list(set(all_slicer_moments))

def init_pipeline(self, pipeline_manager, element_name, partner_names):

Expand All @@ -99,12 +120,11 @@ def init_pipeline(self, pipeline_manager, element_name, partner_names):
partner_names=partner_names)

def track(self, particles):

# Find first active particle to get beta0
if particles.state[0] > 0:
beta0 = particles.beta0[0]
else:
i_alive = np.where(particles.state > 0)[0]
i_alive = self._context.nplike_lib.where(particles.state > 0)[0]
if len(i_alive) == 0:
return
i_first = i_alive[0]
Expand All @@ -122,19 +142,35 @@ def track(self, particles):
cc._conv_data._initialize_conv_data(_flatten=self._flatten,
moments_data=self.moments_data,
beta0=beta0)

# Use common slicer from parent class to measure all moments
status = super().track(particles)

if status and status.on_hold == True:
return status

if self.fake_coupled_bunch_phases:
self._compute_fake_bunch_moments()

for wf in self.components:
wf._conv_data.track(particles,
i_slot_particles=self.i_slot_particles,
i_slice_particles=self.i_slice_particles,
moments_data=self.moments_data)

def _compute_fake_bunch_moments(self):
conjugate_names = {'x':'px','y':'py'}
n_slots = int(self._context.nplike_lib.max(self.slicer.filled_slots))+1
for moment_name in self.fake_coupled_bunch_phases.keys():
z_dummy,mom = self.moments_data.get_source_moment_profile(moment_name,0,self.bunch_selection[0])
z_dummy,mom_conj = self.moments_data.get_source_moment_profile(conjugate_names[moment_name],0,self.bunch_selection[0])
complex_normalised_moments = mom + (1j*self.betas[moment_name])*mom_conj
slots = self._context.nplike_lib.transpose(self._context.nplike_lib.tile(self.slicer.filled_slots,(len(complex_normalised_moments),1)))
complex_normalised_moments = self._context.nplike_lib.tile(complex_normalised_moments,(n_slots,1))
all_beam_moments = self._context.nplike_lib.real(complex_normalised_moments*self._context.nplike_lib.exp(1j*self.fake_coupled_bunch_phases[moment_name]*(self.bunch_selection[0]-slots)))
self.moments_data.set_all_beam_moments(moment_name,0,all_beam_moments)
z_dummy,mom = self.moments_data.get_source_moment_profile('num_particles',0,self.bunch_selection[0])
all_beam_num_particles = self._context.nplike_lib.tile(mom,(n_slots,1))
self.moments_data.set_all_beam_moments('num_particles',0,all_beam_num_particles)

@property
def zeta_range(self):
return self.slicer.zeta_range
Expand All @@ -158,7 +194,7 @@ def num_turns(self):
@property
def circumference(self):
return self.moments_data.circumference

def __add__(self, other):

if other == 0:
Expand Down
36 changes: 36 additions & 0 deletions xfields/slicers/compressed_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,39 @@ def get_moment_profile(self, moment_name, i_turn):
i_start_in_moments_data:i_end_in_moments_data])

return z_out, moment_out

def get_source_moment_profile(self, moment_name, i_turn,i_source):
"""
Get the moment profile for a given turn.

Parameters
----------
moment_name : str
The name of the moment to get
i_turn : int
The turn index, 0 <= i_turn < self.num_turns

Returns
-------
z_out : np.ndarray
The z positions within the moment profile
moment_out : np.ndarray
The moment profile
"""

z_out = self._arr2ctx(np.zeros(self._N_1))
moment_out = self._arr2ctx(np.zeros(self._N_1))
i_moment = self.moments_names.index(moment_name)
_z_P = self._z_P or 0

z_out = (
self._z_a + self.dz / 2
- i_source * _z_P + self.dz * self._arr2ctx(np.arange(self._N_1)))

i_start_in_moments_data = (self._N_S - i_source - 1) * self._N_aux
i_end_in_moments_data = i_start_in_moments_data + self._N_1
moment_out = (
self.data[i_moment, i_turn,
i_start_in_moments_data:i_end_in_moments_data])

return z_out, moment_out