Skip to content

Commit

Permalink
minScore and boost_mode_replace support in MultiFunctionScoreQuery (#578
Browse files Browse the repository at this point in the history
)

augment MultiFunctionScoreQuery with replace_boost_mode and minScore
  • Loading branch information
waziqi89 authored Jun 1, 2023
1 parent 7372b57 commit bc2c69e
Show file tree
Hide file tree
Showing 7 changed files with 1,266 additions and 905 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ sourceCompatibility = 1.14
targetCompatibility = 1.14

allprojects {
version = '0.23.0'
version = '0.24.0'
group = 'com.yelp.nrtsearch'
}

Expand Down
6 changes: 6 additions & 0 deletions clientlib/src/main/proto/yelp/nrtsearch/search.proto
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,8 @@ message MultiFunctionScoreQuery {
BOOST_MODE_MULTIPLY = 0;
// Add scores together
BOOST_MODE_SUM = 1;
// Ignore the query score, and use the function score only
BOOST_MODE_REPLACE = 2;
}

// Main query to produce recalled docs and scores, which will be modified by the final function score
Expand All @@ -326,6 +328,10 @@ message MultiFunctionScoreQuery {
FunctionScoreMode score_mode = 3;
// Method to modify query document scores with final function score
BoostMode boost_mode = 4;
// Optional minimal score to match a document. By default, it's 0.
float min_score = 5;
// Determine minimal score is excluded or not. By default, it's false;
bool min_excluded = 6;
}

// Query that produces a score of 1.0 (modifiable by query boost value) for documents that match the filter query.
Expand Down
6 changes: 6 additions & 0 deletions docs/queries/multi_function_score.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ Proto definition:
BOOST_MODE_MULTIPLY = 0;
// Add scores together
BOOST_MODE_SUM = 1;
// Ignore the query score, and use the function score only
BOOST_MODE_REPLACE = 2;
}
// Main query to produce recalled docs and scores, which will be modified by the final function score
Expand All @@ -49,4 +51,8 @@ Proto definition:
FunctionScoreMode score_mode = 3;
// Method to modify query document scores with final function score
BoostMode boost_mode = 4;
// Optional minimal score to match a document. By default, it's 0.
float min_score = 5;
// Determine minimal score is excluded or not. By default, it's false;
bool min_excluded = 6;
}
14 changes: 12 additions & 2 deletions grpc-gateway/luceneserver.swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -1763,10 +1763,11 @@
"type": "string",
"enum": [
"BOOST_MODE_MULTIPLY",
"BOOST_MODE_SUM"
"BOOST_MODE_SUM",
"BOOST_MODE_REPLACE"
],
"default": "BOOST_MODE_MULTIPLY",
"description": "- BOOST_MODE_MULTIPLY: Multiply scores together\n - BOOST_MODE_SUM: Add scores together",
"description": "- BOOST_MODE_MULTIPLY: Multiply scores together\n - BOOST_MODE_SUM: Add scores together\n - BOOST_MODE_REPLACE: Ignore the query score, and use the function score only",
"title": "How to combine final function score with query score"
},
"MultiFunctionScoreQueryFilterFunction": {
Expand Down Expand Up @@ -3934,6 +3935,15 @@
"boost_mode": {
"$ref": "#/definitions/MultiFunctionScoreQueryBoostMode",
"title": "Method to modify query document scores with final function score"
},
"min_score": {
"type": "number",
"format": "float",
"description": "Optional minimal score to match a document. By default, it's 0."
},
"min_excluded": {
"type": "boolean",
"title": "Determine minimal score is excluded or not. By default, it's false;"
}
},
"title": "A query to modify the score of documents with a given set of functions"
Expand Down
1,790 changes: 909 additions & 881 deletions grpc-gateway/search.pb.go

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.FilterScorer;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;

