Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed memory leaks in rocrand_tests #557

Merged
2 changes: 2 additions & 0 deletions test/internal/test_rocrand_config_dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ TEST(rocrand_config_dispatch_tests, host_matches_device)

ASSERT_NE(host_arch, rocrand_impl::host::target_arch::invalid);
ASSERT_EQ(host_arch, device_arch);

HIP_CHECK(hipFree(device_arch_ptr));
}

TEST(rocrand_config_dispatch_tests, parse_common_architectures)
Expand Down
6 changes: 4 additions & 2 deletions test/test_rocrand_cpp_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ TYPED_TEST(rocrand_cpp_basic_tests, move_construction)

float actual;
HIP_CHECK(hipMemcpy(&actual, d_data, sizeof(actual), hipMemcpyDeviceToHost));

ASSERT_EQ(expected, actual);

HIP_CHECK(hipFree(d_data));
}

TYPED_TEST(rocrand_cpp_basic_tests, move_assignment)
Expand Down Expand Up @@ -119,6 +120,7 @@ TYPED_TEST(rocrand_cpp_basic_tests, move_assignment)

float actual;
HIP_CHECK(hipMemcpy(&actual, d_data, sizeof(actual), hipMemcpyDeviceToHost));

ASSERT_EQ(expected, actual);

HIP_CHECK(hipFree(d_data));
}
50 changes: 25 additions & 25 deletions test/test_rocrand_hipgraphs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,34 +34,34 @@ void test_float(std::function<rocrand_status(rocrand_generator, float*, size_t,
HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking));
rocrand_set_stream(generator, stream);

hipGraphExec_t graph_instance;
hipGraph_t graph = test_utils::createGraphHelper(stream);
test_utils::GraphHelper gHelper;
NguyenNhuDi marked this conversation as resolved.
Show resolved Hide resolved

gHelper.startStreamCapture(stream);

// Any sizes
ROCRAND_CHECK(
generate_fn(generator, data, 1, mean, stddev)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);
gHelper.createAndLaunchGraph(stream);
gHelper.resetGraphHelper(stream);

// Any alignment
ROCRAND_CHECK(
generate_fn(generator, data+1, 2, mean, stddev)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);
gHelper.createAndLaunchGraph(stream);
gHelper.resetGraphHelper(stream);

ROCRAND_CHECK(
generate_fn(generator, data, size, mean, stddev)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
gHelper.createAndLaunchGraph(stream);

HIP_CHECK(hipFree(data));
ROCRAND_CHECK(rocrand_destroy_generator(generator));
test_utils::cleanupGraphHelper(graph, graph_instance);
gHelper.cleanupGraphHelper();
HIP_CHECK(hipStreamDestroy(stream));
}

Expand Down Expand Up @@ -109,34 +109,34 @@ TEST_P(rocrand_hipgraph_generate_tests, uniform_float_test)
HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking));
rocrand_set_stream(generator, stream);

hipGraphExec_t graph_instance;
hipGraph_t graph = test_utils::createGraphHelper(stream);
test_utils::GraphHelper gHelper;
gHelper.startStreamCapture(stream);

// Any sizes
ROCRAND_CHECK(
rocrand_generate_uniform(generator, data, 1)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);
gHelper.createAndLaunchGraph(stream);
gHelper.resetGraphHelper(stream);

// Any alignment
ROCRAND_CHECK(
rocrand_generate_uniform(generator, data+1, 2)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);
gHelper.createAndLaunchGraph(stream);
gHelper.resetGraphHelper(stream);

ROCRAND_CHECK(
rocrand_generate_uniform(generator, data, size)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
gHelper.createAndLaunchGraph(stream);

HIP_CHECK(hipFree(data));
ROCRAND_CHECK(rocrand_destroy_generator(generator));
test_utils::cleanupGraphHelper(graph, graph_instance);
gHelper.cleanupGraphHelper();
HIP_CHECK(hipStreamDestroy(stream));
}

Expand All @@ -159,28 +159,28 @@ TEST_P(rocrand_hipgraph_generate_tests, poisson_test)
HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking));
rocrand_set_stream(generator, stream);

hipGraphExec_t graph_instance;
hipGraph_t graph = test_utils::createGraphHelper(stream);
test_utils::GraphHelper gHelper;
gHelper.startStreamCapture(stream);

// Any sizes
ROCRAND_CHECK(rocrand_generate_poisson(generator, data, 1, 10.0));

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);
gHelper.createAndLaunchGraph(stream);
gHelper.resetGraphHelper(stream);

