Skip to content

Commit

Permalink
Fix collective-permute host memory not being unregistered.
Browse files Browse the repository at this point in the history
CUDA host memory was registered in Initialize() and unregistered in Cleanup() but Cleanup() is not called. Now instead store host memory as a steam_executor::MemoryAllocation object, which automatically unregisters it in the destructor.

PiperOrigin-RevId: 723243458
  • Loading branch information
reedwm authored and tensorflower-gardener committed Feb 4, 2025
1 parent 482e848 commit 91ae966
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 19 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/backends/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,7 @@ cc_library(
"//xla/service:global_device_id",
"//xla/service/gpu:backend_configs_cc",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:memory_allocation",
"//xla/stream_executor:stream",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/platform:errors",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "xla/backends/gpu/runtime/nccl_collective_permute_thunk.h"

#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
Expand Down Expand Up @@ -43,6 +44,7 @@ limitations under the License.
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/memory_allocation.h"
#include "xla/stream_executor/stream.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
Expand Down Expand Up @@ -176,10 +178,10 @@ absl::Status NcclCollectivePermuteStartThunk::Initialize(
GetCurrentId(params.collective_params, config_));
absl::MutexLock lock(&barrier_mutex_);
if (barrier_flags_.find(current_id) == barrier_flags_.end()) {
if (!params.stream->parent()->HostMemoryRegister(
&barrier_flags_[current_id], sizeof(uint8_t))) {
LOG(ERROR) << "Registering barrier flag failed.";
}
TF_ASSIGN_OR_RETURN(
std::unique_ptr<se::MemoryAllocation> alloc,
params.stream->parent()->HostMemoryAllocate(sizeof(uint8_t)));
barrier_flags_[current_id] = std::move(alloc);
}

TF_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -212,18 +214,6 @@ absl::Status NcclCollectivePermuteStartThunk::Initialize(
return absl::OkStatus();
}

absl::Status NcclCollectivePermuteStartThunk::Cleanup(
const CleanupParams& params) {
TF_ASSIGN_OR_RETURN(const int64_t current_id,
GetCurrentId(params.collective_params, config_));

absl::MutexLock lock(&barrier_mutex_);
if (!params.executor->HostMemoryUnregister(&barrier_flags_[current_id])) {
LOG(ERROR) << "Unregistering barrier flag failed.";
}
return absl::OkStatus();
}

absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective(
const ExecuteParams& params, se::Stream& stream,
CommunicatorHandle comm_handle) {
Expand All @@ -248,7 +238,7 @@ absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective(
TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params));
if (use_memcpy) {
se::DeviceMemoryBase sync_var_address =
se::DeviceMemoryBase((void*)(&barrier_flags_[current_id]));
se::DeviceMemoryBase(barrier_flags_[current_id]->opaque());
TF_RETURN_IF_ERROR(comm_handle.comm->AllReduce(
sync_var_address, sync_var_address, PrimitiveType::U8, 1,
ReductionKind::MIN, GpuCollectives::On(stream)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
#define XLA_BACKENDS_GPU_RUNTIME_NCCL_COLLECTIVE_PERMUTE_THUNK_H_

#include <cstdint>
#include <memory>
#include <unordered_map>

#include "absl/base/thread_annotations.h"
#include "absl/container/node_hash_map.h"
Expand All @@ -31,6 +33,7 @@ limitations under the License.
#include "xla/core/collectives/communicator.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/stream_executor/memory_allocation.h"
#include "xla/stream_executor/stream.h"
#include "xla/tsl/concurrency/async_value.h"
#include "xla/tsl/concurrency/async_value_ref.h"
Expand Down Expand Up @@ -104,7 +107,6 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk {
const std::vector<Buffer>& buffers,
bool p2p_memcpy_enabled);
absl::Status Initialize(const InitializeParams& params) override;
absl::Status Cleanup(const CleanupParams& params) override;

static const char* GetHloOpName() { return "collective-permute-start"; }

Expand All @@ -119,7 +121,8 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk {
std::vector<Buffer> buffers_;
RecvPtrMap recv_ptr_map_;
absl::Mutex barrier_mutex_;
std::unordered_map<int64_t, uint8_t> barrier_flags_;
std::unordered_map<int64_t, std::unique_ptr<se::MemoryAllocation>>
barrier_flags_;
bool p2p_memcpy_enabled_ = false;
int64_t device_count_;
};
Expand Down

0 comments on commit 91ae966

Please sign in to comment.