Expand All @@ -53,6 +55,18 @@ public class MultiFunctionScoreQuery extends Query {
private final FilterFunction[] functions;
private final FunctionScoreMode scoreMode;
private final BoostMode boostMode;
private final float minScore;
private final boolean minExcluded;

private static boolean isFilteredByMinScore(
float currentScore, float minimalScore, boolean minimalExcluded) {
if (currentScore > minimalScore) {
return true;
} else if (!minimalExcluded && currentScore == minimalScore) {
return true;
}
return false;
}

/**
* Builder method that creates a {@link MultiFunctionScoreQuery} from its gRPC message definiton.
Expand All @@ -75,7 +89,9 @@ public static MultiFunctionScoreQuery build(
innerQuery,
functions,
multiFunctionScoreQueryGrpc.getScoreMode(),
multiFunctionScoreQueryGrpc.getBoostMode());
multiFunctionScoreQueryGrpc.getBoostMode(),
multiFunctionScoreQueryGrpc.getMinScore(),
multiFunctionScoreQueryGrpc.getMinExcluded());
}

/**
Expand All @@ -85,16 +101,27 @@ public static MultiFunctionScoreQuery build(
* @param functions functions used to produce function score for documents
* @param scoreMode mode to combine function scores
* @param boostMode mode to combine function and document scores
* @param minScore min score to match
* @param minExcluded is min score excluded
*/
public MultiFunctionScoreQuery(
Query innerQuery,
FilterFunction[] functions,
FunctionScoreMode scoreMode,
BoostMode boostMode) {
BoostMode boostMode,
float minScore,
boolean minExcluded) {
this.innerQuery = innerQuery;
this.functions = functions;
this.scoreMode = scoreMode;
this.boostMode = boostMode;
this.minScore = minScore;
this.minExcluded = minExcluded;

if (minScore < 0) {
throw new IllegalArgumentException(
"minScore must be a non-negative number, but got " + minScore);
}
}

@Override
Expand All @@ -111,7 +138,8 @@ public Query rewrite(IndexReader reader) throws IOException {
needsRewrite |= (rewrittenFunctions[i] != functions[i]);
}
if (needsRewrite) {
return new MultiFunctionScoreQuery(rewrittenInner, rewrittenFunctions, scoreMode, boostMode);
return new MultiFunctionScoreQuery(
rewrittenInner, rewrittenFunctions, scoreMode, boostMode, minScore, minExcluded);
} else {
return this;
}
Expand All @@ -121,7 +149,9 @@ public Query rewrite(IndexReader reader) throws IOException {
public Weight createWeight(
IndexSearcher searcher, org.apache.lucene.search.ScoreMode scoreMode, float boost)
throws IOException {
if (scoreMode == ScoreMode.COMPLETE_NO_SCORES) {
if (scoreMode == ScoreMode.COMPLETE_NO_SCORES && !isMinScoreWrapperUsed()) {
// Even if the outer query doesn't require score, inner score is needed if the MinScoreWrapper
// is used for filtering
return innerQuery.createWeight(searcher, scoreMode, boost);
}
Weight[] filterWeights = new Weight[functions.length];
Expand All @@ -134,7 +164,13 @@ public Weight createWeight(
1.0f);
}
}
Weight innerWeight = innerQuery.createWeight(searcher, ScoreMode.COMPLETE, boost);
Weight innerWeight =
innerQuery.createWeight(
searcher,
boostMode == BoostMode.BOOST_MODE_REPLACE
? ScoreMode.COMPLETE_NO_SCORES
: ScoreMode.COMPLETE,
boost);
return new MultiFunctionWeight(this, innerWeight, filterWeights);
}

Expand Down Expand Up @@ -211,6 +247,17 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
}
expl = explainBoost(expl, factorExplanation);
}
float curScore = expl.getValue().floatValue();
if (isFilteredByMinScore(curScore, minScore, minExcluded)) {
expl =
Explanation.noMatch(
"Score value is too low, expected at least "
+ minScore
+ (minExcluded ? " (excluded)" : " (included)")
+ " but got "
+ curScore,
expl);
}
return expl;
}

Expand All @@ -228,6 +275,12 @@ private Explanation explainBoost(Explanation queryExpl, Explanation funcExpl) {
"sum of",
queryExpl,
funcExpl);
case BOOST_MODE_REPLACE:
return Explanation.match(
funcExpl.getValue().floatValue(),
"Ignoring query score, function score of",
queryExpl,
funcExpl);
default:
throw new IllegalStateException("Unknown boost mode type: " + boostMode);
}
Expand All @@ -252,13 +305,107 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
docSets[i] = new Bits.MatchAllBits(context.reader().maxDoc());
}
}
return new MultiFunctionScorer(
innerScorer, this, scoreMode, boostMode, leafFunctions, docSets);

Scorer scorer =
new MultiFunctionScorer(innerScorer, this, scoreMode, boostMode, leafFunctions, docSets);
if (isMinScoreWrapperUsed()) {
scorer = new MinScoreWrapper(scorer.getWeight(), scorer, minScore, minExcluded);
}
return scorer;
}

