Skip to content

Commit

Permalink
Add stream_id knobs to stream_executors and devices
Browse files Browse the repository at this point in the history
  • Loading branch information
buptzyb committed Jul 31, 2024
1 parent 8d0988e commit 9d0745b
Show file tree
Hide file tree
Showing 14 changed files with 93 additions and 26 deletions.
4 changes: 4 additions & 0 deletions tensorflow/compiler/jit/device_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,10 @@ Status DeviceCompiler<ExecutableType, ClientType>::CompileImpl(

if (state == DeviceCompileState::kUncompiled) {
XLA_SCOPED_LOGGING_TIMER("Compilation of XLA executable");
if (options.stream_id > 0) {
VLOG(2) << "Not compiling for stream group " << options.stream_id;
return absl::OkStatus();
}
if (!profiler->ShouldCompileCluster(function, compile_mode,
current_request_count)) {
VLOG(2) << "Not compiling for signature: " << human_signature;
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/jit/xla_compiler_options_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ XlaCompiler::Options GenerateCompilerOptions(
options.client = static_cast<xla::LocalClient*>(xla_device_compiler.client());
if (stream != nullptr) {
options.device_ordinal = stream->parent()->device_ordinal();
options.stream_id = stream->parent()->stream_id();
}
options.device_type = xla_device_compiler.device_type();
options.flib_def = function_library.GetFunctionLibraryDefinition();
Expand Down Expand Up @@ -112,6 +113,7 @@ XlaCompiler::Options GenerateCompilerOptionsForPjRt(
} else {
options.device_ordinal = device_base->parsed_name().id;
}
options.stream_id = device_base->GetStreamId();
options.flib_def = function_library_def;
options.graph_def_version = graph_def_version;
if (const auto* metadata = platform_info.xla_device_metadata();
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/compiler/tf2xla/xla_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ class XlaCompiler {
// -1 indicates the default device should be used.
int device_ordinal = -1;

// The stream group to use during compilation to execute instructions on.
// The compilation should only work on stream group 0 for now.
int stream_id = 0;

xla::Client* client = nullptr;

// Function library in which to find function definitions. Must be non-null.
Expand Down
16 changes: 12 additions & 4 deletions tensorflow/core/common_runtime/device_id_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,21 @@ namespace tensorflow {
// Utility method for getting the associated executor given a TfDeviceId.
class DeviceIdUtil {
public:
static absl::StatusOr<stream_executor::StreamExecutor*> ExecutorForTfDeviceId(
const tsl::DeviceType& type, stream_executor::Platform* device_manager,
tsl::TfDeviceId tf_device_id) {
static absl::StatusOr<stream_executor::StreamExecutor*>
ExecutorForTfDeviceIdAndStream(const tsl::DeviceType& type,
stream_executor::Platform* device_manager,
tsl::TfDeviceId tf_device_id, int stream_id) {
tsl::PlatformDeviceId platform_device_id;
TF_RETURN_IF_ERROR(tsl::DeviceIdManager::TfToPlatformDeviceId(
type, tf_device_id, &platform_device_id));
return device_manager->ExecutorForDevice(platform_device_id.value());
return device_manager->ExecutorForDeviceAndStream(
platform_device_id.value(), stream_id);
}
static absl::StatusOr<stream_executor::StreamExecutor*> ExecutorForTfDeviceId(
const tsl::DeviceType& type, stream_executor::Platform* device_manager,
tsl::TfDeviceId tf_device_id) {
return ExecutorForTfDeviceIdAndStream(type, device_manager, tf_device_id,
0);
}
};

Expand Down
10 changes: 5 additions & 5 deletions tensorflow/core/common_runtime/gpu/gpu_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,8 @@ Status BaseGPUDevice::Init(const SessionOptions& options,
#else
Status BaseGPUDevice::Init(const SessionOptions& options) {
#endif // TF_GPU_USE_PJRT
auto executor_status = DeviceIdUtil::ExecutorForTfDeviceId(
DEVICE_GPU, se::GPUMachineManager(), tf_device_id_);
auto executor_status = DeviceIdUtil::ExecutorForTfDeviceIdAndStream(
DEVICE_GPU, se::GPUMachineManager(), tf_device_id_, stream_id_);
if (!executor_status.status().ok()) {
return errors::Internal("Failed to get StreamExecutor for device ",
tf_device_id_.value());
Expand All @@ -540,15 +540,15 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {

std::pair<StreamGroup*, bool> emplace_result =
StreamGroupFactory::Global().Emplace(
tf_device_id_, /*stream_group_within_gpu=*/0, stream_group);
tf_device_id_, /*stream_group_within_gpu=*/stream_id_, stream_group);
if (!emplace_result.second) {
LOG(WARNING) << "StreamGroup for tf_device_id: " << tf_device_id_.value()
<< " already exists. This usually only happens in unit tests.";
}
stream_ = emplace_result.first;
#else
stream_ = StreamGroupFactory::Global().GetOrCreate(
tf_device_id_, 0, executor_, options.config.gpu_options());
tf_device_id_, stream_id_, executor_, options.config.gpu_options());
#endif // TF_GPU_USE_PJRT

// Get an allocator that allocates pinned memory on host.
Expand All @@ -558,7 +558,7 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
Allocator* host_memory_allocator = GetAllocator(attr);

device_context_ =
new GPUDeviceContext(0, stream_->compute,
new GPUDeviceContext(stream_id_, stream_->compute,
#if TENSORFLOW_USE_ROCM
stream_->nccl,
#endif
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/core/framework/device_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ class DeviceContext : public core::RefCounted {

// Returns the pinned host memory allocator for the device.
virtual Allocator* host_memory_allocator() const { return nullptr; }

// Returns the stream group index of the stream device, or 0 if it's not a
// stream device.
virtual int stream_id() const { return 0; }
};

class DeviceBase {
Expand Down Expand Up @@ -284,12 +288,21 @@ class DeviceBase {
"CopyTensorInSameDevice"));
}

// Sets the stream index of a stream device.
void SetStreamId(int stream_id) { stream_id_ = stream_id; }

// Gets the stream index of a stream device.
int GetStreamId() const { return stream_id_; }

protected:
// Does not take ownership.
void set_tensorflow_device_thread_pool(tsl::thread::ThreadPool* thread_pool) {
device_thread_pool_ = thread_pool;
}

// Stream group index that is managed by this device.
int stream_id_ = 0;

private:
tsl::Env* const env_;
CpuWorkerThreads* cpu_worker_threads_ = nullptr;
Expand Down
12 changes: 8 additions & 4 deletions third_party/xla/xla/stream_executor/cuda/cuda_platform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,11 @@ CudaPlatform::DescriptionForDevice(int ordinal) const {
return GpuExecutor::CreateDeviceDescription(ordinal);
}

absl::StatusOr<StreamExecutor*> CudaPlatform::ExecutorForDevice(int ordinal) {
absl::StatusOr<StreamExecutor*> CudaPlatform::ExecutorForDeviceAndStream(
int ordinal, int stream_id) {
StreamExecutorConfig config;
config.ordinal = ordinal;
config.stream_id = stream_id;
return GetExecutor(config);
}

Expand All @@ -132,12 +134,14 @@ absl::StatusOr<StreamExecutor*> CudaPlatform::GetExecutor(

absl::StatusOr<std::unique_ptr<StreamExecutor>>
CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
auto executor = std::make_unique<GpuExecutor>(this, config.ordinal);
auto executor =
std::make_unique<GpuExecutor>(this, config.ordinal, config.stream_id);
auto init_status = executor->Init();
if (!init_status.ok()) {
return absl::InternalError(absl::StrFormat(
"failed initializing StreamExecutor for CUDA device ordinal %d: %s",
config.ordinal, init_status.ToString()));
"failed initializing StreamExecutor for CUDA device "
"ordinal %d stream group %d: %s",
config.ordinal, config.stream_id, init_status.ToString()));
}

return std::move(executor);
Expand Down
6 changes: 5 additions & 1 deletion third_party/xla/xla/stream_executor/cuda/cuda_platform.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ class CudaPlatform : public Platform {
absl::StatusOr<std::unique_ptr<DeviceDescription>> DescriptionForDevice(
int ordinal) const override;

absl::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
absl::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override {
return ExecutorForDeviceAndStream(ordinal, 0);
}
absl::StatusOr<StreamExecutor*> ExecutorForDeviceAndStream(
int ordinal, int stream_id) override;

absl::StatusOr<StreamExecutor*> GetExecutor(
const StreamExecutorConfig& config) override;
Expand Down
11 changes: 7 additions & 4 deletions third_party/xla/xla/stream_executor/executor_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ absl::StatusOr<StreamExecutor*> ExecutorCache::GetOrCreate(
Entry* entry = nullptr;
{
absl::MutexLock lock{&mutex_};
entry = &cache_[config.ordinal];
entry = &cache_[{config.ordinal, config.stream_id}];
// Release the map lock; the address of 'entry' is stable because
// absl::node_hash_map guarantees reference stability.
}
Expand Down Expand Up @@ -90,18 +90,21 @@ absl::StatusOr<StreamExecutor*> ExecutorCache::Get(
absl::StrFormat("No executors own stream %p", config.gpu_stream));
}

if (auto it = cache_.find(config.ordinal); it != cache_.end()) {
if (auto it = cache_.find({config.ordinal, config.stream_id});
it != cache_.end()) {
entry = &it->second;
} else {
return absl::NotFoundError(absl::StrFormat(
"No executors registered for ordinal %d", config.ordinal));
"No executors registered for ordinal %d, stream group %d",
config.ordinal, config.stream_id));
}
}

absl::ReaderMutexLock lock{&entry->configurations_mutex};
if (entry->configurations.empty()) {
return absl::NotFoundError(absl::StrFormat(
"No executors registered for ordinal %d", config.ordinal));
"No executors registered for ordinal %d, stream group %d",
config.ordinal, config.stream_id));
}

for (auto& [entry_config, entry_executor] : entry->configurations) {
Expand Down
10 changes: 6 additions & 4 deletions third_party/xla/xla/stream_executor/executor_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,13 @@ class ExecutorCache {
configurations ABSL_GUARDED_BY(configurations_mutex);
};

// Maps ordinal number to a list of cached executors for that ordinal.
// We key off of ordinal (instead of just looking up all fields in the
// StreamExecutorConfig) for a slight improvement in lookup time.
// Maps ordinal and stream_id to a list of cached executors for that ordinal
// and stream_id. We key off of the ordinal-stream pair (instead of just
// looking up all fields in the StreamExecutorConfig) for a slight improvement
// in lookup time.
absl::Mutex mutex_;
absl::node_hash_map<int, Entry> cache_ ABSL_GUARDED_BY(mutex_);
absl::node_hash_map<std::pair<int, int>, Entry> cache_
ABSL_GUARDED_BY(mutex_);

ExecutorCache(const ExecutorCache&) = delete;
void operator=(const ExecutorCache&) = delete;
Expand Down
8 changes: 7 additions & 1 deletion third_party/xla/xla/stream_executor/gpu/gpu_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,12 @@ class GpuExecutor : public StreamExecutorCommon {
public:
// sub_platform indicates the subplatform used in this executor; it must
// be a CUDA type.
GpuExecutor(Platform* platform, int device_ordinal)
GpuExecutor(Platform* platform, int device_ordinal, int stream_id = 0)
: StreamExecutorCommon(platform),
device_(0),
context_(nullptr),
device_ordinal_(device_ordinal),
stream_id_(stream_id),
cc_major_(0),
cc_minor_(0),
version_(0),
Expand All @@ -125,6 +126,8 @@ class GpuExecutor : public StreamExecutorCommon {

int device_ordinal() const override { return device_ordinal_; };

int stream_id() const override { return stream_id_; };

absl::StatusOr<std::unique_ptr<Kernel>> LoadKernel(
const MultiKernelLoaderSpec& spec) override;

Expand Down Expand Up @@ -360,6 +363,9 @@ class GpuExecutor : public StreamExecutorCommon {
// for use in getting device metadata. Immutable post-initialization.
int device_ordinal_;

// The stream group index value that this executor was initialized with.
int stream_id_;

// The major version of the compute capability for device_.
int cc_major_;

Expand Down
11 changes: 11 additions & 0 deletions third_party/xla/xla/stream_executor/platform.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ struct StreamExecutorConfig {

// The ordinal of the device to be managed by the returned StreamExecutor.
int ordinal;

// The ordinal of the stream group to be managed by the returned
// StreamExecutor.
int stream_id = 0;
};

// Abstract base class for a platform registered with the PlatformManager.
Expand Down Expand Up @@ -114,6 +118,13 @@ class Platform {
// the Platform owns the executors in a singleton-like fashion.
virtual absl::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) = 0;

// Returns the executor with the given ordinal and stream_id on this
// platform. Only for a CUDA device the stream_id is useful.
virtual absl::StatusOr<StreamExecutor*> ExecutorForDeviceAndStream(
int ordinal, int stream_id) {
return ExecutorForDevice(ordinal);
}

// Returns a device constructed with the options specified in "config".
// Ownership of the executor is NOT transferred to the caller.
virtual absl::StatusOr<StreamExecutor*> GetExecutor(
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/stream_executor/stream_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class StreamExecutor {
// Returns the device ordinal.
virtual int device_ordinal() const { return -1; }

// Returns the stream group ordinal.
virtual int stream_id() const { return -1; }

// Creates and initializes a Stream.
virtual absl::StatusOr<std::unique_ptr<Stream>> CreateStream(
std::optional<std::variant<StreamPriority, int>> priority) = 0;
Expand Down
9 changes: 6 additions & 3 deletions third_party/xla/xla/stream_executor/stream_executor_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@ namespace stream_executor {

// A StreamExecutor manages a single device, in terms of executing work (kernel
// launches) and memory management (allocation/deallocation, memory copies to
// and from the device). It is conceptually the "handle" for a device -- Stream
// objects, which are used to enqueue work to run on the
// coprocessor have a StreamExecutor instance as their "parent" object.
// and from the device). One device can be managed by multiple StreamExecutors,
// e.g., a GPU device has multiple stream groups enabled, and each
// StreamExecutor should manage the operations in one stream group. It is
// conceptually the "handle" for a device -- Stream objects, which are used to
// enqueue work to run on the coprocessor have a StreamExecutor instance as
// their "parent" object.
//
// StreamExecutor objects have an underlying platform that is specified up
// front;
Expand Down

0 comments on commit 9d0745b

Please sign in to comment.