Skip to content

Commit

Permalink
Select stream group for execution
Browse files Browse the repository at this point in the history
  • Loading branch information
buptzyb committed Aug 21, 2024
1 parent 790479e commit 18f850f
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 57 deletions.
3 changes: 3 additions & 0 deletions tensorflow/core/common_runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2143,6 +2143,7 @@ tf_cuda_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/gpu:gpu_serving_device_selector",
"//tensorflow/core/debug:debug_graph_utils",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/nccl:collective_communicator",
Expand All @@ -2151,6 +2152,8 @@ tf_cuda_library(
"//tensorflow/core/profiler/lib:profiler_backends",
"//tensorflow/core/profiler/lib:traceme_encode",
"@com_google_absl//absl/container:flat_hash_set",
"@local_xla//xla/tsl/framework:serving_device_selector",
"@local_xla//xla/tsl/framework:serving_device_selector_policies",
],
alwayslink = 1,
)
Expand Down
157 changes: 100 additions & 57 deletions tensorflow/core/common_runtime/direct_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/local_session_selection.h"
Expand Down Expand Up @@ -80,6 +81,7 @@ limitations under the License.
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/env_var.h"
#include "xla/tsl/framework/serving_device_selector_policies.h"

namespace tensorflow {

Expand Down Expand Up @@ -325,7 +327,10 @@ DirectSession::DirectSession(const SessionOptions& options,
device_mgr_(device_mgr),
factory_(factory),
cancellation_manager_(new CancellationManager()),
operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()),
stream_selector_(absl::make_unique<gpu::GpuServingDeviceSelector>(
std::max(device_mgr->StreamGroupCount(), 1),
std::make_unique<tsl::RoundRobinPolicy>())) {
const int thread_pool_size =
options_.config.session_inter_op_thread_pool_size();
if (thread_pool_size > 0) {
Expand Down Expand Up @@ -735,11 +740,19 @@ Status DirectSession::RunInternal(
}
};

// Get the stream group for execution.
tsl::DeviceReservation reservation =
stream_selector_->ReserveDevice("StreamSelector");
int stream_group_idx = reservation.device_index();

if (can_execute_synchronously) {
PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
args.rendezvous = &rendezvous;

const auto& item = executors_and_keys->items[0];
const auto& item =
executors_and_keys->stream_items[0].size() <= stream_group_idx
? executors_and_keys->items[0]
: executors_and_keys->stream_items[0][stream_group_idx];
set_threadpool_args_for_item(item, &args);
run_status = item.executor->Run(args);
} else {
Expand All @@ -759,7 +772,11 @@ Status DirectSession::RunInternal(
executors_done.Notify();
});

for (const auto& item : executors_and_keys->items) {
for (int i = 0; i < executors_and_keys->items.size(); ++i) {
const auto& item =
executors_and_keys->stream_items[i].size() <= stream_group_idx
? executors_and_keys->items[i]
: executors_and_keys->stream_items[i][stream_group_idx];
set_threadpool_args_for_item(item, &args);
item.executor->RunAsync(args, barrier->Get());
}
Expand Down Expand Up @@ -1358,6 +1375,7 @@ Status DirectSession::CreateExecutors(
}
}
ek->items.reserve(graphs.size());
ek->stream_items.reserve(graphs.size());
const auto& optimizer_opts =
options_.config.graph_options().optimizer_options();

Expand All @@ -1378,6 +1396,7 @@ Status DirectSession::CreateExecutors(
return absl::OkStatus();
}}));

int stream_group_count = device_mgr_->StreamGroupCount();
GraphOptimizer optimizer(optimizer_opts);
for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
const string& partition_name = iter->first;
Expand All @@ -1387,65 +1406,89 @@ Status DirectSession::CreateExecutors(
TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));

ek->items.resize(ek->items.size() + 1);
auto* item = &(ek->items.back());
auto lib = func_info->proc_flr->GetFLR(partition_name);
if (lib == nullptr) {
return errors::Internal("Could not find device: ", partition_name);
}
item->flib = lib;

