Skip to content

Commit

Permalink
PERF: weights caching derivative calculation for MeanSquaresImageToIm…
Browse files Browse the repository at this point in the history
…ageMetric

Weights caching is used for derivative calculation to reduce registration time with BSplineTransforms. Currently only implemented for MeanSquaresImageToImageMetric, other metrics to be added at a later date.
  • Loading branch information
ljm898 authored and hjmjohnson committed Dec 22, 2024
1 parent 4eedf6a commit 34ea07c
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,33 @@ MeanSquaresImageToImageMetric<TFixedImage, TMovingImage>::GetValueAndDerivativeT
transform = this->m_Transform;
}

// Jacobian should be evaluated at the unmapped (fixed image) point.
transform->ComputeJacobianWithRespectToParameters(fixedImagePoint, threadS.m_Jacobian);
for (unsigned int par = 0; par < this->m_NumberOfParameters; ++par)
if (this->m_BSplineTransform && this->m_UseCachingOfBSplineWeights)
{
double sum = 0.0;
for (unsigned int dim = 0; dim < MovingImageDimension; ++dim)
// using pre-computed weights and indexes to calculate only non zero elements of the derivative
for (unsigned int w = 0; w < this->m_NumBSplineWeights; ++w)
{
sum += 2.0 * diff * threadS.m_Jacobian(dim, par) * movingImageGradientValue[dim];
const auto precomputedIndex = this->m_BSplineTransformIndicesArray[fixedImageSample][w];
const auto precomputedWeight = this->m_BSplineTransformWeightsArray[fixedImageSample][w];
for (unsigned int dim = 0; dim < MovingImageDimension; ++dim)
{
const int par = precomputedIndex + this->m_BSplineParametersOffset[dim];
threadS.m_MSEDerivative[par] += 2.0 * diff * precomputedWeight * movingImageGradientValue[dim];
}
}
}
else
{
// Use generic transform to compute Jacobian at the unmapped (fixed image) point.
transform->ComputeJacobianWithRespectToParameters(fixedImagePoint, threadS.m_Jacobian);
for (unsigned int par = 0; par < this->m_NumberOfParameters; ++par)
{
double sum = 0.0;
for (unsigned int dim = 0; dim < MovingImageDimension; ++dim)
{
sum += 2.0 * diff * threadS.m_Jacobian(dim, par) * movingImageGradientValue[dim];
}
threadS.m_MSEDerivative[par] += sum;
}
threadS.m_MSEDerivative[par] += sum;
}

return true;
Expand Down
77 changes: 77 additions & 0 deletions Modules/Registration/Common/test/itkMeanSquaresImageMetricTest.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
#include "itkLinearInterpolateImageFunction.h"
#include "itkMeanSquaresImageToImageMetric.h"
#include "itkGaussianImageSource.h"
#include "itkBSplineTransformInitializer.h"
#include "itkBSplineTransform.h"
#include "itkMersenneTwisterRandomVariateGenerator.h"
#include "itkTestingMacros.h"

#include <iostream>
#include "itkStdStreamStateSave.h"
Expand All @@ -31,6 +35,7 @@
* This test computes the mean squares value and derivatives
* for various shift values in (-10,10).
*
* This test checks the weights caching derivative optimisation.
*/

int
Expand Down Expand Up @@ -294,6 +299,78 @@ itkMeanSquaresImageMetricTest(int, char *[])
//-------------------------------------------------------
metric->Print(std::cout);

// Check consistency between BSplineTransform derivatives computed with vs omitting weights caching.
constexpr unsigned int splineOrder = 3;
constexpr unsigned int nodesPerDimension = 8;
using BSplineTransformType = itk::BSplineTransform<CoordinateRepresentationType, ImageDimension, splineOrder>;
using InitializerType = itk::BSplineTransformInitializer<BSplineTransformType, MovingImageType>;
using GeneratorType = itk::Statistics::MersenneTwisterRandomVariateGenerator;

// Offset moving image to ensure derivatives for BSplineTransform are non-zero.
MovingImageType::PointValueType bSplineTestMovingImageOrigin[] = { 1.0f, 1.0f };
movingImage->SetOrigin(bSplineTestMovingImageOrigin);

// Initialise BSplineTransform
auto bSplineTransform = BSplineTransformType::New();
auto initializer = InitializerType::New();
BSplineTransformType::MeshSizeType meshSize;
meshSize.Fill(nodesPerDimension - splineOrder);
initializer->SetTransform(bSplineTransform);
initializer->SetImage(movingImage);
initializer->SetTransformDomainMeshSize(meshSize);
initializer->InitializeTransform();

// Set bSplineTransform parameters with MersenneTwister
ParametersType bSplineParameters(bSplineTransform->GetNumberOfParameters());
auto generator = GeneratorType::New();
generator->Initialize();
for (unsigned int d = 0; d < bSplineParameters.Size(); ++d)
{
bSplineParameters[d] = generator->GetNormalVariate();
}
bSplineTransform->SetParameters(bSplineParameters);

// Connect metric to bSplineTransform
auto metricCacheTest = MetricType::New();
metricCacheTest->SetFixedImage(fixedImage);
metricCacheTest->SetMovingImage(movingImage);
metricCacheTest->SetInterpolator(interpolator);
metricCacheTest->SetFixedImageRegion(fixedImage->GetBufferedRegion());
metricCacheTest->SetTransform(bSplineTransform);
metricCacheTest->SetUseCachingOfBSplineWeights(true);
ITK_TRY_EXPECT_NO_EXCEPTION(metricCacheTest->Initialize());

// Compute derivatives with and without weights caching
MetricType::DerivativeType derivativeWithCaching, derivativeNoCaching;
metricCacheTest->GetDerivative(bSplineParameters, derivativeWithCaching);
metricCacheTest->SetUseCachingOfBSplineWeights(false);
metricCacheTest->GetDerivative(bSplineParameters, derivativeNoCaching);

// Check consistency between derivatives
bool sameDerivative = true;
for (unsigned int d = 0; d < bSplineParameters.Size(); ++d)
{
if (itk::Math::abs(derivativeWithCaching[d] - derivativeNoCaching[d]) > 1e-5)
{
sameDerivative = false;
break;
}
}

// Set moving to original origin
movingImage->SetOrigin(movingImageOrigin);
if (!sameDerivative)
{
std::cout << "\nTesting weights caching derivative calculation for BSpline Transfrom... FAILED" << std::endl;
std::cout << "Computed derivative using weights caching was:\n"
<< derivativeWithCaching << "\n Derivative should be:\n"
<< derivativeNoCaching << "\n"
<< std::endl;

return EXIT_FAILURE;
}
std::cout << "\nTesting weights caching derivative calculation for BSpline Transfrom... PASSED\n" << std::endl;

//-------------------------------------------------------
// exercise misc member functions
//-------------------------------------------------------
Expand Down

0 comments on commit 34ea07c

Please sign in to comment.