Skip to content

Commit

Permalink
Add comments and use previous_pool_id var
Browse files Browse the repository at this point in the history
  • Loading branch information
buptzyb committed Jun 26, 2024
1 parent dabdccc commit 85db0c6
Showing 1 changed file with 12 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,13 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator(
static auto* all_ids_ = new std::vector<tsl::PlatformDeviceId>();
if (!create_new_pool_) {
DCHECK(all_pools_->size() == all_ids_->size());

// If the pool_ is found in all_pools_, it means it has been initialized
// before. This can happen in some cases, such as when multiple virtual
// devices are created from one physical GPU, the virtual devices will
// actually share the same CUDA memory pool. So the following pool
// initialization steps should be skipped to avoid duplicated initialization
// of the same pool.
for (auto& pool_item_ : *all_pools_) {
if (*pool_item_ == pool_) {
VLOG(2) << Name()
Expand Down Expand Up @@ -282,11 +289,11 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator(
// Set the previous pools access to the current GPU.
map.location.id = platform_device_id.value();

VLOG(2) << "Set access to the pool id: " << (*all_ids_)[i].value()
int previous_pool_id = (*all_ids_)[i].value();
VLOG(2) << "Set access to the pool id: " << previous_pool_id
<< " location id: " << map.location.id;
if (auto status =
cuDeviceCanAccessPeer(&canAccessPeer, (*all_ids_)[i].value(),
platform_device_id.value())) {
if (auto status = cuDeviceCanAccessPeer(&canAccessPeer, previous_pool_id,
platform_device_id.value())) {
pool_ = nullptr;
LOG(FATAL) // Crash OK.
<< "cuDeviceCanAccessPeer failed: " << GetCudaErrorMessage(status);
Expand All @@ -296,7 +303,7 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator(
pool_ = nullptr;
LOG(FATAL) // Crash OK.
<< "Error when setting access to the pool id: "
<< (*all_ids_)[i].value() << " location id: " << map.location.id
<< previous_pool_id << " location id: " << map.location.id
<< " error: " << GetCudaErrorMessage(status);
}
}
Expand Down

0 comments on commit 85db0c6

Please sign in to comment.