Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JP-3755: Remove unused options and add unit tests to SOSS extraction algorithm #9000

Draft
wants to merge 59 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
a1a6ca3
added unit tests for wv_map_bounds and arange_2d
emolter Oct 25, 2024
1842f84
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Oct 25, 2024
fbd4521
simplify oversample_grid and write tests
emolter Oct 25, 2024
7cb20b1
fixed edge case where extrapolate_grid could run forever, added unit …
emolter Oct 25, 2024
8a4759b
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Oct 28, 2024
a69282a
replace legacy scipy interp1d with make_interp_spline
emolter Oct 28, 2024
2dc62d3
add unit tests to get_soss_grid and helpers
emolter Oct 28, 2024
4197b38
remove unused soss utilities
emolter Oct 28, 2024
5f70c9c
combine BaseOverlap class with ExtractionEngine
emolter Oct 29, 2024
33712be
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Oct 29, 2024
c4de667
make data, err required for ExtractionEngine calls, remove them as en…
emolter Oct 30, 2024
8e2766c
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Oct 31, 2024
5f3610f
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Oct 31, 2024
324b3ff
made functions private, removed some unused optionals
emolter Nov 1, 2024
65e3b0a
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Nov 5, 2024
333d817
added fixes and unit tests to WebbKernel
emolter Nov 6, 2024
6be6310
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Nov 6, 2024
d736f8e
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Nov 8, 2024
20ea48e
fix some failing unit tests
emolter Nov 11, 2024
a85d542
remove unused options and simplify call structure of Tikhonov and Tik…
emolter Nov 13, 2024
a6e448b
simplify indexing in get_w
emolter Nov 21, 2024
93c3eab
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Nov 21, 2024
0a54415
some style fixes
emolter Nov 21, 2024
0894291
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Nov 21, 2024
0e976b8
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Nov 21, 2024
402f34d
tracking down bug
emolter Nov 22, 2024
2e20b6f
bugfix for kernels
emolter Nov 22, 2024
53e6349
trying to fix results is nan bug
emolter Nov 22, 2024
056424b
still trying to bugfix mask
emolter Nov 22, 2024
da9c2e2
add note
emolter Nov 22, 2024
6968a3e
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Nov 25, 2024
b80b6fb
adding tests of engine
emolter Nov 26, 2024
7d99e53
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Nov 26, 2024
7742e62
add more tests to extraction engine
emolter Nov 26, 2024
824a916
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Nov 26, 2024
b0e587c
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Nov 27, 2024
53607a6
more unit tests for extraction engine methods
emolter Nov 27, 2024
bed22ac
more unit tests of extraction engine
emolter Nov 29, 2024
cddde89
toy model for tests is round-tripping successfully now
emolter Dec 2, 2024
b13909a
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Dec 3, 2024
d98cf94
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Dec 3, 2024
85e7fb9
move utils tests into atoca main
emolter Dec 4, 2024
f42a74e
split fixtures into conftest, add more tests of kernels
emolter Dec 4, 2024
1797037
ruff check according to stcal rules
emolter Dec 4, 2024
432e38f
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Dec 4, 2024
4d6a71c
starting tests for box extract
emolter Dec 5, 2024
81ffd52
added unit tests for boxextract functions
emolter Dec 6, 2024
1aacac1
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Dec 6, 2024
4fceb78
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Dec 9, 2024
ba52f91
added test coverage for pastasoss helper functions
emolter Dec 10, 2024
22e111d
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Dec 10, 2024
1359e05
added unit tests for soss_syscor functions
emolter Dec 10, 2024
3ae226d
added test for estim_flux_first_order
emolter Dec 11, 2024
9bb45ab
Added unit tests and supporting fixtures for model_image testing
emolter Dec 12, 2024
8c2b093
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Dec 12, 2024
43a77d4
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Dec 16, 2024
74b5719
small start to docs
emolter Dec 16, 2024
b661cb2
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
emolter Dec 17, 2024
e0c8883
fixed a few typos during self review
emolter Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
combine BaseOverlap class with ExtractionEngine
  • Loading branch information