// Any alignment
ROCRAND_CHECK(rocrand_generate_poisson(generator, data + 1, 2, 500.0));

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);
gHelper.createAndLaunchGraph(stream);
gHelper.resetGraphHelper(stream);

ROCRAND_CHECK(rocrand_generate_poisson(generator, data, size, 5000.0));

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
gHelper.createAndLaunchGraph(stream);

HIP_CHECK(hipFree(data));
ROCRAND_CHECK(rocrand_destroy_generator(generator));
test_utils::cleanupGraphHelper(graph, graph_instance);
gHelper.cleanupGraphHelper();
HIP_CHECK(hipStreamDestroy(stream));
}

Expand Down
112 changes: 55 additions & 57 deletions test/test_utils_hipgraphs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,64 +28,62 @@
// Note: graphs will not work on the default stream.
namespace test_utils
{
class GraphHelper{
private:
hipGraph_t graph;
hipGraphExec_t graph_instance;
public:

inline void startStreamCapture(hipStream_t & stream)
{
HIP_CHECK_NON_VOID(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));
}

inline void endStreamCapture(hipStream_t & stream)
{
HIP_CHECK_NON_VOID(hipStreamEndCapture(stream, &graph));
}

inline void createAndLaunchGraph(hipStream_t & stream, const bool launchGraph=true, const bool sync=true)
{

endStreamCapture(stream);

HIP_CHECK_NON_VOID(hipGraphInstantiate(&graph_instance, graph, nullptr, nullptr, 0));

// Optionally launch the graph
if (launchGraph)
HIP_CHECK_NON_VOID(hipGraphLaunch(graph_instance, stream));

// Optionally synchronize the stream when we're done
if (sync)
HIP_CHECK_NON_VOID(hipStreamSynchronize(stream));
}

inline hipGraph_t createGraphHelper(hipStream_t& stream, const bool beginCapture=true)
{
// Create a new graph
hipGraph_t graph;
HIP_CHECK_NON_VOID(hipGraphCreate(&graph, 0));

// Optionally begin stream capture
if (beginCapture)
HIP_CHECK_NON_VOID(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));

return graph;
}

inline void cleanupGraphHelper(hipGraph_t& graph, hipGraphExec_t& instance)
{
HIP_CHECK_NON_VOID(hipGraphDestroy(graph));
HIP_CHECK_NON_VOID(hipGraphExecDestroy(instance));
}

inline void resetGraphHelper(hipGraph_t& graph, hipGraphExec_t& instance, hipStream_t& stream, const bool beginCapture=true)
{
// Destroy the old graph and instance
cleanupGraphHelper(graph, instance);

// Create a new graph and optionally begin capture
graph = createGraphHelper(stream, beginCapture);
}

inline hipGraphExec_t endCaptureGraphHelper(hipGraph_t& graph, hipStream_t& stream, const bool launchGraph=false, const bool sync=false)
{
// End the capture
HIP_CHECK_NON_VOID(hipStreamEndCapture(stream, &graph));

// Instantiate the graph
hipGraphExec_t instance;
HIP_CHECK_NON_VOID(hipGraphInstantiate(&instance, graph, nullptr, nullptr, 0));

// Optionally launch the graph
if (launchGraph)
HIP_CHECK_NON_VOID(hipGraphLaunch(instance, stream));

// Optionally synchronize the stream when we're done
if (sync)
HIP_CHECK_NON_VOID(hipStreamSynchronize(stream));

return instance;
}

inline void launchGraphHelper(hipGraphExec_t& instance, hipStream_t& stream, const bool sync=false)
{
HIP_CHECK_NON_VOID(hipGraphLaunch(instance, stream));

// Optionally sync after the launch
if (sync)
HIP_CHECK_NON_VOID(hipStreamSynchronize(stream));
}

inline void cleanupGraphHelper()
{
HIP_CHECK_NON_VOID(hipGraphDestroy(this->graph));
HIP_CHECK_NON_VOID(hipGraphExecDestroy(this->graph_instance));
}

inline void resetGraphHelper(hipStream_t& stream, const bool beginCapture=true)
{
// Destroy the old graph and instance
cleanupGraphHelper();

if(beginCapture)
startStreamCapture(stream);
}

inline void launchGraphHelper(hipStream_t& stream,const bool sync=false)
{
HIP_CHECK_NON_VOID(hipGraphLaunch(this->graph_instance, stream));

// Optionally sync after the launch
if (sync)
HIP_CHECK_NON_VOID(hipStreamSynchronize(stream));
}
};
} // end namespace test_utils

#endif //ROCRAND_TEST_UTILS_HIPGRAPHS_HPP
Loading