Skip to content

Commit 046e113

Browse files
committed
Removed PH creation from Executor
1 parent 0d0131c commit 046e113

File tree

5 files changed

+44
-49
lines changed

5 files changed

+44
-49
lines changed

include/glow/Runtime/RuntimeTypes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ struct DAGNode {
8787
/// runtime.
8888
std::unique_ptr<RuntimeBundle> runtimeBundle;
8989

90+
/// Pointer to module the function came from. This is so the executor can
91+
/// access the associated PHs for the function that are stored in the Module.
92+
Module *module{nullptr};
93+
9094
DeviceIDTy getNextDevice() {
9195
currentDeviceIdx++;
9296
return deviceIDs[currentDeviceIdx % deviceIDs.size()];

lib/Partitioner/Partitioner.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ void Partitioner::doPartitioning(Function *F, NodeToFunctionMap &mapping) {
460460
nodesDAGNodeTy nodes;
461461
DAGRoot->logicalDevices = {0};
462462
DAGRoot->name = F->getName();
463+
DAGRoot->module = module_;
463464
DAGRoot->deviceIDs = {0};
464465
DAGNode *root = DAGRoot.get();
465466
llvm::DenseMap<Node *, Node *> currToNew;
@@ -580,6 +581,7 @@ llvm::Error Partitioner::Partition() {
580581
std::unique_ptr<DAGNode> DAG0 = llvm::make_unique<DAGNode>();
581582
DAG0->logicalDevices = {0};
582583
DAG0->name = F->getName();
584+
DAG0->module = module_;
583585
std::unique_ptr<DAGNode> DAG1 = llvm::make_unique<DAGNode>();
584586
DAG1->logicalDevices = {0};
585587
DAG1->name = F->getName();

lib/Runtime/Executor/ThreadPoolExecutor.cpp

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include <queue>
2323
#include <unordered_set>
2424

25+
#include <glog/logging.h>
26+
2527
namespace glow {
2628
namespace runtime {
2729

@@ -59,7 +61,7 @@ ExecutionState::ExecutionState(RunIdentifierTy id, const DAGNode *root,
5961
std::unique_ptr<ExecutionContext> resultContext,
6062
ResultCBTy doneCb)
6163
: runId_(id), cb_(doneCb), resultCtx_(std::move(resultContext)),
62-
inflightNodes_(0) {
64+
inflightNodes_(0), module_(root->module) {
6365
// Create a queue for the breadth-first traversal through the graph.
6466
std::queue<const DAGNode *> bfsQueue;
6567

@@ -101,8 +103,10 @@ ExecutionState::ExecutionState(RunIdentifierTy id, const DAGNode *root,
101103
const auto &symbolInfo = symbolPair.second;
102104

103105
if (symbolInfo.symbolCategory == SymbolCategory::Placeholder) {
104-
nodeInputPhBindings->allocate(
105-
createOrGetPlaceholder(symbolName, &symbolInfo.type));
106+
auto PH = module_->getPlaceholderByName(symbolName);
107+
// If a PH name is provided it had to come from the Module originally.
108+
DCHECK(PH) << "Placeholder: " << symbolName << " is not in the module";
109+
nodeInputPhBindings->allocate(PH);
106110
}
107111
}
108112

@@ -241,27 +245,6 @@ ExecutionContext *ExecutionState::getRawResultContextPtr() const {
241245
return resultCtx_.get();
242246
}
243247

244-
Placeholder *ExecutionState::createOrGetPlaceholder(llvm::StringRef name,
245-
TypeRef type) {
246-
auto it = intermediatePlaceholders_.find(name);
247-
Placeholder *ph;
248-
249-
if (it != intermediatePlaceholders_.end()) {
250-
// If the Placeholder already exists, return a pointer to it.
251-
auto &storedPh = it->second;
252-
ph = storedPh.get();
253-
} else {
254-
// If the Placeholder does not exist, create one, remember it, and return a
255-
// pointer to it.
256-
auto newPh =
257-
llvm::make_unique<Placeholder>(name, type, /*isTrainable=*/false);
258-
ph = newPh.get();
259-
intermediatePlaceholders_.insert(std::make_pair(name, std::move(newPh)));
260-
}
261-
262-
return ph;
263-
}
264-
265248
void ThreadPoolExecutor::shutdown() {
266249
// Prevent more requests from being processed.
267250
shuttingDown_ = true;

lib/Runtime/Executor/ThreadPoolExecutor.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,6 @@ class ExecutionState final {
111111
RunIdentifierTy getRunId() const { return runId_; }
112112

113113
private:
114-
/// Create a Placeholder with name \p name and type \p type and store it in
115-
/// intermediatePlaceholders_. If a Placeholder already exists, return that.
116-
Placeholder *createOrGetPlaceholder(llvm::StringRef name, TypeRef type);
117-
118114
/// The run identifier for this execution of a DAG.
119115
RunIdentifierTy runId_;
120116
/// The callback that should be called when execution is done.
@@ -141,6 +137,10 @@ class ExecutionState final {
141137
/// Mutex used by bindings insertion functions to make sure only one thread
142138
/// writes to an ExecutionContext at a time.
143139
std::mutex bindingsMtx_;
140+
141+
/// Module for the network. This contains the PHs used by the functions in
142+
/// this network.
143+
Module *module_{nullptr};
144144
};
145145

146146
/// This implementation of the Executor interface uses a thread pool to

tests/unittests/ExecutorTest.cpp

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -201,16 +201,20 @@ class ExecutorTest final {
201201
public:
202202
/// Constructor.
203203
ExecutorTest(const std::shared_ptr<Executor> &executor,
204-
std::unique_ptr<DAGNode> root, std::unique_ptr<Type> type,
205-
DAGNodeNameMapTy nodes, PlaceholderNameMapTy placeholders,
204+
std::unique_ptr<DAGNode> root, std::unique_ptr<Module> module,
205+
std::unique_ptr<Type> type, DAGNodeNameMapTy nodes,
206+
PlaceholderNameMapTy placeholders,
206207
std::unique_ptr<ExecutionContext> inputContext,
207208
std::unique_ptr<ExecutionContext> outputContext,
208209
RunIdentifierTy runId, bool expectSuccess)
209-
: executor_(executor), root_(std::move(root)), type_(std::move(type)),
210-
nodes_(std::move(nodes)), placeholders_(std::move(placeholders)),
210+
: executor_(executor), root_(std::move(root)), module_(std::move(module)),
211+
type_(std::move(type)), nodes_(std::move(nodes)),
212+
placeholders_(std::move(placeholders)),
211213
inputContext_(std::move(inputContext)),
212214
outputContext_(std::move(outputContext)), runId_(runId),
213-
expectSuccess_(expectSuccess), testRun_(false) {}
215+
expectSuccess_(expectSuccess), testRun_(false) {
216+
root_->module = module_.get();
217+
}
214218

215219
/// Run the test.
216220
bool run() {
@@ -260,6 +264,8 @@ class ExecutorTest final {
260264
std::shared_ptr<Executor> executor_;
261265
/// The root node of the DAG being tested.
262266
std::unique_ptr<DAGNode> root_;
267+
/// The Module containing the PHs.
268+
std::unique_ptr<Module> module_;
263269
/// The Type for all of the Placeholders that will be used during execution.
264270
std::unique_ptr<Type> type_;
265271
/// All nodes in the DAG.
@@ -295,7 +301,8 @@ class ExecutorTestBuilder final {
295301
/// between ExecutionContexts correctly.
296302
ExecutorTestBuilder(const std::shared_ptr<Executor> &executor,
297303
const DeviceManagerMapTy &deviceManagers)
298-
: executor_(executor), root_(llvm::make_unique<DAGNode>()),
304+
: executor_(executor), module_(llvm::make_unique<Module>()),
305+
root_(llvm::make_unique<DAGNode>()),
299306
bindings_(llvm::make_unique<PlaceholderBindings>()),
300307
type_(
301308
std::unique_ptr<Type>(new Type(ElemKind::FloatTy, {32, 64, 128}))),
@@ -465,16 +472,16 @@ class ExecutorTestBuilder final {
465472
insertSymbolIntoPlaceholderBindings(
466473
symbol, outputContext->getPlaceholderBindings());
467474
}
468-
469475
// Create the test object.
470-
ExecutorTest test(executor_, std::move(root_), std::move(type_),
471-
std::move(nodes_), std::move(placeholders_),
472-
std::move(inputContext), std::move(outputContext), runId_,
473-
success_);
476+
ExecutorTest test(executor_, std::move(root_), std::move(module_),
477+
std::move(type_), std::move(nodes_),
478+
std::move(placeholders_), std::move(inputContext),
479+
std::move(outputContext), runId_, success_);
474480

475481
// Reset builder state to allow a new test to be constructed with this
476482
// instance.
477483
root_ = llvm::make_unique<DAGNode>();
484+
module_ = llvm::make_unique<Module>();
478485
bindings_->clear();
479486
type_ = std::unique_ptr<Type>(new Type(ElemKind::FloatTy, {1, 2, 2}));
480487
nodes_.clear();
@@ -537,28 +544,27 @@ class ExecutorTestBuilder final {
537544
/// mapped for the test being created, reuse the existing value.
538545
void insertSymbolIntoPlaceholderBindings(llvm::StringRef name,
539546
PlaceholderBindings *bindings) {
540-
auto it = placeholders_.find(name);
547+
auto ph = module_->getPlaceholderByName(name);
541548

542-
if (it == placeholders_.end()) {
549+
if (!ph) {
543550
// This is a new symbol. Create a Placeholder and an initialize and new
544551
// Tensor for it.
545-
auto placeholder = llvm::make_unique<Placeholder>(name, type_.get(),
546-
/*trainable=*/false);
547-
auto *tensor = bindings_->allocate(placeholder.get());
552+
auto placeholder = module_->createPlaceholder(type_.get(), name, false);
553+
auto *tensor = bindings_->allocate(placeholder);
548554
tensor->init(Tensor::InitKind::Xavier, 1.0, rng_);
549-
bindings->insert(placeholder.get(), tensor->clone());
550-
placeholders_[name] = std::move(placeholder);
555+
bindings->insert(placeholder, tensor->clone());
551556
} else {
552557
// This is a symbol that already has an associated Placeholder and Tensor.
553558
// Copy that Tensor.
554-
auto *placeholder = (it->second).get();
555-
const auto *tensor = bindings_->get(placeholder);
556-
bindings->insert(placeholder, tensor->clone());
559+
const auto *tensor = bindings_->get(ph);
560+
bindings->insert(ph, tensor->clone());
557561
}
558562
}
559563

560564
/// The Executor being tested.
561565
std::shared_ptr<Executor> executor_;
566+
/// Module for holding PHs
567+
std::unique_ptr<Module> module_;
562568
/// The root of the DAG being constructed.
563569
std::unique_ptr<DAGNode> root_;
564570
/// This PlaceholderBindings holds all created and initialized Placeholders

0 commit comments

Comments
 (0)