最近做项目需要用到ES。这里需要修改ES自带评分模型算法。
目地
添加脚本修改_score,例如,从几篇文章搜索特定关键字的文章,并且传入权重根据title长度,内容长度排序。
实践
首先肯定需要用到中文分词
https://www.elastic.co/guide/en/elasticsearch/plugins/current/analysis-smartcn.html
然后
PUT test
{
"settings": {
"index": {
"analysis": {
"analyzer": {
"default": {
"type": "smartcn"
}
}
}
}
}
}
添加语料,文本有title和content字段。
定制自己的模型算法脚本
这里假设title长度、content长度和自带评分与我们想要的最终结果成线性关系。 那么对其中每一个因素一个权重CW、TW、SW。
算法脚本Python更佳
public class TestScript extends Plugin implements ScriptPlugin {
@Override
public List<NativeScriptFactory> getNativeScripts() {
NativeScriptFactory nativeScriptFactory = new LengthNativeScriptFactory();
return Collections.singletonList(nativeScriptFactory);
}
public static class LengthNativeScriptFactory implements NativeScriptFactory {
@Override
public ExecutableScript newScript(@Nullable Map<String, Object> params) {
return new MyNativeScript(params);
}
@Override
public boolean needsScores() {
return true;
}
@Override
public String getName() {
return "length_score";
}
}
public static class MyNativeScript extends AbstractDoubleSearchScript {
protected Map<String, Object> params;
public MyNativeScript(Map<String, Object> params) {
this.params = params;
}
@Override
public double runAsDouble() {
long contentLength = source().get("content").toString().length();
long titleLength = source().get("title").toString().length();
Float esScore = null;
try {
esScore = score();
} catch (IOException e) {
esScore = 0f;
}
Double contentWeight = (Double) params.get("CW");
Double titleWeight = (Double) params.get("TW");
Double scoreWeight = (Double) params.get("SW");
double contentScore = computeScore(contentWeight, contentLength);
double titleScore = computeScore(titleWeight, titleLength);
double score = computeScore(scoreWeight, esScore);
return contentScore + titleScore + score;
}
private double computeScore(Double weight, long x) {
if (weight == null) {
return 0;
}
return weight * x;
}
private double computeScore(Double weight, Float x) {
if (weight == null || x == null) {
return 0;
}
return weight * x;
}
}
}
开始搜索
{
"query" : {
"function_score": {
"query": {
"match": {
"content": "分布 语言"
}
},
"functions": [
{
"script_score": {
"script": {
"params": {
"CW": 0.0,
"TW": 0.0,
"SW": 1.0
},
"inline": "length_score",
"lang" : "native"
}
}
}
]
}
}
}
{
"query" : {
"function_score": {
"query": {
"match": {
"content": "分布 语言"
}
},
"functions": [
{
"script_score": {
"script": {
"params": {
"CW": 1.0,
"TW": 1.0,
"SW": 1.0
},
"inline": "length_score",
"lang" : "native"
}
}
}
]
}
}
}
查看返回结果。 嗯,是预期的结果。
优化改进
我们业务上这样搜索,当索引达到10w的时候就变得很慢,一次搜索需要5、600ms。只能改进,下面定制我们自己的搜索策略,替换es默认搜索BM25。
这里业务需求变了,不是长度排序了。
VSMPlugin.java
import org.elasticsearch.index.IndexModule;
import org.elasticsearch.plugins.Plugin;
public class VSMPlugin extends Plugin {
public void onIndexModule(IndexModule indexModule) {
indexModule.addSimilarity("vsm", VSMSimilarityProvider::new);
}
}
VSMSimilarity.java
import org.apache.lucene.index.*;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.util.BytesRef;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class VSMSimilarity extends Similarity {
public VSMSimilarity() {}
@Override
public long computeNorm(FieldInvertState fieldInvertState) {
// 不使用norm
return 1;
}
@Override
public SimWeight computeWeight(CollectionStatistics collectionStats, TermStatistics... termStats) {
return new VSMStats(collectionStats.field(), termStats);
}
@Override
public final SimScorer simScorer(SimWeight weight, LeafReaderContext context) throws IOException {
VSMStats stats = (VSMStats) weight;
return new VSMScorer(stats, context);
}
class VSMStats extends SimWeight {
private final String field;
private final TermStatistics[] termStats;
private float totalBoost;
public VSMStats(String field, TermStatistics[] termStats) {
this.field = field;
this.termStats = termStats;
}
@Override
public float getValueForNormalization() {
return 1f;
}
@Override
public void normalize(float v, float v1) {
this.totalBoost = v * v1;
}
}
final class VSMScorer extends SimScorer {
private final VSMStats stats;
private final LeafReaderContext context;
private final List<Explanation> explanations = new ArrayList<>();
public VSMScorer(VSMStats positionStats, LeafReaderContext context) throws IOException {
this.stats = positionStats;
this.context = context;
}
@Override
public float score(int doc, float freq) {
return freq;
}
@Override
public float computeSlopFactor(int i) {
return 1f / (i + 1);
}
@Override
public float computePayloadFactor(int i, int i1, int i2, BytesRef bytesRef) {
return 1;
}
@Override
public Explanation explain(int doc, Explanation freq) {
return Explanation.match(
score(doc, freq.getValue()),
"score(doc=" + doc + ", freq=" + freq.getValue() + "), sum of:",
explanations
);
}
}
}
VSMSimilarityProvider.java
import org.apache.lucene.search.similarities.Similarity;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.inject.assistedinject.Assisted;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.similarity.AbstractSimilarityProvider;
import org.elasticsearch.index.similarity.SimilarityProvider;
public class VSMSimilarityProvider extends AbstractSimilarityProvider {
private final VSMSimilarity similarity = new VSMSimilarity();
protected VSMSimilarityProvider(String name) {
super(name);
}
@Inject
public VSMSimilarityProvider(@Assisted String name, @Assisted Settings settings) {
super(name);
}
@Override
public String name() {
return "vsm";
}
@Override
public Similarity get() {
return similarity;
}
}
对索引进行配置
PUT all_question_index_test
{
"settings": {
"index": {
"similarity": {
"default": {
"type": "vsm"
}
}
},
"analysis": {
"analyzer": "whitespace"
}
}
}
搜索
GET all_question_index_test/question/_search
{
"query": {
"function_score": {
"query": {
"multi_match": {
"query" : "用 公 式 法 解 方 程 x 2 - 2 x - 3 0",
"fields" : ["content"]
}}, "functions": [
{
"script_score": {
"script" : {
"inline" : "_score/doc['docmod'].value",
"lang" : "painless"
}
}
}
]
}
}
}
一个很简单的相似度,这里使用了docmod字段来存储文档的模。
这样搜索一次只需要50ms,对于我们的业务,业务精度也从0.7(默认的BM25)上升到0.95。
不过这样我们需要自己定义几个字段才能满足需求,建索引的请求和搜索请求耦合很大了。
再对切词也做了修改,继续优化能达到0.99的精度。
第一次ES研究到此结束。