Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/NRedisStack/PublicAPI/PublicAPI.Shipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1435,7 +1435,7 @@ static NRedisStack.Search.VectorData.Lease<T>(int dimension) -> NRedisStack.Sear
static NRedisStack.Search.VectorData.LeaseWithValues<T>(params System.ReadOnlySpan<T> values) -> NRedisStack.Search.VectorData<T>!
static NRedisStack.Search.VectorData.Parameter(string! name) -> NRedisStack.Search.VectorData!
static NRedisStack.Search.VectorData.Raw(System.ReadOnlyMemory<byte> bytes) -> NRedisStack.Search.VectorData!
static NRedisStack.Search.VectorSearchMethod.NearestNeighbour(int count = 10, int? maxCandidates = null) -> NRedisStack.Search.VectorSearchMethod!
static NRedisStack.Search.VectorSearchMethod.NearestNeighbour(int count, int? maxCandidates) -> NRedisStack.Search.VectorSearchMethod!
static NRedisStack.Search.VectorSearchMethod.Range(double radius, double? epsilon = null) -> NRedisStack.Search.VectorSearchMethod!
static NRedisStack.SearchCommandBuilder.Aggregate(string! index, NRedisStack.Search.AggregationRequest! query) -> NRedisStack.RedisStackCommands.SerializedCommand!
static NRedisStack.SearchCommandBuilder.AliasAdd(string! alias, string! index) -> NRedisStack.RedisStackCommands.SerializedCommand!
Expand Down
1 change: 1 addition & 0 deletions src/NRedisStack/PublicAPI/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ NRedisStack.TsAggregations.IsEmpty.get -> bool
NRedisStack.TsAggregations.Length.get -> int
NRedisStack.TsAggregations.this[int index].get -> NRedisStack.Literals.Enums.TsAggregation
NRedisStack.TsAggregations.TsAggregations(NRedisStack.Literals.Enums.TsAggregation aggregation) -> void
static NRedisStack.Search.VectorSearchMethod.NearestNeighbour(int? count = 10, int? maxTopCandidates = null, string? distanceAlias = null, double? shardRatio = null) -> NRedisStack.Search.VectorSearchMethod!
[NRS002]NRedisStack.TsAggregations.TsAggregations(params NRedisStack.Literals.Enums.TsAggregation[]! aggregations) -> void
static NRedisStack.TsAggregations.implicit operator NRedisStack.TsAggregations(NRedisStack.Literals.Enums.TsAggregation aggregation) -> NRedisStack.TsAggregations
static NRedisStack.TsAggregations.implicit operator NRedisStack.TsAggregations(NRedisStack.Literals.Enums.TsAggregation? aggregation) -> NRedisStack.TsAggregations
Expand Down
44 changes: 29 additions & 15 deletions src/NRedisStack/Search/VectorSearchMethod.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,30 @@ internal static VectorSearchMethod Range(double radius, double? epsilon, string?
=> RangeVectorSearchMethod.Create(radius, epsilon, distanceAlias);

public static VectorSearchMethod NearestNeighbour(
int count = NearestNeighbourVectorSearchMethod.DEFAULT_NEAREST_NEIGHBOUR_COUNT, int? maxCandidates = null)
=> NearestNeighbourVectorSearchMethod.Create(count, maxCandidates, null);
int count, int? maxCandidates) // retained for binary compat
=> NearestNeighbourVectorSearchMethod.Create(count, maxCandidates, null, null);

internal static VectorSearchMethod NearestNeighbour(
int? count, int? maxTopCandidates, string? distanceAlias = null)
=> NearestNeighbourVectorSearchMethod.Create(count ?? NearestNeighbourVectorSearchMethod.DEFAULT_NEAREST_NEIGHBOUR_COUNT, maxTopCandidates, distanceAlias);
public static VectorSearchMethod NearestNeighbour(
int? count = NearestNeighbourVectorSearchMethod.DEFAULT_NEAREST_NEIGHBOUR_COUNT, int? maxTopCandidates = null, string? distanceAlias = null, double? shardRatio = null)
=> NearestNeighbourVectorSearchMethod.Create(count ?? NearestNeighbourVectorSearchMethod.DEFAULT_NEAREST_NEIGHBOUR_COUNT, maxTopCandidates, distanceAlias, shardRatio);

