Skip to content

Commit

Permalink
Add batched cholesky implementation and tests (#1029)
Browse files Browse the repository at this point in the history
* add batched cholesky implementation and tests

* missing files

* fix correctness issues in transpose lower implementation

* address PR comments

* remove print statements

* address more PR comments

* test fixes

* remove outdated comment

* Add missing "throws exception" annotation

---------

Co-authored-by: Manolis Papadakis <[email protected]>
  • Loading branch information
jjwilke and manopapad authored Nov 8, 2023
1 parent 8c67416 commit b66e2ec
Show file tree
Hide file tree
Showing 17 changed files with 692 additions and 110 deletions.
2 changes: 2 additions & 0 deletions cunumeric/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class _CunumericSharedLib:
CUNUMERIC_ADVANCED_INDEXING: int
CUNUMERIC_ARANGE: int
CUNUMERIC_ARGWHERE: int
CUNUMERIC_BATCHED_CHOLESKY: int
CUNUMERIC_BINARY_OP: int
CUNUMERIC_BINARY_RED: int
CUNUMERIC_BINCOUNT: int
Expand Down Expand Up @@ -333,6 +334,7 @@ class CuNumericOpCode(IntEnum):
ADVANCED_INDEXING = _cunumeric.CUNUMERIC_ADVANCED_INDEXING
ARANGE = _cunumeric.CUNUMERIC_ARANGE
ARGWHERE = _cunumeric.CUNUMERIC_ARGWHERE
BATCHED_CHOLESKY = _cunumeric.CUNUMERIC_BATCHED_CHOLESKY
BINARY_OP = _cunumeric.CUNUMERIC_BINARY_OP
BINARY_RED = _cunumeric.CUNUMERIC_BINARY_RED
BINCOUNT = _cunumeric.CUNUMERIC_BINCOUNT
Expand Down
40 changes: 38 additions & 2 deletions cunumeric/linalg/cholesky.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021-2022 NVIDIA Corporation
# Copyright 2023 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -202,11 +202,47 @@ def tril(context: Context, p_output: StorePartition, n: int) -> None:
task.execute()


def _batched_cholesky(output: DeferredArray, input: DeferredArray) -> None:
# the only feasible implementation for right now is that
# each cholesky submatrix fits on a single proc. We will have
# wildly varying memory available depending on the system.
# Just use a fixed cutoff to provide some sensible warning.
# TODO: find a better way to inform the user dims are too big
context: Context = output.context
task = context.create_auto_task(CuNumericOpCode.BATCHED_CHOLESKY)
task.add_input(input.base)
task.add_output(output.base)
ndim = input.base.ndim
task.add_broadcast(input.base, (ndim - 2, ndim - 1))
task.add_broadcast(output.base, (ndim - 2, ndim - 1))
task.add_alignment(input.base, output.base)
task.throws_exception(LinAlgError)
task.execute()


def cholesky(
output: DeferredArray, input: DeferredArray, no_tril: bool
) -> None:
runtime = output.runtime
context = output.context
context: Context = output.context
if len(input.base.shape) > 2:
if no_tril:
raise NotImplementedError(
"batched cholesky expects to only "
"produce the lower triangular matrix"
)
size = input.base.shape[-1]
# Choose 32768 as dimension cutoff for warning
# so that for float64 anything larger than
# 8 GiB produces a warning
if size > 32768:
runtime.warn(
"batched cholesky is only valid"
" when the square submatrices fit"
f" on a single proc, n > {size} may be too large",
category=UserWarning,
)
return _batched_cholesky(output, input)

if runtime.num_procs == 1:
transpose_copy_single(context, input.base, output.base)
Expand Down
4 changes: 0 additions & 4 deletions cunumeric/linalg/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,6 @@ def cholesky(a: ndarray) -> ndarray:
elif shape[-1] != shape[-2]:
raise ValueError("Last 2 dimensions of the array must be square")

if len(shape) > 2:
raise NotImplementedError(
"cuNumeric needs to support stacked 2d arrays"
)
return _cholesky(a)


Expand Down
3 changes: 3 additions & 0 deletions cunumeric_cpp.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ list(APPEND cunumeric_SOURCES
src/cunumeric/index/putmask.cc
src/cunumeric/item/read.cc
src/cunumeric/item/write.cc
src/cunumeric/matrix/batched_cholesky.cc
src/cunumeric/matrix/contract.cc
src/cunumeric/matrix/diag.cc
src/cunumeric/matrix/gemm.cc
Expand Down Expand Up @@ -195,6 +196,7 @@ if(Legion_USE_OpenMP)
src/cunumeric/index/repeat_omp.cc
src/cunumeric/index/wrap_omp.cc
src/cunumeric/index/zip_omp.cc
src/cunumeric/matrix/batched_cholesky_omp.cc
src/cunumeric/matrix/contract_omp.cc
src/cunumeric/matrix/diag_omp.cc
src/cunumeric/matrix/gemm_omp.cc
Expand Down Expand Up @@ -245,6 +247,7 @@ if(Legion_USE_CUDA)
src/cunumeric/index/putmask.cu
src/cunumeric/item/read.cu
src/cunumeric/item/write.cu
src/cunumeric/matrix/batched_cholesky.cu
src/cunumeric/matrix/contract.cu
src/cunumeric/matrix/diag.cu
src/cunumeric/matrix/gemm.cu
Expand Down
1 change: 1 addition & 0 deletions src/cunumeric/cunumeric_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ enum CuNumericOpCode {
CUNUMERIC_ADVANCED_INDEXING,
CUNUMERIC_ARANGE,
CUNUMERIC_ARGWHERE,
CUNUMERIC_BATCHED_CHOLESKY,
CUNUMERIC_BINARY_OP,
CUNUMERIC_BINARY_RED,
CUNUMERIC_BINCOUNT,
Expand Down
19 changes: 19 additions & 0 deletions src/cunumeric/mapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,25 @@ std::vector<StoreMapping> CuNumericMapper::store_mappings(
}
return std::move(mappings);
}
// CHANGE: If this code is changed, make sure all layouts are
// consistent with those assumed in batched_cholesky.cu, etc
case CUNUMERIC_BATCHED_CHOLESKY: {
std::vector<StoreMapping> mappings;
auto& inputs = task.inputs();
auto& outputs = task.outputs();
mappings.reserve(inputs.size() + outputs.size());
for (auto& input : inputs) {
mappings.push_back(StoreMapping::default_mapping(input, options.front()));
mappings.back().policy.exact = true;
mappings.back().policy.ordering.set_c_order();
}
for (auto& output : outputs) {
mappings.push_back(StoreMapping::default_mapping(output, options.front()));
mappings.back().policy.exact = true;
mappings.back().policy.ordering.set_c_order();
}
return std::move(mappings);
}
case CUNUMERIC_TRILU: {
if (task.scalars().size() == 2) return {};
// If we're here, this task was the post-processing for Cholesky.
Expand Down
85 changes: 85 additions & 0 deletions src/cunumeric/matrix/batched_cholesky.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/* Copyright 2023 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

#include "cunumeric/matrix/batched_cholesky.h"
#include "cunumeric/cunumeric.h"
#include "cunumeric/matrix/batched_cholesky_template.inl"

#include <cblas.h>
#include <core/type/type_info.h>
#include <lapack.h>

namespace cunumeric {

using namespace legate;

template <>
void CopyBlockImpl<VariantKind::CPU>::operator()(void* dst, const void* src, size_t size)
{
::memcpy(dst, src, size);
}

template <Type::Code CODE>
struct BatchedTransposeImplBody<VariantKind::CPU, CODE> {
using VAL = legate_type_of<CODE>;

static constexpr int tile_size = 64;

void operator()(VAL* out, int n) const
{
VAL tile[tile_size][tile_size];
int nblocks = (n + tile_size - 1) / tile_size;

for (int rb = 0; rb < nblocks; ++rb) {
for (int cb = 0; cb < nblocks; ++cb) {
int r_start = rb * tile_size;
int r_stop = std::min(r_start + tile_size, n);
int c_start = cb * tile_size;
int c_stop = std::min(c_start + tile_size, n);
for (int r = r_start, tr = 0; r < r_stop; ++r, ++tr) {
for (int c = c_start, tc = 0; c < c_stop; ++c, ++tc) {
if (r <= c) {
tile[tr][tc] = out[r * n + c];
} else {
tile[tr][tc] = 0;
}
}
}
for (int r = c_start, tr = 0; r < c_stop; ++r, ++tr) {
for (int c = r_start, tc = 0; c < r_stop; ++c, ++tc) { out[r * n + c] = tile[tc][tr]; }
}
}
}
}
};

/*static*/ void BatchedCholeskyTask::cpu_variant(TaskContext& context)
{
#ifdef LEGATE_USE_OPENMP
openblas_set_num_threads(1); // make sure this isn't overzealous
#endif
batched_cholesky_task_context_dispatch<VariantKind::CPU>(context);
}

namespace // unnamed
{
static void __attribute__((constructor)) register_tasks(void)
{
BatchedCholeskyTask::register_variants();
}
} // namespace

} // namespace cunumeric
111 changes: 111 additions & 0 deletions src/cunumeric/matrix/batched_cholesky.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/* Copyright 2023 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

#include "cunumeric/matrix/batched_cholesky.h"
#include "cunumeric/matrix/potrf.h"
#include "cunumeric/matrix/batched_cholesky_template.inl"

#include "cunumeric/cuda_help.h"

namespace cunumeric {

using namespace legate;

#define TILE_DIM 32
#define BLOCK_ROWS 8

template <>
void CopyBlockImpl<VariantKind::GPU>::operator()(void* dst, const void* src, size_t size)
{
cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, get_cached_stream());
}

template <typename VAL>
__global__ static void __launch_bounds__((TILE_DIM * BLOCK_ROWS), MIN_CTAS_PER_SM)
transpose_2d_lower(VAL* out, int n)
{
__shared__ VAL tile[TILE_DIM][TILE_DIM + 1 /*avoid bank conflicts*/];

// The y dim is fast-moving index for coalescing
auto r_block = blockIdx.x * TILE_DIM;
auto c_block = blockIdx.y * TILE_DIM;
auto r = blockIdx.x * TILE_DIM + threadIdx.x;
auto c = blockIdx.y * TILE_DIM + threadIdx.y;
auto stride = BLOCK_ROWS;
// The tile coordinates
auto tr = threadIdx.x;
auto tc = threadIdx.y;
auto offset = r * n + c;

// only execute across the upper diagonal
// a single thread block will store the upper diagonal block into
// a temp shared memory then set the block to zeros
if (c_block >= r_block) {
#pragma unroll
for (int i = 0; i < TILE_DIM; i += BLOCK_ROWS, offset += stride) {
if (r < n && (c + i) < n) {
if (r <= (c + i)) {
tile[tr][tc + i] = out[offset];
// clear the upper diagonal entry
out[offset] = 0;
} else {
tile[tr][tc + i] = 0;
}
}
}

// Make sure all the data is in shared memory
__syncthreads();

// Transpose the global coordinates, keep y the fast-moving index
r = blockIdx.y * TILE_DIM + threadIdx.x;
c = blockIdx.x * TILE_DIM + threadIdx.y;
offset = r * n + c;

#pragma unroll
for (int i = 0; i < TILE_DIM; i += BLOCK_ROWS, offset += stride) {
if (r < n && (c + i) < n) {
if (r >= (c + i)) { out[offset] = tile[tc + i][tr]; }
}
}
}
}

