Skip to content

Commit

Permalink
fix predictor interfaces.
Browse files Browse the repository at this point in the history
  • Loading branch information
Yusuke Oda committed Apr 21, 2017
1 parent e733504 commit 72a3f92
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 8 deletions.
7 changes: 5 additions & 2 deletions nmtkit/binary_code_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@ BinaryCodePredictor::BinaryCodePredictor(
, loss_type_(loss_type)
, converter_({input_size, ecc->getNumBits(bc->getNumBits())}, model) {}

void BinaryCodePredictor::prepare(dynet::ComputationGraph * cg) {
void BinaryCodePredictor::prepare(
dynet::ComputationGraph * cg,
const bool /* is_training */) {
converter_.prepare(cg);
}

DE::Expression BinaryCodePredictor::computeLoss(
const DE::Expression & input,
const vector<unsigned> & target_ids,
dynet::ComputationGraph * cg) {
dynet::ComputationGraph * cg,
const bool /* is_training */) {
const unsigned batch_size = target_ids.size();

// Retrieves target bits.
Expand Down
7 changes: 5 additions & 2 deletions nmtkit/binary_code_predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,15 @@ class BinaryCodePredictor : public Predictor {

~BinaryCodePredictor() override {}

void prepare(dynet::ComputationGraph * cg) override;
void prepare(
dynet::ComputationGraph * cg,
const bool is_training) override;

dynet::expr::Expression computeLoss(
const dynet::expr::Expression & input,
const std::vector<unsigned> & target_ids,
dynet::ComputationGraph * cg) override;
dynet::ComputationGraph * cg,
const bool is_training) override;

std::vector<Predictor::Result> predictKBest(
const dynet::expr::Expression & input,
Expand Down
7 changes: 5 additions & 2 deletions nmtkit/hybrid_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,17 @@ HybridPredictor::HybridPredictor(
, converter_(
{input_size, softmax_size + ecc->getNumBits(bc->getNumBits())}, model) {}

void HybridPredictor::prepare(dynet::ComputationGraph * cg) {
void HybridPredictor::prepare(
dynet::ComputationGraph * cg,
const bool /* is_training */) {
converter_.prepare(cg);
}

DE::Expression HybridPredictor::computeLoss(
const DE::Expression & input,
const vector<unsigned> & target_ids,
dynet::ComputationGraph * cg) {
dynet::ComputationGraph * cg,
const bool /* is_training */) {
const unsigned batch_size = target_ids.size();

// Calculates inner variables.
Expand Down
7 changes: 5 additions & 2 deletions nmtkit/hybrid_predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,15 @@ class HybridPredictor : public Predictor {

~HybridPredictor() override {}

void prepare(dynet::ComputationGraph * cg) override;
void prepare(
dynet::ComputationGraph * cg,
const bool is_training) override;

dynet::expr::Expression computeLoss(
const dynet::expr::Expression & input,
const std::vector<unsigned> & target_ids,
dynet::ComputationGraph * cg) override;
dynet::ComputationGraph * cg,
const bool is_training) override;

std::vector<Predictor::Result> predictKBest(
const dynet::expr::Expression & input,
Expand Down

0 comments on commit 72a3f92

Please sign in to comment.