Skip to content

Commit

Permalink
test dense and sparse vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
moshebla committed Sep 25, 2018
1 parent 76bc3da commit 5f036bd
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 64 deletions.
51 changes: 51 additions & 0 deletions src/java/com/github/saaay71/solr/VectorPayloadEncoder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.github.saaay71.solr;


import org.apache.commons.lang3.tuple.Pair;
import org.apache.lucene.analysis.payloads.AbstractEncoder;
import org.apache.lucene.analysis.payloads.PayloadEncoder;
import org.apache.lucene.analysis.payloads.PayloadHelper;
import org.apache.lucene.util.BytesRef;

import java.nio.CharBuffer;

public class VectorPayloadEncoder extends AbstractEncoder implements PayloadEncoder {
private final static char delimiter = ',';
private final static int SIZE = Float.BYTES + Integer.BYTES;
public final static int DENSE_VECTOR_PREFIX = -1;

public VectorPayloadEncoder() { }

public BytesRef encode(char[] buffer, int offset, int length) {
int i;
for(i = offset; i < offset + length; ++i) {
if(buffer[i] == delimiter) {
break;
}
}
byte[] bytes = new byte[SIZE];

final boolean isSparse = i < offset + length;
final float vectorElem;
if(isSparse) {
final int sparseIndex = Integer.parseInt(CharBuffer.wrap(buffer, offset, i - offset).toString()) - 1;
PayloadHelper.encodeInt(sparseIndex, bytes, 0);
vectorElem = Float.parseFloat(CharBuffer.wrap(buffer, offset + i, offset + length - (i + 1)).toString());
} else {
PayloadHelper.encodeInt(DENSE_VECTOR_PREFIX, bytes, 0);
vectorElem = Float.parseFloat(CharBuffer.wrap(buffer, offset, length).toString());
}

PayloadHelper.encodeFloat(vectorElem, bytes, Integer.BYTES);
return new BytesRef(bytes);
}

public static Pair<Integer, Float> decode(byte[] buffer) {
return decode(buffer, 0);
}

public static Pair<Integer, Float> decode(byte[] buffer, int offset) {
final int vecIndex = PayloadHelper.decodeInt(buffer, offset);
return Pair.of(vecIndex, PayloadHelper.decodeFloat(buffer, offset + Integer.BYTES));
}
}
2 changes: 1 addition & 1 deletion src/java/com/github/saaay71/solr/VectorQParserPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public Query parse() throws SyntaxError {
q.setQueryString(localParams.toLocalParamsString());
query = q;
}


