diff --git a/nitransforms/base.py b/nitransforms/base.py index 81ed1a5e..fa05f1f6 100644 --- a/nitransforms/base.py +++ b/nitransforms/base.py @@ -7,6 +7,7 @@ # ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Common interface for transforms.""" + from pathlib import Path import numpy as np import h5py @@ -146,13 +147,13 @@ def from_arrays(cls, coordinates, triangles): darrays = [ nb.gifti.GiftiDataArray( coordinates.astype(np.float32), - intent=nb.nifti1.intent_codes['NIFTI_INTENT_POINTSET'], - datatype=nb.nifti1.data_type_codes['NIFTI_TYPE_FLOAT32'], + intent=nb.nifti1.intent_codes["NIFTI_INTENT_POINTSET"], + datatype=nb.nifti1.data_type_codes["NIFTI_TYPE_FLOAT32"], ), nb.gifti.GiftiDataArray( triangles.astype(np.int32), - intent=nb.nifti1.intent_codes['NIFTI_INTENT_TRIANGLE'], - datatype=nb.nifti1.data_type_codes['NIFTI_TYPE_INT32'], + intent=nb.nifti1.intent_codes["NIFTI_INTENT_TRIANGLE"], + datatype=nb.nifti1.data_type_codes["NIFTI_TYPE_INT32"], ), ] gii = nb.gifti.GiftiImage(darrays=darrays) @@ -279,6 +280,22 @@ def __add__(self, b): return TransformChain(transforms=[self, b]) + def __len__(self): + """ + Enable ``len()``. + + By default, all transforms are of length one. + This must be overriden by transforms arrays and chains. + + Example + ------- + >>> T1 = TransformBase() + >>> len(T1) + 1 + + """ + return 1 + @property def reference(self): """Access a reference space where data will be resampled onto.""" @@ -335,10 +352,8 @@ def apply(self, *args, **kwargs): Deprecated. Please use ``nitransforms.resampling.apply`` instead. """ - message = ( - "The `apply` method is deprecated. Please use `nitransforms.resampling.apply` instead." - ) - warnings.warn(message, DeprecationWarning, stacklevel=2) + _msg = "This method is deprecated. Please use `nitransforms.resampling.apply` instead." + warnings.warn(_msg, DeprecationWarning, stacklevel=2) from .resampling import apply return apply(self, *args, **kwargs) diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index 9de0d2d6..d7c7f9c5 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -7,49 +7,177 @@ # ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Resampling utilities.""" + +import asyncio +from os import cpu_count +from functools import partial from pathlib import Path +from typing import Callable, TypeVar + import numpy as np from nibabel.loadsave import load as _nbload +from nibabel.arrayproxy import get_obj_dtype +from nibabel.spatialimages import SpatialImage from scipy import ndimage as ndi from nitransforms.base import ( ImageGrid, + TransformBase, TransformError, SpatialReference, _as_homogeneous, ) +R = TypeVar("R") -def apply( - transform, - spatialimage, - reference=None, - order=3, - mode="constant", - cval=0.0, - prefilter=True, - output_dtype=None, +SERIALIZE_VOLUME_WINDOW_WIDTH: int = 8 +"""Minimum number of volumes to automatically serialize 4D transforms.""" + + +async def worker(job: Callable[[], R], semaphore) -> R: + async with semaphore: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, job) + + +async def _apply_serial( + data: np.ndarray, + spatialimage: SpatialImage, + targets: np.ndarray, + transform: TransformBase, + ref_ndim: int, + ref_ndcoords: np.ndarray, + n_resamplings: int, + output: np.ndarray, + input_dtype: np.dtype, + order: int = 3, + mode: str = "constant", + cval: float = 0.0, + prefilter: bool = True, + max_concurrent: int = min(cpu_count(), 12), ): + """ + Resample through a given transform serially, in a 3D+t setting. + + Parameters + ---------- + data : :obj:`~numpy.ndarray` + The input data array. + spatialimage : :obj:`~nibabel.spatialimages.SpatialImage` or `os.pathlike` + The image object containing the data to be resampled in reference + space + targets : :obj:`~numpy.ndarray` + The target coordinates for mapping. + transform : :obj:`~nitransforms.base.TransformBase` + The 3D, 3D+t, or 4D transform through which data will be resampled. + ref_ndim : :obj:`int` + Dimensionality of the resampling target (reference image). + ref_ndcoords : :obj:`~numpy.ndarray` + Physical coordinates (RAS+) where data will be interpolated, if the resampling + target is a grid, the scanner coordinates of all voxels. + n_resamplings : :obj:`int` + Total number of 3D resamplings (can be defined by the input image, the transform, + or be matched, that is, same number of volumes in the input and number of transforms). + output : :obj:`~numpy.ndarray` + The output data array where resampled values will be stored volume-by-volume. + order : :obj:`int`, optional + The order of the spline interpolation, default is 3. + The order has to be in the range 0-5. + mode : :obj:`str`, optional + Determines how the input image is extended when the resamplings overflows + a border. One of ``'constant'``, ``'reflect'``, ``'nearest'``, ``'mirror'``, + or ``'wrap'``. Default is ``'constant'``. + cval : :obj:`float`, optional + Constant value for ``mode='constant'``. Default is 0.0. + prefilter: :obj:`bool`, optional + Determines if the image's data array is prefiltered with + a spline filter before interpolation. The default is ``True``, + which will create a temporary *float64* array of filtered values + if *order > 1*. If setting this to ``False``, the output will be + slightly blurred if *order > 1*, unless the input is prefiltered, + i.e. it is the result of calling the spline filter on the original + input. + + Returns + ------- + np.ndarray + Data resampled on the 3D+t array of input coordinates. + + """ + tasks = [] + semaphore = asyncio.Semaphore(max_concurrent) + + for t in range(n_resamplings): + xfm_t = transform if n_resamplings == 1 else transform[t] + + if targets is None: + targets = ImageGrid(spatialimage).index( # data should be an image + _as_homogeneous(xfm_t.map(ref_ndcoords), dim=ref_ndim) + ) + + data_t = ( + data + if data is not None + else spatialimage.dataobj[..., t].astype(input_dtype, copy=False) + ) + + tasks.append( + asyncio.create_task( + worker( + partial( + ndi.map_coordinates, + data_t, + targets, + output=output[..., t], + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + ), + semaphore, + ) + ) + ) + await asyncio.gather(*tasks) + return output + + +def apply( + transform: TransformBase, + spatialimage: str | Path | SpatialImage, + reference: str | Path | SpatialImage = None, + order: int = 3, + mode: str = "constant", + cval: float = 0.0, + prefilter: bool = True, + output_dtype: np.dtype = None, + dtype_width: int = 8, + serialize_nvols: int = SERIALIZE_VOLUME_WINDOW_WIDTH, + max_concurrent: int = min(cpu_count(), 12), +) -> SpatialImage | np.ndarray: """ Apply a transformation to an image, resampling on the reference spatial object. Parameters ---------- - spatialimage : `spatialimage` + transform: :obj:`~nitransforms.base.TransformBase` + The 3D, 3D+t, or 4D transform through which data will be resampled. + spatialimage : :obj:`~nibabel.spatialimages.SpatialImage` or `os.pathlike` The image object containing the data to be resampled in reference space - reference : spatial object, optional + reference : :obj:`~nibabel.spatialimages.SpatialImage` or `os.pathlike` The image, surface, or combination thereof containing the coordinates of samples that will be sampled. - order : int, optional + order : :obj:`int`, optional The order of the spline interpolation, default is 3. The order has to be in the range 0-5. - mode : {'constant', 'reflect', 'nearest', 'mirror', 'wrap'}, optional + mode : :obj:`str`, optional Determines how the input image is extended when the resamplings overflows - a border. Default is 'constant'. - cval : float, optional + a border. One of ``'constant'``, ``'reflect'``, ``'nearest'``, ``'mirror'``, + or ``'wrap'``. Default is ``'constant'``. + cval : :obj:`float`, optional Constant value for ``mode='constant'``. Default is 0.0. - prefilter: bool, optional + prefilter : :obj:`bool`, optional Determines if the image's data array is prefiltered with a spline filter before interpolation. The default is ``True``, which will create a temporary *float64* array of filtered values @@ -57,7 +185,7 @@ def apply( slightly blurred if *order > 1*, unless the input is prefiltered, i.e. it is the result of calling the spline filter on the original input. - output_dtype: dtype specifier, optional + output_dtype : :obj:`~numpy.dtype`, optional The dtype of the returned array or image, if specified. If ``None``, the default behavior is to use the effective dtype of the input image. If slope and/or intercept are defined, the effective @@ -66,10 +194,21 @@ def apply( If ``reference`` is defined, then the return value is an image, with a data array of the effective dtype but with the on-disk dtype set to the input image's on-disk dtype. + dtype_width : :obj:`int` + Cap the width of the input data type to the given number of bytes. + This argument is intended to work as a way to implement lower memory + requirements in resampling. + serialize_nvols : :obj:`int` + Minimum number of volumes in a 3D+t (that is, a series of 3D transformations + independent in time) to resample on a one-by-one basis. + Serialized resampling can be executed concurrently (parallelized) with + the argument ``max_concurrent``. + max_concurrent : :obj:`int` + Maximum number of 3D resamplings to be executed concurrently. Returns ------- - resampled : `spatialimage` or ndarray + resampled : :obj:`~nibabel.spatialimages.SpatialImage` or :obj:`~numpy.ndarray` The data imaged after resampling to reference space. """ @@ -88,52 +227,140 @@ def apply( if isinstance(spatialimage, (str, Path)): spatialimage = _nbload(str(spatialimage)) - data = np.asanyarray(spatialimage.dataobj) + # Avoid opening the data array just yet + input_dtype = cap_dtype(get_obj_dtype(spatialimage.dataobj), dtype_width) + + # Number of data volumes + data_nvols = 1 if spatialimage.ndim < 4 else spatialimage.shape[-1] + # Number of transforms: transforms chains (e.g., affine + field, are a single transform) + xfm_nvols = 1 if transform.ndim < 4 else len(transform) - if data.ndim == 4 and data.shape[-1] != len(transform): + if data_nvols != xfm_nvols and min(data_nvols, xfm_nvols) > 1: raise ValueError( - "The fourth dimension of the data does not match the tranform's shape." + "The fourth dimension of the data does not match the transform's shape." ) - if data.ndim < transform.ndim: - data = data[..., np.newaxis] + serialize_nvols = ( + serialize_nvols if serialize_nvols and serialize_nvols > 1 else np.inf + ) + n_resamplings = max(data_nvols, xfm_nvols) + serialize_4d = n_resamplings >= serialize_nvols - # For model-based nonlinear transforms, generate the corresponding dense field + targets = None + ref_ndcoords = _ref.ndcoords.T if hasattr(transform, "to_field") and callable(transform.to_field): targets = ImageGrid(spatialimage).index( _as_homogeneous( - transform.to_field(reference=reference).map(_ref.ndcoords.T), + transform.to_field(reference=reference).map(ref_ndcoords), dim=_ref.ndim, ) ) - else: + elif xfm_nvols == 1: targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim) + _as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim) ) - if transform.ndim == 4: - targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T - - resampled = ndi.map_coordinates( - data, - targets, - output=output_dtype, - order=order, - mode=mode, - cval=cval, - prefilter=prefilter, - ) + if serialize_4d: + data = ( + np.asanyarray(spatialimage.dataobj, dtype=input_dtype) + if data_nvols == 1 + else None + ) + + # Order F ensures individual volumes are contiguous in memory + # Also matches NIfTI, making final save more efficient + resampled = np.zeros( + (len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F" + ) + + resampled = asyncio.run( + _apply_serial( + data, + spatialimage, + targets, + transform, + _ref.ndim, + ref_ndcoords, + n_resamplings, + resampled, + input_dtype, + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + max_concurrent=max_concurrent, + ) + ) + else: + data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype) + + if targets is None: + targets = ImageGrid(spatialimage).index( # data should be an image + _as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim) + ) + + # Cast 3D data into 4D if 4D nonsequential transform + if data_nvols == 1 and xfm_nvols > 1: + data = data[..., np.newaxis] + + if transform.ndim == 4: + targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T + + resampled = ndi.map_coordinates( + data, + targets, + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + ) if isinstance(_ref, ImageGrid): # If reference is grid, reshape - hdr = None - if _ref.header is not None: - hdr = _ref.header.copy() - hdr.set_data_dtype(output_dtype or spatialimage.get_data_dtype()) + hdr = ( + _ref.header.copy() + if _ref.header is not None + else spatialimage.header.__class__() + ) + hdr.set_data_dtype(output_dtype or spatialimage.header.get_data_dtype()) + moved = spatialimage.__class__( - resampled.reshape(_ref.shape if data.ndim < 4 else _ref.shape + (-1,)), + resampled.reshape(_ref.shape if n_resamplings == 1 else _ref.shape + (-1,)), _ref.affine, hdr, ) return moved - return resampled + output_dtype = output_dtype or input_dtype + return resampled.astype(output_dtype) + + +def cap_dtype(dt, nbytes): + """ + Cap the datatype size to shave off memory requirements. + + Examples + -------- + >>> cap_dtype(np.dtype('f8'), 4) + dtype('float32') + + >>> cap_dtype(np.dtype('f8'), 16) + dtype('float64') + + >>> cap_dtype('float64', 4) + dtype('float32') + + >>> cap_dtype(np.dtype('i1'), 4) + dtype('int8') + + >>> cap_dtype('int8', 4) + dtype('int8') + + >>> cap_dtype('int32', 1) + dtype('int8') + + >>> cap_dtype(np.dtype('i8'), 4) + dtype('int32') + + """ + dt = np.dtype(dt) + return np.dtype(f"{dt.byteorder}{dt.kind}{min(nbytes, dt.itemsize)}") diff --git a/nitransforms/tests/test_base.py b/nitransforms/tests/test_base.py index fb4be8d8..49d7f7af 100644 --- a/nitransforms/tests/test_base.py +++ b/nitransforms/tests/test_base.py @@ -1,6 +1,9 @@ """Tests of the base module.""" + import numpy as np import nibabel as nb +from nibabel.arrayproxy import get_obj_dtype + import pytest import h5py @@ -97,7 +100,7 @@ def _to_hdf5(klass, x5_root): fname = testdata_path / "someones_anatomy.nii.gz" img = nb.load(fname) - imgdata = np.asanyarray(img.dataobj, dtype=img.get_data_dtype()) + imgdata = np.asanyarray(img.dataobj, dtype=get_obj_dtype(img.dataobj)) # Test identity transform - setting reference xfm = TransformBase() @@ -111,7 +114,10 @@ def _to_hdf5(klass, x5_root): xfm = nitl.Affine() xfm.reference = fname moved = apply(xfm, fname, order=0) - assert np.all(imgdata == np.asanyarray(moved.dataobj, dtype=moved.get_data_dtype())) + + assert np.all( + imgdata == np.asanyarray(moved.dataobj, dtype=get_obj_dtype(moved.dataobj)) + ) # Test ndim returned by affine assert nitl.Affine().ndim == 3 @@ -165,7 +171,10 @@ def test_concatenation(testdata_path): def test_SurfaceMesh(testdata_path): surf_path = testdata_path / "sub-200148_hemi-R_pial.surf.gii" - shape_path = testdata_path / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_thickness.shape.gii" + shape_path = ( + testdata_path + / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_thickness.shape.gii" + ) img_path = testdata_path / "bold.nii.gz" mesh = SurfaceMesh(nb.load(surf_path)) @@ -189,3 +198,18 @@ def test_SurfaceMesh(testdata_path): with pytest.raises(TypeError): SurfaceMesh(nb.load(shape_path)) + + +def test_apply_deprecation(monkeypatch): + """Make sure a deprecation warning is issued.""" + from nitransforms import resampling + + def _retval(*args, **kwargs): + return 1 + + monkeypatch.setattr(resampling, "apply", _retval) + + with pytest.deprecated_call(): + retval = TransformBase().apply() + + assert retval == 1 diff --git a/nitransforms/tests/test_linear.py b/nitransforms/tests/test_linear.py index 50cc5371..969b33ab 100644 --- a/nitransforms/tests/test_linear.py +++ b/nitransforms/tests/test_linear.py @@ -1,42 +1,26 @@ # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """Tests of linear transforms.""" -import os + import pytest import numpy as np -from subprocess import check_call -import shutil import h5py -import nibabel as nb from nibabel.eulerangles import euler2mat from nibabel.affines import from_matvec from nitransforms import linear as nitl from nitransforms import io -from nitransforms.resampling import apply from .utils import assert_affines_by_filename -RMSE_TOL = 0.1 -APPLY_LINEAR_CMD = { - "fsl": """\ -flirt -setbackground 0 -interp nearestneighbour -in {moving} -ref {reference} \ --applyxfm -init {transform} -out {resampled}\ -""".format, - "itk": """\ -antsApplyTransforms -d 3 -r {reference} -i {moving} \ --o {resampled} -n NearestNeighbor -t {transform} --float\ -""".format, - "afni": """\ -3dAllineate -base {reference} -input {moving} \ --prefix {resampled} -1Dmatrix_apply {transform} -final NN\ -""".format, - "fs": """\ -mri_vol2vol --mov {moving} --targ {reference} --lta {transform} \ ---o {resampled} --nearest""".format, -} - - -@pytest.mark.parametrize("matrix", [[0.0], np.ones((3, 3, 3)), np.ones((3, 4)), ]) + +@pytest.mark.parametrize( + "matrix", + [ + [0.0], + np.ones((3, 3, 3)), + np.ones((3, 4)), + ], +) def test_linear_typeerrors1(matrix): """Exercise errors in Affine creation.""" with pytest.raises(TypeError): @@ -158,7 +142,9 @@ def test_loadsave(tmp_path, data_path, testdata_path, autofmt, fmt): assert np.allclose( xfm.matrix, - nitl.load(fname, fmt=supplied_fmt, reference=ref_file, moving=ref_file).matrix, + nitl.load( + fname, fmt=supplied_fmt, reference=ref_file, moving=ref_file + ).matrix, ) else: assert xfm == nitl.load(fname, fmt=supplied_fmt, reference=ref_file) @@ -168,7 +154,9 @@ def test_loadsave(tmp_path, data_path, testdata_path, autofmt, fmt): if fmt == "fsl": assert np.allclose( xfm.matrix, - nitl.load(fname, fmt=supplied_fmt, reference=ref_file, moving=ref_file).matrix, + nitl.load( + fname, fmt=supplied_fmt, reference=ref_file, moving=ref_file + ).matrix, rtol=1e-2, # FSL incurs into large errors due to rounding ) else: @@ -182,7 +170,9 @@ def test_loadsave(tmp_path, data_path, testdata_path, autofmt, fmt): if fmt == "fsl": assert np.allclose( xfm.matrix, - nitl.load(fname, fmt=supplied_fmt, reference=ref_file, moving=ref_file).matrix, + nitl.load( + fname, fmt=supplied_fmt, reference=ref_file, moving=ref_file + ).matrix, rtol=1e-2, # FSL incurs into large errors due to rounding ) else: @@ -192,7 +182,9 @@ def test_loadsave(tmp_path, data_path, testdata_path, autofmt, fmt): if fmt == "fsl": assert np.allclose( xfm.matrix, - nitl.load(fname, fmt=supplied_fmt, reference=ref_file, moving=ref_file).matrix, + nitl.load( + fname, fmt=supplied_fmt, reference=ref_file, moving=ref_file + ).matrix, rtol=1e-2, # FSL incurs into large errors due to rounding ) else: @@ -212,12 +204,15 @@ def test_linear_save(tmpdir, data_path, get_testdata, image_orientation, sw_tool T = np.linalg.inv(T) xfm = ( - nitl.Affine(T) if (sw_tool, image_orientation) != ("afni", "oblique") else + nitl.Affine(T) + if (sw_tool, image_orientation) != ("afni", "oblique") # AFNI is special when moving or reference are oblique - let io do the magic - nitl.Affine(io.afni.AFNILinearTransform.from_ras(T).to_ras( - reference=img, - moving=img, - )) + else nitl.Affine( + io.afni.AFNILinearTransform.from_ras(T).to_ras( + reference=img, + moving=img, + ) + ) ) xfm.reference = img @@ -234,96 +229,6 @@ def test_linear_save(tmpdir, data_path, get_testdata, image_orientation, sw_tool assert_affines_by_filename(xfm_fname1, xfm_fname2) -@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", 'oblique', ]) -@pytest.mark.parametrize("sw_tool", ["itk", "fsl", "afni", "fs"]) -def test_apply_linear_transform(tmpdir, get_testdata, get_testmask, image_orientation, sw_tool): - """Check implementation of exporting affines to formats.""" - tmpdir.chdir() - - img = get_testdata[image_orientation] - msk = get_testmask[image_orientation] - - # Generate test transform - T = from_matvec(euler2mat(x=0.9, y=0.001, z=0.001), [4.0, 2.0, -1.0]) - xfm = nitl.Affine(T) - xfm.reference = img - - ext = "" - if sw_tool == "itk": - ext = ".tfm" - elif sw_tool == "fs": - ext = ".lta" - - img.to_filename("img.nii.gz") - msk.to_filename("mask.nii.gz") - - # Write out transform file (software-dependent) - xfm_fname = f"M.{sw_tool}{ext}" - # Change reference dataset for AFNI & oblique - if (sw_tool, image_orientation) == ("afni", "oblique"): - io.afni.AFNILinearTransform.from_ras( - T, - moving=img, - reference=img, - ).to_filename(xfm_fname) - else: - xfm.to_filename(xfm_fname, fmt=sw_tool) - - cmd = APPLY_LINEAR_CMD[sw_tool]( - transform=os.path.abspath(xfm_fname), - reference=os.path.abspath("mask.nii.gz"), - moving=os.path.abspath("mask.nii.gz"), - resampled=os.path.abspath("resampled_brainmask.nii.gz"), - ) - - # skip test if command is not available on host - exe = cmd.split(" ", 1)[0] - if not shutil.which(exe): - pytest.skip(f"Command {exe} not found on host") - - # resample mask - exit_code = check_call([cmd], shell=True) - assert exit_code == 0 - sw_moved_mask = nb.load("resampled_brainmask.nii.gz") - - nt_moved_mask = apply(xfm, msk, order=0) - nt_moved_mask.set_data_dtype(msk.get_data_dtype()) - nt_moved_mask.to_filename("ntmask.nii.gz") - diff = np.asanyarray(sw_moved_mask.dataobj) - np.asanyarray(nt_moved_mask.dataobj) - - assert np.sqrt((diff ** 2).mean()) < RMSE_TOL - brainmask = np.asanyarray(nt_moved_mask.dataobj, dtype=bool) - - cmd = APPLY_LINEAR_CMD[sw_tool]( - transform=os.path.abspath(xfm_fname), - reference=os.path.abspath("img.nii.gz"), - moving=os.path.abspath("img.nii.gz"), - resampled=os.path.abspath("resampled.nii.gz"), - ) - - exit_code = check_call([cmd], shell=True) - assert exit_code == 0 - sw_moved = nb.load("resampled.nii.gz") - sw_moved.set_data_dtype(img.get_data_dtype()) - - nt_moved = apply(xfm, img, order=0) - diff = ( - np.asanyarray(sw_moved.dataobj, dtype=sw_moved.get_data_dtype()) - - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype()) - ) - - # A certain tolerance is necessary because of resampling at borders - assert np.sqrt((diff[brainmask] ** 2).mean()) < RMSE_TOL - - nt_moved = apply(xfm, "img.nii.gz", order=0) - diff = ( - np.asanyarray(sw_moved.dataobj, dtype=sw_moved.get_data_dtype()) - - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype()) - ) - # A certain tolerance is necessary because of resampling at borders - assert np.sqrt((diff[brainmask] ** 2).mean()) < RMSE_TOL - - def test_Affine_to_x5(tmpdir, testdata_path): """Test affine's operations.""" tmpdir.chdir() @@ -336,40 +241,6 @@ def test_Affine_to_x5(tmpdir, testdata_path): aff._to_hdf5(f.create_group("Affine")) -def test_LinearTransformsMapping_apply(tmp_path, data_path, testdata_path): - """Apply transform mappings.""" - hmc = nitl.load( - data_path / "hmc-itk.tfm", fmt="itk", reference=testdata_path / "sbref.nii.gz" - ) - assert isinstance(hmc, nitl.LinearTransformsMapping) - - # Test-case: realign functional data on to sbref - nii = apply( - hmc, testdata_path / "func.nii.gz", order=1, reference=testdata_path / "sbref.nii.gz" - ) - assert nii.dataobj.shape[-1] == len(hmc) - - # Test-case: write out a fieldmap moved with head - hmcinv = nitl.LinearTransformsMapping( - np.linalg.inv(hmc.matrix), reference=testdata_path / "func.nii.gz" - ) - - nii = apply( - hmcinv, testdata_path / "fmap.nii.gz", order=1 - ) - assert nii.dataobj.shape[-1] == len(hmc) - - # Ensure a ValueError is issued when trying to do weird stuff - hmc = nitl.LinearTransformsMapping(hmc.matrix[:1, ...]) - with pytest.raises(ValueError): - apply( - hmc, - testdata_path / "func.nii.gz", - order=1, - reference=testdata_path / "sbref.nii.gz", - ) - - def test_mulmat_operator(testdata_path): """Check the @ operator.""" ref = testdata_path / "someones_anatomy.nii.gz" diff --git a/nitransforms/tests/test_manip.py b/nitransforms/tests/test_manip.py index b7f6a6e4..b5dd5c62 100644 --- a/nitransforms/tests/test_manip.py +++ b/nitransforms/tests/test_manip.py @@ -1,67 +1,16 @@ # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """Tests of nonlinear transforms.""" -import os -import shutil -from subprocess import check_call + import pytest import numpy as np -import nibabel as nb -from ..manip import load as _load, TransformChain +from ..manip import TransformChain from ..linear import Affine -from .test_nonlinear import ( - RMSE_TOL, - APPLY_NONLINEAR_CMD, -) -from nitransforms.resampling import apply FMT = {"lta": "fs", "tfm": "itk"} -def test_itk_h5(tmp_path, testdata_path): - """Check a translation-only field on one or more axes, different image orientations.""" - os.chdir(str(tmp_path)) - img_fname = testdata_path / "T1w_scanner.nii.gz" - xfm_fname = ( - testdata_path - / "ds-005_sub-01_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.h5" - ) - - xfm = _load(xfm_fname) - - assert len(xfm) == 2 - - ref_fname = tmp_path / "reference.nii.gz" - nb.Nifti1Image( - np.zeros(xfm.reference.shape, dtype="uint16"), xfm.reference.affine, - ).to_filename(str(ref_fname)) - - # Then apply the transform and cross-check with software - cmd = APPLY_NONLINEAR_CMD["itk"]( - transform=xfm_fname, - reference=ref_fname, - moving=img_fname, - output="resampled.nii.gz", - extra="", - ) - - # skip test if command is not available on host - exe = cmd.split(" ", 1)[0] - if not shutil.which(exe): - pytest.skip(f"Command {exe} not found on host") - - exit_code = check_call([cmd], shell=True) - assert exit_code == 0 - sw_moved = nb.load("resampled.nii.gz") - - nt_moved = apply(xfm, img_fname, order=0) - nt_moved.to_filename("nt_resampled.nii.gz") - diff = sw_moved.get_fdata() - nt_moved.get_fdata() - # A certain tolerance is necessary because of resampling at borders - assert (np.abs(diff) > 1e-3).sum() / diff.size < RMSE_TOL - - @pytest.mark.parametrize("ext0", ["lta", "tfm"]) @pytest.mark.parametrize("ext1", ["lta", "tfm"]) @pytest.mark.parametrize("ext2", ["lta", "tfm"]) diff --git a/nitransforms/tests/test_nonlinear.py b/nitransforms/tests/test_nonlinear.py index 24d1f83e..6112f633 100644 --- a/nitransforms/tests/test_nonlinear.py +++ b/nitransforms/tests/test_nonlinear.py @@ -1,9 +1,8 @@ # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """Tests of nonlinear transforms.""" + import os -import shutil -from subprocess import check_call import pytest import numpy as np @@ -14,27 +13,10 @@ from nitransforms.nonlinear import ( BSplineFieldTransform, DenseFieldTransform, - load as nlload, ) from ..io.itk import ITKDisplacementsField -RMSE_TOL = 0.05 -APPLY_NONLINEAR_CMD = { - "itk": """\ -antsApplyTransforms -d 3 -r {reference} -i {moving} \ --o {output} -n NearestNeighbor -t {transform} {extra}\ -""".format, - "afni": """\ -3dNwarpApply -nwarp {transform} -source {moving} \ --master {reference} -interp NN -prefix {output} {extra}\ -""".format, - "fsl": """\ -applywarp -i {moving} -r {reference} -o {output} {extra}\ --w {transform} --interp=nn""".format, -} - - @pytest.mark.parametrize("size", [(20, 20, 20), (20, 20, 20, 3)]) def test_itk_disp_load(size): """Checks field sizes.""" @@ -113,132 +95,6 @@ def test_bsplines_references(testdata_path): ) -@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"]) -@pytest.mark.parametrize("sw_tool", ["itk", "afni"]) -@pytest.mark.parametrize("axis", [0, 1, 2, (0, 1), (1, 2), (0, 1, 2)]) -def test_displacements_field1( - tmp_path, - get_testdata, - get_testmask, - image_orientation, - sw_tool, - axis, -): - """Check a translation-only field on one or more axes, different image orientations.""" - if (image_orientation, sw_tool) == ("oblique", "afni"): - pytest.skip("AFNI obliques are not yet implemented for displacements fields") - - os.chdir(str(tmp_path)) - nii = get_testdata[image_orientation] - msk = get_testmask[image_orientation] - nii.to_filename("reference.nii.gz") - msk.to_filename("mask.nii.gz") - - fieldmap = np.zeros( - (*nii.shape[:3], 1, 3) if sw_tool != "fsl" else (*nii.shape[:3], 3), - dtype="float32", - ) - fieldmap[..., axis] = -10.0 - - _hdr = nii.header.copy() - if sw_tool in ("itk",): - _hdr.set_intent("vector") - _hdr.set_data_dtype("float32") - - xfm_fname = "warp.nii.gz" - field = nb.Nifti1Image(fieldmap, nii.affine, _hdr) - field.to_filename(xfm_fname) - - xfm = nlload(xfm_fname, fmt=sw_tool) - - # Then apply the transform and cross-check with software - cmd = APPLY_NONLINEAR_CMD[sw_tool]( - transform=os.path.abspath(xfm_fname), - reference=tmp_path / "mask.nii.gz", - moving=tmp_path / "mask.nii.gz", - output=tmp_path / "resampled_brainmask.nii.gz", - extra="--output-data-type uchar" if sw_tool == "itk" else "", - ) - - # skip test if command is not available on host - exe = cmd.split(" ", 1)[0] - if not shutil.which(exe): - pytest.skip(f"Command {exe} not found on host") - - # resample mask - exit_code = check_call([cmd], shell=True) - assert exit_code == 0 - sw_moved_mask = nb.load("resampled_brainmask.nii.gz") - nt_moved_mask = apply(xfm, msk, order=0) - nt_moved_mask.set_data_dtype(msk.get_data_dtype()) - diff = np.asanyarray(sw_moved_mask.dataobj) - np.asanyarray(nt_moved_mask.dataobj) - - assert np.sqrt((diff**2).mean()) < RMSE_TOL - brainmask = np.asanyarray(nt_moved_mask.dataobj, dtype=bool) - - # Then apply the transform and cross-check with software - cmd = APPLY_NONLINEAR_CMD[sw_tool]( - transform=os.path.abspath(xfm_fname), - reference=tmp_path / "reference.nii.gz", - moving=tmp_path / "reference.nii.gz", - output=tmp_path / "resampled.nii.gz", - extra="--output-data-type uchar" if sw_tool == "itk" else "", - ) - - exit_code = check_call([cmd], shell=True) - assert exit_code == 0 - sw_moved = nb.load("resampled.nii.gz") - - nt_moved = apply(xfm, nii, order=0) - nt_moved.set_data_dtype(nii.get_data_dtype()) - nt_moved.to_filename("nt_resampled.nii.gz") - sw_moved.set_data_dtype(nt_moved.get_data_dtype()) - diff = np.asanyarray( - sw_moved.dataobj, dtype=sw_moved.get_data_dtype() - ) - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype()) - # A certain tolerance is necessary because of resampling at borders - assert np.sqrt((diff[brainmask] ** 2).mean()) < RMSE_TOL - - -@pytest.mark.parametrize("sw_tool", ["itk", "afni"]) -def test_displacements_field2(tmp_path, testdata_path, sw_tool): - """Check a translation-only field on one or more axes, different image orientations.""" - os.chdir(str(tmp_path)) - img_fname = testdata_path / "tpl-OASIS30ANTs_T1w.nii.gz" - xfm_fname = testdata_path / "ds-005_sub-01_from-OASIS_to-T1_warp_{}.nii.gz".format( - sw_tool - ) - - xfm = nlload(xfm_fname, fmt=sw_tool) - - # Then apply the transform and cross-check with software - cmd = APPLY_NONLINEAR_CMD[sw_tool]( - transform=xfm_fname, - reference=img_fname, - moving=img_fname, - output="resampled.nii.gz", - extra="", - ) - - # skip test if command is not available on host - exe = cmd.split(" ", 1)[0] - if not shutil.which(exe): - pytest.skip(f"Command {exe} not found on host") - - exit_code = check_call([cmd], shell=True) - assert exit_code == 0 - sw_moved = nb.load("resampled.nii.gz") - - nt_moved = apply(xfm, img_fname, order=0) - nt_moved.to_filename("nt_resampled.nii.gz") - sw_moved.set_data_dtype(nt_moved.get_data_dtype()) - diff = np.asanyarray( - sw_moved.dataobj, dtype=sw_moved.get_data_dtype() - ) - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype()) - # A certain tolerance is necessary because of resampling at borders - assert np.sqrt((diff**2).mean()) < RMSE_TOL - - def test_bspline(tmp_path, testdata_path): """Cross-check B-Splines and deformation field.""" os.chdir(str(tmp_path)) diff --git a/nitransforms/tests/test_resampling.py b/nitransforms/tests/test_resampling.py new file mode 100644 index 00000000..2384ad97 --- /dev/null +++ b/nitransforms/tests/test_resampling.py @@ -0,0 +1,365 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +"""Exercise the standalone ``apply()`` implementation.""" + +import os +import pytest +import numpy as np +from subprocess import check_call +import shutil + +import nibabel as nb +from nibabel.eulerangles import euler2mat +from nibabel.affines import from_matvec +from nitransforms import linear as nitl +from nitransforms import nonlinear as nitnl +from nitransforms import manip as nitm +from nitransforms import io +from nitransforms.resampling import apply + +RMSE_TOL_LINEAR = 0.09 +RMSE_TOL_NONLINEAR = 0.05 +APPLY_LINEAR_CMD = { + "fsl": """\ +flirt -setbackground 0 -interp nearestneighbour -in {moving} -ref {reference} \ +-applyxfm -init {transform} -out {resampled}\ +""".format, + "itk": """\ +antsApplyTransforms -d 3 -r {reference} -i {moving} \ +-o {resampled} -n NearestNeighbor -t {transform} --float\ +""".format, + "afni": """\ +3dAllineate -base {reference} -input {moving} \ +-prefix {resampled} -1Dmatrix_apply {transform} -final NN\ +""".format, + "fs": """\ +mri_vol2vol --mov {moving} --targ {reference} --lta {transform} \ +--o {resampled} --nearest""".format, +} +APPLY_NONLINEAR_CMD = { + "itk": """\ +antsApplyTransforms -d 3 -r {reference} -i {moving} \ +-o {output} -n NearestNeighbor -t {transform} {extra}\ +""".format, + "afni": """\ +3dNwarpApply -nwarp {transform} -source {moving} \ +-master {reference} -interp NN -prefix {output} {extra}\ +""".format, + "fsl": """\ +applywarp -i {moving} -r {reference} -o {output} {extra}\ +-w {transform} --interp=nn""".format, +} + + +@pytest.mark.parametrize( + "image_orientation", + [ + "RAS", + "LAS", + "LPS", + "oblique", + ], +) +@pytest.mark.parametrize("sw_tool", ["itk", "fsl", "afni", "fs"]) +def test_apply_linear_transform( + tmpdir, get_testdata, get_testmask, image_orientation, sw_tool +): + """Check implementation of exporting affines to formats.""" + tmpdir.chdir() + + img = get_testdata[image_orientation] + msk = get_testmask[image_orientation] + + # Generate test transform + T = from_matvec(euler2mat(x=0.9, y=0.001, z=0.001), [4.0, 2.0, -1.0]) + xfm = nitl.Affine(T) + xfm.reference = img + + ext = "" + if sw_tool == "itk": + ext = ".tfm" + elif sw_tool == "fs": + ext = ".lta" + + img.to_filename("img.nii.gz") + msk.to_filename("mask.nii.gz") + + # Write out transform file (software-dependent) + xfm_fname = f"M.{sw_tool}{ext}" + # Change reference dataset for AFNI & oblique + if (sw_tool, image_orientation) == ("afni", "oblique"): + io.afni.AFNILinearTransform.from_ras( + T, + moving=img, + reference=img, + ).to_filename(xfm_fname) + else: + xfm.to_filename(xfm_fname, fmt=sw_tool) + + cmd = APPLY_LINEAR_CMD[sw_tool]( + transform=os.path.abspath(xfm_fname), + reference=os.path.abspath("mask.nii.gz"), + moving=os.path.abspath("mask.nii.gz"), + resampled=os.path.abspath("resampled_brainmask.nii.gz"), + ) + + # skip test if command is not available on host + exe = cmd.split(" ", 1)[0] + if not shutil.which(exe): + pytest.skip(f"Command {exe} not found on host") + + # resample mask + exit_code = check_call([cmd], shell=True) + assert exit_code == 0 + sw_moved_mask = nb.load("resampled_brainmask.nii.gz") + + nt_moved_mask = apply(xfm, msk, order=0) + nt_moved_mask.set_data_dtype(msk.get_data_dtype()) + nt_moved_mask.to_filename("ntmask.nii.gz") + diff = np.asanyarray(sw_moved_mask.dataobj) - np.asanyarray(nt_moved_mask.dataobj) + + assert np.sqrt((diff**2).mean()) < RMSE_TOL_LINEAR + brainmask = np.asanyarray(nt_moved_mask.dataobj, dtype=bool) + + cmd = APPLY_LINEAR_CMD[sw_tool]( + transform=os.path.abspath(xfm_fname), + reference=os.path.abspath("img.nii.gz"), + moving=os.path.abspath("img.nii.gz"), + resampled=os.path.abspath("resampled.nii.gz"), + ) + + exit_code = check_call([cmd], shell=True) + assert exit_code == 0 + sw_moved = nb.load("resampled.nii.gz") + sw_moved.set_data_dtype(img.get_data_dtype()) + + nt_moved = apply(xfm, img, order=0) + diff = np.asanyarray( + sw_moved.dataobj, dtype=sw_moved.get_data_dtype() + ) - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype()) + + # A certain tolerance is necessary because of resampling at borders + assert np.sqrt((diff[brainmask] ** 2).mean()) < RMSE_TOL_LINEAR + + nt_moved = apply(xfm, "img.nii.gz", order=0) + diff = np.asanyarray( + sw_moved.dataobj, dtype=sw_moved.get_data_dtype() + ) - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype()) + # A certain tolerance is necessary because of resampling at borders + assert np.sqrt((diff[brainmask] ** 2).mean()) < RMSE_TOL_LINEAR + + +@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"]) +@pytest.mark.parametrize("sw_tool", ["itk", "afni"]) +@pytest.mark.parametrize("axis", [0, 1, 2, (0, 1), (1, 2), (0, 1, 2)]) +def test_displacements_field1( + tmp_path, + get_testdata, + get_testmask, + image_orientation, + sw_tool, + axis, +): + """Check a translation-only field on one or more axes, different image orientations.""" + if (image_orientation, sw_tool) == ("oblique", "afni"): + pytest.skip("AFNI obliques are not yet implemented for displacements fields") + + os.chdir(str(tmp_path)) + nii = get_testdata[image_orientation] + msk = get_testmask[image_orientation] + nii.to_filename("reference.nii.gz") + msk.to_filename("mask.nii.gz") + + fieldmap = np.zeros( + (*nii.shape[:3], 1, 3) if sw_tool != "fsl" else (*nii.shape[:3], 3), + dtype="float32", + ) + fieldmap[..., axis] = -10.0 + + _hdr = nii.header.copy() + if sw_tool in ("itk",): + _hdr.set_intent("vector") + _hdr.set_data_dtype("float32") + + xfm_fname = "warp.nii.gz" + field = nb.Nifti1Image(fieldmap, nii.affine, _hdr) + field.to_filename(xfm_fname) + + xfm = nitnl.load(xfm_fname, fmt=sw_tool) + + # Then apply the transform and cross-check with software + cmd = APPLY_NONLINEAR_CMD[sw_tool]( + transform=os.path.abspath(xfm_fname), + reference=tmp_path / "mask.nii.gz", + moving=tmp_path / "mask.nii.gz", + output=tmp_path / "resampled_brainmask.nii.gz", + extra="--output-data-type uchar" if sw_tool == "itk" else "", + ) + + # skip test if command is not available on host + exe = cmd.split(" ", 1)[0] + if not shutil.which(exe): + pytest.skip(f"Command {exe} not found on host") + + # resample mask + exit_code = check_call([cmd], shell=True) + assert exit_code == 0 + sw_moved_mask = nb.load("resampled_brainmask.nii.gz") + nt_moved_mask = apply(xfm, msk, order=0) + nt_moved_mask.set_data_dtype(msk.get_data_dtype()) + diff = np.asanyarray(sw_moved_mask.dataobj) - np.asanyarray(nt_moved_mask.dataobj) + + assert np.sqrt((diff**2).mean()) < RMSE_TOL_LINEAR + brainmask = np.asanyarray(nt_moved_mask.dataobj, dtype=bool) + + # Then apply the transform and cross-check with software + cmd = APPLY_NONLINEAR_CMD[sw_tool]( + transform=os.path.abspath(xfm_fname), + reference=tmp_path / "reference.nii.gz", + moving=tmp_path / "reference.nii.gz", + output=tmp_path / "resampled.nii.gz", + extra="--output-data-type uchar" if sw_tool == "itk" else "", + ) + + exit_code = check_call([cmd], shell=True) + assert exit_code == 0 + sw_moved = nb.load("resampled.nii.gz") + + nt_moved = apply(xfm, nii, order=0) + nt_moved.set_data_dtype(nii.get_data_dtype()) + nt_moved.to_filename("nt_resampled.nii.gz") + sw_moved.set_data_dtype(nt_moved.get_data_dtype()) + diff = np.asanyarray( + sw_moved.dataobj, dtype=sw_moved.get_data_dtype() + ) - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype()) + # A certain tolerance is necessary because of resampling at borders + assert np.sqrt((diff[brainmask] ** 2).mean()) < RMSE_TOL_LINEAR + + +@pytest.mark.parametrize("sw_tool", ["itk", "afni"]) +def test_displacements_field2(tmp_path, testdata_path, sw_tool): + """Check a translation-only field on one or more axes, different image orientations.""" + os.chdir(str(tmp_path)) + img_fname = testdata_path / "tpl-OASIS30ANTs_T1w.nii.gz" + xfm_fname = testdata_path / "ds-005_sub-01_from-OASIS_to-T1_warp_{}.nii.gz".format( + sw_tool + ) + + xfm = nitnl.load(xfm_fname, fmt=sw_tool) + + # Then apply the transform and cross-check with software + cmd = APPLY_NONLINEAR_CMD[sw_tool]( + transform=xfm_fname, + reference=img_fname, + moving=img_fname, + output="resampled.nii.gz", + extra="", + ) + + # skip test if command is not available on host + exe = cmd.split(" ", 1)[0] + if not shutil.which(exe): + pytest.skip(f"Command {exe} not found on host") + + exit_code = check_call([cmd], shell=True) + assert exit_code == 0 + sw_moved = nb.load("resampled.nii.gz") + + nt_moved = apply(xfm, img_fname, order=0) + nt_moved.to_filename("nt_resampled.nii.gz") + sw_moved.set_data_dtype(nt_moved.get_data_dtype()) + diff = np.asanyarray( + sw_moved.dataobj, dtype=sw_moved.get_data_dtype() + ) - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype()) + # A certain tolerance is necessary because of resampling at borders + assert np.sqrt((diff**2).mean()) < RMSE_TOL_LINEAR + + +def test_apply_transformchain(tmp_path, testdata_path): + """Check a translation-only field on one or more axes, different image orientations.""" + os.chdir(str(tmp_path)) + img_fname = testdata_path / "T1w_scanner.nii.gz" + xfm_fname = ( + testdata_path + / "ds-005_sub-01_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.h5" + ) + + xfm = nitm.load(xfm_fname) + + assert len(xfm) == 2 + + ref_fname = tmp_path / "reference.nii.gz" + nb.Nifti1Image( + np.zeros(xfm.reference.shape, dtype="uint16"), + xfm.reference.affine, + ).to_filename(str(ref_fname)) + + # Then apply the transform and cross-check with software + cmd = APPLY_NONLINEAR_CMD["itk"]( + transform=xfm_fname, + reference=ref_fname, + moving=img_fname, + output="resampled.nii.gz", + extra="", + ) + + # skip test if command is not available on host + exe = cmd.split(" ", 1)[0] + if not shutil.which(exe): + pytest.skip(f"Command {exe} not found on host") + + exit_code = check_call([cmd], shell=True) + assert exit_code == 0 + sw_moved = nb.load("resampled.nii.gz") + + nt_moved = apply(xfm, img_fname, order=0) + nt_moved.to_filename("nt_resampled.nii.gz") + diff = sw_moved.get_fdata() - nt_moved.get_fdata() + # A certain tolerance is necessary because of resampling at borders + assert (np.abs(diff) > 1e-3).sum() / diff.size < RMSE_TOL_LINEAR + + +@pytest.mark.parametrize("serialize_4d", [True, False]) +def test_LinearTransformsMapping_apply( + tmp_path, data_path, testdata_path, serialize_4d +): + """Apply transform mappings.""" + hmc = nitl.load( + data_path / "hmc-itk.tfm", fmt="itk", reference=testdata_path / "sbref.nii.gz" + ) + assert isinstance(hmc, nitl.LinearTransformsMapping) + + # Test-case: realign functional data on to sbref + nii = apply( + hmc, + testdata_path / "func.nii.gz", + order=1, + reference=testdata_path / "sbref.nii.gz", + serialize_nvols=2 if serialize_4d else np.inf, + ) + assert nii.dataobj.shape[-1] == len(hmc) + + # Test-case: write out a fieldmap moved with head + hmcinv = nitl.LinearTransformsMapping( + np.linalg.inv(hmc.matrix), reference=testdata_path / "func.nii.gz" + ) + + nii = apply( + hmcinv, + testdata_path / "fmap.nii.gz", + order=1, + serialize_nvols=2 if serialize_4d else np.inf, + ) + assert nii.dataobj.shape[-1] == len(hmc) + + # Ensure a ValueError is issued when trying to apply mismatched transforms + # (e.g., in this case, two transforms while the functional has 8 volumes) + hmc = nitl.LinearTransformsMapping(hmc.matrix[:2, ...]) + with pytest.raises(ValueError): + apply( + hmc, + testdata_path / "func.nii.gz", + order=1, + reference=testdata_path / "sbref.nii.gz", + serialize_nvols=2 if serialize_4d else np.inf, + )