Skip to content

Commit

Permalink
add math functions for complex type for oneAPI backend
Browse files Browse the repository at this point in the history
  • Loading branch information
AuroraPerego committed Dec 5, 2024
1 parent b78fd8b commit d01a3d5
Show file tree
Hide file tree
Showing 2 changed files with 316 additions and 325 deletions.
316 changes: 315 additions & 1 deletion include/alpaka/math/Complex.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2022 Sergei Bastrakov
/* Copyright 2024 Sergei Bastrakov, Aurora Perego
* SPDX-License-Identifier: MPL-2.0
*/

Expand Down Expand Up @@ -579,4 +579,318 @@ namespace alpaka
} // namespace internal

using internal::Complex;

#if defined(ALPAKA_ACC_SYCL_ENABLED) || defined(ALPAKA_ACC_GPU_CUDA_ENABLED) || defined(ALPAKA_ACC_GPU_HIP_ENABLED)

namespace math::trait
{

//! The abs trait specialization for complex types.
template<typename TAcc, typename T>
struct Abs<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& arg)
{
return sqrt(ctx, arg.real() * arg.real() + arg.imag() * arg.imag());
}
};

//! The acos trait specialization for complex types.
template<typename TAcc, typename T>
struct Acos<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& arg)
{
// This holds everywhere, including the branch cuts: acos(z) = -i * ln(z + i * sqrt(1 - z^2))
return Complex<T>{static_cast<T>(0.0), static_cast<T>(-1.0)}
* log(
ctx,
arg
+ Complex<T>{static_cast<T>(0.0), static_cast<T>(1.0)}
* sqrt(ctx, static_cast<T>(1.0) - arg * arg));
}
};

//! The acosh trait specialization for complex types.
template<typename TAcc, typename T>
struct Acosh<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& arg)
{
// acos(z) = ln(z + sqrt(z-1) * sqrt(z+1))
return log(ctx, arg + sqrt(ctx, arg - static_cast<T>(1.0)) * sqrt(ctx, arg + static_cast<T>(1.0)));
}
};

//! The arg Complex<T> specialization for complex types.
template<typename TAcc, typename T>
struct Arg<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& argument)
{
return atan2(ctx, argument.imag(), argument.real());
}
};

//! The asin trait specialization for complex types.
template<typename TAcc, typename T>
struct Asin<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& arg)
{
// This holds everywhere, including the branch cuts: asin(z) = i * ln(sqrt(1 - z^2) - i * z)
return Complex<T>{static_cast<T>(0.0), static_cast<T>(1.0)}
* log(
ctx,
sqrt(ctx, static_cast<T>(1.0) - arg * arg)
- Complex<T>{static_cast<T>(0.0), static_cast<T>(1.0)} * arg);
}
};

//! The asinh trait specialization for complex types.
template<typename TAcc, typename T>
struct Asinh<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& arg)
{
// asinh(z) = ln(z + sqrt(z^2 + 1))
return log(ctx, arg + sqrt(ctx, arg * arg + static_cast<T>(1.0)));
}
};

//! The atan trait specialization for complex types.
template<typename TAcc, typename T>
struct Atan<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& arg)
{
// This holds everywhere, including the branch cuts: atan(z) = -i/2 * ln((i - z) / (i + z))
return Complex<T>{static_cast<T>(0.0), static_cast<T>(-0.5)}
* log(
ctx,
(Complex<T>{static_cast<T>(0.0), static_cast<T>(1.0)} - arg)
/ (Complex<T>{static_cast<T>(0.0), static_cast<T>(1.0)} + arg));
}
};

//! The atanh trait specialization for complex types.
template<typename TAcc, typename T>
struct Atanh<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& arg)
{
// atanh(z) = 0.5 * (ln(1 + z) - ln(1 - z))
return static_cast<T>(0.5)
* (log(ctx, static_cast<T>(1.0) + arg) - log(ctx, static_cast<T>(1.0) - arg));
}
};

//! The conj specialization for complex types.
template<typename TAcc, typename T>
struct Conj<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& /* conj_ctx */, Complex<T> const& arg)
{
return Complex<T>{arg.real(), -arg.imag()};
}
};

//! The cos trait specialization for complex types.
template<typename TAcc, typename T>
struct Cos<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& arg)
{
// cos(z) = 0.5 * (exp(i * z) + exp(-i * z))
return T(0.5)
* (exp(ctx, Complex<T>{static_cast<T>(0.0), static_cast<T>(1.0)} * arg)
+ exp(ctx, Complex<T>{static_cast<T>(0.0), static_cast<T>(-1.0)} * arg));
}
};

//! The cosh trait specialization for complex types.
template<typename TAcc, typename T>
struct Cosh<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& arg)
{
// cosh(z) = 0.5 * (exp(z) + exp(-z))
return T(0.5) * (exp(ctx, arg) + exp(ctx, static_cast<T>(-1.0) * arg));
}
};

//! The exp trait specialization for complex types.
template<typename TAcc, typename T>
struct Exp<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& arg)
{
// exp(z) = exp(x + iy) = exp(x) * (cos(y) + i * sin(y))
auto re = T{}, im = T{};
sincos(ctx, arg.imag(), im, re);
return exp(ctx, arg.real()) * Complex<T>{re, im};
}
};