if (query == null) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "Query is null");
Expand Down
10 changes: 6 additions & 4 deletions src/java/com/github/saaay71/solr/VectorScoreQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.lucene.analysis.payloads.PayloadHelper;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
Expand All @@ -26,12 +28,12 @@ public VectorScoreQuery(Query subQuery, String Vector, String field, boolean cos
super(subQuery);
this.field = field;
this.cosine = cosine;
this.vector = new ArrayList<Double>();
this.vector = new ArrayList<>();
String[] vectorArray = Vector.split(",");
for(int i=0;i<vectorArray.length;i++){
double v = Double.parseDouble(vectorArray[i]);
vector.add(v);
if (cosine){
if (cosine) {
queryVectorNorm += Math.pow(v, 2.0);
}
}
Expand All @@ -46,7 +48,7 @@ public float customScore(int docID, float subQueryScore, float valSrcScore) thro
LeafReader reader = context.reader();
Terms terms = reader.getTermVector(docID, field);
if(vector == null || vector.size() == 0){
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "vector was not indexed");
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "vector could not be parsed");
}
TermsEnum iter = terms.iterator();
BytesRef text;
Expand All @@ -60,7 +62,7 @@ public float customScore(int docID, float subQueryScore, float valSrcScore) thro
for(int freq = 0; freq < freqs; ++freq) {
int currPos = postings.nextPosition();
BytesRef payload = postings.getPayload();
payloadValue = PayloadHelper.decodeFloat(payload.bytes, payload.offset);
payloadValue = PayloadHelper.decodeFloat(payload.bytes, payload.offset + Integer.BYTES);

if (cosine)
docVectorNorm += Math.pow(payloadValue, 2.0);
Expand Down
2 changes: 1 addition & 1 deletion src/test-files/solr/collection1/conf/schema-vector.xml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
<fieldType name="VectorField" class="solr.TextField" indexed="true" termOffsets="true" stored="true" termPayloads="true" termPositions="true" termVectors="true" storeOffsetsWithPositions="true">
<analyzer>
<tokenizer class="solr.WhitespaceTokenizerFactory"/>
<filter class="solr.DelimitedPayloadTokenFilterFactory" encoder="float"/>
<filter class="solr.DelimitedPayloadTokenFilterFactory" encoder="com.github.saaay71.solr.VectorPayloadEncoder"/>
</analyzer>
</fieldType>
<dynamicField name="*" type="string" indexed="true" stored="true"/>
Expand Down
58 changes: 0 additions & 58 deletions src/test/com/github/saaay17/solr/VectorQueryTest.java

This file was deleted.

117 changes: 117 additions & 0 deletions src/test/com/github/saaay71/solr/VectorQueryTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package com.github.saaay71.solr;

import com.google.common.collect.Iterables;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.common.util.StrUtils;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;

import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;


public class VectorQueryTest extends SolrTestCaseJ4 {

private static AtomicInteger idCounter = new AtomicInteger();
private static String[] vectors = {
"|1,1.55 |2,3.53 |3,2.3 |4,0.7 |5,3.44 |6,2.33",
"|1,3.54 |2,0.4 |3,4.16 |4,4.88 |5,4.28 |6,4.25"
};
private static String[] denseVectors = {
"|1.55 |3.53 |2.3 |0.7 |3.44 |2.33",
"|3.54 |0.4 |4.16 |4.88 |4.28 |4.25"
};
private static Iterator<String> vectorsIter = Iterables.cycle(vectors).iterator();
private static Iterator<String> denseVectorsIter = Iterables.cycle(denseVectors).iterator();


@BeforeClass
public static void beforeClass() throws Exception {
initCore("solrconfig.xml", "schema-vector.xml");
}

@Before
public void before() throws Exception {
deleteByQueryAndGetVersion("*:*", params());
idCounter.set(0);
}

@Test
public void denseDataTest() throws Exception {
System.out.println("test runs!");
indexSampleDenseData();

assertQ(req("q", "*:*"),
"//*[@numFound='10']");

assertQ(req("q", "{!vp f=vector vector=\"0.1,4.75,0.3,1.2,0.7,4.0\"}",
"fl", "name,score,vector"), "//*[@numFound='10']");

assertQ(req("q", "{!vp f=vector vector=\"1.55,3.53,2.3,0.7,3.44,2.33\"}",
"fl", "name,score,vector"),
"//*[@numFound='10']",
"//doc[1]/float[@name='score'][.='1.0']",
"count(//float[@name='score'][.='1.0'])=5"
);

assertQ(req("q", "{!vp f=vector vector=\""
+ denseVectors[0].replaceAll("\\|", "").replaceAll(" ", ",")
+ "\"}",
"fl", "name,score,vector"),
"//*[@numFound='10']",
"//doc[1]/float[@name='score'][.='1.0']",
"count(//float[@name='score'][.='1.0'])=5"
);
}

@Test
public void sparseDataTest() throws Exception {
System.out.println("test runs!");
indexSampleData();

assertQ(req("q", "*:*"),
"//*[@numFound='10']");

assertQ(req("q", "{!vp f=vector vector=\"0.1,4.75,0.3,1.2,0.7,4.0\"}",
"fl", "name,score,vector"), "//*[@numFound='10']");

assertQ(req("q", "{!vp f=vector vector=\"1.55,3.53,2.3,0.7,3.44,2.33\"}",
"fl", "name,score,vector"),
"//*[@numFound='10']",
"//doc[1]/float[@name='score'][.='1.0']",
"count(//float[@name='score'][.='1.0'])=5"
);

assertQ(req("q", "{!vp f=vector vector=\"" + sparseToDenseVector(vectors[0]) + "\"}",
"fl", "name,score,vector"),
"//*[@numFound='10']",
"//doc[1]/float[@name='score'][.='1.0']",
"count(//float[@name='score'][.='1.0'])=5"
);
}

private void indexSampleData() throws Exception {
for(int i = 0; i < 10; i++) {
assertU(adoc(sdoc("id", idCounter.incrementAndGet(), "vector", vectorsIter.next())));
}
assertU(commit());
}

private void indexSampleDenseData() throws Exception {
for(int i = 0; i < 10; i++) {
assertU(adoc(sdoc("id", idCounter.incrementAndGet(), "vector", denseVectorsIter.next())));
}
assertU(commit());
}

private String sparseToDenseVector(String sparseVec) {
List<String> splitList = StrUtils.splitSmart(sparseVec.replaceAll("\\|", ""), ' ')
.stream().map(x -> x.split(",")[1])
.collect(Collectors.toList());

return StrUtils.join(splitList, ',');
}
}

0 comments on commit 5f036bd

Please sign in to comment.