@@ -15,7 +15,10 @@ limitations under the License.
15
15
#ifndef TENSORFLOW_TSL_PROFILER_LIB_TRACEME_H_
16
16
#define TENSORFLOW_TSL_PROFILER_LIB_TRACEME_H_
17
17
18
+ #include < sys/types.h>
19
+
18
20
#include < cstdint>
21
+ #include < limits>
19
22
#include < string>
20
23
#include < type_traits>
21
24
#include < utility>
@@ -34,6 +37,9 @@ limitations under the License.
34
37
namespace tsl {
35
38
namespace profiler {
36
39
40
+ constexpr uint64_t kTraceMeDefaultFilterMask =
41
+ std::numeric_limits<uint64_t >::max();
42
+
37
43
// Predefined levels:
38
44
// - Level 1 (kCritical) is the default and used only for user instrumentation.
39
45
// - Level 2 (kInfo) is used by profiler for instrumenting high level program
@@ -88,10 +94,12 @@ class TraceMe {
88
94
// - Can be a value in enum TraceMeLevel.
89
95
// Users are welcome to use level > 3 in their code, if they wish to filter
90
96
// out their host traces based on verbosity.
91
- explicit TraceMe (absl::string_view name, int level = 1 ) {
97
+ explicit TraceMe (absl::string_view name, int level = 1 ,
98
+ uint64_t filter_mask = kTraceMeDefaultFilterMask ) {
92
99
DCHECK_GE (level, 1 );
93
100
#if !defined(IS_MOBILE_PLATFORM)
94
- if (TF_PREDICT_FALSE (TraceMeRecorder::Active (level))) {
101
+ if (TF_PREDICT_FALSE (TraceMeRecorder::Active (level) &&
102
+ TraceMeRecorder::CheckFilter (filter_mask))) {
95
103
name_.Emplace (std::string (name));
96
104
start_time_ = GetCurrentTimeNanos ();
97
105
}
@@ -102,19 +110,22 @@ class TraceMe {
102
110
// string should only be incurred when tracing is enabled. Wrap the temporary
103
111
// string generation (e.g., StrCat) in a lambda and use the name_generator
104
112
// template instead.
105
- explicit TraceMe (std::string&& name, int level = 1 ) = delete;
113
+ explicit TraceMe (std::string&& name, int level = 1 ,
114
+ uint64_t filter_mask = kTraceMeDefaultFilterMask ) = delete;
106
115
107
116
// Do not allow passing strings by reference or value since the caller
108
117
// may unintentionally maintain ownership of the name.
109
118
// Explicitly wrap the name in a string_view if you really wish to maintain
110
119
// ownership of a string already generated for other purposes. For temporary
111
120
// strings (e.g., result of StrCat) use the name_generator template.
112
- explicit TraceMe (const std::string& name, int level = 1 ) = delete;
121
+ explicit TraceMe (const std::string& name, int level = 1 ,
122
+ uint64_t filter_mask = kTraceMeDefaultFilterMask ) = delete;
113
123
114
124
// This overload is necessary to make TraceMe's with string literals work.
115
125
// Otherwise, the name_generator template would be used.
116
- explicit TraceMe (const char * raw, int level = 1 )
117
- : TraceMe(absl::string_view(raw), level) {}
126
+ explicit TraceMe (const char * raw, int level = 1 ,
127
+ uint64_t filter_mask = kTraceMeDefaultFilterMask )
128
+ : TraceMe(absl::string_view(raw), level, filter_mask) {}
118
129
119
130
// This overload only generates the name (and possibly metadata) if tracing is
120
131
// enabled. Useful for avoiding expensive operations (e.g., string
@@ -135,10 +146,12 @@ class TraceMe {
135
146
// });
136
147
template <typename NameGeneratorT,
137
148
std::enable_if_t <std::is_invocable_v<NameGeneratorT>, bool > = true >
138
- explicit TraceMe (NameGeneratorT&& name_generator, int level = 1 ) {
149
+ explicit TraceMe (NameGeneratorT&& name_generator, int level = 1 ,
150
+ uint64_t filter_mask = kTraceMeDefaultFilterMask ) {
139
151
DCHECK_GE (level, 1 );
140
152
#if !defined(IS_MOBILE_PLATFORM)
141
- if (TF_PREDICT_FALSE (TraceMeRecorder::Active (level))) {
153
+ if (TF_PREDICT_FALSE (TraceMeRecorder::Active (level) &&
154
+ TraceMeRecorder::CheckFilter (filter_mask))) {
142
155
name_.Emplace (std::forward<NameGeneratorT>(name_generator)());
143
156
start_time_ = GetCurrentTimeNanos ();
144
157
}
@@ -215,9 +228,12 @@ class TraceMe {
215
228
// Calls `name_generator` to get the name for activity.
216
229
template <typename NameGeneratorT,
217
230
std::enable_if_t <std::is_invocable_v<NameGeneratorT>, bool > = true >
218
- static int64_t ActivityStart (NameGeneratorT&& name_generator, int level = 1 ) {
231
+ static int64_t ActivityStart (
232
+ NameGeneratorT&& name_generator, int level = 1 ,
233
+ uint64_t filter_mask = kTraceMeDefaultFilterMask ) {
219
234
#if !defined(IS_MOBILE_PLATFORM)
220
- if (TF_PREDICT_FALSE (TraceMeRecorder::Active (level))) {
235
+ if (TF_PREDICT_FALSE (TraceMeRecorder::Active (level) &&
236
+ TraceMeRecorder::CheckFilter (filter_mask))) {
221
237
int64_t activity_id = TraceMeRecorder::NewActivityId ();
222
238
TraceMeRecorder::Record ({std::forward<NameGeneratorT>(name_generator)(),
223
239
GetCurrentTimeNanos (), -activity_id});
@@ -229,9 +245,12 @@ class TraceMe {
229
245
230
246
// Record the start time of an activity.
231
247
// Returns the activity ID, which is used to stop the activity.
232
- static int64_t ActivityStart (absl::string_view name, int level = 1 ) {
248
+ static int64_t ActivityStart (
249
+ absl::string_view name, int level = 1 ,
250
+ uint64_t filter_mask = kTraceMeDefaultFilterMask ) {
233
251
#if !defined(IS_MOBILE_PLATFORM)
234
- if (TF_PREDICT_FALSE (TraceMeRecorder::Active (level))) {
252
+ if (TF_PREDICT_FALSE (TraceMeRecorder::Active (level) &&
253
+ TraceMeRecorder::CheckFilter (filter_mask))) {
235
254
int64_t activity_id = TraceMeRecorder::NewActivityId ();
236
255
TraceMeRecorder::Record (
237
256
{std::string (name), GetCurrentTimeNanos (), -activity_id});
@@ -242,13 +261,17 @@ class TraceMe {
242
261
}
243
262
244
263
// Same as ActivityStart above, an overload for "const std::string&"
245
- static int64_t ActivityStart (const std::string& name, int level = 1 ) {
246
- return ActivityStart (absl::string_view (name), level);
264
+ static int64_t ActivityStart (
265
+ const std::string& name, int level = 1 ,
266
+ uint64_t filter_mask = kTraceMeDefaultFilterMask ) {
267
+ return ActivityStart (absl::string_view (name), level, filter_mask);
247
268
}
248
269
249
270
// Same as ActivityStart above, an overload for "const char*"
250
- static int64_t ActivityStart (const char * name, int level = 1 ) {
251
- return ActivityStart (absl::string_view (name), level);
271
+ static int64_t ActivityStart (
272
+ const char * name, int level = 1 ,
273
+ uint64_t filter_mask = kTraceMeDefaultFilterMask ) {
274
+ return ActivityStart (absl::string_view (name), level, filter_mask);
252
275
}
253
276
254
277
// Record the end time of an activity started by ActivityStart().
@@ -267,9 +290,12 @@ class TraceMe {
267
290
// Records the time of an instant activity.
268
291
template <typename NameGeneratorT,
269
292
std::enable_if_t <std::is_invocable_v<NameGeneratorT>, bool > = true >
270
- static void InstantActivity (NameGeneratorT&& name_generator, int level = 1 ) {
293
+ static void InstantActivity (
294
+ NameGeneratorT&& name_generator, int level = 1 ,
295
+ uint64_t filter_mask = kTraceMeDefaultFilterMask ) {
271
296
#if !defined(IS_MOBILE_PLATFORM)
272
- if (TF_PREDICT_FALSE (TraceMeRecorder::Active (level))) {
297
+ if (TF_PREDICT_FALSE (TraceMeRecorder::Active (level) &&
298
+ TraceMeRecorder::CheckFilter (filter_mask))) {
273
299
int64_t now = GetCurrentTimeNanos ();
274
300
TraceMeRecorder::Record ({std::forward<NameGeneratorT>(name_generator)(),
275
301
/* start_time=*/ now, /* end_time=*/ now});
0 commit comments