Skip to content

Commit

Permalink
Implement CUDASTF_DOT_TIMING for the host_launch construct (NVIDIA#3170)
Browse files Browse the repository at this point in the history
  • Loading branch information
caugonnet authored Dec 15, 2024
1 parent 0aa0b37 commit 1393082
Showing 1 changed file with 40 additions and 2 deletions.
42 changes: 40 additions & 2 deletions cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,58 @@ public:
template <typename Fun>
void operator->*(Fun&& f)
{
auto& dot = *ctx.get_dot();
auto& statistics = reserved::task_statistics::instance();

auto t = ctx.task(exec_place::host);
t.add_deps(deps);
if (!symbol.empty())
{
t.set_symbol(symbol);
}

cudaEvent_t start_event, end_event;
const bool record_time = t.schedule_task() || statistics.is_calibrating_to_file();

t.start();

if constexpr (::std::is_same_v<Ctx, stream_ctx>)
{
if (record_time)
{
cuda_safe_call(cudaEventCreate(&start_event));
cuda_safe_call(cudaEventCreate(&end_event));
cuda_safe_call(cudaEventRecord(start_event, t.get_stream()));
}
}

SCOPE(exit)
{
t.end();
t.end_uncleared();
if constexpr (::std::is_same_v<Ctx, stream_ctx>)
{
if (record_time)
{
cuda_safe_call(cudaEventRecord(end_event, t.get_stream()));
cuda_safe_call(cudaEventSynchronize(end_event));

float milliseconds = 0;
cuda_safe_call(cudaEventElapsedTime(&milliseconds, start_event, end_event));

if (dot.is_tracing())
{
dot.template add_vertex_timing<typename Ctx::task_type>(t, milliseconds, -1);
}

if (statistics.is_calibrating())
{
statistics.log_task_time(t, milliseconds);
}
}
}
t.clear();
};

auto& dot = *ctx.get_dot();
if (dot.is_tracing())
{
dot.template add_vertex<typename Ctx::task_type, logical_data_untyped>(t);
Expand Down

0 comments on commit 1393082

Please sign in to comment.