Skip to content

Commit

Permalink
Merge pull request tensorflow#60887 from tensorflow/r2.13-7a0fa5541fd
Browse files Browse the repository at this point in the history
r2.13 cherry-pick: 7a0fa55 "* Fix memory free error in tpu_execute.cc when there are zero addresses   returned by TpuEmbedding. * Adding VLOGs to tpu_execute.cc for improving debugging."
  • Loading branch information
learning-to-play authored Jun 15, 2023
2 parents 035d7b1 + a949616 commit 8c602e6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
2 changes: 2 additions & 0 deletions tensorflow/core/tpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ cc_library(
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/log",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],
)

Expand Down
26 changes: 19 additions & 7 deletions tensorflow/core/tpu/tpu_execute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ limitations under the License.
#include "absl/cleanup/cleanup.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/hlo/ir/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h"
Expand Down Expand Up @@ -462,9 +464,6 @@ xla::StatusOr<xla::ExecutionOutput> TPUExecute(
host_transfer_manager->Initialize(
host_transfers, rendezvous_key_base, ctx));

VLOG(2) << "Cloud TPU: Executing computation on device "
<< node_context->device_ordinal();

xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
run_options.set_device_assignment(device_assignment);
Expand Down Expand Up @@ -504,17 +503,25 @@ xla::StatusOr<xla::ExecutionOutput> TPUExecute(
prefetch.offset());
}

VLOG(1) << "TPUExecute: Updating dynamic HLO inputs on "
<< node_context->device_ordinal();

TF_RETURN_IF_ERROR(UpdateDynamicInputs(stream, backend->memory_allocator(),
&arguments, input_shapes));

// Retrieve the TPU embedding memory addresses to be fed to the TPU. The
// memory addresses are communicated with a dynamically allocated C array
// (which needs to be free'd once the function terminates).
SE_DeviceMemoryBase* device_memory_addrs;
VLOG(1) << "TPUExecute: Updating TPUEmbedding memory addresses on "
<< node_context->device_ordinal();

SE_DeviceMemoryBase* device_memory_addrs = nullptr;
size_t device_memory_addrs_count;
auto device_memory_cleanup = absl::MakeCleanup([&device_memory_addrs]() {
stream_executor::tpu::OpsApiFn()->SE_DeviceMemoryBase_FreeArrayFn(
device_memory_addrs);
if (device_memory_addrs != nullptr) {
stream_executor::tpu::OpsApiFn()->SE_DeviceMemoryBase_FreeArrayFn(
device_memory_addrs);
}
});

SE_StreamExecutor executor{stream->parent()};
Expand All @@ -529,10 +536,15 @@ xla::StatusOr<xla::ExecutionOutput> TPUExecute(

// Add the TPU embedding memory addresses as additional arguments for the TPU
// executable.
VLOG(1) << "TPUExecute: Adding " << device_memory_addrs_count
<< " TPUEmbedding memory addresses to HLO parameters.";
for (int i = 0; i < device_memory_addrs_count; ++i) {
xla::ShapeTree<xla::MaybeOwningDeviceMemory> tree(
xla::ShapeUtil::MakeOpaqueShape());
*tree.mutable_element({}) = ApiConverter::FromC(device_memory_addrs[i]);
const SE_DeviceMemoryBase& addr = device_memory_addrs[i];
VLOG(2) << absl::StrFormat("Device memory addr[%i] = {%p, %llu, %llu}", i,
addr.opaque, addr.size, addr.payload);
*tree.mutable_element({}) = ApiConverter::FromC(addr);
xla::ExecutionInput input(std::move(tree));
arguments.push_back(std::move(input));
}
Expand Down

0 comments on commit 8c602e6

Please sign in to comment.