emolter committed Oct 29, 2024
commit 5f70c9c33fc00a1e4146490a18c91572db6ccf11
201 changes: 57 additions & 144 deletions jwst/extract_1d/soss_extract/atoca.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import warnings
from scipy.sparse import issparse, csr_matrix, diags
from scipy.sparse.linalg import spsolve, lsqr, MatrixRankWarning
from scipy.interpolate import interp1d

# Local imports.
from . import atoca_utils
Expand All @@ -33,8 +32,12 @@ def __init__(self, message):
super().__init__(self.message)


class _BaseOverlap:
"""Base class for the ATOCA algorithm (Darveau-Bernier 2021, in prep).
Comment on lines -36 to -37
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This base class was only used once as a super of ExtractionEngine, and it wasn't clear why the distinction was helpful. Code seems much cleaner just making a single ExtractionEngine class

class ExtractionEngine():
"""
Run the ATOCA algorithm (Darveau-Bernier 2021, in prep).

TODO: fix the below, which came from _BaseOverlap class
Base class for the ATOCA algorithm (Darveau-Bernier 2021, in prep).
Used to perform an overlapping extraction of the form:
(B_T * B) * f = (data/sig)_T * B
where B is a matrix and f is an array.
Expand All @@ -47,8 +50,19 @@ class _BaseOverlap:
The classes inheriting from this class should specify the
methods get_w which computes the 'k' associated to each pixel 'i'.
These depends of the type of interpolation used.
"""

This version models the pixels of the detector using an oversampled trapezoidal integration.
TODO: Merge with BaseOverlap class for readability
TODO: the following arguments can be simplified
mask - never used, always superseded by mask_trace_profile
wave_grid - always passed in explicitly; no need to have a default
wave_bounds - never used, default is always computed from wave_map
n_os - apparently never used because wave_grid is always explicit


TODO: I don't understand why data is not a required argument. If the data are only
needed when ExtractionEngine.__call__ happens, then the data should be input there somehow
"""
# The desired data-type for computations, e.g., 'float32'. 'float64' is recommended.
dtype = 'float64'