//! The log trait specialization for complex types.
template<typename TAcc, typename T>
struct Log<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& argument)
{
// Branch cut along the negative real axis (same as for std::complex),
// principal value of ln(z) = ln(|z|) + i * arg(z)
return log(ctx, abs(ctx, argument))
+ Complex<T>{static_cast<T>(0.0), static_cast<T>(1.0)} * arg(ctx, argument);
}
};

//! The log2 trait specialization for complex types.
template<typename TAcc, typename T>
struct Log2<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& argument)
{
return log(ctx, argument) / log(ctx, static_cast<T>(2));
}
};

//! The log10 trait specialization for complex types.
template<typename TAcc, typename T>
struct Log10<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& argument)
{
return log(ctx, argument) / log(ctx, static_cast<T>(10));
}
};

//! The pow trait specialization for complex types.
template<typename TAcc, typename T, typename U>
struct Pow<TAcc, Complex<T>, Complex<U>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& base, Complex<U> const& exponent)
{
// Type promotion matching rules of complex std::pow but simplified given our math only supports float
// and double, no long double.
using Promoted
= Complex<std::conditional_t<is_decayed_v<T, float> && is_decayed_v<U, float>, float, double>>;
// pow(z1, z2) = e^(z2 * log(z1))
return exp(ctx, Promoted{exponent} * log(ctx, Promoted{base}));
}
};

//! The pow trait specialization for complex and real types.
template<typename TAcc, typename T, typename U>
struct Pow<TAcc, Complex<T>, U>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& base, U const& exponent)
{
return pow(ctx, base, Complex<U>{exponent});
}
};

//! The pow trait specialization for real and complex types.
template<typename TAcc, typename T, typename U>
struct Pow<TAcc, T, Complex<U>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, T const& base, Complex<U> const& exponent)
{
return pow(ctx, Complex<T>{base}, exponent);
}
};

//! The rsqrt trait specialization for complex types.
template<typename TAcc, typename T>
struct Rsqrt<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& arg)
{
return static_cast<T>(1.0) / sqrt(ctx, arg);
}
};

//! The sin trait specialization for complex types.
template<typename TAcc, typename T>
struct Sin<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& arg)
{
// sin(z) = (exp(i * z) - exp(-i * z)) / 2i
return (exp(ctx, Complex<T>{static_cast<T>(0.0), static_cast<T>(1.0)} * arg)
- exp(ctx, Complex<T>{static_cast<T>(0.0), static_cast<T>(-1.0)} * arg))
/ Complex<T>{static_cast<T>(0.0), static_cast<T>(2.0)};
}
};

//! The sinh trait specialization for complex types.
template<typename TAcc, typename T>
struct Sinh<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& arg)
{
// sinh(z) = (exp(z) - exp(-i * z)) / 2
return (exp(ctx, arg) - exp(ctx, static_cast<T>(-1.0) * arg)) / static_cast<T>(2.0);
}
};

//! The sincos trait specialization for complex types.
template<typename TAcc, typename T>
struct SinCos<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(
TAcc const& ctx,
Complex<T> const& arg,
Complex<T>& result_sin,
Complex<T>& result_cos) -> void
{
result_sin = sin(ctx, arg);
result_cos = cos(ctx, arg);
}
};

//! The sqrt trait specialization for complex types.
template<typename TAcc, typename T>
struct Sqrt<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& argument)
{
// Branch cut along the negative real axis (same as for std::complex),
// principal value of sqrt(z) = sqrt(|z|) * e^(i * arg(z) / 2)
auto const halfArg = T(0.5) * arg(ctx, argument);
auto re = T{}, im = T{};
sincos(ctx, halfArg, im, re);
return sqrt(ctx, abs(ctx, argument)) * Complex<T>(re, im);
}
};

//! The tan trait specialization for complex types.
template<typename TAcc, typename T>
struct Tan<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& arg)
{
// tan(z) = i * (e^-iz - e^iz) / (e^-iz + e^iz) = i * (1 - e^2iz) / (1 + e^2iz)
// Warning: this straightforward implementation can easily result in NaN as 0/0 or inf/inf.
auto const expValue = exp(ctx, Complex<T>{static_cast<T>(0.0), static_cast<T>(2.0)} * arg);
return Complex<T>{static_cast<T>(0.0), static_cast<T>(1.0)} * (static_cast<T>(1.0) - expValue)
/ (static_cast<T>(1.0) + expValue);
}
};

//! The tanh trait specialization for complex types.
template<typename TAcc, typename T>
struct Tanh<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& ctx, Complex<T> const& arg)
{
// tanh(z) = (e^z - e^-z)/(e^z+e^-z)
return (exp(ctx, arg) - exp(ctx, static_cast<T>(-1.0) * arg))
/ (exp(ctx, arg) + exp(ctx, static_cast<T>(-1.0) * arg));
}
};
} // namespace math::trait

#endif

} // namespace alpaka
Loading

0 comments on commit d01a3d5

Please sign in to comment.