@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
// When not using MinScoreWrapper, it is cacheable.
return !isMinScoreWrapperUsed();
}
}

/*
* minScoreWrapper is used under either condition:
* 1. minScore is set to be a positive number (no matter min is excluded or not)
* 2. minScore is zero, but zero is excluded
* * */
private boolean isMinScoreWrapperUsed() {
return minScore > 0 || minExcluded;
}

/**
* A port with minimal modification of Elasticsearch <a
* href="https://github.com/elastic/elasticsearch/blob/v7.2.0/server/src/main/java/org/elasticsearch/common/lucene/search/function/MinScoreScorer.java">MinScoreScorer</a>.
* We add minExcluded to make the boundary clear for inclusion/exclusion.
*/
public static class MinScoreWrapper extends Scorer {
private final Scorer in;
private final float minScore;
private float curScore;
private final boolean minExcluded;

public MinScoreWrapper(Weight weight, Scorer in, float minScore, boolean minExcluded) {
super(weight);
this.in = in;
this.minScore = minScore;
this.minExcluded = minExcluded;
}

@Override
public TwoPhaseIterator twoPhaseIterator() {
TwoPhaseIterator inTwoPhase = in.twoPhaseIterator();
DocIdSetIterator approximation;
if (inTwoPhase == null) {
approximation = in.iterator();
if (TwoPhaseIterator.unwrap(approximation) != null) {
inTwoPhase = TwoPhaseIterator.unwrap(approximation);
approximation = inTwoPhase.approximation();
}
} else {
approximation = inTwoPhase.approximation();
}
final TwoPhaseIterator finalTwoPhase = inTwoPhase;
return new TwoPhaseIterator(approximation) {

@Override
public boolean matches() throws IOException {
if (finalTwoPhase != null && finalTwoPhase.matches() == false) {
return false;
}
// we need to check the two-phase iterator first
// otherwise calling score() is illegal
curScore = in.score();
return isFilteredByMinScore(curScore, minScore, minExcluded);
}

@Override
public float matchCost() {
return 1000f // random constant for the score computation
+ (finalTwoPhase == null ? 0 : finalTwoPhase.matchCost());
}
};
}

@Override
public DocIdSetIterator iterator() {
return TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator());
}

@Override
public float getMaxScore(int upTo) throws IOException {
return in.getMaxScore(upTo);
}

@Override
public float score() throws IOException {
return curScore;
}

@Override
public int advanceShallow(int target) throws IOException {
return in.advanceShallow(target);
}

@Override
public int docID() {
return in.docID();
}
}

Expand Down Expand Up @@ -340,6 +487,8 @@ private float computeFinalScore(float innerQueryScore, double functionScore) {
return (float) (innerQueryScore * functionScore);
case BOOST_MODE_SUM:
return (float) (innerQueryScore + functionScore);
case BOOST_MODE_REPLACE:
return (float) functionScore;
default:
throw new IllegalStateException("Unknown boost mode type: " + boostMode);
}
Expand All @@ -359,26 +508,27 @@ public String toString(String field) {
sb.append("{" + (function == null ? "" : function.toString()) + "}");
}
sb.append("])");
sb.append(", minScore: " + minScore).append(minExcluded ? " (excluded)" : " (included)");
return sb.toString();
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (!sameClassAs(obj)) {
return false;
}
MultiFunctionScoreQuery other = (MultiFunctionScoreQuery) obj;
return Objects.equals(this.innerQuery, other.innerQuery)
&& Objects.equals(this.scoreMode, other.scoreMode)
&& Objects.equals(this.boostMode, other.boostMode)
&& Arrays.equals(this.functions, other.functions);
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof MultiFunctionScoreQuery)) return false;
MultiFunctionScoreQuery that = (MultiFunctionScoreQuery) o;
return Float.compare(that.minScore, minScore) == 0
&& minExcluded == that.minExcluded
&& Objects.equals(innerQuery, that.innerQuery)
&& Arrays.equals(functions, that.functions)
&& scoreMode == that.scoreMode
&& boostMode == that.boostMode;
}

@Override
public int hashCode() {
return Objects.hash(classHash(), innerQuery, scoreMode, boostMode, Arrays.hashCode(functions));
int result = Objects.hash(innerQuery, scoreMode, boostMode, minScore, minExcluded);
result = 31 * result + Arrays.hashCode(functions);
return result;
}
}
Loading

0 comments on commit bc2c69e

Please sign in to comment.