Processing math: 100%

学习ES

修改ES自带评分模型算法

Posted by Jie on March 13, 2017

最近做项目需要用到ES。这里需要修改ES自带评分模型算法。

目地

添加脚本修改_score,例如,从几篇文章搜索特定关键字的文章,并且传入权重根据title长度,内容长度排序。

实践

首先肯定需要用到中文分词

https://www.elastic.co/guide/en/elasticsearch/plugins/current/analysis-smartcn.html

然后

PUT  test
{
  "settings": {
    "index": {
      "analysis": {
        "analyzer": {
          "default": {
            "type": "smartcn"
          }
        }
      }
    }
  }
}

添加语料,文本有title和content字段。

定制自己的模型算法脚本

这里假设title长度、content长度和自带评分与我们想要的最终结果成线性关系。 那么对其中每一个因素一个权重CW、TW、SW。

算法脚本Python

public class TestScript extends Plugin implements ScriptPlugin {


    @Override
    public List<NativeScriptFactory> getNativeScripts() {
        NativeScriptFactory nativeScriptFactory = new LengthNativeScriptFactory();
        return Collections.singletonList(nativeScriptFactory);
    }

    public static class LengthNativeScriptFactory implements NativeScriptFactory {
        @Override
        public ExecutableScript newScript(@Nullable Map<String, Object> params) {
            return new MyNativeScript(params);
        }
        @Override
        public boolean needsScores() {
            return true;
        }
        @Override
        public String getName() {
            return "length_score";
        }
    }

    public static class MyNativeScript extends AbstractDoubleSearchScript {
        protected Map<String, Object> params;
        public MyNativeScript(Map<String, Object> params) {
            this.params = params;
        }

        @Override
        public double runAsDouble() {
            long contentLength = source().get("content").toString().length();
            long titleLength = source().get("title").toString().length();
            Float esScore = null;
            try {
                esScore = score();
            } catch (IOException e) {
                esScore = 0f;
            }

            Double contentWeight = (Double) params.get("CW");
            Double titleWeight = (Double) params.get("TW");
            Double scoreWeight = (Double) params.get("SW");

            double contentScore =  computeScore(contentWeight, contentLength);
            double titleScore = computeScore(titleWeight, titleLength);
            double score = computeScore(scoreWeight, esScore);
            return contentScore + titleScore + score;
        }

        private double computeScore(Double weight, long x) {
            if (weight == null) {
                return 0;
            }
            return weight * x;
        }

        private double computeScore(Double weight, Float x) {
            if (weight == null || x == null) {
                return 0;
            }
            return weight * x;
        }
    }

}

开始搜索

{
    "query" : {
        "function_score": {
      "query": {
        "match": {
          "content": "分布 语言"
        }
      },
      "functions": [
        {
          "script_score": {
            "script": {
                "params": {
            "CW": 0.0,
            "TW": 0.0,
            "SW": 1.0
        },
                "inline": "length_score",
                "lang" : "native"
            }
          }
        }
      ]
    }
    }
}
{
    "query" : {
        "function_score": {
      "query": {
        "match": {
          "content": "分布 语言"
        }
      },
      "functions": [
        {
          "script_score": {
            "script": {
                "params": {
            "CW": 1.0,
            "TW": 1.0,
            "SW": 1.0
        },
                "inline": "length_score",
                "lang" : "native"
            }
          }
        }
      ]
    }
    }
}

查看返回结果。 嗯,是预期的结果。

优化改进

我们业务上这样搜索,当索引达到10w的时候就变得很慢,一次搜索需要5、600ms。只能改进,下面定制我们自己的搜索策略,替换es默认搜索BM25。

这里业务需求变了,不是长度排序了。

VSMPlugin.java

import org.elasticsearch.index.IndexModule;
import org.elasticsearch.plugins.Plugin;

