Skip to content

Commit

Permalink
Updates and test coverage for hipGraph support in rocRAND (#439)
Browse files Browse the repository at this point in the history
This change allows device-side generators to be used inside of hipGraphs.
More specifically, you can wrap calls to rocrand_generate_* inside of a hipGraph. There are a few things to be aware of:
  - Generator creation (rocrand_create_generator), initialization (rocrand_initialize_generator), and destruction (rocrand_destroy_generator) must still happen outside the hipGraph.
  - After the generator is created, you may call API functions to set its seed, offset, and order.
  - After the generator is initialized (but before stream capture or manual graph creation begins), use rocrand_set_stream to set the stream the generator will use within the graph.
  - A generator's seed, offset, and stream may not be changed from within the hipGraph. Attempting to do so may result in unpredicable behaviour.
  - API calls for the poisson distribution (eg. rocrand_generate_poisson) are not yet supported inside of hipGraphs.

I've added a note to the changelog that mentions these details.
In addition to the changes necessary to support the behaviour described above, this change also:
- updates the changelog to alert the user to the restrictions mentioned above
- adds new unit test coverage to exercises generators and distributions within hipGraphs.
  • Loading branch information
umfranzw authored Apr 22, 2024
1 parent 22f00da commit c710cb6
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 0 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ Documentation for rocRAND is available at

### Changes

* For device-side generators, you can now wrap calls to rocrand_generate_* inside of a hipGraph. There are a few
things to be aware of:
- Generator creation (rocrand_create_generator), initialization (rocrand_initialize_generator), and destruction (rocrand_destroy_generator) must still happen outside the hipGraph.
- After the generator is created, you may call API functions to set its seed, offset, and order.
- After the generator is initialized (but before stream capture or manual graph creation begins), use rocrand_set_stream to set the stream the generator will use within the graph.
- A generator's seed, offset, and stream may not be changed from within the hipGraph. Attempting to do so may result in unpredicable behaviour.
- API calls for the poisson distribution (eg. rocrand_generate_poisson) are not yet supported inside of hipGraphs.
- For sample usage, see the unit tests in test/test_rocrand_hipgraphs.cpp
* Building rocRAND now requires a C++17 capable compiler, as the internal library sources now require it. However consuming rocRAND is still possible from C++11 as public headers don't make use of the new features.
* Building rocRAND should be faster on machines with multiple CPU cores as the library has been
split to multiple compilation units.
Expand Down
145 changes: 145 additions & 0 deletions test/test_rocrand_hipgraphs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#include <stdio.h>
#include <gtest/gtest.h>

#include <hip/hip_runtime.h>
#include <rocrand/rocrand.h>

#include "test_common.hpp"
#include "test_rocrand_common.hpp"
#include "test_utils_hipgraphs.hpp"

class rocrand_hipgraph_generate_tests : public ::testing::TestWithParam<rocrand_rng_type> {};

void test_float(std::function<rocrand_status(rocrand_generator, float*, size_t, float, float)> generate_fn, rocrand_rng_type rng_type)
{
rocrand_generator generator;
ROCRAND_CHECK(
rocrand_create_generator(
&generator,
rng_type
)
);

ROCRAND_CHECK(rocrand_initialize_generator(generator));

const size_t size = 12563;
float mean = 5.0f;
float stddev = 2.0f;
float * data;
HIP_CHECK(hipMallocHelper(&data, size * sizeof(float)));
HIP_CHECK(hipDeviceSynchronize());

// Default stream does not support hipGraph stream capture, so create a non-blocking one
hipStream_t stream = 0;
HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking));
rocrand_set_stream(generator, stream);

hipGraphExec_t graph_instance;
hipGraph_t graph = test_utils::createGraphHelper(stream);

// Any sizes
ROCRAND_CHECK(
generate_fn(generator, data, 1, mean, stddev)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);

// Any alignment
ROCRAND_CHECK(
generate_fn(generator, data+1, 2, mean, stddev)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);

ROCRAND_CHECK(
generate_fn(generator, data, size, mean, stddev)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);

HIP_CHECK(hipFree(data));
ROCRAND_CHECK(rocrand_destroy_generator(generator));
test_utils::cleanupGraphHelper(graph, graph_instance);
HIP_CHECK(hipStreamDestroy(stream));
}

TEST_P(rocrand_hipgraph_generate_tests, normal_float_test)
{
auto generator_fcn = [](rocrand_generator generator, float* output_data, size_t n, float mean, float stddev)
{
return rocrand_generate_normal(generator, output_data, n, mean, stddev);
};

test_float(generator_fcn, GetParam());
}