LocalExecutorParams params;
params.device = device;
params.session_metadata = session_metadata;
params.function_library = lib;
auto opseg = device->op_segment();
params.create_kernel =
[this, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) {
// NOTE(mrry): We must not share function kernels (implemented
// using `CallOp`) between subgraphs, because `CallOp::handle_`
// is tied to a particular subgraph. Even if the function itself
// is stateful, the `CallOp` that invokes it is not.
if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
return lib->CreateKernel(props, kernel);
}
auto create_fn = [lib, &props](OpKernel** kernel) {
return lib->CreateKernel(props, kernel);
ek->stream_items.resize(ek->stream_items.size() + 1);
std::vector<PerPartitionExecutorsAndLib>* stream_item =
&(ek->stream_items.back());
bool use_multistream = device_mgr_->DeviceHasMultipleStreams(device);
if (use_multistream) {
stream_item->resize(stream_group_count);
}
for (int exec_idx = 0;
exec_idx <= (use_multistream ? stream_group_count : 0); ++exec_idx) {
PerPartitionExecutorsAndLib* item;
Device* exec_device;
if (exec_idx == 0) {
// Create the original executor first, then stream-related executors.
item = &(ek->items.back());
exec_device = device;
} else {
item = &(stream_item->at(exec_idx - 1));
exec_device = device_mgr_->LookupStream(device, exec_idx - 1);
}
auto lib = func_info->proc_flr->GetFLR(exec_device->name());
if (lib == nullptr) {
return errors::Internal("Could not find device: ", exec_device->name());
}
item->flib = lib;

LocalExecutorParams params;
params.device = exec_device;
params.session_metadata = session_metadata;
params.function_library = lib;
auto opseg = device->op_segment();
params.create_kernel =
[this, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) {
// NOTE(mrry): We must not share function kernels (implemented
// using `CallOp`) between subgraphs, because `CallOp::handle_`
// is tied to a particular subgraph. Even if the function itself
// is stateful, the `CallOp` that invokes it is not.
if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
return lib->CreateKernel(props, kernel);
}
auto create_fn = [lib, &props](OpKernel** kernel) {
return lib->CreateKernel(props, kernel);
};
// Kernels created for subgraph nodes need to be cached. On
// cache miss, create_fn() is invoked to create a kernel based
// on the function library here + global op registry.
return opseg->FindOrCreate(session_handle_, props->node_def.name(),
kernel, create_fn);
};
// Kernels created for subgraph nodes need to be cached. On
// cache miss, create_fn() is invoked to create a kernel based
// on the function library here + global op registry.
return opseg->FindOrCreate(session_handle_, props->node_def.name(),
kernel, create_fn);
};
params.delete_kernel = [lib](OpKernel* kernel) {
if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string()))
delete kernel;
};
params.delete_kernel = [lib](OpKernel* kernel) {
if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string()))
delete kernel;
};

optimizer.Optimize(lib, options_.env, device, &partition_graph,
GraphOptimizer::Options());
if (exec_idx == 0) {
optimizer.Optimize(lib, options_.env, device, &partition_graph,
GraphOptimizer::Options());

// TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
const DebugOptions& debug_options =
options.callable_options.run_options().debug_options();
if (!debug_options.debug_tensor_watch_opts().empty()) {
TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
debug_options, partition_graph.get(), params.device));
}
// TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
const DebugOptions& debug_options =
options.callable_options.run_options().debug_options();
if (!debug_options.debug_tensor_watch_opts().empty()) {
TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
debug_options, partition_graph.get(), params.device));
}

TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
device->name(),
partition_graph.get()));
TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
device->name(),
partition_graph.get()));
} else {
partition_graph = std::make_unique<Graph>(func_info->flib_def.get());
CopyGraph(*ek->items.back().graph, partition_graph.get());
}

item->executor = nullptr;
item->device = device;
auto executor_type = options_.config.experimental().executor_type();
TF_RETURN_IF_ERROR(
NewExecutor(executor_type, params, *partition_graph, &item->executor));
if (!options_.config.experimental().disable_output_partition_graphs() ||
options_.config.graph_options().build_cost_model() > 0) {
item->graph = std::move(partition_graph);
item->executor = nullptr;
item->device = exec_device;
auto executor_type = options_.config.experimental().executor_type();
TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, *partition_graph,
&item->executor));
if (!options_.config.experimental().disable_output_partition_graphs() ||
options_.config.graph_options().build_cost_model() > 0) {
item->graph = std::move(partition_graph);
}
}
}

Expand Down
6 changes: 6 additions & 0 deletions tensorflow/core/common_runtime/direct_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "xla/tsl/framework/serving_device_selector.h"

namespace tensorflow {

Expand Down Expand Up @@ -165,6 +166,7 @@ class DirectSession : public Session {
std::unique_ptr<Graph> graph;
NameNodeMap name_to_node;
std::vector<PerPartitionExecutorsAndLib> items;
std::vector<std::vector<PerPartitionExecutorsAndLib>> stream_items;
std::unordered_map<string, size_t> input_name_to_index;
std::unordered_map<string, string> input_name_to_rendezvous_key;
std::unordered_map<string, size_t> output_name_to_index;
Expand Down Expand Up @@ -438,6 +440,10 @@ class DirectSession : public Session {
// pool according to other specifications of RunOptions and ConfigProto.
bool run_in_caller_thread_ = false;

// Select the stream group to execute GPU graphs if there are multiple stream
// groups available.
std::unique_ptr<tsl::ServingDeviceSelector> stream_selector_;

DirectSession(const DirectSession&) = delete;
void operator=(const DirectSession&) = delete;

Expand Down

0 comments on commit 18f850f

Please sign in to comment.