public class VSMPlugin extends Plugin {
    public void onIndexModule(IndexModule indexModule) {
        indexModule.addSimilarity("vsm", VSMSimilarityProvider::new);
    }
}

VSMSimilarity.java

import org.apache.lucene.index.*;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.util.BytesRef;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class VSMSimilarity extends Similarity {
    public VSMSimilarity() {}

    @Override
    public long computeNorm(FieldInvertState fieldInvertState) {
        // 不使用norm
        return 1;
    }

    @Override
    public SimWeight computeWeight(CollectionStatistics collectionStats, TermStatistics... termStats) {
        return new VSMStats(collectionStats.field(), termStats);
    }

    @Override
    public final SimScorer simScorer(SimWeight weight, LeafReaderContext context) throws IOException {
        VSMStats stats = (VSMStats) weight;

        return new VSMScorer(stats, context);
    }



    class VSMStats extends SimWeight {
        private final String field;
        private final TermStatistics[] termStats;
        private float totalBoost;
        public VSMStats(String field, TermStatistics[] termStats) {
            this.field = field;
            this.termStats = termStats;
        }

        @Override
        public float getValueForNormalization() {
            return 1f;
        }

        @Override
        public void normalize(float v, float v1) {
            this.totalBoost = v * v1;
        }
    }

    final class VSMScorer extends SimScorer {
        private final VSMStats stats;
        private final LeafReaderContext context;
        private final List<Explanation> explanations = new ArrayList<>();
        public VSMScorer(VSMStats positionStats, LeafReaderContext context) throws IOException {
            this.stats = positionStats;
            this.context = context;
        }

        @Override
        public float score(int doc, float freq)  {

            return freq;
        }

        @Override
        public float computeSlopFactor(int i) {

            return 1f / (i + 1);
        }

        @Override
        public float computePayloadFactor(int i, int i1, int i2, BytesRef bytesRef) {
            return 1;
        }

        @Override
        public Explanation explain(int doc, Explanation freq) {
            return Explanation.match(
                    score(doc, freq.getValue()),
                    "score(doc=" + doc + ", freq=" + freq.getValue() + "), sum of:",
                    explanations
            );
        }

    }
}

VSMSimilarityProvider.java

import org.apache.lucene.search.similarities.Similarity;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.inject.assistedinject.Assisted;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.similarity.AbstractSimilarityProvider;
import org.elasticsearch.index.similarity.SimilarityProvider;

public class VSMSimilarityProvider extends AbstractSimilarityProvider {

    private final VSMSimilarity similarity = new VSMSimilarity();

    protected VSMSimilarityProvider(String name) {
        super(name);
    }

    @Inject
    public VSMSimilarityProvider(@Assisted String name, @Assisted Settings settings) {
        super(name);
    }

    @Override
    public String name() {
        return "vsm";
    }

    @Override
    public Similarity get() {
        return similarity;
    }
}

对索引进行配置

PUT all_question_index_test
{
  "settings": {
    "index": {
      "similarity": {
        "default": {
          "type": "vsm"
        }
      }
    },
    "analysis": {
      "analyzer": "whitespace"
    }
  }
}

搜索

GET all_question_index_test/question/_search
{
    "query": {
         "function_score": {
             "query": {
        "multi_match": {
            "query" : "用 公 式 法 解 方 程 x 2 - 2 x - 3 0",
            "fields" : ["content"]
        }}, "functions": [
        {
          "script_score": {
            "script" : {
                        "inline" : "_score/doc['docmod'].value",
                        "lang"   : "painless"
            }
          }
        }
      ]
      }
    }
}

一个很简单的相似度,这里使用了docmod字段来存储文档的模。

这样搜索一次只需要50ms,对于我们的业务,业务精度也从0.7(默认的BM25)上升到0.95。

不过这样我们需要自己定义几个字段才能满足需求,建索引的请求和搜索请求耦合很大了。

再对切词也做了修改,继续优化能达到0.99的精度。

第一次ES研究到此结束。