Skip to content

Commit 3584da2

Browse files
jakeharmon8copybara-github
authored andcommitted
Create tf_vendored and load TSL with it
See openxla@9517b9b PiperOrigin-RevId: 508395488
1 parent f87dc12 commit 3584da2

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

third_party/repo.bzl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,22 @@ def tf_http_archive(name, sha256, urls, **kwargs):
139139
urls = urls,
140140
**kwargs
141141
)
142+
143+
def _tf_vendored_impl(repository_ctx):
144+
parent_path = repository_ctx.path(repository_ctx.attr.parent).dirname
145+
146+
# get_child doesn't allow slashes. Yes this is silly. bazel_skylib paths
147+
# doesn't work with path objects.
148+
relpath_parts = repository_ctx.attr.relpath.split("/")
149+
vendored_path = parent_path
150+
for part in relpath_parts:
151+
vendored_path = vendored_path.get_child(part)
152+
repository_ctx.symlink(vendored_path, ".")
153+
154+
tf_vendored = repository_rule(
155+
implementation = _tf_vendored_impl,
156+
attrs = {
157+
"parent": attr.label(default = "//:WORKSPACE"),
158+
"relpath": attr.string(),
159+
},
160+
)

workspace2.bzl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
# Import third party config rules.
44
load("@bazel_skylib//lib:versions.bzl", "versions")
5-
load("@bazel_skylib//lib:paths.bzl", "paths")
65
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
76
load("//third_party/gpus:rocm_configure.bzl", "rocm_configure")
87
load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
@@ -13,7 +12,7 @@ load("//third_party/systemlibs:syslibs_configure.bzl", "syslibs_configure")
1312
load("//tools/toolchains:cpus/aarch64/aarch64_compiler_configure.bzl", "aarch64_compiler_configure")
1413
load("//tools/toolchains:cpus/arm/arm_compiler_configure.bzl", "arm_compiler_configure")
1514
load("//tools/toolchains/embedded/arm-linux:arm_linux_toolchain_configure.bzl", "arm_linux_toolchain_configure")
16-
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
15+
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls", "tf_vendored")
1716
load("//third_party/clang_toolchain:cc_configure_clang.bzl", "cc_download_clang_toolchain")
1817
load("//third_party/llvm:setup.bzl", "llvm_setup")
1918

@@ -39,7 +38,7 @@ load("//tools/toolchains/remote_config:configs.bzl", "initialize_rbe_configs")
3938
load("//tools/toolchains/remote:configure.bzl", "remote_execution_configure")
4039
load("//tools/toolchains/clang6:repo.bzl", "clang6_configure")
4140

42-
def _initialize_third_party(xla_path):
41+
def _initialize_third_party():
4342
""" Load third party repositories. See above load() statements. """
4443
absl()
4544
benchmark()
@@ -54,7 +53,7 @@ def _initialize_third_party(xla_path):
5453
tensorrt()
5554
triton()
5655

57-
native.local_repository(name = "tsl", path = paths.join(xla_path + "third_party/tsl"))
56+
tf_vendored(name = "tsl", relpath = "third_party/tsl")
5857

5958
# Toolchains & platforms required by Tensorflow to build.
6059
def _tf_toolchains():
@@ -605,7 +604,7 @@ def _tf_repositories():
605604

606605
# buildifier: disable=function-docstring
607606
# buildifier: disable=unnamed-macro
608-
def workspace(xla_path = "./"):
607+
def workspace():
609608
# Check the bazel version before executing any repository rules, in case
610609
# those rules rely on the version we require here.
611610
versions.check("1.0.0")
@@ -614,7 +613,7 @@ def workspace(xla_path = "./"):
614613
_tf_toolchains()
615614

616615
# Import third party repositories according to go/tfbr-thirdparty.
617-
_initialize_third_party(xla_path)
616+
_initialize_third_party()
618617

619618
# Import all other repositories. This should happen before initializing
620619
# any external repositories, because those come with their own

0 commit comments

Comments
 (0)