forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
profiler_kineto.cpp
799 lines (700 loc) · 27.2 KB
/
profiler_kineto.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <torch/csrc/autograd/profiler_kineto.h>
#include <c10/macros/Export.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include <c10/util/overloaded.h>
#include <c10/util/variant.h>
#include <torch/csrc/profiler/api.h>
#include <torch/csrc/profiler/collection.h>
#include <torch/csrc/profiler/containers.h>
#include <torch/csrc/profiler/kineto_shim.h>
#include <torch/csrc/profiler/nvtx_observer.h>
#include <ATen/Context.h>
#include <deque>
#include <limits>
#include <sstream>
#include <stdexcept>
#ifdef USE_KINETO
#include <libkineto.h>
#include <time_since_epoch.h>
#ifndef _MSC_VER
// TODO: TO be removed, once this properly works from libkineto
// Literal copy-n-paste from third_party/kineto/libkineto/src/WeakSymbols.cpp
extern "C" {
// This function is needed to avoid superfluous dependency on GNU OpenMP library
// when cuPTI is linked statically For more details see
// https://github.com/pytorch/pytorch/issues/51026
__attribute__((weak)) int acc_get_device_type() {
throw std::runtime_error(
"Dummy implementation of acc_get_device_type is not supposed to be called!");
}
} // extern "C"
#endif // _MSC_VER
#endif // USE_KINETO
namespace torch {
namespace autograd {
namespace profiler {
namespace {
const std::string kMemoryEventName = "[memory]";
// TODO: consider TLS (tid + tls counter)
uint64_t next_correlation_id() {
static std::atomic<uint64_t> corr_id_{1};
return corr_id_++;
}
inline int64_t getTimeUs() {
#ifdef USE_KINETO
return libkineto::timeSinceEpoch(std::chrono::system_clock::now());
#else
return torch::profiler::impl::getTime() / 1000;
#endif // USE_KINETO
}
} // namespace
namespace python_tracer {
namespace {
CallFn call_fn;
TraceEventsFn get_events_fn;
} // namespace
void registerFunctions(CallFn call, TraceEventsFn get_events) {
call_fn = call;
get_events_fn = get_events;
}
void call(Command c) {
if (call_fn != nullptr) {
call_fn(c);
}
}
std::vector<std::unique_ptr<PyTraceEvent>> get_events() {
return get_events_fn != nullptr
? get_events_fn()
: std::vector<std::unique_ptr<PyTraceEvent>>();
}
// We do not want `getTimeUs` to be directly visible, but we need a way for
// the python tracer to use the same timing convention as the profiler.
int64_t now() {
return getTimeUs();
}
struct Replay {
PyTraceEvent* frame_;
bool enter_;
C10_NODISCARD int64_t t() const {
return enter_ ? frame_->startTime_ : frame_->endTime_;
}
C10_NODISCARD size_t idx() const {
return enter_ ? frame_->call_idx_ : frame_->return_idx_;
}
bool operator<(const Replay& other) const {
return idx() < other.idx();
}
};
void _push_reverse_order(PyTraceEvent* e, std::vector<std::string>& names) {
if (e != nullptr) {
_push_reverse_order(e->parent_, names);
names.push_back(e->name_);
}
}
} // namespace python_tracer
namespace {
using torch::profiler::impl::ProfilerThreadLocalStateBase;
using torch::profiler::impl::ActiveProfilerType;
struct MemoryEventData {
torch::profiler::impl::approx_time_t start_time;
void* ptr;
int64_t alloc_size;
int64_t total_allocated;
int64_t total_reserved;
uint64_t threadID;
torch::profiler::impl::kineto::DeviceAndResource kineto_info;
c10::DeviceType device_type;
c10::DeviceIndex device_index;
};
static_assert(std::is_pod<MemoryEventData>::value, "Non-POD member of MemoryEventData.");
// Assumption: Total threads number will not exceed 2^16-1, and total ops will
// not exceed 2^48 -1.
static inline uint64_t getForwardThreadKey(uint64_t tid, uint64_t seqNr) {
return (((tid) << 48) | ((seqNr) & (((uint64_t)1 << 48) - 1)));
}
struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase {
explicit KinetoThreadLocalState(
const ProfilerConfig& config,
std::set<torch::profiler::impl::ActivityType> activities)
: ProfilerThreadLocalStateBase(config),
start_time_(getTimeUs()),
activities_(std::move(activities)),
record_queue_(config),
cpu_trace_(start_time_, "PyTorch Profiler") {}
~KinetoThreadLocalState() override = default;
static KinetoThreadLocalState* getTLS() {
auto tls = ProfilerThreadLocalStateBase::getTLS();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
tls == nullptr || tls->profilerType() == ActiveProfilerType::KINETO);
return static_cast<KinetoThreadLocalState*>(tls);
}
ActiveProfilerType profilerType() override {
return ActiveProfilerType::KINETO;
}
bool tracePython() {
return config().with_stack && activities_.count(ActivityType::CPU);
}
void reportMemoryUsage(
void* ptr,
int64_t alloc_size,
int64_t total_allocated,
int64_t total_reserved,
c10::Device device) override {
if (config_.profile_memory && config_.state != ProfilerState::Disabled) {
std::lock_guard<std::mutex> guard(state_mutex_);
memory_events_.emplace_back(
torch::profiler::impl::getApproximateTime(),
ptr,
alloc_size,
total_allocated,
total_reserved,
at::RecordFunction::currentThreadId(),
torch::profiler::impl::kineto::kineto_ids(),
device.type(),
device.index());
}
}
const std::function<void(std::vector<KinetoEvent>&)>&
getEventPostProcessingCallback() const {
return event_post_process_cb_;
}
void setEventPostProcessingCallback(
std::function<void(std::vector<KinetoEvent>&)>&& cb) {
event_post_process_cb_ = std::move(cb);
}
torch::profiler::impl::kineto::ActivityTraceWrapper finalizeTrace() {
auto end_time = getTimeUs();
materializeOpEvents();
// Call events post processing callback before finalizing trace, if there is
// one.
if (getEventPostProcessingCallback()) {
getEventPostProcessingCallback()(kineto_events_);
}
finalizeCPUTrace(cpu_trace_.get());
{
std::lock_guard<std::mutex> guard(state_mutex_);
cpu_trace_.transferCpuTrace(end_time);
}
auto trace = torch::profiler::impl::kineto::stopTrace();
TORCH_CHECK(trace || !torch::profiler::kKinetoAvailable);
addTraceEvents(trace);
return trace;
}
void materializeOpEvents() {
std::lock_guard<std::mutex> guard(state_mutex_);
auto converter = clock_converter_.makeConverter();
for (const auto& e : memory_events_) {
auto start_time_us = converter(e.start_time) / 1000;
cpu_trace_.addMemoryUsageActivity(
kMemoryEventName,
e.kineto_info,
start_time_us,
c10::Device(e.device_type, e.device_index),
e.ptr,
e.alloc_size,
e.total_allocated,
e.total_reserved);
kineto_events_.emplace_back();
auto& evt = kineto_events_.back();
evt.name(kMemoryEventName)
.startUs(start_time_us)
.deviceIndex(e.device_index)
.deviceType(e.device_type)
.nBytes(e.alloc_size)
.startThreadId(e.threadID);
}
memory_events_.clear();
for (const auto& e : record_queue_.getRecords(converter)) {
// `take_data` handles time conversion.
int64_t start_us = e.start_time_us_;
int64_t end_us = e.end_time_us_;
if (end_us < start_us) {
// We initialize end_us_ to the smallest int64_t, so this means that
// the op did not finish before we stopped profiling.
continue;
}
cpu_trace_.addCPUActivity(
e.name(),
e.kineto_info_,
e.correlation_id(),
start_us,
end_us);
kineto_events_.emplace_back();
kineto_events_.back()
.name(e.name())
.startUs(start_us)
.durationUs(end_us - start_us)
.correlationId(e.correlation_id())
.deviceType(c10::DeviceType::CPU)
.startThreadId(e.start_tid_);
c10::visit(
c10::overloaded(
[&](const torch::profiler::impl::OpEvent& op_event) {
kineto_events_.back()
.endThreadId(op_event.end_thread_id_)
.sequenceNr(op_event.sequence_number_)
.fwdThreadId(op_event.forward_thread_id_)
.scope(op_event.record_function_scope_)
.setAsync(op_event.is_async_)
.debugHandle(op_event.debug_handle_);
},
[&](const torch::profiler::impl::BackendEvent& backend_event) {
kineto_events_.back()
.endThreadId(e.start_tid_)
.scope(backend_event.record_function_scope_)
.debugHandle(backend_event.debug_handle_)
.backend(backend_event.backend_);
}),
e.event_);
if (!e.inputs_.shapes_.empty()) {
kineto_events_.back().shapes(e.inputs_.shapes_);
}
if (!e.inputs_.dtypes_.empty()) {
kineto_events_.back().dtypes(e.inputs_.dtypes_);
}
if (!e.jit_stack_.empty()) {
kineto_events_.back().stack(e.jit_stack_);
}
if (!e.jit_modules_.empty()) {
kineto_events_.back().moduleHierarchy(e.jit_modules_);
}
if (!e.extra_args_.empty()) {
kineto_events_.back().flops(
computeFlops(e.name(), e.extra_args_));
}
kineto_events_.back().cuda_event_start_ =
e.gpu_fallback_.cuda_event_start_;
kineto_events_.back().cuda_event_end_ =
e.gpu_fallback_.cuda_event_end_;
}
}
void finalizeCPUTrace(std::unique_ptr<torch::profiler::impl::kineto::trace_t>& cpu_trace) {
#ifndef USE_KINETO
}
#else // USE_KINETO
TORCH_INTERNAL_ASSERT(
cpu_trace->activities.size() == kineto_events_.size());
// startThreadId_seqNum to pointer of activity.
// Low-16bits of startThreadId and low-48bits seqNum are concatenated into
// one uint64_t variable as key.
std::unordered_map<uint64_t, libkineto::GenericTraceActivity*>
tidSeq2activity;
uint64_t fwd_bwd_link_id = 1;
for (const auto idx : c10::irange(cpu_trace->activities.size())) {
auto& kineto_event = kineto_events_[idx];
auto& activity = cpu_trace->activities[idx];
if (kineto_event.hasShapes()) {
activity.addMetadata("Input Dims", torch::profiler::impl::shapesToStr(kineto_event.shapes()));
}
if (kineto_event.hasStack()) {
// NB: This is only for the JIT stack. The python stack (if applicable)
// is constructed later.
activity.addMetadata(
"Call stack", torch::profiler::impl::stacksToStr(kineto_event.stack(), ";"));
}
if (kineto_event.hasModuleHierarchy()) {
activity.addMetadata(
"Module Hierarchy",
torch::profiler::impl::stacksToStr(kineto_event.moduleHierarchy(), "."));
}
if (kineto_event.hasTypes()) {
activity.addMetadata("Input type", torch::profiler::impl::dtypesToStr(kineto_event.dtypes()));
}
if (!kineto_event.backend().empty()) {
activity.addMetadata("Backend", "\"" + kineto_event.backend() + "\"");
}
// add information about an associated forward op, if a sequence number
// is available (e.g. during training)
if (kineto_event.sequenceNr() >= 0) {
activity.addMetadata(
"Fwd thread id", std::to_string(kineto_event.fwdThreadId()));
activity.addMetadata(
"Sequence number", std::to_string(kineto_event.sequenceNr()));
// From the time being, we need disable the forward/backward correlation feature to
// workaround the crash bug.
// TODO: by Mike Guo
// reenable the forward/backward correlation when kineto fix the following raw pointer
// GenericTraceActivity.flow.linkedActivity
// generateForwardBackwardLink(
// kineto_event, fwd_bwd_link_id, activity, tidSeq2activity);
}
}
addPythonEvents(cpu_trace);
}
void addPythonEvents(std::unique_ptr<torch::profiler::impl::kineto::trace_t>& cpu_trace) {
if (!tracePython()) {
return;
}
auto py_events = python_tracer::get_events();
for (const auto& e : py_events) {
TORCH_INTERNAL_ASSERT(
!e->thread_id_,
"Profiler expects only single threaded Python tracing.");
}
// The remainder of this function merges the Python and Kineto event
// streams into a single stream. If Python tracing is not enabled, we want
// to avoid this process altogether to cut down on processing time.
if (!py_events.size()) {
return;
}
// Kineto event times
std::vector<int64_t> op_start_times;
for (const auto& a : cpu_trace->activities) {
op_start_times.push_back(a.startTime);
}
std::sort(op_start_times.begin(), op_start_times.end());
// Map PyTraceEvent* to sequential integers for JSON export.
ska::flat_hash_map<python_tracer::PyTraceEvent*, std::string>
py_event_indices_{
{ nullptr,
std::string("null") }};
for (const auto i : c10::irange(py_events.size())) {
py_event_indices_.insert({py_events[i].get(), std::to_string(i)});
}
ska::flat_hash_map<std::string, size_t> module_counter_;
ska::flat_hash_map<size_t, std::string> module_id_map_;
auto record_module_id = [&](python_tracer::PyTraceEvent* e) {
if (e->call_type_ == python_tracer::CallType::kPyModuleCall &&
module_id_map_.find(e->module_id_) == module_id_map_.end()) {
// We use the fact that operator[] will default initialize new keys.
module_id_map_[e->module_id_] =
std::to_string(module_counter_[e->name_]++);
}
};
// Python events
std::vector<python_tracer::Replay> py_replay;
for (const auto& e : py_events) {
py_replay.push_back({e.get(), true});
py_replay.push_back({e.get(), false});
}
std::sort(py_replay.begin(), py_replay.end());
// In order to determine the state of the python interpreter when a
// particular op is called, we have to replay the python events and note
// timestamps which are associated with op start times.
std::vector<python_tracer::PyTraceEvent*> py_stack;
ska::flat_hash_map<int64_t, python_tracer::PyTraceEvent*> op_py_map;
auto replay_it = py_replay.begin();
for (auto t : op_start_times) {
while (replay_it != py_replay.end() && replay_it->t() <= t) {
if (replay_it->enter_) {
py_stack.push_back(replay_it->frame_);
record_module_id(replay_it->frame_);
} else {
TORCH_INTERNAL_ASSERT(py_stack.size());
TORCH_INTERNAL_ASSERT(py_stack.back() == replay_it->frame_);
py_stack.pop_back();
}
replay_it++;
}
op_py_map.insert({t, py_stack.size() ? py_stack.back() : nullptr});
}
std::vector<libkineto::GenericTraceActivity> py_activities;
auto py_events_it = py_events.begin();
auto py_device = libkineto::processId();
auto main_thread = libkineto::systemThreadId();
auto push_py_event = [&]() {
auto e = (*py_events_it).get();
libkineto::GenericTraceActivity op(
cpu_trace->span, libkineto::ActivityType::PYTHON_FUNCTION, e->name_);
op.device = py_device;
op.resource = main_thread;
op.startTime = e->startTime_;
op.endTime = e->endTime_;
op.addMetadata("Python id", py_event_indices_.at(e));
op.addMetadata("Python parent id", py_event_indices_.at(e->parent_));
op.addMetadata("Python thread", std::to_string(e->thread_id_));
if (e->call_type_ == python_tracer::CallType::kPyModuleCall) {
op.addMetadata("Python module id", module_id_map_.at(e->module_id_));
}
py_activities.push_back(op);
py_events_it++;
};
TORCH_INTERNAL_ASSERT(cpu_trace->activities.size() == kineto_events_.size());
for (const auto idx : c10::irange(cpu_trace->activities.size())) {
auto& activity = cpu_trace->activities[idx];
// Add any python events that occurred between this Kineto event and the
// previous Kineto event.
while (py_events_it != py_events.end() &&
(*py_events_it)->endTime_ <= activity.endTime) {
push_py_event();
}
auto python_caller = op_py_map.at(activity.startTime);
activity.addMetadata(
"python_caller_id", py_event_indices_.at(python_caller));
// If the kineto event has a stack that means the JIT model has a stack
// associated with it that we need to respect.
if (!kineto_events_[idx].hasStack()) {
std::vector<std::string> py_names;
_push_reverse_order(python_caller, py_names);
kineto_events_[idx].stack(py_names);
activity.addMetadata("Call stack", torch::profiler::impl::stacksToStr(py_names, ";"));
}
}
// Add any Python events which finish after the last Kineto event.
while (py_events_it != py_events.end()) {
push_py_event();
}
cpu_trace->activities.insert(cpu_trace->activities.end(), py_activities.begin(), py_activities.end());
}
void generateForwardBackwardLink(
const KinetoEvent& kineto_event,
uint64_t& fwd_bwd_link_id,
libkineto::GenericTraceActivity& activity,
std::unordered_map<uint64_t, libkineto::GenericTraceActivity*>&
tidSeq2activity) {
if (kineto_event.fwdThreadId() > 0) {
// act is backward op.
uint64_t key = getForwardThreadKey(
kineto_event.fwdThreadId(), kineto_event.sequenceNr());
auto iter = tidSeq2activity.find(key);
if (iter != tidSeq2activity.end()) {
libkineto::GenericTraceActivity* fwd = iter->second;
fwd->flow.start = true;
activity.flow.id = fwd->flow.id = fwd_bwd_link_id;
activity.flow.type = fwd->flow.type = libkineto::kLinkFwdBwd;
++fwd_bwd_link_id;
}
} else if (kineto_event.startThreadId() != 0) {
// act is forward op.
uint64_t key = getForwardThreadKey(
kineto_event.startThreadId(), kineto_event.sequenceNr());
// Assumption: Among all ops with same sequence number,
// the one with biggest start time is most likely launching backward op.
auto iter = tidSeq2activity.find(key);
if (iter == tidSeq2activity.end()) {
tidSeq2activity[key] = &activity;
} else {
// Now the sequence number is only incremented on creating a "Node"
// object for backward pass, by calling
// "at::sequence_number::get_and_increment()". Among all ops with same
// sequence number, the one with biggest startTime is the one launching
// backward op.
if (activity.startTime >= iter->second->startTime) {
tidSeq2activity[key] = &activity;
}
}
}
}
#endif // USE_KINETO
void addTraceEvents(torch::profiler::impl::kineto::ActivityTraceWrapper& trace) {
#ifdef USE_KINETO
const auto& events = *(trace.get()->activities());
for (const auto& ev_ptr : events) {
if (ev_ptr == nullptr) {
continue;
}
const auto& activity = *ev_ptr;
// These events are already processed
if (activity.type() != libkineto::ActivityType::CPU_OP &&
activity.type() != libkineto::ActivityType::CPU_INSTANT_EVENT &&
activity.type() != libkineto::ActivityType::USER_ANNOTATION &&
activity.type() != libkineto::ActivityType::PYTHON_FUNCTION) {
kineto_events_.emplace_back();
auto& kineto_event = kineto_events_.back();
kineto_event.name(activity.name())
.deviceIndex(activity.deviceId())
.deviceResourceId(activity.resourceId())
.startUs(activity.timestamp())
.durationUs(activity.duration())
.activityType((uint8_t)activity.type());
if (activity.linkedActivity()) {
kineto_event.linkedCorrelationId(
activity.linkedActivity()->correlationId());
}
kineto_event.deviceType(deviceTypeFromActivity(activity.type()));
}
}
#endif // USE_KINETO
}
uint64_t start_time_;
torch::profiler::impl::ApproximateClockToUnixTimeConverter clock_converter_;
std::set<torch::profiler::impl::ActivityType> activities_;
torch::profiler::impl::RecordQueue record_queue_;
torch::profiler::impl::AppendOnlyList<MemoryEventData, 1024> memory_events_;
torch::profiler::impl::kineto::TraceWrapper cpu_trace_;
std::vector<KinetoEvent> kineto_events_;
// Optional, if event post-processing is enabled.
std::function<void(std::vector<KinetoEvent>&)> event_post_process_cb_;
};
void pushProfilingCallbacks(const std::unordered_set<at::RecordScope>& scopes) {
auto registration_state_ptr = KinetoThreadLocalState::getTLS();
TORCH_INTERNAL_ASSERT(registration_state_ptr, "Expected profiler state set");
auto handle = at::addThreadLocalCallback(
at::RecordFunctionCallback(
[](const at::RecordFunction& fn)
-> std::unique_ptr<at::ObserverContext> {
auto state_ptr = KinetoThreadLocalState::getTLS();
if (!state_ptr) {
return nullptr;
}
const auto& config = state_ptr->config();
auto corr_id = next_correlation_id();
torch::profiler::impl::kineto::pushCorrelationId(corr_id);
return state_ptr->record_queue_.getSubqueue()->begin_op(fn, corr_id);
},
[](const at::RecordFunction& fn, at::ObserverContext* ctx_ptr) {
auto state_ptr = KinetoThreadLocalState::getTLS();
if (!state_ptr) {
return;
}
const auto& config = state_ptr->config();
auto* kineto_ctx_ptr =
static_cast<torch::profiler::impl::KinetoObserverContext*>(ctx_ptr);
TORCH_INTERNAL_ASSERT(kineto_ctx_ptr != nullptr);
kineto_ctx_ptr->event_->end_time_ = torch::profiler::impl::getApproximateTime();
kineto_ctx_ptr->event_->end_thread_id_ = at::RecordFunction::currentThreadId();
if (config.state == ProfilerState::KINETO_GPU_FALLBACK) {
try {
auto fallback = kineto_ctx_ptr->fallback_;
TORCH_INTERNAL_ASSERT(fallback != nullptr);
torch::profiler::impl::cudaStubs()->record(
nullptr, &fallback->cuda_event_end_, nullptr);
} catch (const std::exception& e) {
LOG(WARNING) << "Failed to record CUDA event. " << e.what();
}
}
torch::profiler::impl::kineto::popCorrelationId();
torch::profiler::impl::kineto::recordThreadInfo();
})
.needsInputs(registration_state_ptr->config().report_input_shapes)
.scopes(scopes));
registration_state_ptr->setCallbackHandle(handle);
}
} // namespace
void reportBackendEventToActiveKinetoProfiler(
const int64_t start_time_us,
const int64_t end_time_us,
const int64_t debug_handle,
const at::RecordScope scope,
const std::string& event_name,
const std::string& backend_name) {
auto state_ptr = KinetoThreadLocalState::getTLS();
if (!state_ptr) {
return;
}
state_ptr->record_queue_.getSubqueue()->emplace_backend_event(
torch::profiler::impl::BackendEvent {
start_time_us,
end_time_us,
(uint8_t)scope,
debug_handle,
event_name,
backend_name});
/* no support for input shapes now?
if (config.report_input_shapes) {
ctx_ptr->shapes = inputSizes(fn);
ctx_ptr->dtypes = inputTypes(fn);
}
*/
torch::profiler::impl::kineto::recordThreadInfo();
}
void prepareProfiler(
const torch::profiler::impl::ProfilerConfig& config,
const std::set<torch::profiler::impl::ActivityType>& activities) {
if (config.state == ProfilerState::NVTX) {
return;
}
TORCH_CHECK(
config.state == ProfilerState::KINETO ||
config.state == ProfilerState::KINETO_GPU_FALLBACK,
"Supported only in Kineto profiler");
torch::profiler::impl::kineto::prepareTrace(
/*cpuOnly=*/!at::hasCUDA(), activities);
}
void enableProfilerWithEventPostProcess(
const torch::profiler::impl::ProfilerConfig& config,
const std::set<torch::profiler::impl::ActivityType>& activities,
std::function<void(std::vector<KinetoEvent>&)>&& cb,
const std::unordered_set<at::RecordScope>& scopes) {
TORCH_CHECK(
config.state != ProfilerState::NVTX,
"NVTX does not support post processing callback.");
enableProfiler(config, activities, scopes);
auto state_ptr = KinetoThreadLocalState::getTLS();
state_ptr->setEventPostProcessingCallback(std::move(cb));
}
void enableProfiler(
const torch::profiler::impl::ProfilerConfig& config,
const std::set<torch::profiler::impl::ActivityType>& activities,
const std::unordered_set<at::RecordScope>& scopes) {
TORCH_CHECK(!profilerEnabled(), "Profiler is already enabled on this thread");
if (config.state == ProfilerState::NVTX) {
torch::profiler::impl::pushNVTXCallbacks(config, scopes);
return;
}
TORCH_CHECK(
config.state == ProfilerState::KINETO ||
config.state == ProfilerState::KINETO_GPU_FALLBACK);
TORCH_CHECK(
!activities.empty(), "No activities specified for Kineto profiler");
auto state = std::make_shared<KinetoThreadLocalState>(config, activities);
c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);
if (state->tracePython()) {
python_tracer::call(python_tracer::Command::kStartOne);
}
if (activities.count(ActivityType::CPU)) {
pushProfilingCallbacks(scopes);
}
torch::profiler::impl::kineto::startTrace();
}
std::unique_ptr<ProfilerResult> disableProfiler() {
// all the DebugInfoBase objects are scope based and supposed to use
// DebugInfoGuard
auto state =
c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE);
auto state_ptr = static_cast<ProfilerThreadLocalStateBase*>(state.get());
const auto& config = state_ptr->config();
TORCH_CHECK(
state_ptr &&
(config.state == ProfilerState::KINETO ||
config.state == ProfilerState::KINETO_GPU_FALLBACK ||
config.state == ProfilerState::NVTX),
"Can't disable Kineto profiler when it's not running");
if (state_ptr->hasCallbackHandle()) {
at::removeCallback(state_ptr->callbackHandle());
}
if (state_ptr->config().state == ProfilerState::NVTX) {
return std::make_unique<ProfilerResult>();
}
auto kineto_state_ptr = static_cast<KinetoThreadLocalState*>(state_ptr);
if (kineto_state_ptr->tracePython()) {
python_tracer::call(python_tracer::Command::kStop);
}
auto trace = kineto_state_ptr->finalizeTrace();
if (kineto_state_ptr->tracePython()) {
python_tracer::call(python_tracer::Command::kClear);
}
return std::make_unique<ProfilerResult>(
kineto_state_ptr->start_time_,
std::move(kineto_state_ptr->kineto_events_),
std::move(trace));
}
int64_t KinetoEvent::cudaElapsedUs() const {
if (!cuda_event_start_ || !cuda_event_end_) {
return -1;
}
try {
return (int64_t)torch::profiler::impl::cudaStubs()->elapsed(&cuda_event_start_, &cuda_event_end_);
} catch (std::exception& e) {
LOG(WARNING) << "Failed to measure time between two CUDA events. "
<< e.what();
}
return -1;
}
ProfilerResult::ProfilerResult(
uint64_t start_time,
std::vector<KinetoEvent> events,
torch::profiler::impl::kineto::ActivityTraceWrapper trace)
: trace_start_us_(start_time),
events_(std::move(events)),
trace_(std::move(trace)) {}
ProfilerResult::ProfilerResult() = default;
ProfilerResult::~ProfilerResult() = default;
void ProfilerResult::save(const std::string& path) {
trace_.save(path);
}
} // namespace profiler
} // namespace autograd
} // namespace torch