template <Type::Code CODE>
struct BatchedTransposeImplBody<VariantKind::GPU, CODE> {
using VAL = legate_type_of<CODE>;

void operator()(VAL* out, int n) const
{
const dim3 blocks((n + TILE_DIM - 1) / TILE_DIM, (n + TILE_DIM - 1) / TILE_DIM, 1);
const dim3 threads(TILE_DIM, BLOCK_ROWS, 1);

auto stream = get_cached_stream();

// CUDA Potrf produces the full matrix, we only want
// the lower diagonal
transpose_2d_lower<VAL><<<blocks, threads, 0, stream>>>(out, n);

CHECK_CUDA_STREAM(stream);
}
};

/*static*/ void BatchedCholeskyTask::gpu_variant(TaskContext& context)
{
batched_cholesky_task_context_dispatch<VariantKind::GPU>(context);
}

} // namespace cunumeric
38 changes: 38 additions & 0 deletions src/cunumeric/matrix/batched_cholesky.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/* Copyright 2021-2022 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

#pragma once

#include "cunumeric/cunumeric.h"
#include "cunumeric/cunumeric_c.h"

namespace cunumeric {

class BatchedCholeskyTask : public CuNumericTask<BatchedCholeskyTask> {
public:
static const int TASK_ID = CUNUMERIC_BATCHED_CHOLESKY;

public:
static void cpu_variant(legate::TaskContext& context);
#ifdef LEGATE_USE_OPENMP
static void omp_variant(legate::TaskContext& context);
#endif
#ifdef LEGATE_USE_CUDA
static void gpu_variant(legate::TaskContext& context);
#endif
};

} // namespace cunumeric
Loading

0 comments on commit b66e2ec

Please sign in to comment.