Skip to content

Commit

Permalink
Remove CUDA dependencies from jaxlib OSS wheel.
Browse files Browse the repository at this point in the history
With this change `jaxlib` wheel content is identical for CPU and GPU configurations. It enables reusing bazel cache when building all three targets together with `--config=cuda`: `build_wheel`, `build_gpu_plugin_wheel` and `build_gpu_kernels_wheel`.

PiperOrigin-RevId: 706016685
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Dec 27, 2024
1 parent 7d4367e commit 27bd5c2
Showing 1 changed file with 60 additions and 9 deletions.
69 changes: 60 additions & 9 deletions tsl/profiler/lib/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("@xla//xla/tsl:tsl.bzl", "if_not_android", "if_oss", "internal_visibility", "nvtx_headers")
load("@xla//xla/tsl:tsl.bzl", "if_google", "if_not_android", "if_oss", "internal_visibility", "nvtx_headers")
load("@xla//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable")
load("@xla//xla/tsl/platform:build_config.bzl", "tsl_cc_test")
load("@xla//xla/tsl/platform:build_config_root.bzl", "if_static")
Expand Down Expand Up @@ -288,6 +288,21 @@ cc_library(
deps = if_cuda_is_configured(nvtx_headers()),
)

cc_library(
name = "nvtx_utils_cpu_impl",
srcs = if_cuda_is_configured(
if_google(
["nvtx_utils.cc"],
["nvtx_utils_stub.cc"],
),
["nvtx_utils_stub.cc"],
),
hdrs = ["nvtx_utils.h"],
local_defines = if_oss(["NVTX_VERSION_3_1=1"]),
visibility = ["//visibility:public"],
deps = if_cuda_is_configured(if_google(nvtx_headers())),
)

cc_library(
name = "nvtx_utils_libtpu",
srcs = ["nvtx_utils_stub.cc"],
Expand All @@ -312,21 +327,42 @@ cc_library(
}),
)

cc_library(
name = "nvtx_utils_cpu",
hdrs = ["nvtx_utils.h"],
visibility = ["//visibility:public"],
deps = [":nvtx_utils_cpu_impl"],
)

SCOPED_ANNOTATION_DEPS = [
"@com_google_absl//absl/strings",
"//tsl/platform",
"//tsl/platform:macros",
"//tsl/platform:types",
] + if_not_android([
"@xla//xla/tsl/profiler/backends/cpu:annotation_stack",
])

cc_library(
name = "scoped_annotation",
hdrs = [
"scoped_annotation.h",
],
visibility = ["//visibility:public"],
deps = [
deps = SCOPED_ANNOTATION_DEPS + [
":nvtx_utils",
"//tsl/platform",
"//tsl/platform:macros",
"//tsl/platform:types",
"@com_google_absl//absl/strings",
] + if_not_android([
"@xla//xla/tsl/profiler/backends/cpu:annotation_stack",
]),
],
)

cc_library(
name = "scoped_annotation_cpu",
hdrs = [
"scoped_annotation.h",
],
visibility = ["//visibility:public"],
deps = SCOPED_ANNOTATION_DEPS + [
":nvtx_utils_cpu",
],
)

tsl_cc_test(
Expand All @@ -344,6 +380,21 @@ tsl_cc_test(
],
)

tsl_cc_test(
name = "scoped_annotation_cpu_test",
size = "small",
srcs = ["scoped_annotation_test.cc"],
deps = [
":scoped_annotation_cpu",
"//tsl/platform:test",
"//tsl/platform:test_benchmark",
"//tsl/platform:test_main",
"@com_google_absl//absl/strings",
"@xla//xla/tsl/profiler/backends/cpu:annotation_stack",
"@xla//xla/tsl/profiler/backends/cpu:annotation_stack_impl",
],
)

cc_library(
name = "connected_traceme",
hdrs = ["connected_traceme.h"],
Expand Down

0 comments on commit 27bd5c2

Please sign in to comment.