TEST_P(rocrand_hipgraph_generate_tests, log_normal_float_test)
{
auto generator_fcn = [](rocrand_generator generator, float* output_data, size_t n, float mean, float stddev)
{
return rocrand_generate_log_normal(generator, output_data, n, mean, stddev);
};

test_float(generator_fcn, GetParam());
}

TEST_P(rocrand_hipgraph_generate_tests, uniform_float_test)
{
const rocrand_rng_type rng_type = GetParam();

rocrand_generator generator;
ROCRAND_CHECK(
rocrand_create_generator(
&generator,
rng_type
)
);

ROCRAND_CHECK(rocrand_initialize_generator(generator));

const size_t size = 12563;
float * data;
HIP_CHECK(hipMallocHelper(&data, size * sizeof(float)));
HIP_CHECK(hipDeviceSynchronize());

// Default stream does not support hipGraph stream capture, so create a non-blocking one
hipStream_t stream = 0;
HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking));
rocrand_set_stream(generator, stream);

hipGraphExec_t graph_instance;
hipGraph_t graph = test_utils::createGraphHelper(stream);

// Any sizes
ROCRAND_CHECK(
rocrand_generate_uniform(generator, data, 1)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);

// Any alignment
ROCRAND_CHECK(
rocrand_generate_uniform(generator, data+1, 2)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);

ROCRAND_CHECK(
rocrand_generate_uniform(generator, data, size)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);

HIP_CHECK(hipFree(data));
ROCRAND_CHECK(rocrand_destroy_generator(generator));
test_utils::cleanupGraphHelper(graph, graph_instance);
HIP_CHECK(hipStreamDestroy(stream));
}

INSTANTIATE_TEST_SUITE_P(rocrand_hipgraph_generate_tests,
rocrand_hipgraph_generate_tests,
::testing::ValuesIn(rng_types));
91 changes: 91 additions & 0 deletions test/test_utils_hipgraphs.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

#ifndef ROCRAND_TEST_UTILS_HIPGRAPHS_HPP
#define ROCRAND_TEST_UTILS_HIPGRAPHS_HPP

#include <hip/hip_runtime.h>
#include "test_common.hpp"

// Helper functions for testing with hipGraph stream capture.
// Note: graphs will not work on the default stream.
namespace test_utils
{

inline hipGraph_t createGraphHelper(hipStream_t& stream, const bool beginCapture=true)
{
// Create a new graph
hipGraph_t graph;
HIP_CHECK_NON_VOID(hipGraphCreate(&graph, 0));

// Optionally begin stream capture
if (beginCapture)
HIP_CHECK_NON_VOID(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));

return graph;
}

inline void cleanupGraphHelper(hipGraph_t& graph, hipGraphExec_t& instance)
{
HIP_CHECK_NON_VOID(hipGraphDestroy(graph));
HIP_CHECK_NON_VOID(hipGraphExecDestroy(instance));
}

inline void resetGraphHelper(hipGraph_t& graph, hipGraphExec_t& instance, hipStream_t& stream, const bool beginCapture=true)
{
// Destroy the old graph and instance
cleanupGraphHelper(graph, instance);

// Create a new graph and optionally begin capture
graph = createGraphHelper(stream, beginCapture);
}

inline hipGraphExec_t endCaptureGraphHelper(hipGraph_t& graph, hipStream_t& stream, const bool launchGraph=false, const bool sync=false)
{
// End the capture
HIP_CHECK_NON_VOID(hipStreamEndCapture(stream, &graph));

// Instantiate the graph
hipGraphExec_t instance;
HIP_CHECK_NON_VOID(hipGraphInstantiate(&instance, graph, nullptr, nullptr, 0));

// Optionally launch the graph
if (launchGraph)
HIP_CHECK_NON_VOID(hipGraphLaunch(instance, stream));

// Optionally synchronize the stream when we're done
if (sync)
HIP_CHECK_NON_VOID(hipStreamSynchronize(stream));

return instance;
}

inline void launchGraphHelper(hipGraphExec_t& instance, hipStream_t& stream, const bool sync=false)
{
HIP_CHECK_NON_VOID(hipGraphLaunch(instance, stream));

// Optionally sync after the launch
if (sync)
HIP_CHECK_NON_VOID(hipStreamSynchronize(stream));
}

} // end namespace test_utils

#endif //ROCRAND_TEST_UTILS_HIPGRAPHS_HPP

0 comments on commit c710cb6

Please sign in to comment.