@@ -201,16 +201,20 @@ class ExecutorTest final {
201201public:
202202 // / Constructor.
203203 ExecutorTest (const std::shared_ptr<Executor> &executor,
204- std::unique_ptr<DAGNode> root, std::unique_ptr<Type> type,
205- DAGNodeNameMapTy nodes, PlaceholderNameMapTy placeholders,
204+ std::unique_ptr<DAGNode> root, std::unique_ptr<Module> module ,
205+ std::unique_ptr<Type> type, DAGNodeNameMapTy nodes,
206+ PlaceholderNameMapTy placeholders,
206207 std::unique_ptr<ExecutionContext> inputContext,
207208 std::unique_ptr<ExecutionContext> outputContext,
208209 RunIdentifierTy runId, bool expectSuccess)
209- : executor_(executor), root_(std::move(root)), type_(std::move(type)),
210- nodes_ (std::move(nodes)), placeholders_(std::move(placeholders)),
210+ : executor_(executor), root_(std::move(root)), module_(std::move(module )),
211+ type_ (std::move(type)), nodes_(std::move(nodes)),
212+ placeholders_(std::move(placeholders)),
211213 inputContext_(std::move(inputContext)),
212214 outputContext_(std::move(outputContext)), runId_(runId),
213- expectSuccess_(expectSuccess), testRun_(false ) {}
215+ expectSuccess_(expectSuccess), testRun_(false ) {
216+ root_->module = module_.get ();
217+ }
214218
215219 // / Run the test.
216220 bool run () {
@@ -260,6 +264,8 @@ class ExecutorTest final {
260264 std::shared_ptr<Executor> executor_;
261265 // / The root node of the DAG being tested.
262266 std::unique_ptr<DAGNode> root_;
267+ // / The Module containing the PHs.
268+ std::unique_ptr<Module> module_;
263269 // / The Type for all of the Placeholders that will be used during execution.
264270 std::unique_ptr<Type> type_;
265271 // / All nodes in the DAG.
@@ -295,7 +301,8 @@ class ExecutorTestBuilder final {
295301 // / between ExecutionContexts correctly.
296302 ExecutorTestBuilder (const std::shared_ptr<Executor> &executor,
297303 const DeviceManagerMapTy &deviceManagers)
298- : executor_(executor), root_(llvm::make_unique<DAGNode>()),
304+ : executor_(executor), module_(llvm::make_unique<Module>()),
305+ root_ (llvm::make_unique<DAGNode>()),
299306 bindings_(llvm::make_unique<PlaceholderBindings>()),
300307 type_(
301308 std::unique_ptr<Type>(new Type(ElemKind::FloatTy, {32 , 64 , 128 }))),
@@ -465,16 +472,16 @@ class ExecutorTestBuilder final {
465472 insertSymbolIntoPlaceholderBindings (
466473 symbol, outputContext->getPlaceholderBindings ());
467474 }
468-
469475 // Create the test object.
470- ExecutorTest test (executor_, std::move (root_), std::move (type_ ),
471- std::move (nodes_ ), std::move (placeholders_ ),
472- std::move (inputContext ), std::move (outputContext), runId_ ,
473- success_);
476+ ExecutorTest test (executor_, std::move (root_), std::move (module_ ),
477+ std::move (type_ ), std::move (nodes_ ),
478+ std::move (placeholders_ ), std::move (inputContext) ,
479+ std::move (outputContext), runId_, success_);
474480
475481 // Reset builder state to allow a new test to be constructed with this
476482 // instance.
477483 root_ = llvm::make_unique<DAGNode>();
484+ module_ = llvm::make_unique<Module>();
478485 bindings_->clear ();
479486 type_ = std::unique_ptr<Type>(new Type (ElemKind::FloatTy, {1 , 2 , 2 }));
480487 nodes_.clear ();
@@ -537,28 +544,27 @@ class ExecutorTestBuilder final {
537544 // / mapped for the test being created, reuse the existing value.
538545 void insertSymbolIntoPlaceholderBindings (llvm::StringRef name,
539546 PlaceholderBindings *bindings) {
540- auto it = placeholders_. find (name);
547+ auto ph = module_-> getPlaceholderByName (name);
541548
542- if (it == placeholders_. end () ) {
549+ if (!ph ) {
543550 // This is a new symbol. Create a Placeholder and an initialize and new
544551 // Tensor for it.
545- auto placeholder = llvm::make_unique<Placeholder>(name, type_.get (),
546- /* trainable=*/ false );
547- auto *tensor = bindings_->allocate (placeholder.get ());
552+ auto placeholder = module_->createPlaceholder (type_.get (), name, false );
553+ auto *tensor = bindings_->allocate (placeholder);
548554 tensor->init (Tensor::InitKind::Xavier, 1.0 , rng_);
549- bindings->insert (placeholder.get (), tensor->clone ());
550- placeholders_[name] = std::move (placeholder);
555+ bindings->insert (placeholder, tensor->clone ());
551556 } else {
552557 // This is a symbol that already has an associated Placeholder and Tensor.
553558 // Copy that Tensor.
554- auto *placeholder = (it->second ).get ();
555- const auto *tensor = bindings_->get (placeholder);
556- bindings->insert (placeholder, tensor->clone ());
559+ const auto *tensor = bindings_->get (ph);
560+ bindings->insert (ph, tensor->clone ());
557561 }
558562 }
559563
560564 // / The Executor being tested.
561565 std::shared_ptr<Executor> executor_;
566+ // / Module for holding PHs
567+ std::unique_ptr<Module> module_;
562568 // / The root of the DAG being constructed.
563569 std::unique_ptr<DAGNode> root_;
564570 // / This PlaceholderBindings holds all created and initialized Placeholders
0 commit comments