private sealed class NearestNeighbourVectorSearchMethod : VectorSearchMethod
{
private static NearestNeighbourVectorSearchMethod? s_Default;

internal static NearestNeighbourVectorSearchMethod Create(int count, int? maxTopCandidates,
string? distanceAlias)
=> count == DEFAULT_NEAREST_NEIGHBOUR_COUNT & maxTopCandidates == null & distanceAlias == null
? (s_Default ??= new NearestNeighbourVectorSearchMethod(DEFAULT_NEAREST_NEIGHBOUR_COUNT, null, null))
: new(count, maxTopCandidates, distanceAlias);
string? distanceAlias, double? shardRatio)
=> count == DEFAULT_NEAREST_NEIGHBOUR_COUNT & maxTopCandidates == null & distanceAlias == null & !shardRatio.HasValue
? (s_Default ??= new NearestNeighbourVectorSearchMethod(DEFAULT_NEAREST_NEIGHBOUR_COUNT, null, null, null))
: new(count, maxTopCandidates, distanceAlias, shardRatio);

private NearestNeighbourVectorSearchMethod(int nearestNeighbourCount, int? maxTopCandidates,
string? distanceAlias)
string? distanceAlias, double? shardRatio)
{
NearestNeighbourCount = nearestNeighbourCount;
MaxTopCandidates = maxTopCandidates;
DistanceAlias = distanceAlias;
ShardRatio = shardRatio;
}

internal const int DEFAULT_NEAREST_NEIGHBOUR_COUNT = 10;
Expand All @@ -65,34 +66,47 @@ private NearestNeighbourVectorSearchMethod(int nearestNeighbourCount, int? maxTo
/// </summary>
public string? DistanceAlias { get; }

/// <summary>
/// Limits the number of documents processed per shard. Only relevant for cluster scenarios. This corresponds
/// to the "SHARD_K_RATIO" parameter.
/// </summary>
public double? ShardRatio { get; }

internal override int GetOwnArgsCount()
{
int count = 4;
if (MaxTopCandidates != null) count += 2;
if (MaxTopCandidates.HasValue) count += 2;
if (DistanceAlias != null) count += 2;
if (ShardRatio.HasValue) count += 2;
return count;
}

internal override void AddOwnArgs(List<object> args)
{
args.Add(Method);
int tokens = 2;
if (MaxTopCandidates != null) tokens += 2;
if (MaxTopCandidates.HasValue) tokens += 2;
if (DistanceAlias != null) tokens += 2;
if (ShardRatio.HasValue) tokens += 2;
args.Add(tokens);
args.Add("K");
args.Add(NearestNeighbourCount);
if (MaxTopCandidates != null)
if (MaxTopCandidates.HasValue)
{
args.Add("EF_RUNTIME");
args.Add(MaxTopCandidates);
args.Add(MaxTopCandidates.GetValueOrDefault());
}

if (DistanceAlias != null)
{
args.Add("YIELD_DISTANCE_AS");
args.Add(DistanceAlias);
}

if (ShardRatio.HasValue)
{
args.Add("SHARD_K_RATIO");
args.Add(ShardRatio.GetValueOrDefault());
}
}
}

Expand Down
13 changes: 11 additions & 2 deletions tests/NRedisStack.Tests/AbstractNRedisStackTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,20 @@ protected internal AbstractNRedisStackTest(EndpointsFixture endpointsFixture, IT
this.log = log;
}