Expand All @@ -59,18 +73,19 @@ def __init__(self, wave_map, trace_profile, throughput, kernels,
"""
Parameters
----------
wave_map : (N_ord, N, M) list or array of 2-D arrays
A list or array of the central wavelength position for each
order on the detector. It must have the same (N, M) as `data`.
trace_profile : (N_ord, N, M) list or array of 2-D arrays
A list or array of the spatial profile for each order
on the detector. It must have the same (N, M) as `data`.
on the detector. It has to have the same (N, M) as `data`.
wave_map : (N_ord, N, M) list or array of 2-D arrays
A list or array of the central wavelength position for each
order on the detector.
It has to have the same (N, M) as `data`.
throughput : (N_ord [, N_k]) list of array or callable
A list of functions or array of the throughput at each order.
If callable, the functions depend on the wavelength.
If array, projected on `wave_grid`.
kernels : array, callable or sparse matrix
Convolution kernel to be applied on the spectrum (f_k) for each order.
Convolution kernel to be applied on spectrum (f_k) for each orders.
Can be array of the shape (N_ker, N_k_c).
Can be a callable with the form f(x, x0) where x0 is
the position of the center of the kernel. In this case, it must
Expand All @@ -81,21 +96,24 @@ def __init__(self, wave_map, trace_profile, throughput, kernels,
If sparse, the shape has to be (N_k_c, N_k) and it will
be used directly. N_ker is the length of the effective kernel
and N_k_c is the length of the spectrum (f_k) convolved.
global_mask : (N, M) array_like boolean, optional
data : (N, M) array_like, optional
A 2-D array of real values representing the detector image.
error : (N, M) array_like, optional
Estimate of the error on each pixel. Default is one everywhere.
mask : (N, M) array_like boolean, optional
Boolean Mask of the detector pixels to mask for every extraction.
Should not be related to a specific order (if so, use `mask_trace_profile` instead).
mask_trace_profile : (N_ord, N, M) list or array of 2-D arrays[bool], optional
A list or array of the pixels that need to be used for extraction,
A list or array of the pixel that need to be used for extraction,
for each order on the detector. It has to have the same (N_ord, N, M) as `trace_profile`.
If not given, `threshold` will be applied on spatial profiles to define the masks.
orders : list, optional:
orders : list, optional
List of orders considered. Default is orders = [1, 2]
wave_grid : (N_k) array_like, optional
The grid on which f(lambda) will be projected.
Default is a grid from `utils.get_soss_grid`.
`n_os` will be passed to this function.
Default still has to be improved.
wave_bounds : list or array-like (N_ord, 2), optional
Boundary wavelengths covered by each order.
Boundary wavelengths covered by each orders.
Default is the wavelength covered by `wave_map`.
n_os : int, optional
Oversampling rate. If `wave_grid`is None, it will be used to
Expand All @@ -111,6 +129,17 @@ def __init__(self, wave_map, trace_profile, throughput, kernels,
If dictionary, the same c_kwargs will be used for each order.
"""

# Get wavelength at the boundary of each pixel
wave_p, wave_m = [], []
for wave in wave_map: # For each order
lp, lm = atoca_utils.get_wave_p_or_m(wave) # Lambda plus or minus
# Make sure it is the good precision
wave_p.append(lp.astype(self.dtype))
wave_m.append(lm.astype(self.dtype))

# Save values
self.wave_p, self.wave_m = wave_p, wave_m

# If no orders specified extract on orders 1 and 2.
if orders is None:
orders = [1, 2]
Expand Down Expand Up @@ -227,7 +256,6 @@ def __init__(self, wave_map, trace_profile, throughput, kernels,
self.tikho_mat = None
self.w_t_wave_c = None

return

def get_attributes(self, *args, i_order=None):
"""Return list of attributes
Expand Down Expand Up @@ -258,6 +286,7 @@ def get_attributes(self, *args, i_order=None):

def update_wave_map(self, wave_map):
"""Update internal wave_map
TODO: can this be removed?
Parameters
----------
wave_map : array[float]
Expand All @@ -270,10 +299,10 @@ def update_wave_map(self, wave_map):
dtype = self.dtype
self.wave_map = [wave_n.astype(dtype).copy() for wave_n in wave_map]

return

def update_trace_profile(self, trace_profile):
"""Update internal trace_profiles
TODO: can this be removed?
Parameters
----------
trace_profile : array[float]
Expand All @@ -288,10 +317,10 @@ def update_trace_profile(self, trace_profile):
# Update the trace_profile profile.
self.trace_profile = [trace_profile_n.astype(dtype).copy() for trace_profile_n in trace_profile]

return

def update_throughput(self, throughput):
"""Update internal throughput values
TODO: can this be removed?
Parameters
----------
throughput : array[float] or callable
Expand Down Expand Up @@ -325,10 +354,10 @@ def update_throughput(self, throughput):
# Set the attribute to the new values.
self.throughput = throughput_new

return

def update_kernels(self, kernels, c_kwargs):
"""Update internal kernels
TODO: can this be removed?
Parameters
----------
kernels : array, callable or sparse matrix
Expand Down Expand Up @@ -377,30 +406,6 @@ def update_kernels(self, kernels, c_kwargs):
self.kernels = kernels_new


def get_mask_wave(self, i_order):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method was duplicated after removing base class; kept other instance below

"""Generate mask bounded by limits of wavelength grid
Parameters
----------
i_order : int
Order to select the wave_map on which a mask
will be generated

Returns
-------
array[bool]
A mask with True where wave_map is outside the bounds
of wave_grid
"""

wave = self.wave_map[i_order]
imin, imax = self.i_bounds[i_order]
wave_min = self.wave_grid[imin]
wave_max = self.wave_grid[imax - 1]

mask = (wave <= wave_min) | (wave >= wave_max)

return mask

def _get_masks(self, global_mask):
"""Compute a general mask on the detector and for each order.
Depends on the spatial profile, the wavelength grid
Expand Down Expand Up @@ -486,7 +491,6 @@ def update_mask(self, mask):
# Re-compute weights
self.weights, self.weights_k_idx = self.compute_weights()

return

def _get_i_bnds(self, wave_bounds=None):
"""Define wavelength boundaries for each order using the order's mask.
Expand Down Expand Up @@ -550,7 +554,6 @@ def update_i_bnds(self):
# Update attribute.
self.i_bounds = i_bnds_new

return

def wave_grid_c(self, i_order):
"""Return wave_grid for the convolved flux at a given order.
Expand All @@ -560,12 +563,6 @@ def wave_grid_c(self, i_order):

return self.wave_grid[index]

def get_w(self, i_order):
"""Dummy method to init this class
TODO: so is this an abstract base class? is it ever actually subclassed more than once?
"""

return np.array([]), np.array([])

def compute_weights(self):
"""
Expand All @@ -583,9 +580,10 @@ def compute_weights(self):

# Init lists
weights, weights_k_idx = [], []
for i_order in range(self.n_orders): # For each orders
for i_order in range(self.n_orders):

weights_n, k_idx_n = self.get_w(i_order) # Compute weights
# Compute weights
weights_n, k_idx_n = self.get_w(i_order)

# Convert to sparse matrix
# First get the dimension of the convolved grid
Expand All @@ -608,7 +606,6 @@ def _set_w_t_wave_c(self, i_order, product):
# Assign value
self.w_t_wave_c[i_order] = product.copy()

return

def grid_from_map(self, i_order=0):
"""Return the wavelength grid and the columns associated
Expand Down Expand Up @@ -962,7 +959,6 @@ def set_tikho_matrix(self, t_mat=None, t_mat_func=None, fargs=None, fkwargs=None
# Set attribute
self.tikho_mat = t_mat

return

def get_tikho_matrix(self, **kwargs):
"""
Expand Down Expand Up @@ -1276,9 +1272,8 @@ def compute_likelihood(self, spectrum=None, same=False):
# Compute the log-likelihood for the spectrum.
with np.errstate(divide='ignore'):
logl = (model - data) / error
logl = -np.nansum((logl[~mask])**2)
return -np.nansum((logl[~mask])**2)

return logl

@staticmethod
def _solve(matrix, result):
Expand Down Expand Up @@ -1387,87 +1382,6 @@ def __call__(self, tikhonov=False, tikho_kwargs=None, factor=None, **kwargs):
return spectrum


class ExtractionEngine(_BaseOverlap):
"""
Run the ATOCA algorithm (Darveau-Bernier 2021, in prep).

This version models the pixels of the detector using an oversampled trapezoidal integration.
"""

def __init__(self, wave_map, trace_profile, *args, **kwargs):
"""
Parameters
----------
trace_profile : (N_ord, N, M) list or array of 2-D arrays
A list or array of the spatial profile for each order
on the detector. It has to have the same (N, M) as `data`.
wave_map : (N_ord, N, M) list or array of 2-D arrays
A list or array of the central wavelength position for each
order on the detector.
It has to have the same (N, M) as `data`.
throughput : (N_ord [, N_k]) list of array or callable
A list of functions or array of the throughput at each order.
If callable, the functions depend on the wavelength.
If array, projected on `wave_grid`.
kernels : array, callable or sparse matrix
Convolution kernel to be applied on spectrum (f_k) for each orders.
Can be array of the shape (N_ker, N_k_c).
Can be a callable with the form f(x, x0) where x0 is
the position of the center of the kernel. In this case, it must
return a 1D array (len(x)), so a kernel value
for each pairs of (x, x0). If array or callable,
it will be passed to `convolution.get_c_matrix` function
and the `c_kwargs` can be passed to this function.
If sparse, the shape has to be (N_k_c, N_k) and it will
be used directly. N_ker is the length of the effective kernel
and N_k_c is the length of the spectrum (f_k) convolved.
data : (N, M) array_like, optional
A 2-D array of real values representing the detector image.
error : (N, M) array_like, optional
Estimate of the error on each pixel. Default is one everywhere.
mask : (N, M) array_like boolean, optional
Boolean Mask of the detector pixels to mask for every extraction.
Should not be related to a specific order (if so, use `mask_trace_profile` instead).
mask_trace_profile : (N_ord, N, M) list or array of 2-D arrays[bool], optional
A list or array of the pixel that need to be used for extraction,
for each order on the detector. It has to have the same (N_ord, N, M) as `trace_profile`.
If not given, `threshold` will be applied on spatial profiles to define the masks.
orders : list, optional
List of orders considered. Default is orders = [1, 2]
wave_grid : (N_k) array_like, optional
The grid on which f(lambda) will be projected.
Default still has to be improved.
wave_bounds : list or array-like (N_ord, 2), optional
Boundary wavelengths covered by each orders.
Default is the wavelength covered by `wave_map`.
n_os : int, optional
Oversampling rate. If `wave_grid`is None, it will be used to
generate a grid. Default is 2.
threshold : float, optional:
The contribution of any order on a pixel is considered significant if
its estimated spatial profile is greater than this threshold value.
If it is not properly modeled (not covered by the wavelength grid),
it will be masked. Default is 1e-3.
c_kwargs : list of N_ord dictionaries or dictionary, optional
Inputs keywords arguments to pass to
`convolution.get_c_matrix` function for each order.
If dictionary, the same c_kwargs will be used for each order.
"""

# Get wavelength at the boundary of each pixel
wave_p, wave_m = [], []
for wave in wave_map: # For each order
lp, lm = atoca_utils.get_wave_p_or_m(wave) # Lambda plus or minus
# Make sure it is the good precision
wave_p.append(lp.astype(self.dtype))
wave_m.append(lm.astype(self.dtype))

# Save values
self.wave_p, self.wave_m = wave_p, wave_m

# Init upper class
super().__init__(wave_map, trace_profile, *args, **kwargs)

def _get_lo_hi(self, grid, i_order):
"""
Find the lowest (lo) and highest (hi) index
Expand Down Expand Up @@ -1509,7 +1423,8 @@ def _get_lo_hi(self, grid, i_order):
ma = mask_ord[~mask]
lo[ma], hi[ma] = -1, -2

return lo, hi
return lo, hi


def get_mask_wave(self, i_order):
"""Generate mask bounded by limits of wavelength grid
Expand All @@ -1531,15 +1446,13 @@ def get_mask_wave(self, i_order):
wave_min = self.wave_grid[i_bnds[0]]
wave_max = self.wave_grid[i_bnds[1] - 1]

mask = (wave_m < wave_min) | (wave_p > wave_max)
return (wave_m < wave_min) | (wave_p > wave_max)

return mask

def get_w(self, i_order):
"""Compute integration weights for each grid points and each pixels.
Depends on the order `n`.
TODO: is this the same order as the spectral order? if so, can we ignore
a bunch of the cases in this function and just keep orders 1 and 2?
TODO: what is this doing? where can we find the math?

Parameters
----------
Expand Down Expand Up @@ -1616,7 +1529,7 @@ def get_w(self, i_order):

# Generate array of all k_i. Set to max value of uint16 if not valid
k_n = atoca_utils.arange_2d(k_first, k_last + 1)
bad = k_n == np.iinfo(k_n.dtype).max
bad = k_n == -1

# Number of valid k per pixel
n_k = np.sum(~bad, axis=-1)
Expand Down
Loading