Skip to content

Commit d5d6e26

Browse files
authored
Merge pull request #118 from steleal/gh-117-add_radius_to_knn
Gh-117 add parameter radius to knn search Closes gh-117
2 parents ef9b9b7 + f8b46ae commit d5d6e26

11 files changed

Lines changed: 420 additions & 113 deletions

File tree

src/main/java/ru/rt/restream/reindexer/binding/Consts.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
*/
2121
public final class Consts {
2222

23-
public static final String REINDEXER_VERSION = "v5.0.0";
23+
public static final String REINDEXER_VERSION = "v5.4.0";
2424
public static final String DEF_APP_NAME = "java-connector";
2525
public static final String APP_PROPERTY_NAME = "app.name";
2626

@@ -84,7 +84,7 @@ public final class Consts {
8484
public static final int KNN_QUERY_TYPE_HNSW = 2;
8585
public static final int KNN_QUERY_TYPE_IVF = 3;
8686

87-
public static final int KNN_QUERY_PARAMS_VERSION = 0;
87+
public static final int KNN_QUERY_PARAMS_VERSION = 1;
8888

8989
public static final int RESULTS_FORMAT_MASK = 0xF;
9090
public static final int RESULTS_PURE = 0x0;

src/main/java/ru/rt/restream/reindexer/binding/cproto/ByteBuffer.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,22 @@ public ByteBuffer(float expandFactor, int initialCapacity) {
9797
buffer = new byte[initialCapacity];
9898
}
9999

100+
101+
/**
102+
* Encodes an integer value into unsigned 8-bit integer.
103+
* Increments buffer position.
104+
*
105+
* @param value value to encode
106+
* @return the {@link ByteBuffer} for further customizations
107+
*/
108+
public ByteBuffer putUInt8(int value) {
109+
if (value < 0 || value > 0xFF) {
110+
throw new IllegalArgumentException();
111+
}
112+
putIntBits(value, Byte.BYTES, -1);
113+
return this;
114+
}
115+
100116
/**
101117
* Encodes an integer value into unsigned 16-bit integer.
102118
* Increments buffer position.

src/main/java/ru/rt/restream/reindexer/vector/params/BaseKnnSearchParam.java

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,35 +20,86 @@
2020
import lombok.Getter;
2121
import ru.rt.restream.reindexer.binding.cproto.ByteBuffer;
2222

23-
import java.util.Collections;
23+
import java.util.ArrayList;
2424
import java.util.List;
2525

2626
import static ru.rt.restream.reindexer.binding.Consts.KNN_QUERY_PARAMS_VERSION;
2727
import static ru.rt.restream.reindexer.binding.Consts.KNN_QUERY_TYPE_BASE;
2828

29+
/**
30+
* Common parameters for all types of KNN indices.
31+
*
32+
* <p>If all parameters are specified, the filtering will be performed in such a way that all conditions are met.
33+
* At least one of these parameters must be specified.
34+
*/
2935
@Getter
3036
@AllArgsConstructor(access = AccessLevel.PACKAGE)
3137
public class BaseKnnSearchParam implements KnnSearchParam {
38+
private static final int KNN_SERIALIZE_WITH_K = 1;
39+
private static final int KNN_SERIALIZE_WITH_RADIUS = 1 << 1;
40+
3241
/**
3342
* The maximum number of documents returned from the index for subsequent filtering.
3443
*/
35-
private final int k;
44+
private final Integer k;
45+
/**
46+
* Parameter for filtering vectors by ranks.
47+
*
48+
* <p>Rank() < radius for L2 metrics and rank() > radius for cosine and inner product metrics.
49+
* About default values and usage see
50+
* <a href="https://reindexer.io/reindexer-docs/select/vector_search/float_vector/#range-search">
51+
*/
52+
private final Float radius;
3653

3754
/**
3855
* {@inheritDoc}
3956
*/
4057
@Override
4158
public void serializeBy(ByteBuffer buffer) {
4259
buffer.putVarUInt32(KNN_QUERY_TYPE_BASE)
43-
.putVarUInt32(KNN_QUERY_PARAMS_VERSION)
44-
.putVarUInt32(k);
60+
.putVarUInt32(KNN_QUERY_PARAMS_VERSION);
61+
serializeKAndRadius(buffer);
62+
}
63+
64+
void serializeKAndRadius(ByteBuffer buffer) {
65+
checkValues();
66+
int mask = 0;
67+
if (k != null) {
68+
mask |= KNN_SERIALIZE_WITH_K;
69+
}
70+
if (radius != null) {
71+
mask |= KNN_SERIALIZE_WITH_RADIUS;
72+
}
73+
buffer.putUInt8(mask);
74+
if (k != null) {
75+
buffer.putVarUInt32(k);
76+
}
77+
if (radius != null) {
78+
buffer.putFloat(radius);
79+
}
80+
}
81+
82+
private void checkValues() {
83+
if (k == null && radius == null) {
84+
throw new IllegalArgumentException("Both params (k and radius) cannot be null");
85+
}
86+
if (k != null && k <= 0) {
87+
throw new IllegalArgumentException("'k' must be greater than 0");
88+
}
4589
}
4690

4791
/**
4892
* {@inheritDoc}
4993
*/
5094
@Override
5195
public List<String> toLog() {
52-
return Collections.singletonList("k=" + k);
96+
List<String> values = new ArrayList<>(2);
97+
if (k != null) {
98+
values.add("k=" + k);
99+
}
100+
if (radius != null) {
101+
values.add("radius=" + radius);
102+
}
103+
return values;
53104
}
54105
}

src/main/java/ru/rt/restream/reindexer/vector/params/IndexBfSearchParam.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
import lombok.AccessLevel;
1919
import lombok.AllArgsConstructor;
2020
import lombok.Getter;
21+
import lombok.NonNull;
2122
import ru.rt.restream.reindexer.binding.cproto.ByteBuffer;
2223

23-
import java.util.Arrays;
24-
import java.util.Collections;
24+
import java.util.ArrayList;
2525
import java.util.List;
2626

2727
import static ru.rt.restream.reindexer.binding.Consts.KNN_QUERY_PARAMS_VERSION;
@@ -31,25 +31,26 @@
3131
@AllArgsConstructor(access = AccessLevel.PACKAGE)
3232
public class IndexBfSearchParam implements KnnSearchParam {
3333
/**
34-
* The maximum number of documents returned from the index for subsequent filtering.
34+
* Common parameters for KNN search.
3535
*/
36-
private final int k;
36+
@NonNull
37+
private final BaseKnnSearchParam base;
3738

3839
/**
3940
* {@inheritDoc}
4041
*/
4142
@Override
4243
public void serializeBy(ByteBuffer buffer) {
4344
buffer.putVarUInt32(KNN_QUERY_TYPE_BRUTE_FORCE)
44-
.putVarUInt32(KNN_QUERY_PARAMS_VERSION)
45-
.putVarUInt32(k);
45+
.putVarUInt32(KNN_QUERY_PARAMS_VERSION);
46+
base.serializeKAndRadius(buffer);
4647
}
4748

4849
/**
4950
* {@inheritDoc}
5051
*/
5152
@Override
5253
public List<String> toLog() {
53-
return Collections.singletonList("k=" + k);
54+
return new ArrayList<>(base.toLog());
5455
}
5556
}

src/main/java/ru/rt/restream/reindexer/vector/params/IndexHnswSearchParam.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
import lombok.AccessLevel;
1919
import lombok.AllArgsConstructor;
2020
import lombok.Getter;
21+
import lombok.NonNull;
2122
import ru.rt.restream.reindexer.binding.cproto.ByteBuffer;
2223

23-
import java.util.Arrays;
24-
import java.util.Collections;
24+
import java.util.ArrayList;
2525
import java.util.List;
2626

2727
import static ru.rt.restream.reindexer.binding.Consts.KNN_QUERY_PARAMS_VERSION;
@@ -31,9 +31,10 @@
3131
@AllArgsConstructor(access = AccessLevel.PACKAGE)
3232
public class IndexHnswSearchParam implements KnnSearchParam {
3333
/**
34-
* The maximum number of documents returned from the index for subsequent filtering.
34+
* Common parameters for KNN search.
3535
*/
36-
private final int k;
36+
@NonNull
37+
private final BaseKnnSearchParam base;
3738

3839
/**
3940
* The size of the dynamic list for the nearest neighbors.
@@ -50,16 +51,19 @@ public class IndexHnswSearchParam implements KnnSearchParam {
5051
@Override
5152
public void serializeBy(ByteBuffer buffer) {
5253
buffer.putVarUInt32(KNN_QUERY_TYPE_HNSW)
53-
.putVarUInt32(KNN_QUERY_PARAMS_VERSION)
54-
.putVarUInt32(k)
55-
.putVarInt32(ef);
54+
.putVarUInt32(KNN_QUERY_PARAMS_VERSION);
55+
base.serializeKAndRadius(buffer);
56+
buffer.putVarInt32(ef);
5657
}
5758

5859
/**
5960
* {@inheritDoc}
6061
*/
6162
@Override
6263
public List<String> toLog() {
63-
return Arrays.asList("k=" + k, "ef=" + ef);
64+
List<String> values = new ArrayList<>(3);
65+
values.addAll(base.toLog());
66+
values.add("ef=" + ef);
67+
return values;
6468
}
6569
}

src/main/java/ru/rt/restream/reindexer/vector/params/IndexIvfSearchParam.java

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
import lombok.AccessLevel;
1919
import lombok.AllArgsConstructor;
2020
import lombok.Getter;
21+
import lombok.NonNull;
2122
import ru.rt.restream.reindexer.binding.cproto.ByteBuffer;
2223

23-
import java.util.Arrays;
24+
import java.util.ArrayList;
2425
import java.util.List;
2526

2627
import static ru.rt.restream.reindexer.binding.Consts.KNN_QUERY_PARAMS_VERSION;
@@ -30,9 +31,10 @@
3031
@AllArgsConstructor(access = AccessLevel.PACKAGE)
3132
public class IndexIvfSearchParam implements KnnSearchParam {
3233
/**
33-
* The maximum number of documents returned from the index for subsequent filtering.
34+
* Common parameters for KNN search.
3435
*/
35-
private final int k;
36+
@NonNull
37+
private final BaseKnnSearchParam base;
3638

3739
/**
3840
* the number of clusters to be looked at during the search.
@@ -49,16 +51,19 @@ public class IndexIvfSearchParam implements KnnSearchParam {
4951
@Override
5052
public void serializeBy(ByteBuffer buffer) {
5153
buffer.putVarUInt32(KNN_QUERY_TYPE_IVF)
52-
.putVarUInt32(KNN_QUERY_PARAMS_VERSION)
53-
.putVarUInt32(k)
54-
.putVarUInt32(nProbe);
54+
.putVarUInt32(KNN_QUERY_PARAMS_VERSION);
55+
base.serializeKAndRadius(buffer);
56+
buffer.putVarUInt32(nProbe);
5557
}
5658

5759
/**
5860
* {@inheritDoc}
5961
*/
6062
@Override
6163
public List<String> toLog() {
62-
return Arrays.asList("k=" + k, "nprobe=" + nProbe);
64+
List<String> values = new ArrayList<>(3);
65+
values.addAll(base.toLog());
66+
values.add("nprobe=" + nProbe);
67+
return values;
6368
}
6469
}

src/main/java/ru/rt/restream/reindexer/vector/params/KnnParams.java

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,34 +15,66 @@
1515
*/
1616
package ru.rt.restream.reindexer.vector.params;
1717

18+
import lombok.NonNull;
19+
1820
/**
1921
* Factories for KnnSearchParams.
2022
*/
2123
public class KnnParams {
24+
@Deprecated
2225
public static BaseKnnSearchParam base(int k) {
2326
checkK(k);
24-
return new BaseKnnSearchParam(k);
27+
return new BaseKnnSearchParam(k, null);
2528
}
2629

27-
public static IndexHnswSearchParam hnsw(int k, int ef) {
30+
public static BaseKnnSearchParam base(int k, float radius) {
31+
checkK(k);
32+
return new BaseKnnSearchParam(k, radius);
33+
}
34+
35+
public static BaseKnnSearchParam k(int k) {
2836
checkK(k);
37+
return new BaseKnnSearchParam(k, null);
38+
}
39+
40+
public static BaseKnnSearchParam radius(float radius) {
41+
return new BaseKnnSearchParam(null, radius);
42+
}
43+
44+
public static IndexHnswSearchParam hnsw(int k, int ef) {
2945
if (ef < k) {
3046
throw new IllegalArgumentException("Minimal value of 'ef' must be greater than or equal to 'k'");
3147
}
32-
return new IndexHnswSearchParam(k, ef);
48+
return new IndexHnswSearchParam(k(k), ef);
49+
}
50+
51+
public static IndexHnswSearchParam hnsw(@NonNull BaseKnnSearchParam base, int ef) {
52+
if (base.getK() != null && ef < base.getK()) {
53+
throw new IllegalArgumentException("Minimal value of 'ef' must be greater than or equal to 'k'");
54+
}
55+
return new IndexHnswSearchParam(base, ef);
3356
}
3457

3558
public static IndexBfSearchParam bf(int k) {
36-
checkK(k);
37-
return new IndexBfSearchParam(k);
59+
return new IndexBfSearchParam(k(k));
60+
}
61+
62+
public static IndexBfSearchParam bf(@NonNull BaseKnnSearchParam base) {
63+
return new IndexBfSearchParam(base);
3864
}
3965

4066
public static IndexIvfSearchParam ivf(int k, int nProbe) {
41-
checkK(k);
4267
if (nProbe <= 0) {
4368
throw new IllegalArgumentException("Minimal value of 'nProbe' must be greater than 0");
4469
}
45-
return new IndexIvfSearchParam(k, nProbe);
70+
return new IndexIvfSearchParam(k(k), nProbe);
71+
}
72+
73+
public static IndexIvfSearchParam ivf(@NonNull BaseKnnSearchParam base, int nProbe) {
74+
if (nProbe <= 0) {
75+
throw new IllegalArgumentException("Minimal value of 'nProbe' must be greater than 0");
76+
}
77+
return new IndexIvfSearchParam(base, nProbe);
4678
}
4779

4880
private static void checkK(int k) {

src/main/java/ru/rt/restream/reindexer/vector/params/KnnSearchParam.java

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,6 @@
2323
* Common interface for KNN search parameters.
2424
*/
2525
public interface KnnSearchParam {
26-
/**
27-
* K - the maximum number of documents returned from the index for subsequent filtering.
28-
*
29-
* <p>Only required parameter for all vector index types.
30-
*/
31-
int getK();
32-
3326
/**
3427
* Utility method for serializing KNN parameters to CJSON avoiding switch.
3528
*/

0 commit comments

Comments
 (0)