diff --git a/CHANGELOG.md b/CHANGELOG.md index fc8dd28e..ce34c619 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,14 @@ Documentation for rocRAND is available at ### Changes +* For device-side generators, you can now wrap calls to rocrand_generate_* inside of a hipGraph. There are a few + things to be aware of: + - Generator creation (rocrand_create_generator), initialization (rocrand_initialize_generator), and destruction (rocrand_destroy_generator) must still happen outside the hipGraph. + - After the generator is created, you may call API functions to set its seed, offset, and order. + - After the generator is initialized (but before stream capture or manual graph creation begins), use rocrand_set_stream to set the stream the generator will use within the graph. + - A generator's seed, offset, and stream may not be changed from within the hipGraph. Attempting to do so may result in unpredicable behaviour. + - API calls for the poisson distribution (eg. rocrand_generate_poisson) are not yet supported inside of hipGraphs. + - For sample usage, see the unit tests in test/test_rocrand_hipgraphs.cpp * Building rocRAND now requires a C++17 capable compiler, as the internal library sources now require it. However consuming rocRAND is still possible from C++11 as public headers don't make use of the new features. * Building rocRAND should be faster on machines with multiple CPU cores as the library has been split to multiple compilation units. diff --git a/test/test_rocrand_hipgraphs.cpp b/test/test_rocrand_hipgraphs.cpp new file mode 100644 index 00000000..d6a67144 --- /dev/null +++ b/test/test_rocrand_hipgraphs.cpp @@ -0,0 +1,145 @@ +#include +#include + +#include +#include + +#include "test_common.hpp" +#include "test_rocrand_common.hpp" +#include "test_utils_hipgraphs.hpp" + +class rocrand_hipgraph_generate_tests : public ::testing::TestWithParam {}; + +void test_float(std::function generate_fn, rocrand_rng_type rng_type) +{ + rocrand_generator generator; + ROCRAND_CHECK( + rocrand_create_generator( + &generator, + rng_type + ) + ); + + ROCRAND_CHECK(rocrand_initialize_generator(generator)); + + const size_t size = 12563; + float mean = 5.0f; + float stddev = 2.0f; + float * data; + HIP_CHECK(hipMallocHelper(&data, size * sizeof(float))); + HIP_CHECK(hipDeviceSynchronize()); + + // Default stream does not support hipGraph stream capture, so create a non-blocking one + hipStream_t stream = 0; + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + rocrand_set_stream(generator, stream); + + hipGraphExec_t graph_instance; + hipGraph_t graph = test_utils::createGraphHelper(stream); + + // Any sizes + ROCRAND_CHECK( + generate_fn(generator, data, 1, mean, stddev) + ); + + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + test_utils::resetGraphHelper(graph, graph_instance, stream); + + // Any alignment + ROCRAND_CHECK( + generate_fn(generator, data+1, 2, mean, stddev) + ); + + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + test_utils::resetGraphHelper(graph, graph_instance, stream); + + ROCRAND_CHECK( + generate_fn(generator, data, size, mean, stddev) + ); + + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + + HIP_CHECK(hipFree(data)); + ROCRAND_CHECK(rocrand_destroy_generator(generator)); + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); +} + +TEST_P(rocrand_hipgraph_generate_tests, normal_float_test) +{ + auto generator_fcn = [](rocrand_generator generator, float* output_data, size_t n, float mean, float stddev) + { + return rocrand_generate_normal(generator, output_data, n, mean, stddev); + }; + + test_float(generator_fcn, GetParam()); +} + +TEST_P(rocrand_hipgraph_generate_tests, log_normal_float_test) +{ + auto generator_fcn = [](rocrand_generator generator, float* output_data, size_t n, float mean, float stddev) + { + return rocrand_generate_log_normal(generator, output_data, n, mean, stddev); + }; + + test_float(generator_fcn, GetParam()); +} + +TEST_P(rocrand_hipgraph_generate_tests, uniform_float_test) +{ + const rocrand_rng_type rng_type = GetParam(); + + rocrand_generator generator; + ROCRAND_CHECK( + rocrand_create_generator( + &generator, + rng_type + ) + ); + + ROCRAND_CHECK(rocrand_initialize_generator(generator)); + + const size_t size = 12563; + float * data; + HIP_CHECK(hipMallocHelper(&data, size * sizeof(float))); + HIP_CHECK(hipDeviceSynchronize()); + + // Default stream does not support hipGraph stream capture, so create a non-blocking one + hipStream_t stream = 0; + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + rocrand_set_stream(generator, stream); + + hipGraphExec_t graph_instance; + hipGraph_t graph = test_utils::createGraphHelper(stream); + + // Any sizes + ROCRAND_CHECK( + rocrand_generate_uniform(generator, data, 1) + ); + + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + test_utils::resetGraphHelper(graph, graph_instance, stream); + + // Any alignment + ROCRAND_CHECK( + rocrand_generate_uniform(generator, data+1, 2) + ); + + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + test_utils::resetGraphHelper(graph, graph_instance, stream); + + ROCRAND_CHECK( + rocrand_generate_uniform(generator, data, size) + ); + + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + + HIP_CHECK(hipFree(data)); + ROCRAND_CHECK(rocrand_destroy_generator(generator)); + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); +} + +INSTANTIATE_TEST_SUITE_P(rocrand_hipgraph_generate_tests, + rocrand_hipgraph_generate_tests, + ::testing::ValuesIn(rng_types)); diff --git a/test/test_utils_hipgraphs.hpp b/test/test_utils_hipgraphs.hpp new file mode 100644 index 00000000..37dc61a4 --- /dev/null +++ b/test/test_utils_hipgraphs.hpp @@ -0,0 +1,91 @@ +// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCRAND_TEST_UTILS_HIPGRAPHS_HPP +#define ROCRAND_TEST_UTILS_HIPGRAPHS_HPP + +#include +#include "test_common.hpp" + +// Helper functions for testing with hipGraph stream capture. +// Note: graphs will not work on the default stream. +namespace test_utils +{ + + inline hipGraph_t createGraphHelper(hipStream_t& stream, const bool beginCapture=true) + { + // Create a new graph + hipGraph_t graph; + HIP_CHECK_NON_VOID(hipGraphCreate(&graph, 0)); + + // Optionally begin stream capture + if (beginCapture) + HIP_CHECK_NON_VOID(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal)); + + return graph; + } + + inline void cleanupGraphHelper(hipGraph_t& graph, hipGraphExec_t& instance) + { + HIP_CHECK_NON_VOID(hipGraphDestroy(graph)); + HIP_CHECK_NON_VOID(hipGraphExecDestroy(instance)); + } + + inline void resetGraphHelper(hipGraph_t& graph, hipGraphExec_t& instance, hipStream_t& stream, const bool beginCapture=true) + { + // Destroy the old graph and instance + cleanupGraphHelper(graph, instance); + + // Create a new graph and optionally begin capture + graph = createGraphHelper(stream, beginCapture); + } + + inline hipGraphExec_t endCaptureGraphHelper(hipGraph_t& graph, hipStream_t& stream, const bool launchGraph=false, const bool sync=false) + { + // End the capture + HIP_CHECK_NON_VOID(hipStreamEndCapture(stream, &graph)); + + // Instantiate the graph + hipGraphExec_t instance; + HIP_CHECK_NON_VOID(hipGraphInstantiate(&instance, graph, nullptr, nullptr, 0)); + + // Optionally launch the graph + if (launchGraph) + HIP_CHECK_NON_VOID(hipGraphLaunch(instance, stream)); + + // Optionally synchronize the stream when we're done + if (sync) + HIP_CHECK_NON_VOID(hipStreamSynchronize(stream)); + + return instance; + } + + inline void launchGraphHelper(hipGraphExec_t& instance, hipStream_t& stream, const bool sync=false) + { + HIP_CHECK_NON_VOID(hipGraphLaunch(instance, stream)); + + // Optionally sync after the launch + if (sync) + HIP_CHECK_NON_VOID(hipStreamSynchronize(stream)); + } + +} // end namespace test_utils + +#endif //ROCRAND_TEST_UTILS_HIPGRAPHS_HPP