Skip to content

Commit 8f72cbc

Browse files
Add filter functionality to TraceMeRecorder to filter events based on filter parameter.
PiperOrigin-RevId: 696613849
1 parent b807244 commit 8f72cbc

File tree

1 file changed

+44
-18
lines changed

1 file changed

+44
-18
lines changed

tsl/profiler/lib/traceme.h

+44-18
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ limitations under the License.
1515
#ifndef TENSORFLOW_TSL_PROFILER_LIB_TRACEME_H_
1616
#define TENSORFLOW_TSL_PROFILER_LIB_TRACEME_H_
1717

18+
#include <sys/types.h>
19+
1820
#include <cstdint>
21+
#include <limits>
1922
#include <string>
2023
#include <type_traits>
2124
#include <utility>
@@ -34,6 +37,9 @@ limitations under the License.
3437
namespace tsl {
3538
namespace profiler {
3639

40+
constexpr uint64_t kTraceMeDefaultFilterMask =
41+
std::numeric_limits<uint64_t>::max();
42+
3743
// Predefined levels:
3844
// - Level 1 (kCritical) is the default and used only for user instrumentation.
3945
// - Level 2 (kInfo) is used by profiler for instrumenting high level program
@@ -88,10 +94,12 @@ class TraceMe {
8894
// - Can be a value in enum TraceMeLevel.
8995
// Users are welcome to use level > 3 in their code, if they wish to filter
9096
// out their host traces based on verbosity.
91-
explicit TraceMe(absl::string_view name, int level = 1) {
97+
explicit TraceMe(absl::string_view name, int level = 1,
98+
uint64_t filter_mask = kTraceMeDefaultFilterMask) {
9299
DCHECK_GE(level, 1);
93100
#if !defined(IS_MOBILE_PLATFORM)
94-
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) {
101+
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level) &&
102+
TraceMeRecorder::CheckFilter(filter_mask))) {
95103
name_.Emplace(std::string(name));
96104
start_time_ = GetCurrentTimeNanos();
97105
}
@@ -102,19 +110,22 @@ class TraceMe {
102110
// string should only be incurred when tracing is enabled. Wrap the temporary
103111
// string generation (e.g., StrCat) in a lambda and use the name_generator
104112
// template instead.
105-
explicit TraceMe(std::string&& name, int level = 1) = delete;
113+
explicit TraceMe(std::string&& name, int level = 1,
114+
uint64_t filter_mask = kTraceMeDefaultFilterMask) = delete;
106115

107116
// Do not allow passing strings by reference or value since the caller
108117
// may unintentionally maintain ownership of the name.
109118
// Explicitly wrap the name in a string_view if you really wish to maintain
110119
// ownership of a string already generated for other purposes. For temporary
111120
// strings (e.g., result of StrCat) use the name_generator template.
112-
explicit TraceMe(const std::string& name, int level = 1) = delete;
121+
explicit TraceMe(const std::string& name, int level = 1,
122+
uint64_t filter_mask = kTraceMeDefaultFilterMask) = delete;
113123

114124
// This overload is necessary to make TraceMe's with string literals work.
115125
// Otherwise, the name_generator template would be used.
116-
explicit TraceMe(const char* raw, int level = 1)
117-
: TraceMe(absl::string_view(raw), level) {}
126+
explicit TraceMe(const char* raw, int level = 1,
127+
uint64_t filter_mask = kTraceMeDefaultFilterMask)
128+
: TraceMe(absl::string_view(raw), level, filter_mask) {}
118129

119130
// This overload only generates the name (and possibly metadata) if tracing is
120131
// enabled. Useful for avoiding expensive operations (e.g., string
@@ -135,10 +146,12 @@ class TraceMe {
135146
// });
136147
template <typename NameGeneratorT,
137148
std::enable_if_t<std::is_invocable_v<NameGeneratorT>, bool> = true>
138-
explicit TraceMe(NameGeneratorT&& name_generator, int level = 1) {
149+
explicit TraceMe(NameGeneratorT&& name_generator, int level = 1,
150+
uint64_t filter_mask = kTraceMeDefaultFilterMask) {
139151
DCHECK_GE(level, 1);
140152
#if !defined(IS_MOBILE_PLATFORM)
141-
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) {
153+
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level) &&
154+
TraceMeRecorder::CheckFilter(filter_mask))) {
142155
name_.Emplace(std::forward<NameGeneratorT>(name_generator)());
143156
start_time_ = GetCurrentTimeNanos();
144157
}
@@ -215,9 +228,12 @@ class TraceMe {
215228
// Calls `name_generator` to get the name for activity.
216229
template <typename NameGeneratorT,
217230
std::enable_if_t<std::is_invocable_v<NameGeneratorT>, bool> = true>
218-
static int64_t ActivityStart(NameGeneratorT&& name_generator, int level = 1) {
231+
static int64_t ActivityStart(
232+
NameGeneratorT&& name_generator, int level = 1,
233+
uint64_t filter_mask = kTraceMeDefaultFilterMask) {
219234
#if !defined(IS_MOBILE_PLATFORM)
220-
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) {
235+
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level) &&
236+
TraceMeRecorder::CheckFilter(filter_mask))) {
221237
int64_t activity_id = TraceMeRecorder::NewActivityId();
222238
TraceMeRecorder::Record({std::forward<NameGeneratorT>(name_generator)(),
223239
GetCurrentTimeNanos(), -activity_id});
@@ -229,9 +245,12 @@ class TraceMe {
229245

230246
// Record the start time of an activity.
231247
// Returns the activity ID, which is used to stop the activity.
232-
static int64_t ActivityStart(absl::string_view name, int level = 1) {
248+
static int64_t ActivityStart(
249+
absl::string_view name, int level = 1,
250+
uint64_t filter_mask = kTraceMeDefaultFilterMask) {
233251
#if !defined(IS_MOBILE_PLATFORM)
234-
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) {
252+
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level) &&
253+
TraceMeRecorder::CheckFilter(filter_mask))) {
235254
int64_t activity_id = TraceMeRecorder::NewActivityId();
236255
TraceMeRecorder::Record(
237256
{std::string(name), GetCurrentTimeNanos(), -activity_id});
@@ -242,13 +261,17 @@ class TraceMe {
242261
}
243262

244263
// Same as ActivityStart above, an overload for "const std::string&"
245-
static int64_t ActivityStart(const std::string& name, int level = 1) {
246-
return ActivityStart(absl::string_view(name), level);
264+
static int64_t ActivityStart(
265+
const std::string& name, int level = 1,
266+
uint64_t filter_mask = kTraceMeDefaultFilterMask) {
267+
return ActivityStart(absl::string_view(name), level, filter_mask);
247268
}
248269

249270
// Same as ActivityStart above, an overload for "const char*"
250-
static int64_t ActivityStart(const char* name, int level = 1) {
251-
return ActivityStart(absl::string_view(name), level);
271+
static int64_t ActivityStart(
272+
const char* name, int level = 1,
273+
uint64_t filter_mask = kTraceMeDefaultFilterMask) {
274+
return ActivityStart(absl::string_view(name), level, filter_mask);
252275
}
253276

254277
// Record the end time of an activity started by ActivityStart().
@@ -267,9 +290,12 @@ class TraceMe {
267290
// Records the time of an instant activity.
268291
template <typename NameGeneratorT,
269292
std::enable_if_t<std::is_invocable_v<NameGeneratorT>, bool> = true>
270-
static void InstantActivity(NameGeneratorT&& name_generator, int level = 1) {
293+
static void InstantActivity(
294+
NameGeneratorT&& name_generator, int level = 1,
295+
uint64_t filter_mask = kTraceMeDefaultFilterMask) {
271296
#if !defined(IS_MOBILE_PLATFORM)
272-
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) {
297+
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level) &&
298+
TraceMeRecorder::CheckFilter(filter_mask))) {
273299
int64_t now = GetCurrentTimeNanos();
274300
TraceMeRecorder::Record({std::forward<NameGeneratorT>(name_generator)(),
275301
/*start_time=*/now, /*end_time=*/now});

0 commit comments

Comments
 (0)