Skip to content

Commit fa285b3

Browse files
authored
Add Nn::ScaledLabelScorer and use it in Nn::CombineLabelScorer (#171)
1 parent ce6b87e commit fa285b3

File tree

9 files changed

+182
-44
lines changed

9 files changed

+182
-44
lines changed

src/Nn/LabelScorer/CombineLabelScorer.cc

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,37 +21,34 @@ namespace Nn {
2121
Core::ParameterInt CombineLabelScorer::paramNumLabelScorers(
2222
"num-scorers", "Number of label scorers to combine", 1, 1);
2323

24-
Core::ParameterFloat CombineLabelScorer::paramScale(
25-
"scale", "Scores of a sub-label-scorer are scaled by this factor", 1.0);
26-
2724
CombineLabelScorer::CombineLabelScorer(Core::Configuration const& config)
2825
: Core::Component(config),
2926
Precursor(config, TransitionPresetType::ALL) {
3027
size_t numLabelScorers = paramNumLabelScorers(config);
3128
for (size_t i = 0ul; i < numLabelScorers; ++i) {
3229
Core::Configuration subConfig = select(std::string("scorer-") + std::to_string(i + 1));
33-
scaledScorers_.push_back({Nn::Module::instance().labelScorerFactory().createLabelScorer(subConfig), static_cast<Score>(paramScale(subConfig))});
30+
scorers_.push_back(Nn::Module::instance().labelScorerFactory().createLabelScorer(subConfig));
3431
}
3532
}
3633

3734
void CombineLabelScorer::reset() {
38-
for (auto& scaledScorer : scaledScorers_) {
39-
scaledScorer.scorer->reset();
35+
for (auto& scorer : scorers_) {
36+
scorer->reset();
4037
}
4138
}
4239

4340
void CombineLabelScorer::signalNoMoreFeatures() {
44-
for (auto& scaledScorer : scaledScorers_) {
45-
scaledScorer.scorer->signalNoMoreFeatures();
41+
for (auto& scorer : scorers_) {
42+
scorer->signalNoMoreFeatures();
4643
}
4744
}
4845

4946
ScoringContextRef CombineLabelScorer::getInitialScoringContext() {
5047
std::vector<ScoringContextRef> scoringContexts;
51-
scoringContexts.reserve(scaledScorers_.size());
48+
scoringContexts.reserve(scorers_.size());
5249

53-
for (const auto& scaledScorer : scaledScorers_) {
54-
scoringContexts.push_back(scaledScorer.scorer->getInitialScoringContext());
50+
for (const auto& scorer : scorers_) {
51+
scoringContexts.push_back(scorer->getInitialScoringContext());
5552
}
5653
return Core::ref(new CombineScoringContext(std::move(scoringContexts)));
5754
}
@@ -63,41 +60,41 @@ void CombineLabelScorer::cleanupCaches(Core::CollapsedVector<ScoringContextRef>
6360
combineContexts.push_back(dynamic_cast<const CombineScoringContext*>(activeContext.get()));
6461
}
6562

66-
for (size_t scorerIdx = 0ul; scorerIdx < scaledScorers_.size(); ++scorerIdx) {
67-
auto const& scaledScorer = scaledScorers_[scorerIdx];
63+
for (size_t scorerIdx = 0ul; scorerIdx < scorers_.size(); ++scorerIdx) {
64+
auto const& scorer = scorers_[scorerIdx];
6865
Core::CollapsedVector<ScoringContextRef> subScoringContexts;
6966
for (auto const& combineContext : combineContexts) {
7067
subScoringContexts.push_back(combineContext->scoringContexts[scorerIdx]);
7168
}
7269

73-
scaledScorer.scorer->cleanupCaches(subScoringContexts);
70+
scorer->cleanupCaches(subScoringContexts);
7471
}
7572
}
7673

7774
void CombineLabelScorer::addInput(DataView const& input) {
78-
for (auto& scaledScorer : scaledScorers_) {
79-
scaledScorer.scorer->addInput(input);
75+
for (auto& scorer : scorers_) {
76+
scorer->addInput(input);
8077
}
8178
}
8279

8380
void CombineLabelScorer::addInputs(DataView const& input, size_t nTimesteps) {
84-
for (auto& scaledScorer : scaledScorers_) {
85-
scaledScorer.scorer->addInputs(input, nTimesteps);
81+
for (auto& scorer : scorers_) {
82+
scorer->addInputs(input, nTimesteps);
8683
}
8784
}
8885

8986
ScoringContextRef CombineLabelScorer::extendedScoringContextInternal(Request const& request) {
9087
auto combineContext = dynamic_cast<const CombineScoringContext*>(request.context.get());
9188

9289
std::vector<ScoringContextRef> extScoringContexts;
93-
extScoringContexts.reserve(scaledScorers_.size());
90+
extScoringContexts.reserve(scorers_.size());
9491

95-
auto scorerIt = scaledScorers_.begin();
92+
auto scorerIt = scorers_.begin();
9693
auto contextIt = combineContext->scoringContexts.begin();
9794

98-
for (; scorerIt != scaledScorers_.end(); ++scorerIt, ++contextIt) {
95+
for (; scorerIt != scorers_.end(); ++scorerIt, ++contextIt) {
9996
Request subRequest{*contextIt, request.nextToken, request.transitionType};
100-
extScoringContexts.push_back(scorerIt->scorer->extendedScoringContext(subRequest));
97+
extScoringContexts.push_back((*scorerIt)->extendedScoringContext(subRequest));
10198
}
10299
return Core::ref(new CombineScoringContext(std::move(extScoringContexts)));
103100
}
@@ -109,22 +106,22 @@ std::optional<LabelScorer::ScoreWithTime> CombineLabelScorer::computeScoreWithTi
109106
auto combineContext = dynamic_cast<const CombineScoringContext*>(request.context.get());
110107

111108
// Iterate over all the scorers and accumulate their results into `accumResult`
112-
auto scorerIt = scaledScorers_.begin();
109+
auto scorerIt = scorers_.begin();
113110
auto contextIt = combineContext->scoringContexts.begin();
114-
for (; scorerIt != scaledScorers_.end(); ++scorerIt, ++contextIt) {
111+
for (; scorerIt != scorers_.end(); ++scorerIt, ++contextIt) {
115112
// Prepare sub-request for the current scorer by extracting the appropriate
116113
// ScoringContext from the combined ScoringContext
117114
Request subRequest{*contextIt, request.nextToken, request.transitionType};
118115

119116
// Run current scorer
120-
auto result = scorerIt->scorer->computeScoreWithTime(subRequest);
117+
auto result = (*scorerIt)->computeScoreWithTime(subRequest);
121118
if (!result) {
122119
return {};
123120
}
124121

125122
// Merge results of current scorer into `accumResult`
126123
// Scores are weighted sum, timeframes are maximum
127-
accumResult.score += result->score * scorerIt->scale;
124+
accumResult.score += result->score;
128125
accumResult.timeframe = std::max(accumResult.timeframe, result->timeframe);
129126
}
130127

@@ -147,7 +144,7 @@ std::optional<LabelScorer::ScoresWithTimes> CombineLabelScorer::computeScoresWit
147144
}
148145

149146
// Iterate over all the scorers and accumulate their results into `accumResult`
150-
for (size_t scorerIdx = 0ul; scorerIdx < scaledScorers_.size(); ++scorerIdx) {
147+
for (size_t scorerIdx = 0ul; scorerIdx < scorers_.size(); ++scorerIdx) {
151148
// Prepare sub-requests for the current scorer by extracting the appropriate
152149
// ScoringContext from all the CombineScoringContexts
153150
std::vector<Request> subRequests;
@@ -159,7 +156,7 @@ std::optional<LabelScorer::ScoresWithTimes> CombineLabelScorer::computeScoresWit
159156
}
160157

161158
// Run current scorer
162-
auto subResults = scaledScorers_[scorerIdx].scorer->computeScoresWithTimes(subRequests);
159+
auto subResults = scorers_[scorerIdx]->computeScoresWithTimes(subRequests);
163160
if (!subResults) {
164161
return {};
165162
}
@@ -168,7 +165,7 @@ std::optional<LabelScorer::ScoresWithTimes> CombineLabelScorer::computeScoresWit
168165
// Scores are weighted sum, timeframes are maximum
169166
Core::CollapsedVector<Speech::TimeframeIndex> newTimeframes;
170167
for (size_t requestIdx = 0ul; requestIdx < requests.size(); ++requestIdx) {
171-
accumResult.scores[requestIdx] += subResults->scores[requestIdx] * scaledScorers_[scorerIdx].scale;
168+
accumResult.scores[requestIdx] += subResults->scores[requestIdx];
172169
newTimeframes.push_back(std::max(accumResult.timeframes[requestIdx], subResults->timeframes[requestIdx]));
173170
}
174171
accumResult.timeframes = newTimeframes;

src/Nn/LabelScorer/CombineLabelScorer.hh

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ class CombineLabelScorer : public LabelScorer {
3131
using Precursor = LabelScorer;
3232

3333
public:
34-
static Core::ParameterInt paramNumLabelScorers;
35-
static Core::ParameterFloat paramScale;
34+
static Core::ParameterInt paramNumLabelScorers;
3635

3736
CombineLabelScorer(const Core::Configuration& config);
3837
virtual ~CombineLabelScorer() = default;
@@ -56,12 +55,7 @@ public:
5655
virtual void addInputs(DataView const& input, size_t nTimesteps) override;
5756

5857
protected:
59-
struct ScaledLabelScorer {
60-
Core::Ref<LabelScorer> scorer;
61-
Score scale;
62-
};
63-
64-
std::vector<ScaledLabelScorer> scaledScorers_;
58+
std::vector<Core::Ref<LabelScorer>> scorers_;
6559

6660
// Combine extended ScoringContexts from all sub-scorers
6761
ScoringContextRef extendedScoringContextInternal(Request const& request) override;

src/Nn/LabelScorer/LabelScorerFactory.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ void LabelScorerFactory::registerLabelScorer(const char* name, CreationFunction
2525
registry_.push_back(std::move(creationFunction));
2626
}
2727

28-
Core::Ref<LabelScorer> LabelScorerFactory::createLabelScorer(Core::Configuration const& config) const {
29-
return registry_.at(paramLabelScorerType(config))(config);
28+
Core::Ref<ScaledLabelScorer> LabelScorerFactory::createLabelScorer(Core::Configuration const& config) const {
29+
auto subScorer = registry_.at(paramLabelScorerType(config))(config);
30+
return Core::ref(new ScaledLabelScorer(config, subScorer));
3031
}
3132

3233
} // namespace Nn

src/Nn/LabelScorer/LabelScorerFactory.hh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <Core/ReferenceCounting.hh>
2424

2525
#include "LabelScorer.hh"
26+
#include "ScaledLabelScorer.hh"
2627

2728
namespace Nn {
2829

@@ -49,9 +50,9 @@ public:
4950
void registerLabelScorer(const char* name, CreationFunction creationFunction);
5051

5152
/*
52-
* Create a LabelScorer instance of the type given by `paramLabelScorerType` using the config object
53+
* Create a ScaledLabelScorer instance of the type given by `paramLabelScorerType` using the config object
5354
*/
54-
Core::Ref<LabelScorer> createLabelScorer(Core::Configuration const& config) const;
55+
Core::Ref<ScaledLabelScorer> createLabelScorer(Core::Configuration const& config) const;
5556

5657
private:
5758
typedef std::vector<CreationFunction> Registry;

src/Nn/LabelScorer/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ LIBSPRINTLABELSCORER_O = \
2121
$(OBJDIR)/FixedContextOnnxLabelScorer.o \
2222
$(OBJDIR)/NoContextOnnxLabelScorer.o \
2323
$(OBJDIR)/NoOpLabelScorer.o \
24+
$(OBJDIR)/ScaledLabelScorer.o \
2425
$(OBJDIR)/ScoringContext.o \
2526
$(OBJDIR)/StatefulOnnxLabelScorer.o \
2627
$(OBJDIR)/TransitionLabelScorer.o
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/** Copyright 2025 RWTH Aachen University. All rights reserved.
2+
*
3+
* Licensed under the RWTH ASR License (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
#include "ScaledLabelScorer.hh"
17+
18+
namespace Nn {
19+
20+
const Core::ParameterFloat ScaledLabelScorer::paramScale(
21+
"scale",
22+
"Scale to multiply the scores of the sub-scorer by.",
23+
1.0);
24+
25+
ScaledLabelScorer::ScaledLabelScorer(Core::Configuration const& config, Core::Ref<LabelScorer> const& scorer)
26+
: Core::Component(config),
27+
LabelScorer(config, TransitionPresetType::ALL),
28+
scale_(paramScale(config)) {
29+
}
30+
31+
void ScaledLabelScorer::reset() {
32+
scorer_->reset();
33+
}
34+
35+
void ScaledLabelScorer::signalNoMoreFeatures() {
36+
scorer_->signalNoMoreFeatures();
37+
}
38+
39+
ScoringContextRef ScaledLabelScorer::getInitialScoringContext() {
40+
return scorer_->getInitialScoringContext();
41+
}
42+
43+
void ScaledLabelScorer::cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) {
44+
scorer_->cleanupCaches(activeContexts);
45+
}
46+
47+
void ScaledLabelScorer::addInput(DataView const& input) {
48+
scorer_->addInput(input);
49+
}
50+
51+
void ScaledLabelScorer::addInputs(DataView const& input, size_t nTimesteps) {
52+
scorer_->addInputs(input, nTimesteps);
53+
}
54+
55+
ScoringContextRef ScaledLabelScorer::extendedScoringContextInternal(Request const& request) {
56+
return scorer_->extendedScoringContext(request);
57+
}
58+
59+
std::optional<LabelScorer::ScoreWithTime> ScaledLabelScorer::computeScoreWithTimeInternal(Request const& request) {
60+
auto result = scorer_->computeScoreWithTime(request);
61+
if (result and scale_ != 1) {
62+
result->score *= scale_;
63+
}
64+
return result;
65+
}
66+
67+
std::optional<LabelScorer::ScoresWithTimes> ScaledLabelScorer::computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) {
68+
auto result = scorer_->computeScoresWithTimes(requests);
69+
if (result and scale_ != 1) {
70+
for (auto& score : result->scores) {
71+
score *= scale_;
72+
}
73+
}
74+
return result;
75+
}
76+
77+
} // namespace Nn
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/** Copyright 2025 RWTH Aachen University. All rights reserved.
2+
*
3+
* Licensed under the RWTH ASR License (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
#ifndef SCALED_LABEL_SCORER_HH
17+
#define SCALED_LABEL_SCORER_HH
18+
19+
#include <Core/Configuration.hh>
20+
21+
#include "LabelScorer.hh"
22+
23+
namespace Nn {
24+
25+
/*
26+
* Wraps a sub label scorer and scales all the scores by a given factor
27+
*/
28+
class ScaledLabelScorer : public LabelScorer {
29+
public:
30+
static const Core::ParameterFloat paramScale;
31+
32+
ScaledLabelScorer(Core::Configuration const& config, Core::Ref<LabelScorer> const& scorer);
33+
34+
// Reset sub-scorer
35+
void reset() override;
36+
37+
// Forward signal to sub-scorer
38+
void signalNoMoreFeatures() override;
39+
40+
// Initial ScoringContext from sub-scorer
41+
ScoringContextRef getInitialScoringContext() override;
42+
43+
// Cleanup sub-scorer
44+
void cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) override;
45+
46+
// Add input to sub-scorer
47+
void addInput(DataView const& input) override;
48+
49+
// Add inputs to sub-scorer
50+
virtual void addInputs(DataView const& input, size_t nTimesteps) override;
51+
52+
protected:
53+
// Extended ScoringContext from sub-scorer
54+
ScoringContextRef extendedScoringContextInternal(Request const& request) override;
55+
56+
// Compute scaled score of request with sub-scorer
57+
std::optional<ScoreWithTime> computeScoreWithTimeInternal(Request const& request) override;
58+
59+
// Compute scaled scores of requests with sub-scorer
60+
std::optional<ScoresWithTimes> computeScoresWithTimesInternal(std::vector<Request> const& requests) override;
61+
62+
private:
63+
Core::Ref<LabelScorer> scorer_;
64+
Score scale_;
65+
};
66+
67+
} // namespace Nn
68+
69+
#endif // SCALED_LABEL_SCORER_HH

src/Nn/Module.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "LabelScorer/FixedContextOnnxLabelScorer.hh"
2424
#include "LabelScorer/NoContextOnnxLabelScorer.hh"
2525
#include "LabelScorer/NoOpLabelScorer.hh"
26+
#include "LabelScorer/ScaledLabelScorer.hh"
2627
#include "LabelScorer/StatefulOnnxLabelScorer.hh"
2728
#include "LabelScorer/TransitionLabelScorer.hh"
2829
#include "Statistics.hh"

src/Nn/Module.hh

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
#include <Flow/Module.hh>
2424

2525
#include "LabelScorer/EncoderFactory.hh"
26-
#include "LabelScorer/LabelScorer.hh"
2726
#include "LabelScorer/LabelScorerFactory.hh"
2827

2928
namespace Core {
@@ -63,8 +62,6 @@ public:
6362
*/
6463
LabelScorerFactory& labelScorerFactory();
6564

66-
Core::Ref<LabelScorer> createLabelScorer(const Core::Configuration& config) const;
67-
6865
private:
6966
Core::FormatSet* formats_;
7067
EncoderFactory encoderFactory_;

0 commit comments

Comments
 (0)