Skip to content

Commit

Permalink
multiple streams for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
buptzyb committed Aug 12, 2024
1 parent 8f7e71c commit 7e28117
Show file tree
Hide file tree
Showing 60 changed files with 1,261 additions and 283 deletions.
4 changes: 4 additions & 0 deletions tensorflow/compiler/jit/device_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,10 @@ Status DeviceCompiler<ExecutableType, ClientType>::CompileImpl(

if (state == DeviceCompileState::kUncompiled) {
XLA_SCOPED_LOGGING_TIMER("Compilation of XLA executable");
if (options.stream_id > 0) {
VLOG(2) << "Not compiling for stream group " << options.stream_id;
return absl::OkStatus();
}
if (!profiler->ShouldCompileCluster(function, compile_mode,
current_request_count)) {
VLOG(2) << "Not compiling for signature: " << human_signature;
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/jit/variable_info_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev,
" to GetVariableInfosFromInputs.");
}
ResourceHandle handle = inputs[var_idx]->flat<ResourceHandle>()(0);
if (handle.device() != dev->attributes().name()) {
if (!DeviceNameUtils::HaveSameDeviceName(handle.device(),
dev->attributes().name())) {
std::string definition_location =
DefinitionLocationMsg(handle.definition_stack_trace());
return errors::InvalidArgument(
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/jit/xla_compiler_options_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ XlaCompiler::Options GenerateCompilerOptions(
options.client = static_cast<xla::LocalClient*>(xla_device_compiler.client());
if (stream != nullptr) {
options.device_ordinal = stream->parent()->device_ordinal();
options.stream_id = stream->parent()->stream_id();
}
options.device_type = xla_device_compiler.device_type();
options.flib_def = function_library.GetFunctionLibraryDefinition();
Expand Down Expand Up @@ -84,6 +85,7 @@ XlaCompiler::Options GenerateCompilerOptionsForPjRt(
const DeviceBase* device_base, const XlaPlatformInfo& platform_info) {
XlaCompiler::Options options;
options.device_ordinal = device_base->parsed_name().id;
options.stream_id = device_base->GetStreamId();
options.flib_def = function_library.GetFunctionLibraryDefinition();
options.graph_def_version = function_library.graph_def_version();
if (const auto* metadata = platform_info.xla_device_metadata();
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/tools/kernel_gen/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ cc_library(
"@llvm-project//mlir:mlir_runner_utils",
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers",
"//tensorflow/compiler/xla/stream_executor/cuda:cuda_driver",
"//tensorflow/compiler/xla/stream_executor/cuda:cuda_gpu_executor_header",
"//tensorflow/compiler/xla/stream_executor/cuda:stream_executor_cuda",
]) + if_rocm_is_configured([
"@local_config_rocm//rocm:rocm_headers",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ limitations under the License.

#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
#if GOOGLE_CUDA
#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.h"
#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_executor.h"
#endif
#include "tensorflow/compiler/xla/stream_executor/stream.h"
#include "tensorflow/compiler/xla/stream_executor/stream_executor_internal.h"
#include "tensorflow/core/platform/logging.h"
Expand Down Expand Up @@ -141,10 +145,20 @@ extern "C" void _mlir_ciface_tf_launch_kernel(void *ctx, void *module_blob,
return;
}
GPURuntimeCache *cache = nullptr;
#if GOOGLE_CUDA
auto *gpu_executor = static_cast<stream_executor::gpu::GpuExecutor *>(
op_kernel_ctx->op_device_context()->stream()->parent()->implementation());
int ctx_id = gpu_executor->gpu_context()->id();
#else
int ctx_id = 0;
#endif
std::string name =
ctx_id > 0
? absl::StrCat(GPURuntimeCache::kDefaultResourceName, "_", ctx_id)
: GPURuntimeCache::kDefaultResourceName;
OP_REQUIRES_OK(op_kernel_ctx, rm->LookupOrCreate<GPURuntimeCache>(
rm->default_container(),
GPURuntimeCache::kDefaultResourceName,
&cache, GPURuntimeCache::Create));
rm->default_container(), name, &cache,
GPURuntimeCache::Create));
assert(cache != nullptr && "cache creation must not fail");
tensorflow::core::ScopedUnref ref(cache);

Expand Down
9 changes: 7 additions & 2 deletions tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1085,9 +1085,14 @@ Status TRTEngineOp::GetEngineCacheResource(OpKernelContext* ctx,
resource_name.remove_prefix(last_slash + 1);
}

// Get engine cache.
// Get engine cache. Each stream group should have its own cache to avoid
// sharing the same engine_context.
int stream_id = ctx->device()->GetStreamId();
std::string name = stream_id > 0 ? strings::StrCat(std::string(resource_name),
"_", stream_id)
: std::string(resource_name);
return ctx->resource_manager()->LookupOrCreate(
std::string(kTfTrtContainerName), std::string(resource_name), cache_res,
std::string(kTfTrtContainerName), name, cache_res,
{[this, ctx](TRTEngineCacheResource** cr) -> Status {
*cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_);
return OkStatus();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ class InitializeTRTResource : public OpKernel {
tensorflow::profiler::TraceMeLevel::kInfo);
ResourceHandle handle = HandleFromInput(ctx, 0);
core::RefCountPtr<TRTEngineCacheResource> resource;
int stream_id = ctx->device()->GetStreamId();
if (stream_id > 0) {
handle.set_name(strings::StrCat(handle.name(), "_", stream_id));
}
OP_REQUIRES_OK(
ctx, LookupOrCreateResource<TRTEngineCacheResource>(
ctx, handle, &resource,
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/compiler/tf2xla/xla_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ class XlaCompiler {
// -1 indicates the default device should be used.
int device_ordinal = -1;

// The stream group to use during compilation to execute instructions on.
// The compilation should only work on stream group 0 for now.
int stream_id = 0;

xla::Client* client = nullptr;

// Function library in which to find function definitions. Must be non-null.
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ cc_library(
"//tensorflow/compiler/xla/stream_executor/platform:dso_loader",
"//tensorflow/tsl/platform:env",
"//tensorflow/tsl/platform:static_threadlocal",
"//tensorflow/tsl/util:env_var",
] + tf_additional_cuda_driver_deps()) + select({
# include dynamic loading implementation only when if_cuda_is_configured and build dynamically
"//tensorflow/tsl:is_cuda_enabled_and_oss": ["cudart_stub"],
Expand Down
60 changes: 36 additions & 24 deletions tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/tsl/platform/stacktrace.h"
#include "tensorflow/tsl/platform/static_threadlocal.h"
#include "tensorflow/tsl/platform/threadpool.h"
#include "tensorflow/tsl/util/env_var.h"

bool FLAGS_gpuexec_cuda_driver_inject_init_error = false;
bool FLAGS_gpuexec_cuda_sync_around_driver_calls = false;
Expand Down Expand Up @@ -76,6 +77,7 @@ namespace gpu {

/* static */ absl::Mutex CreatedContexts::mu_{absl::kConstInit};
/* static */ int64_t CreatedContexts::next_id_ = 1; // 0 means "no context"
static std::unordered_map<CUcontext, CUdevice> primary_ctx_used_;

namespace {

Expand Down Expand Up @@ -328,8 +330,8 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options,
}

/* static */ tsl::Status GpuDriver::CreateContext(
int device_ordinal, CUdevice device, const DeviceOptions& device_options,
GpuContext** context) {
int device_ordinal, int stream_id, CUdevice device,
const DeviceOptions& device_options, GpuContext** context) {
*context = nullptr;

int flags = 0;
Expand Down Expand Up @@ -359,31 +361,29 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options,

former_context = cuda::CurrentContextOrDie();
res = cuDevicePrimaryCtxRetain(&new_context, device);
if (former_context != nullptr) {
CUdevice former_device;
if (cuCtxGetDevice(&former_device) == CUDA_SUCCESS) {
if (former_device == device) {
if (former_context == new_context) {
VLOG(2) << "The primary context " << former_context << " for device "
<< device
<< " exists before initializing the StreamExecutor.";
} else {
LOG(WARNING) << "A non-primary context " << former_context
<< " for device " << device
<< " exists before initializing the StreamExecutor. The "
<< "primary context is now " << new_context << ". We "
<< "haven't verified StreamExecutor works with that.";
}
}
} else {
LOG(ERROR) << "Failed to get the device of the current context "
<< former_context;
}
int64_t gpu_context_count;
TF_CHECK_OK(tsl::ReadInt64FromEnvVar("TF_GPU_CONTEXT_COUNT",
/*default_val=*/1, &gpu_context_count));
int context_idx = stream_id % gpu_context_count;
if (CreatedContexts::OrdinalHas(device_ordinal, context_idx)) {
new_context = CreatedContexts::OrdinalGet(device_ordinal, context_idx);
VLOG(2) << "Device " << device << " stream " << stream_id
<< " use created context " << new_context;
} else if (stream_id == 0 &&
primary_ctx_used_.find(new_context) == primary_ctx_used_.end()) {
// Don't create new context. Use the primary context.
VLOG(2) << "No context for device " << device << " stream " << stream_id
<< ", use cuDevicePrimaryCtxRetain context " << new_context;
primary_ctx_used_.insert(std::make_pair(new_context, device));
} else {
CHECK_EQ(CUDA_SUCCESS, cuCtxCreate(&new_context, flags, device));
VLOG(2) << "No context for device " << device << " stream " << stream_id
<< ", cuCtxCreate context " << new_context;
}
CHECK_EQ(CUDA_SUCCESS, cuCtxSetCurrent(former_context));

if (res == CUDA_SUCCESS) {
*context = CreatedContexts::Add(new_context, device_ordinal);
*context = CreatedContexts::Add(new_context, device_ordinal, context_idx);
CHECK(*context != nullptr)
<< "success in this call must entail non-null result";
VLOG(2) << "created or reused context " << new_context
Expand Down Expand Up @@ -415,7 +415,19 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options,
cuCtxGetDevice(&device);
cuCtxSetCurrent(former_context);

res = cuDevicePrimaryCtxRelease(device);
bool is_primary_ctx = false;
for (auto iter = primary_ctx_used_.begin(); iter != primary_ctx_used_.end();
++iter) {
if (iter->second == device) {
res = cuDevicePrimaryCtxRelease(device);
primary_ctx_used_.erase(iter);
is_primary_ctx = true;
break;
}
}
if (!is_primary_ctx) {
res = cuCtxDestroy(context->context());
}

if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to release CUDA context; leaking: " << ToString(res);
Expand Down
50 changes: 38 additions & 12 deletions tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,22 @@ class CreatedContexts {
return Live()->find(context) != Live()->end();
}

// Returns whether device ordinal is a member of the live ordinal set.
static bool OrdinalHas(int ordinal, int context_idx) {
absl::ReaderMutexLock lock(&mu_);
return ((LiveOrdinal()->find(ordinal) != LiveOrdinal()->end()) &&
((*LiveOrdinal())[ordinal].size() > context_idx) &&
((*LiveOrdinal())[ordinal][context_idx] != nullptr));
}

static CUcontext OrdinalGet(int ordinal, int context_idx) {
absl::ReaderMutexLock lock(&mu_);
return (*LiveOrdinal())[ordinal][context_idx];
}

// Adds context to the live set, or returns it if it's already present.
static GpuContext* Add(CUcontext context, int device_ordinal) {
static GpuContext* Add(CUcontext context, int device_ordinal,
int context_idx) {
CHECK(context != nullptr);
absl::MutexLock lock(&mu_);

Expand All @@ -85,7 +99,11 @@ class CreatedContexts {
if (insert_result.second) {
// context was not present in the map. Add it.
it->second = std::make_unique<GpuContext>(context, next_id_++);
(*LiveOrdinal())[device_ordinal].push_back(context);
auto& ctx_vec = (*LiveOrdinal())[device_ordinal];
if (ctx_vec.size() <= context_idx) {
ctx_vec.resize(context_idx + 1);
}
ctx_vec[context_idx] = context;
}
return it->second.get();
}
Expand All @@ -111,19 +129,27 @@ class CreatedContexts {

// Return the context associated to that ptr.
static CUcontext GetAnyContext(void* ptr) {
absl::ReaderMutexLock lock(&mu_);
int device_ordinal;
CUresult result = cuPointerGetAttribute(static_cast<void*>(&device_ordinal),
CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
reinterpret_cast<CUdeviceptr>(ptr));
static const bool use_cuda_malloc_async = [] {
const char* allocator_env = std::getenv("TF_GPU_ALLOCATOR");
bool result = allocator_env != nullptr &&
std::strcmp(allocator_env, "cuda_malloc_async") == 0;
#if CUDA_VERSION >= 11020
return result;
#else
return false;
#endif
}();
if (use_cuda_malloc_async) return nullptr;
CUcontext context;
CUresult result =
cuPointerGetAttribute(&context, CU_POINTER_ATTRIBUTE_CONTEXT,
reinterpret_cast<CUdeviceptr>(ptr));
if (result != CUDA_SUCCESS) {
LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << ptr
LOG(FATAL) << "Not able to get the CUDA context for ptr: " << ptr
<< ". Error: " << ToString(result);
}
CHECK_EQ(LiveOrdinal()->count(device_ordinal), 1);
CHECK(!LiveOrdinal()->at(device_ordinal).empty())
<< "Need at least one context.";
return LiveOrdinal()->at(device_ordinal)[0];
CHECK(Has(context));
return context;
}

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,10 @@ GpuExecutor::~GpuExecutor() {
}
}

tsl::Status GpuExecutor::Init(int device_ordinal,
tsl::Status GpuExecutor::Init(int device_ordinal, int stream_id,
DeviceOptions device_options) {
device_ordinal_ = device_ordinal;
stream_id_ = stream_id;

auto status = GpuDriver::Init();
if (!status.ok()) {
Expand All @@ -141,8 +142,8 @@ tsl::Status GpuExecutor::Init(int device_ordinal,
return status;
}

status = GpuDriver::CreateContext(device_ordinal_, device_, device_options,
&context_);
status = GpuDriver::CreateContext(device_ordinal_, stream_id_, device_,
device_options, &context_);
if (!status.ok()) {
return status;
}
Expand Down
18 changes: 11 additions & 7 deletions tensorflow/compiler/xla/stream_executor/cuda/cuda_platform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,18 +140,21 @@ CudaPlatform::DescriptionForDevice(int ordinal) const {
return GpuExecutor::CreateDeviceDescription(ordinal);
}

tsl::StatusOr<StreamExecutor*> CudaPlatform::ExecutorForDevice(int ordinal) {
tsl::StatusOr<StreamExecutor*> CudaPlatform::ExecutorForDevice(int ordinal,
int stream_id) {
StreamExecutorConfig config;
config.ordinal = ordinal;
config.stream_id = stream_id;
config.plugin_config = PluginConfig();
config.device_options = GetDeviceOptionsFromEnv();
return GetExecutor(config);
}

tsl::StatusOr<StreamExecutor*> CudaPlatform::ExecutorForDeviceWithPluginConfig(
int device_ordinal, const PluginConfig& plugin_config) {
int device_ordinal, const PluginConfig& plugin_config, int stream_id) {
StreamExecutorConfig config;
config.ordinal = device_ordinal;
config.stream_id = stream_id;
config.plugin_config = plugin_config;
config.device_options = GetDeviceOptionsFromEnv();
return GetExecutor(config);
Expand All @@ -172,15 +175,16 @@ tsl::StatusOr<StreamExecutor*> CudaPlatform::GetExecutor(
tsl::StatusOr<std::unique_ptr<StreamExecutor>>
CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
auto executor = std::make_unique<StreamExecutor>(
this, std::make_unique<GpuExecutor>(config.plugin_config),
config.ordinal);
this, std::make_unique<GpuExecutor>(config.plugin_config), config.ordinal,
config.stream_id);
auto init_status = executor->Init(config.device_options);
if (!init_status.ok()) {
return tsl::Status(
absl::StatusCode::kInternal,
absl::StrFormat(
"failed initializing StreamExecutor for CUDA device ordinal %d: %s",
config.ordinal, init_status.ToString()));
absl::StrFormat("failed initializing StreamExecutor for CUDA device "
"ordinal %d stream group %d: %s",
config.ordinal, config.stream_id,
init_status.ToString()));
}

return std::move(executor);
Expand Down
12 changes: 10 additions & 2 deletions tensorflow/compiler/xla/stream_executor/cuda/cuda_platform.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,18 @@ class CudaPlatform : public Platform {
tsl::StatusOr<std::unique_ptr<DeviceDescription>> DescriptionForDevice(
int ordinal) const override;

tsl::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
tsl::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override {
return ExecutorForDevice(ordinal, 0);
}
tsl::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal,
int stream_id) override;

tsl::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig(
int ordinal, const PluginConfig& config) override;
int ordinal, const PluginConfig& config) override {
return ExecutorForDeviceWithPluginConfig(ordinal, config, 0);
}
tsl::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig(
int ordinal, const PluginConfig& config, int stream_id) override;

tsl::StatusOr<StreamExecutor*> GetExecutor(
const StreamExecutorConfig& config) override;
Expand Down
Loading

0 comments on commit 7e28117

Please sign in to comment.