protected void AssertVersion(IDatabase db, [CallerMemberName] string testName = "")
protected void AssertVersion(IDatabase db, [CallerMemberName] string testName = "", string? overrideVersion = null)
{
Version? version = null;
if (overrideVersion is not null)
{
// ignore the attribs; apply "at least the override"
version ??= db.Multiplexer.GetServer((RedisKey)"any key").Version;
Log($"Validating with detected server version: {version}", demand: false);
var skip = new SkipIfRedisCore(Comparison.LessThan, overrideVersion).GetSkip(version);
Assert.SkipWhen(skip is not null, skip ?? "");
return;
}
// this is used to reapply "Skip" logic after auto-discovery of the server version is possible
var attributes = GetType().GetMethod(testName)?.GetCustomAttributes(true) ?? [];
Version? version = null;
foreach (var attribute in attributes)
{
SkipIfRedisCore? core = attribute switch
Expand Down
32 changes: 27 additions & 5 deletions tests/NRedisStack.Tests/Search/HybridSearchIntegrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ private readonly struct Api(SearchCommands ft, string index, IDatabase db)
private const int V1DIM = 5;

private async Task<Api> CreateIndexAsync(string endpointId, [CallerMemberName] string caller = "",
bool populate = true)
bool populate = true, string? overrideVersion = null)
{
var index = $"ix_{caller}";
var db = GetCleanDatabase(endpointId);
AssertVersion(db, caller, overrideVersion);
// ReSharper disable once RedundantArgumentDefaultValue
var ft = db.FT(2);

Expand Down Expand Up @@ -80,7 +81,7 @@ private async Task<Api> CreateIndexAsync(string endpointId, [CallerMemberName] s
return new(ft, index, db);
}

[SkipIfRedisTheory(Comparison.LessThan, "8.3.224")]
[SkipIfRedisTheory(Comparison.LessThan, "8.3.224", deferVersionCheck: true)]
[MemberData(nameof(EndpointsFixture.Env.AllEnvironments), MemberType = typeof(EndpointsFixture.Env))]
public async Task TestSetup(string endpointId)
{
Expand All @@ -98,7 +99,7 @@ public async Task TestSetup(string endpointId)
Assert.Empty(result.Results);
}

[SkipIfRedisTheory(Comparison.LessThan, "8.3.224")]
[SkipIfRedisTheory(Comparison.LessThan, "8.3.224", deferVersionCheck: true)]
[MemberData(nameof(EndpointsFixture.Env.AllEnvironments), MemberType = typeof(EndpointsFixture.Env))]
public async Task TestSearch(string endpointId)
{
Expand Down Expand Up @@ -162,6 +163,8 @@ public enum Scenario
VectorWithNumericFilter,
VectorWithNearest,
VectorWithNearestCount,
VectorWithNearestWithRatio,
VectorWithNearestCountWithRatio,
PreFilterByTag,
PreFilterByNumeric,
ParamSearch,
Expand All @@ -170,6 +173,7 @@ public enum Scenario
ParamMultiPreFilter,
VectorWithRangeAndEpsilon,
VectorWithNearestMaxCandidates,
VectorWithNearestMaxCandidatesWithRatio,

[NotYetImplemented] ExplainScore,
[NotYetImplemented] LinearWithScore,
Expand All @@ -178,6 +182,7 @@ public enum Scenario
[NotYetImplemented] SearchWithComplexScorer,
[NotYetImplemented] VectorWithRangeAndDistanceAlias,
[NotYetImplemented] VectorWithNearestDistAlias,
[NotYetImplemented] VectorWithNearestDistAliasWithRatio,
[NotYetImplemented] ParamPostFilter,
[NotYetImplemented] ParamMultiPostFilter,
}
Expand Down Expand Up @@ -219,7 +224,7 @@ internal static IEnumerable<object[]> CrossJoin<T>(Func<IEnumerable<object[]>> e
public static IEnumerable<object[]> AllEnvironments_Scenarios() =>
CrossJoin<Scenario>(EndpointsFixture.Env.AllEnvironments);

[SkipIfRedisTheory(Comparison.LessThan, "8.3.224")]
[SkipIfRedisTheory(Comparison.LessThan, "8.3.224", deferVersionCheck: true)]
[MemberData(nameof(AllEnvironments_Scenarios))]
public async Task TestSearchScenarios(string endpointId, Scenario scenario)
{
Expand All @@ -228,7 +233,16 @@ public async Task TestSearchScenarios(string endpointId, Scenario scenario)
// throw SkipException.ForSkip("Not expected to work right now");
return;
}
var api = await CreateIndexAsync(endpointId, populate: true);

string? overrideVersion = scenario switch
{
Scenario.VectorWithNearestMaxCandidatesWithRatio
or Scenario.VectorWithNearestCountWithRatio
or Scenario.VectorWithNearestWithRatio
or Scenario.VectorWithNearestDistAliasWithRatio => "8.6.1",
_ => null,
};
var api = await CreateIndexAsync(endpointId, populate: true, overrideVersion: overrideVersion);
if (api.IsNull) return;

var hash = (await api.DB.HashGetAllAsync($"{api.Index}_entry2")).ToDictionary(k => k.Name, v => v.Value);
Expand Down Expand Up @@ -261,12 +275,20 @@ public async Task TestSearchScenarios(string endpointId, Scenario scenario)
method: VectorSearchMethod.Range(42, epsilon: 0.1))),
Scenario.VectorWithNearest => query.VectorSearch(new("@vector1", VectorData.Raw(vec),
method: VectorSearchMethod.NearestNeighbour())),
Scenario.VectorWithNearestWithRatio => query.VectorSearch(new("@vector1", VectorData.Raw(vec),
method: VectorSearchMethod.NearestNeighbour(shardRatio: 0.5))),
Scenario.VectorWithNearestCount => query.VectorSearch(new("@vector1", VectorData.Raw(vec),
method: VectorSearchMethod.NearestNeighbour(20))),
Scenario.VectorWithNearestCountWithRatio => query.VectorSearch(new("@vector1", VectorData.Raw(vec),
method: VectorSearchMethod.NearestNeighbour(20, shardRatio: 0.5))),
Scenario.VectorWithNearestDistAlias => query.VectorSearch(new("@vector1", VectorData.Raw(vec),
method: VectorSearchMethod.NearestNeighbour(null, null, distanceAlias: "dist_alias"))),
Scenario.VectorWithNearestDistAliasWithRatio => query.VectorSearch(new("@vector1", VectorData.Raw(vec),
method: VectorSearchMethod.NearestNeighbour(null, null, distanceAlias: "dist_alias", shardRatio: 0.5))),
Scenario.VectorWithNearestMaxCandidates => query.VectorSearch(new("@vector1", VectorData.Raw(vec),
method: VectorSearchMethod.NearestNeighbour(null, maxTopCandidates: 10))),
Scenario.VectorWithNearestMaxCandidatesWithRatio => query.VectorSearch(new("@vector1", VectorData.Raw(vec),
method: VectorSearchMethod.NearestNeighbour(null, maxTopCandidates: 10, shardRatio: 0.5))),
Scenario.VectorWithTagFilter => query.VectorSearch(new("@vector1", VectorData.Raw(vec),
filter: "@tag1:{foo}")),
Scenario.VectorWithNumericFilter => query.VectorSearch(new("@vector1", VectorData.Raw(vec),
Expand Down
35 changes: 26 additions & 9 deletions tests/NRedisStack.Tests/Search/HybridSearchUnitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -174,33 +174,50 @@ public void BasicNonZeroLengthVectorSearch()
}

[Theory]
[InlineData(false, false)]
[InlineData(false, true)]
[InlineData(true, false)]
[InlineData(true, true)]
public void BasicVectorSearch_WithKNN(bool withScoreAlias, bool withDistanceAlias)
[InlineData(false, false, false)]
[InlineData(false, true, false)]
[InlineData(true, false, false)]
[InlineData(true, true, false)]
[InlineData(false, false, true)]
[InlineData(false, true, true)]
[InlineData(true, false, true)]
[InlineData(true, true, true)]
public void BasicVectorSearch_WithKNN(bool withScoreAlias, bool withDistanceAlias, bool withShardRatio)
{
HybridSearchQuery query = new();
var searchConfig = new HybridSearchQuery.VectorSearchConfig("vField", SomeRandomDataHere);
if (withScoreAlias) searchConfig = searchConfig.WithScoreAlias("my_score_alias");
searchConfig = searchConfig.WithMethod(VectorSearchMethod.NearestNeighbour(null, null,
distanceAlias: withDistanceAlias ? "my_distance_alias" : null));
var method = withShardRatio
? VectorSearchMethod.NearestNeighbour(null, null, distanceAlias: withDistanceAlias ? "my_distance_alias" : null, shardRatio: 0.5)
: VectorSearchMethod.NearestNeighbour(null, null, distanceAlias: withDistanceAlias ? "my_distance_alias" : null);
searchConfig = searchConfig.WithMethod(method);
query.VectorSearch(searchConfig);

var knnCount = 2;
if (withDistanceAlias) knnCount += 2;
if (withShardRatio) knnCount += 2;
object[] expected =
[Index, "VSIM", "vField", "$v", "KNN", withDistanceAlias ? 4 : 2, "K", 10];
[Index, "VSIM", "vField", "$v", "KNN", knnCount, "K", 10];
if (withDistanceAlias)
{
expected = [.. expected, "YIELD_DISTANCE_AS", "my_distance_alias"];
}

if (withShardRatio)
{
expected = [.. expected, "SHARD_K_RATIO", 0.5];
}

if (withScoreAlias)
{
expected = [.. expected, "YIELD_SCORE_AS", "my_score_alias"];
}

expected = [.. expected, "PARAMS", 2, "v", SomeRandomVectorValue];
Assert.Equivalent(expected, GetArgs(query));
var args = GetArgs(query);
log?.WriteLine($"expected: {string.Join(" ", expected)}");
log?.WriteLine($"actual: {string.Join(" ", args)}");
Assert.Equivalent(expected, args);
}

[Theory]
Expand Down
Loading
Loading