summaryrefslogtreecommitdiff
path: root/core/src/main/java/org/elasticsearch/action/termvectors/TermVectorsFilter.java
blob: e6904ee5ede07b50d061b9e048783c394a414fb0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
/*
 * Licensed to Elasticsearch under one or more contributor
 * license agreements. See the NOTICE file distributed with
 * this work for additional information regarding copyright
 * ownership. Elasticsearch licenses this file to you under
 * the Apache License, Version 2.0 (the "License"); you may
 * not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.elasticsearch.action.termvectors;

import org.apache.lucene.index.Fields;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.DefaultSimilarity;
import org.apache.lucene.search.similarities.TFIDFSimilarity;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.search.dfs.AggregatedDfs;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

public class TermVectorsFilter {
    public static final int DEFAULT_MAX_QUERY_TERMS = 25;
    public static final int DEFAULT_MIN_TERM_FREQ = 0;
    public static final int DEFAULT_MAX_TERM_FREQ = Integer.MAX_VALUE;
    public static final int DEFAULT_MIN_DOC_FREQ = 0;
    public static final int DEFAULT_MAX_DOC_FREQ = Integer.MAX_VALUE;
    public static final int DEFAULT_MIN_WORD_LENGTH = 0;
    public static final int DEFAULT_MAX_WORD_LENGTH = 0;

    private int maxNumTerms = DEFAULT_MAX_QUERY_TERMS;
    private int minTermFreq = DEFAULT_MIN_TERM_FREQ;
    private int maxTermFreq = DEFAULT_MAX_TERM_FREQ;
    private int minDocFreq = DEFAULT_MIN_DOC_FREQ;
    private int maxDocFreq = DEFAULT_MAX_DOC_FREQ;
    private int minWordLength = DEFAULT_MIN_WORD_LENGTH;
    private int maxWordLength = DEFAULT_MAX_WORD_LENGTH;

    private Fields fields;
    private Fields topLevelFields;
    private final Set<String> selectedFields;
    private AggregatedDfs dfs;
    private Map<Term, ScoreTerm> scoreTerms;
    private Map<String, Integer> sizes = new HashMap<>();
    private TFIDFSimilarity similarity;

    public TermVectorsFilter(Fields termVectorsByField, Fields topLevelFields, Set<String> selectedFields, @Nullable AggregatedDfs dfs) {
        this.fields = termVectorsByField;
        this.topLevelFields = topLevelFields;
        this.selectedFields = selectedFields;

        this.dfs = dfs;
        this.scoreTerms = new HashMap<>();
        this.similarity = new DefaultSimilarity();
    }

    public void setSettings(TermVectorsRequest.FilterSettings settings) {
        if (settings.maxNumTerms != null) {
            setMaxNumTerms(settings.maxNumTerms);
        }
        if (settings.minTermFreq != null) {
            setMinTermFreq(settings.minTermFreq);
        }
        if (settings.maxTermFreq != null) {
            setMaxTermFreq(settings.maxTermFreq);
        }
        if (settings.minDocFreq != null) {
            setMinDocFreq(settings.minDocFreq);
        }
        if (settings.maxDocFreq != null) {
            setMaxDocFreq(settings.maxDocFreq);
        }
        if (settings.minWordLength != null) {
            setMinWordLength(settings.minWordLength);
        }
        if (settings.maxWordLength != null) {
            setMaxWordLength(settings.maxWordLength);
        }
    }

    public ScoreTerm getScoreTerm(Term term) {
        return scoreTerms.get(term);
    }

    public boolean hasScoreTerm(Term term) {
        return getScoreTerm(term) != null;
    }

    public long size(String fieldName) {
        return sizes.get(fieldName);
    }

    public int getMaxNumTerms() {
        return maxNumTerms;
    }

    public int getMinTermFreq() {
        return minTermFreq;
    }

    public int getMaxTermFreq() {
        return maxTermFreq;
    }

    public int getMinDocFreq() {
        return minDocFreq;
    }

    public int getMaxDocFreq() {
        return maxDocFreq;
    }

    public int getMinWordLength() {
        return minWordLength;
    }

    public int getMaxWordLength() {
        return maxWordLength;
    }

    public void setMaxNumTerms(int maxNumTerms) {
        this.maxNumTerms = maxNumTerms;
    }

    public void setMinTermFreq(int minTermFreq) {
        this.minTermFreq = minTermFreq;
    }

    public void setMaxTermFreq(int maxTermFreq) {
        this.maxTermFreq = maxTermFreq;
    }

    public void setMinDocFreq(int minDocFreq) {
        this.minDocFreq = minDocFreq;
    }

    public void setMaxDocFreq(int maxDocFreq) {
        this.maxDocFreq = maxDocFreq;
    }

    public void setMinWordLength(int minWordLength) {
        this.minWordLength = minWordLength;
    }

    public void setMaxWordLength(int maxWordLength) {
        this.maxWordLength = maxWordLength;
    }

    public static final class ScoreTerm {
        public String field;
        public String word;
        public float score;

        ScoreTerm(String field, String word, float score) {
            this.field = field;
            this.word = word;
            this.score = score;
        }

        void update(String field, String word, float score) {
            this.field = field;
            this.word = word;
            this.score = score;
        }
    }

    public void selectBestTerms() throws IOException {
        PostingsEnum docsEnum = null;

        for (String fieldName : fields) {
            if ((selectedFields != null) && (!selectedFields.contains(fieldName))) {
                continue;
            }

            Terms terms = fields.terms(fieldName);
            Terms topLevelTerms = topLevelFields.terms(fieldName);

            // if no terms found, take the retrieved term vector fields for stats
            if (topLevelTerms == null) {
                topLevelTerms = terms;
            }

            long numDocs = getDocCount(fieldName, topLevelTerms);

            // one queue per field name
            ScoreTermsQueue queue = new ScoreTermsQueue(Math.min(maxNumTerms, (int) terms.size()));

            // select terms with highest tf-idf
            TermsEnum termsEnum = terms.iterator();
            TermsEnum topLevelTermsEnum = topLevelTerms.iterator();
            while (termsEnum.next() != null) {
                BytesRef termBytesRef = termsEnum.term();
                boolean foundTerm = topLevelTermsEnum.seekExact(termBytesRef);
                assert foundTerm : "Term: " + termBytesRef.utf8ToString() + " not found!";

                Term term = new Term(fieldName, termBytesRef);

                // remove noise words
                int freq = getTermFreq(termsEnum, docsEnum);
                if (isNoise(term.bytes().utf8ToString(), freq)) {
                    continue;
                }

                // now call on docFreq
                long docFreq = getTermStatistics(topLevelTermsEnum, term).docFreq();
                if (!isAccepted(docFreq)) {
                    continue;
                }

                // filter based on score
                float score = computeScore(docFreq, freq, numDocs);
                queue.addOrUpdate(new ScoreTerm(term.field(), term.bytes().utf8ToString(), score));
            }

            // retain the best terms for quick lookups
            ScoreTerm scoreTerm;
            int count = 0;
            while ((scoreTerm = queue.pop()) != null) {
                scoreTerms.put(new Term(scoreTerm.field, scoreTerm.word), scoreTerm);
                count++;
            }
            sizes.put(fieldName, count);
        }
    }

    private boolean isNoise(String word, int freq) {
        // filter out words based on length
        int len = word.length();
        if (minWordLength > 0 && len < minWordLength) {
            return true;
        }
        if (maxWordLength > 0 && len > maxWordLength) {
            return true;
        }
        // filter out words that don't occur enough times in the source
        if (minTermFreq > 0 && freq < minTermFreq) {
            return true;
        }
        // filter out words that occur too many times in the source
        if (freq > maxTermFreq) {
            return true;
        }
        return false;
    }

    private boolean isAccepted(long docFreq) {
        // filter out words that don't occur in enough docs
        if (minDocFreq > 0 && docFreq < minDocFreq) {
            return false;
        }
        // filter out words that occur in too many docs
        if (docFreq > maxDocFreq) {
            return false;
        }
        // index update problem?
        if (docFreq == 0) {
            return false;
        }
        return true;
    }

    private long getDocCount(String fieldName, Terms topLevelTerms) throws IOException {
        if (dfs != null) {
            return dfs.fieldStatistics().get(fieldName).docCount();
        }
        return topLevelTerms.getDocCount();
    }

    private TermStatistics getTermStatistics(TermsEnum termsEnum, Term term) throws IOException {
        if (dfs != null) {
            return dfs.termStatistics().get(term);
        }
        return new TermStatistics(termsEnum.term(), termsEnum.docFreq(), termsEnum.totalTermFreq());
    }

    private int getTermFreq(TermsEnum termsEnum, PostingsEnum docsEnum) throws IOException {
        docsEnum = termsEnum.postings(docsEnum);
        docsEnum.nextDoc();
        return docsEnum.freq();
    }

    private float computeScore(long docFreq, int freq, long numDocs) {
        return freq * similarity.idf(docFreq, numDocs);
    }

    private static class ScoreTermsQueue extends org.apache.lucene.util.PriorityQueue<ScoreTerm> {
        private final int limit;

        ScoreTermsQueue(int maxSize) {
            super(maxSize);
            this.limit = maxSize;
        }

        @Override
        protected boolean lessThan(ScoreTerm a, ScoreTerm b) {
            return a.score < b.score;
        }

        public void addOrUpdate(ScoreTerm scoreTerm) {
            if (this.size() < limit) {
                // there is still space in the queue
                this.add(scoreTerm);
            } else {
                // otherwise update the smallest in the queue in place and update the queue
                ScoreTerm scoreTermTop = this.top();
                if (scoreTermTop.score < scoreTerm.score) {
                    scoreTermTop.update(scoreTerm.field, scoreTerm.word, scoreTerm.score);
                    this.updateTop();
                }
            }
        }
    }
}