diff --git a/sycl/source/detail/graph_impl.cpp b/sycl/source/detail/graph_impl.cpp index 408ba1351f6b1..13fa608a5d890 100644 --- a/sycl/source/detail/graph_impl.cpp +++ b/sycl/source/detail/graph_impl.cpp @@ -757,8 +757,9 @@ void exec_graph_impl::createCommandBuffers( exec_graph_impl::exec_graph_impl(sycl::context Context, const std::shared_ptr &GraphImpl, const property_list &PropList) - : MSchedule(), MGraphImpl(GraphImpl), MPiSyncPoints(), MContext(Context), - MRequirements(), MExecutionEvents(), + : MSchedule(), MGraphImpl(GraphImpl), MPiSyncPoints(), + MDevice(GraphImpl->getDevice()), MContext(Context), MRequirements(), + MExecutionEvents(), MIsUpdatable(PropList.has_property()) { // If the graph has been marked as updatable then check if the backend @@ -1155,9 +1156,48 @@ void exec_graph_impl::duplicateNodes() { MNodeStorage.insert(MNodeStorage.begin(), NewNodes.begin(), NewNodes.end()); } +void exec_graph_impl::update(std::shared_ptr GraphImpl) { + + if (MDevice != GraphImpl->getDevice()) { + throw sycl::exception( + sycl::make_error_code(errc::invalid), + "The graphs must have been created with matching devices."); + } + if (MContext != GraphImpl->getContext()) { + throw sycl::exception( + sycl::make_error_code(errc::invalid), + "The graphs must have been created with matching contexts."); + } + + if (MNodeStorage.size() != GraphImpl->MNodeStorage.size()) { + throw sycl::exception(sycl::make_error_code(errc::invalid), + "Mismatch found in the number of nodes. The graphs " + "must have a matching topology."); + } else { + for (uint32_t i = 0; i < MNodeStorage.size(); ++i) { + if (MNodeStorage[i]->MSuccessors.size() != + GraphImpl->MNodeStorage[i]->MSuccessors.size() || + MNodeStorage[i]->MPredecessors.size() != + GraphImpl->MNodeStorage[i]->MPredecessors.size()) { + throw sycl::exception(sycl::make_error_code(errc::invalid), + "Mismatch found in the number of edges. The " + "graphs must have a matching topology."); + } + } + } + + for (uint32_t i = 0; i < MNodeStorage.size(); ++i) { + MIDCache.insert( + std::make_pair(GraphImpl->MNodeStorage[i]->MID, MNodeStorage[i])); + } + + update(GraphImpl->MNodeStorage); +} + void exec_graph_impl::update(std::shared_ptr Node) { this->update(std::vector>{Node}); } + void exec_graph_impl::update( const std::vector> Nodes) { @@ -1598,9 +1638,7 @@ void executable_command_graph::finalizeImpl() { void executable_command_graph::update( const command_graph &Graph) { - (void)Graph; - throw sycl::exception(sycl::make_error_code(errc::invalid), - "Method not yet implemented"); + impl->update(sycl::detail::getSyclObjImpl(Graph)); } void executable_command_graph::update(const node &Node) { diff --git a/sycl/source/detail/graph_impl.hpp b/sycl/source/detail/graph_impl.hpp index ae6fedbfd12a0..8cc1e75204697 100644 --- a/sycl/source/detail/graph_impl.hpp +++ b/sycl/source/detail/graph_impl.hpp @@ -1281,6 +1281,10 @@ class exec_graph_impl { void createCommandBuffers(sycl::device Device, std::shared_ptr &Partition); + /// Query for the device tied to this graph. + /// @return Device associated with graph. + sycl::device getDevice() const { return MDevice; } + /// Query for the context tied to this graph. /// @return Context associated with graph. sycl::context getContext() const { return MContext; } @@ -1320,6 +1324,7 @@ class exec_graph_impl { return MRequirements; } + void update(std::shared_ptr GraphImpl); void update(std::shared_ptr Node); void update(const std::vector> Nodes); @@ -1408,6 +1413,8 @@ class exec_graph_impl { /// Map of nodes in the exec graph to the partition number to which they /// belong. std::unordered_map, int> MPartitionNodes; + /// Device associated with this executable graph. + sycl::device MDevice; /// Context associated with this executable graph. sycl::context MContext; /// List of requirements for enqueueing this command graph, accumulated from diff --git a/sycl/test-e2e/Graph/Explicit/double_buffer.cpp b/sycl/test-e2e/Graph/Explicit/double_buffer.cpp index 578234b0e0bcf..a9db218bc9271 100644 --- a/sycl/test-e2e/Graph/Explicit/double_buffer.cpp +++ b/sycl/test-e2e/Graph/Explicit/double_buffer.cpp @@ -11,4 +11,4 @@ #define GRAPH_E2E_EXPLICIT -#include "../Inputs/double_buffer.cpp" +#include "../Update/whole_update_double_buffer.cpp" diff --git a/sycl/test-e2e/Graph/Explicit/executable_graph_update.cpp b/sycl/test-e2e/Graph/Explicit/executable_graph_update.cpp index a3cf942d72e22..ca4d30781c5de 100644 --- a/sycl/test-e2e/Graph/Explicit/executable_graph_update.cpp +++ b/sycl/test-e2e/Graph/Explicit/executable_graph_update.cpp @@ -11,4 +11,4 @@ #define GRAPH_E2E_EXPLICIT -#include "../Inputs/executable_graph_update.cpp" +#include "../Update/whole_update_usm.cpp" diff --git a/sycl/test-e2e/Graph/Explicit/executable_graph_update_ordering.cpp b/sycl/test-e2e/Graph/Explicit/executable_graph_update_ordering.cpp index c6ca7cd801ac8..0668d94c35077 100644 --- a/sycl/test-e2e/Graph/Explicit/executable_graph_update_ordering.cpp +++ b/sycl/test-e2e/Graph/Explicit/executable_graph_update_ordering.cpp @@ -13,4 +13,4 @@ #define GRAPH_E2E_EXPLICIT -#include "../Inputs/executable_graph_update_ordering.cpp" +#include "../Update/whole_update_delay.cpp" diff --git a/sycl/test-e2e/Graph/Inputs/double_buffer.cpp b/sycl/test-e2e/Graph/Inputs/double_buffer.cpp deleted file mode 100644 index ac340ecd08091..0000000000000 --- a/sycl/test-e2e/Graph/Inputs/double_buffer.cpp +++ /dev/null @@ -1,104 +0,0 @@ -// Tests executable graph update by creating a double buffering scenario, where -// a single graph is repeatedly executed then updated to swap between two sets -// of buffers. - -#include "../graph_common.hpp" - -int main() { - queue Queue{}; - - using T = int; - - std::vector DataA(Size), DataB(Size), DataC(Size); - std::vector DataA2(Size), DataB2(Size), DataC2(Size); - - std::iota(DataA.begin(), DataA.end(), 1); - std::iota(DataB.begin(), DataB.end(), 10); - std::iota(DataC.begin(), DataC.end(), 1000); - - std::iota(DataA2.begin(), DataA2.end(), 3); - std::iota(DataB2.begin(), DataB2.end(), 13); - std::iota(DataC2.begin(), DataC2.end(), 1333); - - std::vector ReferenceA(DataA), ReferenceB(DataB), ReferenceC(DataC); - std::vector ReferenceA2(DataA2), ReferenceB2(DataB2), ReferenceC2(DataC2); - - calculate_reference_data(Iterations, Size, ReferenceA, ReferenceB, - ReferenceC); - calculate_reference_data(Iterations, Size, ReferenceA2, ReferenceB2, - ReferenceC2); - - exp_ext::command_graph Graph{Queue.get_context(), Queue.get_device()}; - - T *PtrA = malloc_device(Size, Queue); - T *PtrB = malloc_device(Size, Queue); - T *PtrC = malloc_device(Size, Queue); - - T *PtrA2 = malloc_device(Size, Queue); - T *PtrB2 = malloc_device(Size, Queue); - T *PtrC2 = malloc_device(Size, Queue); - - Queue.copy(DataA.data(), PtrA, Size); - Queue.copy(DataB.data(), PtrB, Size); - Queue.copy(DataC.data(), PtrC, Size); - - Queue.copy(DataA2.data(), PtrA, Size); - Queue.copy(DataB2.data(), PtrB, Size); - Queue.copy(DataC2.data(), PtrC, Size); - Queue.wait_and_throw(); - - add_nodes(Graph, Queue, Size, PtrA, PtrB, PtrC); - - auto ExecGraph = Graph.finalize(); - - // Create second graph using other buffer set - exp_ext::command_graph GraphUpdate{Queue.get_context(), Queue.get_device()}; - add_nodes(GraphUpdate, Queue, Size, PtrA, PtrB, PtrC); - - event Event; - for (size_t i = 0; i < Iterations; i++) { - Event = Queue.submit([&](handler &CGH) { - CGH.depends_on(Event); - CGH.ext_oneapi_graph(ExecGraph); - }); - // Update to second set of buffers - ExecGraph.update(GraphUpdate); - Event = Queue.submit([&](handler &CGH) { - CGH.depends_on(Event); - CGH.ext_oneapi_graph(ExecGraph); - }); - // Reset back to original buffers - ExecGraph.update(Graph); - } - - Queue.wait_and_throw(); - - Queue.copy(PtrA, DataA.data(), Size); - Queue.copy(PtrB, DataB.data(), Size); - Queue.copy(PtrC, DataC.data(), Size); - - Queue.copy(PtrA2, DataA2.data(), Size); - Queue.copy(PtrB2, DataB2.data(), Size); - Queue.copy(PtrC2, DataC2.data(), Size); - Queue.wait_and_throw(); - - free(PtrA, Queue); - free(PtrB, Queue); - free(PtrC, Queue); - - free(PtrA2, Queue); - free(PtrB2, Queue); - free(PtrC2, Queue); - - for (size_t i = 0; i < Size; i++) { - assert(check_value(i, ReferenceA[i], DataA[i], "DataA")); - assert(check_value(i, ReferenceB[i], DataB[i], "DataB")); - assert(check_value(i, ReferenceC[i], DataC[i], "DataC")); - - assert(check_value(i, ReferenceA2[i], DataA2[i], "DataA2")); - assert(check_value(i, ReferenceB2[i], DataB2[i], "DataB2")); - assert(check_value(i, ReferenceC2[i], DataC2[i], "DataC2")); - } - - return 0; -} diff --git a/sycl/test-e2e/Graph/RecordReplay/double_buffer.cpp b/sycl/test-e2e/Graph/RecordReplay/double_buffer.cpp index ec66cf58f9f93..76700d8815603 100644 --- a/sycl/test-e2e/Graph/RecordReplay/double_buffer.cpp +++ b/sycl/test-e2e/Graph/RecordReplay/double_buffer.cpp @@ -11,4 +11,4 @@ #define GRAPH_E2E_RECORD_REPLAY -#include "../Inputs/double_buffer.cpp" +#include "../Update/whole_update_double_buffer.cpp" diff --git a/sycl/test-e2e/Graph/RecordReplay/executable_graph_update.cpp b/sycl/test-e2e/Graph/RecordReplay/executable_graph_update.cpp index bbe69300bb08c..e543528555ac4 100644 --- a/sycl/test-e2e/Graph/RecordReplay/executable_graph_update.cpp +++ b/sycl/test-e2e/Graph/RecordReplay/executable_graph_update.cpp @@ -11,4 +11,4 @@ #define GRAPH_E2E_RECORD_REPLAY -#include "../Inputs/executable_graph_update.cpp" +#include "../Update/whole_update_usm.cpp" diff --git a/sycl/test-e2e/Graph/RecordReplay/executable_graph_update_ordering.cpp b/sycl/test-e2e/Graph/RecordReplay/executable_graph_update_ordering.cpp index 71d8d7133780e..aaf5841587b5a 100644 --- a/sycl/test-e2e/Graph/RecordReplay/executable_graph_update_ordering.cpp +++ b/sycl/test-e2e/Graph/RecordReplay/executable_graph_update_ordering.cpp @@ -13,4 +13,4 @@ #define GRAPH_E2E_RECORD_REPLAY -#include "../Inputs/executable_graph_update_ordering.cpp" +#include "../Update/whole_update_delay.cpp" diff --git a/sycl/test-e2e/Graph/Inputs/executable_graph_update_ordering.cpp b/sycl/test-e2e/Graph/Update/whole_update_delay.cpp similarity index 97% rename from sycl/test-e2e/Graph/Inputs/executable_graph_update_ordering.cpp rename to sycl/test-e2e/Graph/Update/whole_update_delay.cpp index c38cf6d4f5e3c..25a3b6186f5fe 100644 --- a/sycl/test-e2e/Graph/Inputs/executable_graph_update_ordering.cpp +++ b/sycl/test-e2e/Graph/Update/whole_update_delay.cpp @@ -1,6 +1,7 @@ // Tests executable graph update by introducing a delay in to the update // transactions dependencies to check correctness of behaviour. - +// TODO This test is disabled because host-tasks are not supported for graph +// updates yet. #include "../graph_common.hpp" #include diff --git a/sycl/test-e2e/Graph/Update/whole_update_double_buffer.cpp b/sycl/test-e2e/Graph/Update/whole_update_double_buffer.cpp new file mode 100644 index 0000000000000..d0f075d2225db --- /dev/null +++ b/sycl/test-e2e/Graph/Update/whole_update_double_buffer.cpp @@ -0,0 +1,107 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out +// Extra run to check for leaks in Level Zero using UR_L0_LEAKS_DEBUG +// RUN: %if level_zero %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=0 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %} +// Extra run to check for immediate-command-list in Level Zero +// RUN: %if level_zero && linux %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %} +// +// UNSUPPORTED: opencl, level_zero + +// Tests executable graph update by creating a double buffering scenario, where +// a single graph is repeatedly executed then updated to swap between two sets +// of buffers. +#define GRAPH_E2E_EXPLICIT + +#include "../graph_common.hpp" + +int main() { + queue Queue{}; + + using T = int; + + std::vector DataA(Size), DataB(Size), DataC(Size); + std::vector DataA2(Size), DataB2(Size), DataC2(Size); + + std::iota(DataA.begin(), DataA.end(), 1); + std::iota(DataB.begin(), DataB.end(), 10); + std::iota(DataC.begin(), DataC.end(), 1000); + + std::iota(DataA2.begin(), DataA2.end(), 3); + std::iota(DataB2.begin(), DataB2.end(), 13); + std::iota(DataC2.begin(), DataC2.end(), 1333); + + std::vector ReferenceA(DataA), ReferenceB(DataB), ReferenceC(DataC); + std::vector ReferenceA2(DataA2), ReferenceB2(DataB2), ReferenceC2(DataC2); + + calculate_reference_data(Iterations, Size, ReferenceA, ReferenceB, + ReferenceC); + calculate_reference_data(Iterations, Size, ReferenceA2, ReferenceB2, + ReferenceC2); + + buffer BufferA{DataA}; + buffer BufferB{DataB}; + buffer BufferC{DataC}; + + buffer BufferA2{DataA2}; + buffer BufferB2{DataB2}; + buffer BufferC2{DataC2}; + + BufferA.set_write_back(false); + BufferB.set_write_back(false); + BufferC.set_write_back(false); + BufferA2.set_write_back(false); + BufferB2.set_write_back(false); + BufferC2.set_write_back(false); + + Queue.wait_and_throw(); + { + exp_ext::command_graph Graph{ + Queue.get_context(), Queue.get_device(), + exp_ext::property::graph::assume_buffer_outlives_graph{}}; + add_nodes(Graph, Queue, Size, BufferA, BufferB, BufferC); + + auto ExecGraph = Graph.finalize(exp_ext::property::graph::updatable{}); + + // Create second graph using other buffer set + exp_ext::command_graph GraphUpdate{ + Queue.get_context(), Queue.get_device(), + exp_ext::property::graph::assume_buffer_outlives_graph{}}; + add_nodes(GraphUpdate, Queue, Size, BufferA2, BufferB2, BufferC2); + + event Event; + for (size_t i = 0; i < Iterations; i++) { + Event = Queue.submit([&](handler &CGH) { + CGH.depends_on(Event); + CGH.ext_oneapi_graph(ExecGraph); + }); + // Update to second set of buffers + ExecGraph.update(GraphUpdate); + Event = Queue.submit([&](handler &CGH) { + CGH.depends_on(Event); + CGH.ext_oneapi_graph(ExecGraph); + }); + // Reset back to original buffers + ExecGraph.update(Graph); + } + + Queue.wait_and_throw(); + } + host_accessor HostDataA(BufferA); + host_accessor HostDataB(BufferB); + host_accessor HostDataC(BufferC); + host_accessor HostDataA2(BufferA2); + host_accessor HostDataB2(BufferB2); + host_accessor HostDataC2(BufferC2); + + for (size_t i = 0; i < Size; i++) { + assert(check_value(i, ReferenceA[i], HostDataA[i], "DataA")); + assert(check_value(i, ReferenceB[i], HostDataB[i], "DataB")); + assert(check_value(i, ReferenceC[i], HostDataC[i], "DataC")); + + assert(check_value(i, ReferenceA2[i], HostDataA2[i], "DataA2")); + assert(check_value(i, ReferenceB2[i], HostDataB2[i], "DataB2")); + assert(check_value(i, ReferenceC2[i], HostDataC2[i], "DataC2")); + } + + return 0; +} diff --git a/sycl/test-e2e/Graph/Update/whole_update_dynamic_param.cpp b/sycl/test-e2e/Graph/Update/whole_update_dynamic_param.cpp new file mode 100644 index 0000000000000..131b265eeeee0 --- /dev/null +++ b/sycl/test-e2e/Graph/Update/whole_update_dynamic_param.cpp @@ -0,0 +1,83 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out +// Extra run to check for leaks in Level Zero using UR_L0_LEAKS_DEBUG +// RUN: %if level_zero %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=0 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %} +// Extra run to check for immediate-command-list in Level Zero +// RUN: %if level_zero && linux %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %} +// +// UNSUPPORTED: opencl, level_zero + +// Tests that whole graph update works when using dynamic parameters. +#include "../graph_common.hpp" + +int main() { + queue Queue{}; + + using T = int; + + std::vector InputDataHost1(Size); + std::vector InputDataHost2(Size); + std::vector OutputDataHost1(Size); + + std::iota(InputDataHost1.begin(), InputDataHost1.end(), 1); + std::iota(InputDataHost2.begin(), InputDataHost2.end(), 10); + std::iota(OutputDataHost1.begin(), OutputDataHost1.end(), 100); + + T *InputDataDevice1 = malloc_device(Size, Queue); + T *InputDataDevice2 = malloc_device(Size, Queue); + T *OutputDataDevice1 = malloc_device(Size, Queue); + + Queue.copy(InputDataHost1.data(), InputDataDevice1, Size); + Queue.copy(InputDataHost2.data(), InputDataDevice2, Size); + Queue.copy(OutputDataHost1.data(), OutputDataDevice1, Size); + + exp_ext::command_graph GraphA{Queue.get_context(), Queue.get_device()}; + + exp_ext::dynamic_parameter InputParam(GraphA, InputDataDevice1); + auto GraphANode = GraphA.add([&](handler &CGH) { + CGH.set_arg(1, InputParam); + CGH.single_task([=]() { + for (size_t i = 0; i < Size; i++) { + OutputDataDevice1[i] = InputDataDevice1[i]; + } + }); + }); + + auto GraphExecA = GraphA.finalize(); + Queue.ext_oneapi_graph(GraphExecA).wait(); + + Queue.copy(OutputDataDevice1, OutputDataHost1.data(), Size); + Queue.wait_and_throw(); + + for (size_t i = 0; i < Size; i++) { + assert(check_value(i, InputDataHost1[i], OutputDataHost1[i], "OutputDataHost1")); + } + + InputParam.update(InputDataDevice2); + exp_ext::command_graph GraphB{Queue.get_context(), Queue.get_device()}; + + auto GraphBNode = GraphB.add([&](handler &CGH) { + CGH.single_task([=]() { + for (size_t i = 0; i < Size; i++) { + OutputDataDevice1[i] = InputDataDevice1[i]; + } + }); + }); + + auto GraphExecB = GraphB.finalize(exp_ext::property::graph::updatable{}); + GraphExecB.update(GraphA); + Queue.ext_oneapi_graph(GraphExecB).wait(); + + Queue.copy(OutputDataDevice1, OutputDataHost1.data(), Size); + Queue.wait_and_throw(); + + free(InputDataDevice1, Queue); + free(InputDataDevice2, Queue); + free(OutputDataDevice1, Queue); + + for (size_t i = 0; i < Size; i++) { + assert(check_value(i, InputDataHost2[i], OutputDataHost1[i], "OutputDataHost1")); + } + + return 0; +} diff --git a/sycl/test-e2e/Graph/Inputs/executable_graph_update.cpp b/sycl/test-e2e/Graph/Update/whole_update_usm.cpp similarity index 81% rename from sycl/test-e2e/Graph/Inputs/executable_graph_update.cpp rename to sycl/test-e2e/Graph/Update/whole_update_usm.cpp index 96c2c0c325024..4aaed6364dd6f 100644 --- a/sycl/test-e2e/Graph/Inputs/executable_graph_update.cpp +++ b/sycl/test-e2e/Graph/Update/whole_update_usm.cpp @@ -1,5 +1,15 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out +// Extra run to check for leaks in Level Zero using UR_L0_LEAKS_DEBUG +// RUN: %if level_zero %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=0 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %} +// Extra run to check for immediate-command-list in Level Zero +// RUN: %if level_zero && linux %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %} +// +// UNSUPPORTED: opencl, level_zero + // Tests executable graph update by creating two graphs with USM ptrs and // attempting to update one from the other. +#define GRAPH_E2E_EXPLICIT #include "../graph_common.hpp" @@ -35,7 +45,7 @@ int main() { // Add commands to first graph add_nodes(GraphA, Queue, Size, PtrA, PtrB, PtrC); - auto GraphExec = GraphA.finalize(); + auto GraphExec = GraphA.finalize(exp_ext::property::graph::updatable{}); exp_ext::command_graph GraphB{Queue.get_context(), Queue.get_device()}; diff --git a/sycl/unittests/Extensions/CommandGraph/Common.hpp b/sycl/unittests/Extensions/CommandGraph/Common.hpp index a2e0965572cbf..2056846f92ac3 100644 --- a/sycl/unittests/Extensions/CommandGraph/Common.hpp +++ b/sycl/unittests/Extensions/CommandGraph/Common.hpp @@ -20,6 +20,7 @@ using namespace sycl; using namespace sycl::ext::oneapi; +namespace exp_ext = sycl::ext::oneapi::experimental; // Common Test fixture class CommandGraphTest : public ::testing::Test { diff --git a/sycl/unittests/Extensions/CommandGraph/Update.cpp b/sycl/unittests/Extensions/CommandGraph/Update.cpp index 92246fb83678d..05d077b32d556 100644 --- a/sycl/unittests/Extensions/CommandGraph/Update.cpp +++ b/sycl/unittests/Extensions/CommandGraph/Update.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "Common.hpp" +#include "sycl/exception.hpp" using namespace sycl; using namespace sycl::ext::oneapi; @@ -150,3 +151,255 @@ TEST_F(CommandGraphTest, UpdateRangeErrors) { // Can't update with a different number of dimensions EXPECT_ANY_THROW(NodeRange.update_range(range<2>{128, 128})); } + +class WholeGraphUpdateTest : public CommandGraphTest { + +protected: + static constexpr size_t Size = 1024; + + WholeGraphUpdateTest() + : UpdateGraph{ + Queue.get_context(), + Dev, + {experimental::property::graph::assume_buffer_outlives_graph{}}} {} + + experimental::command_graph + UpdateGraph; + + std::function EmptyKernel = [&](handler &CGH) { + CGH.parallel_for>(range<1>(Size), [=](item<1> Id) {}); + }; +}; + +TEST_F(WholeGraphUpdateTest, NoUpdates) { + // Test that using an update graph that has no updates is fine. + + auto NodeA = Graph.add(EmptyKernel); + auto NodeB = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeA)); + auto NodeC = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeA)); + auto NodeD = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeB, NodeC)); + + auto UpdateNodeA = UpdateGraph.add(EmptyKernel); + auto UpdateNodeB = UpdateGraph.add( + EmptyKernel, exp_ext::property::node::depends_on(UpdateNodeA)); + auto UpdateNodeC = UpdateGraph.add( + EmptyKernel, exp_ext::property::node::depends_on(UpdateNodeA)); + auto UpdateNodeD = UpdateGraph.add( + EmptyKernel, + exp_ext::property::node::depends_on(UpdateNodeB, UpdateNodeC)); + + auto GraphExec = Graph.finalize(experimental::property::graph::updatable{}); + EXPECT_NO_THROW(GraphExec.update(UpdateGraph)); +} + +TEST_F(WholeGraphUpdateTest, MoreNodes) { + // Test that using an update graph that has extra nodes results in an error. + + auto NodeA = Graph.add(EmptyKernel); + auto NodeB = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeA)); + auto NodeC = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeA)); + auto NodeD = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeB, NodeC)); + + auto UpdateNodeA = UpdateGraph.add(EmptyKernel); + auto UpdateNodeB = UpdateGraph.add( + EmptyKernel, exp_ext::property::node::depends_on(UpdateNodeA)); + auto UpdateNodeC = UpdateGraph.add( + EmptyKernel, exp_ext::property::node::depends_on(UpdateNodeA)); + auto UpdateNodeD = UpdateGraph.add( + EmptyKernel, + exp_ext::property::node::depends_on(UpdateNodeB, UpdateNodeC)); + // NodeE is the extra node + auto UpdateNodeE = UpdateGraph.add(EmptyKernel); + + auto GraphExec = Graph.finalize(experimental::property::graph::updatable{}); + EXPECT_THROW(GraphExec.update(UpdateGraph), sycl::exception); +} + +TEST_F(WholeGraphUpdateTest, LessNodes) { + // Test that using an update graph that has less nodes results in an error. + + auto NodeA = Graph.add(EmptyKernel); + auto NodeB = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeA)); + auto NodeC = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeA)); + auto NodeD = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeB, NodeC)); + + auto UpdateNodeA = UpdateGraph.add(EmptyKernel); + auto UpdateNodeB = UpdateGraph.add( + EmptyKernel, exp_ext::property::node::depends_on(UpdateNodeA)); + auto UpdateNodeC = UpdateGraph.add( + EmptyKernel, exp_ext::property::node::depends_on(UpdateNodeA)); + // NodeD is missing in the update + + auto GraphExec = Graph.finalize(experimental::property::graph::updatable{}); + EXPECT_THROW(GraphExec.update(UpdateGraph), sycl::exception); +} + +TEST_F(WholeGraphUpdateTest, ExtraEdges) { + // Test that using an update graph with extra nodes results in an error. + + auto NodeA = Graph.add(EmptyKernel); + auto NodeB = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeA)); + auto NodeC = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeA)); + auto NodeD = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeB, NodeC)); + + auto UpdateNodeA = UpdateGraph.add(EmptyKernel); + auto UpdateNodeB = UpdateGraph.add( + EmptyKernel, exp_ext::property::node::depends_on(UpdateNodeA)); + auto UpdateNodeC = UpdateGraph.add( + EmptyKernel, exp_ext::property::node::depends_on(UpdateNodeA)); + auto UpdateNodeD = UpdateGraph.add( + EmptyKernel, exp_ext::property::node::depends_on( + UpdateNodeA, UpdateNodeB, UpdateNodeC /* Extra Edge */)); + + auto GraphExec = Graph.finalize(experimental::property::graph::updatable{}); + EXPECT_THROW(GraphExec.update(UpdateGraph), sycl::exception); +} + +TEST_F(WholeGraphUpdateTest, MissingEdges) { + // Test that using an update graph with missing edges results in an error. + + auto NodeA = Graph.add(EmptyKernel); + auto NodeB = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeA)); + auto NodeC = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeA)); + auto NodeD = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeB, NodeC)); + + auto UpdateNodeA = UpdateGraph.add(EmptyKernel); + auto UpdateNodeB = UpdateGraph.add( + EmptyKernel, exp_ext::property::node::depends_on(UpdateNodeA)); + auto UpdateNodeC = UpdateGraph.add( + EmptyKernel, exp_ext::property::node::depends_on(UpdateNodeA)); + auto UpdateNodeD = UpdateGraph.add( + EmptyKernel, + exp_ext::property::node::depends_on(/* Missing Edge */ UpdateNodeB)); + + auto GraphExec = Graph.finalize(experimental::property::graph::updatable{}); + EXPECT_THROW(GraphExec.update(UpdateGraph), sycl::exception); +} + +// FIXME TODO Is this an error or not? +TEST_F(WholeGraphUpdateTest, WrongOrderEdges) { + // Test that using an update graph with edges added in a different order + // does not result in an error. + + auto NodeA = Graph.add(EmptyKernel); + auto NodeB = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeA)); + auto NodeC = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeA)); + auto NodeD = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeB, NodeC)); + + auto UpdateNodeA = UpdateGraph.add(EmptyKernel); + auto UpdateNodeB = UpdateGraph.add( + EmptyKernel, exp_ext::property::node::depends_on(UpdateNodeA)); + auto UpdateNodeC = UpdateGraph.add( + EmptyKernel, exp_ext::property::node::depends_on(UpdateNodeA)); + auto UpdateNodeD = UpdateGraph.add( + EmptyKernel, exp_ext::property::node::depends_on( + UpdateNodeC, UpdateNodeB /* Reversed Edges */)); + + auto GraphExec = Graph.finalize(experimental::property::graph::updatable{}); + EXPECT_NO_THROW(GraphExec.update(UpdateGraph)); +} + +TEST_F(WholeGraphUpdateTest, UnsupportedNodeType) { + // Test that using an update graph that contains unsupported node types + // results in an error. + buffer Buffer{range<1>{1}}; + + auto NodeA = Graph.add(EmptyKernel); + auto NodeB = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeA)); + auto NodeC = + Graph.add(EmptyKernel, exp_ext::property::node::depends_on(NodeA)); + auto NodeD = Graph.add( + [&](handler &CGH) { + auto Acc = Buffer.get_access(CGH); + CGH.fill(Acc, 1); + }, + exp_ext::property::node::depends_on(NodeB, NodeC)); + + auto UpdateNodeA = UpdateGraph.add(EmptyKernel); + auto UpdateNodeB = UpdateGraph.add( + EmptyKernel, exp_ext::property::node::depends_on(UpdateNodeA)); + auto UpdateNodeC = UpdateGraph.add( + EmptyKernel, exp_ext::property::node::depends_on(UpdateNodeA)); + auto UpdateNodeD = Graph.add( + [&](handler &CGH) { + auto Acc = Buffer.get_access(CGH); + CGH.fill(Acc, 1); + }, + exp_ext::property::node::depends_on(UpdateNodeB, UpdateNodeC)); + + auto GraphExec = Graph.finalize(experimental::property::graph::updatable{}); + EXPECT_THROW(GraphExec.update(UpdateGraph), sycl::exception); +} + +TEST_F(WholeGraphUpdateTest, WrongContext) { + // Test that using an update graph that was created with a different context + // (when compared to the original graph) results in an error. + + auto NodeA = Graph.add(EmptyKernel); + auto GraphExec = Graph.finalize(experimental::property::graph::updatable{}); + + context OtherContext(Dev); + experimental::command_graph + WrongContextGraph{ + OtherContext, + Dev, + {experimental::property::graph::assume_buffer_outlives_graph{}}}; + + auto UpdateNodeA = WrongContextGraph.add(EmptyKernel); + + EXPECT_THROW(GraphExec.update(WrongContextGraph), sycl::exception); +} + +TEST_F(WholeGraphUpdateTest, WrongDevice) { + // Test that using an update graph that was created with a different device + // (when compared to the original graph) results in an error. + + auto devices = device::get_devices(); + if (devices.size() > 1) { + + device &OtherDevice = (devices[0] == Dev ? devices[1] : devices[0]); + + auto NodeA = Graph.add(EmptyKernel); + auto GraphExec = Graph.finalize(experimental::property::graph::updatable{}); + + experimental::command_graph + WrongDeviceGraph{ + Queue.get_context(), + OtherDevice, + {experimental::property::graph::assume_buffer_outlives_graph{}}}; + + auto UpdateNodeA = WrongDeviceGraph.add(EmptyKernel); + + EXPECT_THROW(GraphExec.update(WrongDeviceGraph), sycl::exception); + } +} + +TEST_F(WholeGraphUpdateTest, MissingUpdatableProperty) { + // Test that updating a graph that was not created with the updatable property + // results in an error. + + auto NodeA = Graph.add(EmptyKernel); + auto UpdateNodeA = UpdateGraph.add(EmptyKernel); + + auto GraphExec = Graph.finalize(); + EXPECT_THROW(GraphExec.update(UpdateGraph), sycl::exception); +}