diff --git a/src/NRedisStack/PublicAPI/PublicAPI.Shipped.txt b/src/NRedisStack/PublicAPI/PublicAPI.Shipped.txt index 6c5161bf..278d4a2e 100644 --- a/src/NRedisStack/PublicAPI/PublicAPI.Shipped.txt +++ b/src/NRedisStack/PublicAPI/PublicAPI.Shipped.txt @@ -1435,7 +1435,7 @@ static NRedisStack.Search.VectorData.Lease(int dimension) -> NRedisStack.Sear static NRedisStack.Search.VectorData.LeaseWithValues(params System.ReadOnlySpan values) -> NRedisStack.Search.VectorData! static NRedisStack.Search.VectorData.Parameter(string! name) -> NRedisStack.Search.VectorData! static NRedisStack.Search.VectorData.Raw(System.ReadOnlyMemory bytes) -> NRedisStack.Search.VectorData! -static NRedisStack.Search.VectorSearchMethod.NearestNeighbour(int count = 10, int? maxCandidates = null) -> NRedisStack.Search.VectorSearchMethod! +static NRedisStack.Search.VectorSearchMethod.NearestNeighbour(int count, int? maxCandidates) -> NRedisStack.Search.VectorSearchMethod! static NRedisStack.Search.VectorSearchMethod.Range(double radius, double? epsilon = null) -> NRedisStack.Search.VectorSearchMethod! static NRedisStack.SearchCommandBuilder.Aggregate(string! index, NRedisStack.Search.AggregationRequest! query) -> NRedisStack.RedisStackCommands.SerializedCommand! static NRedisStack.SearchCommandBuilder.AliasAdd(string! alias, string! index) -> NRedisStack.RedisStackCommands.SerializedCommand! diff --git a/src/NRedisStack/PublicAPI/PublicAPI.Unshipped.txt b/src/NRedisStack/PublicAPI/PublicAPI.Unshipped.txt index de27282e..37b6fe25 100644 --- a/src/NRedisStack/PublicAPI/PublicAPI.Unshipped.txt +++ b/src/NRedisStack/PublicAPI/PublicAPI.Unshipped.txt @@ -9,6 +9,7 @@ NRedisStack.TsAggregations.IsEmpty.get -> bool NRedisStack.TsAggregations.Length.get -> int NRedisStack.TsAggregations.this[int index].get -> NRedisStack.Literals.Enums.TsAggregation NRedisStack.TsAggregations.TsAggregations(NRedisStack.Literals.Enums.TsAggregation aggregation) -> void +static NRedisStack.Search.VectorSearchMethod.NearestNeighbour(int? count = 10, int? maxTopCandidates = null, string? distanceAlias = null, double? shardRatio = null) -> NRedisStack.Search.VectorSearchMethod! [NRS002]NRedisStack.TsAggregations.TsAggregations(params NRedisStack.Literals.Enums.TsAggregation[]! aggregations) -> void static NRedisStack.TsAggregations.implicit operator NRedisStack.TsAggregations(NRedisStack.Literals.Enums.TsAggregation aggregation) -> NRedisStack.TsAggregations static NRedisStack.TsAggregations.implicit operator NRedisStack.TsAggregations(NRedisStack.Literals.Enums.TsAggregation? aggregation) -> NRedisStack.TsAggregations diff --git a/src/NRedisStack/Search/VectorSearchMethod.cs b/src/NRedisStack/Search/VectorSearchMethod.cs index b5565cc8..5e84d71b 100644 --- a/src/NRedisStack/Search/VectorSearchMethod.cs +++ b/src/NRedisStack/Search/VectorSearchMethod.cs @@ -21,29 +21,30 @@ internal static VectorSearchMethod Range(double radius, double? epsilon, string? => RangeVectorSearchMethod.Create(radius, epsilon, distanceAlias); public static VectorSearchMethod NearestNeighbour( - int count = NearestNeighbourVectorSearchMethod.DEFAULT_NEAREST_NEIGHBOUR_COUNT, int? maxCandidates = null) - => NearestNeighbourVectorSearchMethod.Create(count, maxCandidates, null); + int count, int? maxCandidates) // retained for binary compat + => NearestNeighbourVectorSearchMethod.Create(count, maxCandidates, null, null); - internal static VectorSearchMethod NearestNeighbour( - int? count, int? maxTopCandidates, string? distanceAlias = null) - => NearestNeighbourVectorSearchMethod.Create(count ?? NearestNeighbourVectorSearchMethod.DEFAULT_NEAREST_NEIGHBOUR_COUNT, maxTopCandidates, distanceAlias); + public static VectorSearchMethod NearestNeighbour( + int? count = NearestNeighbourVectorSearchMethod.DEFAULT_NEAREST_NEIGHBOUR_COUNT, int? maxTopCandidates = null, string? distanceAlias = null, double? shardRatio = null) + => NearestNeighbourVectorSearchMethod.Create(count ?? NearestNeighbourVectorSearchMethod.DEFAULT_NEAREST_NEIGHBOUR_COUNT, maxTopCandidates, distanceAlias, shardRatio); private sealed class NearestNeighbourVectorSearchMethod : VectorSearchMethod { private static NearestNeighbourVectorSearchMethod? s_Default; internal static NearestNeighbourVectorSearchMethod Create(int count, int? maxTopCandidates, - string? distanceAlias) - => count == DEFAULT_NEAREST_NEIGHBOUR_COUNT & maxTopCandidates == null & distanceAlias == null - ? (s_Default ??= new NearestNeighbourVectorSearchMethod(DEFAULT_NEAREST_NEIGHBOUR_COUNT, null, null)) - : new(count, maxTopCandidates, distanceAlias); + string? distanceAlias, double? shardRatio) + => count == DEFAULT_NEAREST_NEIGHBOUR_COUNT & maxTopCandidates == null & distanceAlias == null & !shardRatio.HasValue + ? (s_Default ??= new NearestNeighbourVectorSearchMethod(DEFAULT_NEAREST_NEIGHBOUR_COUNT, null, null, null)) + : new(count, maxTopCandidates, distanceAlias, shardRatio); private NearestNeighbourVectorSearchMethod(int nearestNeighbourCount, int? maxTopCandidates, - string? distanceAlias) + string? distanceAlias, double? shardRatio) { NearestNeighbourCount = nearestNeighbourCount; MaxTopCandidates = maxTopCandidates; DistanceAlias = distanceAlias; + ShardRatio = shardRatio; } internal const int DEFAULT_NEAREST_NEIGHBOUR_COUNT = 10; @@ -65,11 +66,18 @@ private NearestNeighbourVectorSearchMethod(int nearestNeighbourCount, int? maxTo /// public string? DistanceAlias { get; } + /// + /// Limits the number of documents processed per shard. Only relevant for cluster scenarios. This corresponds + /// to the "SHARD_K_RATIO" parameter. + /// + public double? ShardRatio { get; } + internal override int GetOwnArgsCount() { int count = 4; - if (MaxTopCandidates != null) count += 2; + if (MaxTopCandidates.HasValue) count += 2; if (DistanceAlias != null) count += 2; + if (ShardRatio.HasValue) count += 2; return count; } @@ -77,22 +85,28 @@ internal override void AddOwnArgs(List args) { args.Add(Method); int tokens = 2; - if (MaxTopCandidates != null) tokens += 2; + if (MaxTopCandidates.HasValue) tokens += 2; if (DistanceAlias != null) tokens += 2; + if (ShardRatio.HasValue) tokens += 2; args.Add(tokens); args.Add("K"); args.Add(NearestNeighbourCount); - if (MaxTopCandidates != null) + if (MaxTopCandidates.HasValue) { args.Add("EF_RUNTIME"); - args.Add(MaxTopCandidates); + args.Add(MaxTopCandidates.GetValueOrDefault()); } - if (DistanceAlias != null) { args.Add("YIELD_DISTANCE_AS"); args.Add(DistanceAlias); } + + if (ShardRatio.HasValue) + { + args.Add("SHARD_K_RATIO"); + args.Add(ShardRatio.GetValueOrDefault()); + } } } diff --git a/tests/NRedisStack.Tests/AbstractNRedisStackTest.cs b/tests/NRedisStack.Tests/AbstractNRedisStackTest.cs index cb850bd3..8e848153 100644 --- a/tests/NRedisStack.Tests/AbstractNRedisStackTest.cs +++ b/tests/NRedisStack.Tests/AbstractNRedisStackTest.cs @@ -36,11 +36,20 @@ protected internal AbstractNRedisStackTest(EndpointsFixture endpointsFixture, IT this.log = log; } - protected void AssertVersion(IDatabase db, [CallerMemberName] string testName = "") + protected void AssertVersion(IDatabase db, [CallerMemberName] string testName = "", string? overrideVersion = null) { + Version? version = null; + if (overrideVersion is not null) + { + // ignore the attribs; apply "at least the override" + version ??= db.Multiplexer.GetServer((RedisKey)"any key").Version; + Log($"Validating with detected server version: {version}", demand: false); + var skip = new SkipIfRedisCore(Comparison.LessThan, overrideVersion).GetSkip(version); + Assert.SkipWhen(skip is not null, skip ?? ""); + return; + } // this is used to reapply "Skip" logic after auto-discovery of the server version is possible var attributes = GetType().GetMethod(testName)?.GetCustomAttributes(true) ?? []; - Version? version = null; foreach (var attribute in attributes) { SkipIfRedisCore? core = attribute switch diff --git a/tests/NRedisStack.Tests/Search/HybridSearchIntegrationTests.cs b/tests/NRedisStack.Tests/Search/HybridSearchIntegrationTests.cs index f9bfe29b..b5581091 100644 --- a/tests/NRedisStack.Tests/Search/HybridSearchIntegrationTests.cs +++ b/tests/NRedisStack.Tests/Search/HybridSearchIntegrationTests.cs @@ -26,10 +26,11 @@ private readonly struct Api(SearchCommands ft, string index, IDatabase db) private const int V1DIM = 5; private async Task CreateIndexAsync(string endpointId, [CallerMemberName] string caller = "", - bool populate = true) + bool populate = true, string? overrideVersion = null) { var index = $"ix_{caller}"; var db = GetCleanDatabase(endpointId); + AssertVersion(db, caller, overrideVersion); // ReSharper disable once RedundantArgumentDefaultValue var ft = db.FT(2); @@ -80,7 +81,7 @@ private async Task CreateIndexAsync(string endpointId, [CallerMemberName] s return new(ft, index, db); } - [SkipIfRedisTheory(Comparison.LessThan, "8.3.224")] + [SkipIfRedisTheory(Comparison.LessThan, "8.3.224", deferVersionCheck: true)] [MemberData(nameof(EndpointsFixture.Env.AllEnvironments), MemberType = typeof(EndpointsFixture.Env))] public async Task TestSetup(string endpointId) { @@ -98,7 +99,7 @@ public async Task TestSetup(string endpointId) Assert.Empty(result.Results); } - [SkipIfRedisTheory(Comparison.LessThan, "8.3.224")] + [SkipIfRedisTheory(Comparison.LessThan, "8.3.224", deferVersionCheck: true)] [MemberData(nameof(EndpointsFixture.Env.AllEnvironments), MemberType = typeof(EndpointsFixture.Env))] public async Task TestSearch(string endpointId) { @@ -162,6 +163,8 @@ public enum Scenario VectorWithNumericFilter, VectorWithNearest, VectorWithNearestCount, + VectorWithNearestWithRatio, + VectorWithNearestCountWithRatio, PreFilterByTag, PreFilterByNumeric, ParamSearch, @@ -170,6 +173,7 @@ public enum Scenario ParamMultiPreFilter, VectorWithRangeAndEpsilon, VectorWithNearestMaxCandidates, + VectorWithNearestMaxCandidatesWithRatio, [NotYetImplemented] ExplainScore, [NotYetImplemented] LinearWithScore, @@ -178,6 +182,7 @@ public enum Scenario [NotYetImplemented] SearchWithComplexScorer, [NotYetImplemented] VectorWithRangeAndDistanceAlias, [NotYetImplemented] VectorWithNearestDistAlias, + [NotYetImplemented] VectorWithNearestDistAliasWithRatio, [NotYetImplemented] ParamPostFilter, [NotYetImplemented] ParamMultiPostFilter, } @@ -219,7 +224,7 @@ internal static IEnumerable CrossJoin(Func> e public static IEnumerable AllEnvironments_Scenarios() => CrossJoin(EndpointsFixture.Env.AllEnvironments); - [SkipIfRedisTheory(Comparison.LessThan, "8.3.224")] + [SkipIfRedisTheory(Comparison.LessThan, "8.3.224", deferVersionCheck: true)] [MemberData(nameof(AllEnvironments_Scenarios))] public async Task TestSearchScenarios(string endpointId, Scenario scenario) { @@ -228,7 +233,16 @@ public async Task TestSearchScenarios(string endpointId, Scenario scenario) // throw SkipException.ForSkip("Not expected to work right now"); return; } - var api = await CreateIndexAsync(endpointId, populate: true); + + string? overrideVersion = scenario switch + { + Scenario.VectorWithNearestMaxCandidatesWithRatio + or Scenario.VectorWithNearestCountWithRatio + or Scenario.VectorWithNearestWithRatio + or Scenario.VectorWithNearestDistAliasWithRatio => "8.6.1", + _ => null, + }; + var api = await CreateIndexAsync(endpointId, populate: true, overrideVersion: overrideVersion); if (api.IsNull) return; var hash = (await api.DB.HashGetAllAsync($"{api.Index}_entry2")).ToDictionary(k => k.Name, v => v.Value); @@ -261,12 +275,20 @@ public async Task TestSearchScenarios(string endpointId, Scenario scenario) method: VectorSearchMethod.Range(42, epsilon: 0.1))), Scenario.VectorWithNearest => query.VectorSearch(new("@vector1", VectorData.Raw(vec), method: VectorSearchMethod.NearestNeighbour())), + Scenario.VectorWithNearestWithRatio => query.VectorSearch(new("@vector1", VectorData.Raw(vec), + method: VectorSearchMethod.NearestNeighbour(shardRatio: 0.5))), Scenario.VectorWithNearestCount => query.VectorSearch(new("@vector1", VectorData.Raw(vec), method: VectorSearchMethod.NearestNeighbour(20))), + Scenario.VectorWithNearestCountWithRatio => query.VectorSearch(new("@vector1", VectorData.Raw(vec), + method: VectorSearchMethod.NearestNeighbour(20, shardRatio: 0.5))), Scenario.VectorWithNearestDistAlias => query.VectorSearch(new("@vector1", VectorData.Raw(vec), method: VectorSearchMethod.NearestNeighbour(null, null, distanceAlias: "dist_alias"))), + Scenario.VectorWithNearestDistAliasWithRatio => query.VectorSearch(new("@vector1", VectorData.Raw(vec), + method: VectorSearchMethod.NearestNeighbour(null, null, distanceAlias: "dist_alias", shardRatio: 0.5))), Scenario.VectorWithNearestMaxCandidates => query.VectorSearch(new("@vector1", VectorData.Raw(vec), method: VectorSearchMethod.NearestNeighbour(null, maxTopCandidates: 10))), + Scenario.VectorWithNearestMaxCandidatesWithRatio => query.VectorSearch(new("@vector1", VectorData.Raw(vec), + method: VectorSearchMethod.NearestNeighbour(null, maxTopCandidates: 10, shardRatio: 0.5))), Scenario.VectorWithTagFilter => query.VectorSearch(new("@vector1", VectorData.Raw(vec), filter: "@tag1:{foo}")), Scenario.VectorWithNumericFilter => query.VectorSearch(new("@vector1", VectorData.Raw(vec), diff --git a/tests/NRedisStack.Tests/Search/HybridSearchUnitTests.cs b/tests/NRedisStack.Tests/Search/HybridSearchUnitTests.cs index f131cfd3..4d8a2f9d 100644 --- a/tests/NRedisStack.Tests/Search/HybridSearchUnitTests.cs +++ b/tests/NRedisStack.Tests/Search/HybridSearchUnitTests.cs @@ -174,33 +174,50 @@ public void BasicNonZeroLengthVectorSearch() } [Theory] - [InlineData(false, false)] - [InlineData(false, true)] - [InlineData(true, false)] - [InlineData(true, true)] - public void BasicVectorSearch_WithKNN(bool withScoreAlias, bool withDistanceAlias) + [InlineData(false, false, false)] + [InlineData(false, true, false)] + [InlineData(true, false, false)] + [InlineData(true, true, false)] + [InlineData(false, false, true)] + [InlineData(false, true, true)] + [InlineData(true, false, true)] + [InlineData(true, true, true)] + public void BasicVectorSearch_WithKNN(bool withScoreAlias, bool withDistanceAlias, bool withShardRatio) { HybridSearchQuery query = new(); var searchConfig = new HybridSearchQuery.VectorSearchConfig("vField", SomeRandomDataHere); if (withScoreAlias) searchConfig = searchConfig.WithScoreAlias("my_score_alias"); - searchConfig = searchConfig.WithMethod(VectorSearchMethod.NearestNeighbour(null, null, - distanceAlias: withDistanceAlias ? "my_distance_alias" : null)); + var method = withShardRatio + ? VectorSearchMethod.NearestNeighbour(null, null, distanceAlias: withDistanceAlias ? "my_distance_alias" : null, shardRatio: 0.5) + : VectorSearchMethod.NearestNeighbour(null, null, distanceAlias: withDistanceAlias ? "my_distance_alias" : null); + searchConfig = searchConfig.WithMethod(method); query.VectorSearch(searchConfig); + var knnCount = 2; + if (withDistanceAlias) knnCount += 2; + if (withShardRatio) knnCount += 2; object[] expected = - [Index, "VSIM", "vField", "$v", "KNN", withDistanceAlias ? 4 : 2, "K", 10]; + [Index, "VSIM", "vField", "$v", "KNN", knnCount, "K", 10]; if (withDistanceAlias) { expected = [.. expected, "YIELD_DISTANCE_AS", "my_distance_alias"]; } + if (withShardRatio) + { + expected = [.. expected, "SHARD_K_RATIO", 0.5]; + } + if (withScoreAlias) { expected = [.. expected, "YIELD_SCORE_AS", "my_score_alias"]; } expected = [.. expected, "PARAMS", 2, "v", SomeRandomVectorValue]; - Assert.Equivalent(expected, GetArgs(query)); + var args = GetArgs(query); + log?.WriteLine($"expected: {string.Join(" ", expected)}"); + log?.WriteLine($"actual: {string.Join(" ", args)}"); + Assert.Equivalent(expected, args); } [Theory] diff --git a/tests/NRedisStack.Tests/SkipIfRedisTheoryAttribute.cs b/tests/NRedisStack.Tests/SkipIfRedisTheoryAttribute.cs index 1f34c793..8906df82 100644 --- a/tests/NRedisStack.Tests/SkipIfRedisTheoryAttribute.cs +++ b/tests/NRedisStack.Tests/SkipIfRedisTheoryAttribute.cs @@ -128,6 +128,9 @@ protected override ValueTask> CreateTestCase [XunitTestCaseDiscoverer(typeof(ExpandingTheoryDiscoverer))] public class SkipIfRedisTheoryAttribute : TheoryAttribute { + // historically, this tests against an env arg; however, the AssertVersion(db) + // method can be used to check against the actual discovered/declared server version; + // to avoid xunit using the env check (because you're using AssertVersion), add deferVersionCheck: true internal SkipIfRedisCore Core { get; } public SkipIfRedisTheoryAttribute( @@ -135,28 +138,31 @@ public SkipIfRedisTheoryAttribute( Comparison comparison = Comparison.LessThan, string targetVersion = "0.0.0", [CallerFilePath] string? sourceFilePath = null, - [CallerLineNumber] int sourceLineNumber = -1) : base(sourceFilePath, sourceLineNumber) + [CallerLineNumber] int sourceLineNumber = -1, + bool deferVersionCheck = false) : base(sourceFilePath, sourceLineNumber) { Core = new(environment, comparison, targetVersion); - Skip = Core.Skip; + if (!deferVersionCheck) Skip = Core.Skip; } public SkipIfRedisTheoryAttribute( string targetVersion, [CallerFilePath] string? sourceFilePath = null, - [CallerLineNumber] int sourceLineNumber = -1) : base(sourceFilePath, sourceLineNumber) // defaults to LessThan + [CallerLineNumber] int sourceLineNumber = -1, + bool deferVersionCheck = false) : base(sourceFilePath, sourceLineNumber) // defaults to LessThan { Core = new(targetVersion); - Skip = Core.Skip; + if (!deferVersionCheck) Skip = Core.Skip; } public SkipIfRedisTheoryAttribute( Comparison comparison, string targetVersion, [CallerFilePath] string? sourceFilePath = null, - [CallerLineNumber] int sourceLineNumber = -1) : base(sourceFilePath, sourceLineNumber) + [CallerLineNumber] int sourceLineNumber = -1, + bool deferVersionCheck = false) : base(sourceFilePath, sourceLineNumber) { Core = new(comparison, targetVersion); - Skip = Core.Skip; + if (!deferVersionCheck) Skip = Core.Skip; } } @@ -184,15 +190,20 @@ public override ValueTask> Discover(ITestFra [XunitTestCaseDiscoverer(typeof(ExpandingFactDiscoverer))] public class SkipIfRedisFactAttribute : FactAttribute { + // historically, this tests against an env arg; however, the AssertVersion(db) + // method can be used to check against the actual discovered/declared server version; + // to avoid xunit using the env check (because you're using AssertVersion), add deferVersionCheck: true + public SkipIfRedisFactAttribute( Is environment, Comparison comparison = Comparison.LessThan, string targetVersion = "0.0.0", [CallerFilePath] string? sourceFilePath = null, - [CallerLineNumber] int sourceLineNumber = -1) : base(sourceFilePath, sourceLineNumber) + [CallerLineNumber] int sourceLineNumber = -1, + bool deferVersionCheck = false) : base(sourceFilePath, sourceLineNumber) { Core = new(environment, comparison, targetVersion); - Skip = Core.Skip; + if (!deferVersionCheck) Skip = Core.Skip; } internal SkipIfRedisCore Core { get; } @@ -200,20 +211,22 @@ public SkipIfRedisFactAttribute( public SkipIfRedisFactAttribute( // defaults to LessThan string targetVersion, [CallerFilePath] string? sourceFilePath = null, - [CallerLineNumber] int sourceLineNumber = -1) : base(sourceFilePath, sourceLineNumber) + [CallerLineNumber] int sourceLineNumber = -1, + bool deferVersionCheck = false) : base(sourceFilePath, sourceLineNumber) { Core = new(targetVersion); - Skip = Core.Skip; + if (!deferVersionCheck) Skip = Core.Skip; } public SkipIfRedisFactAttribute( Comparison comparison, string targetVersion, [CallerFilePath] string? sourceFilePath = null, - [CallerLineNumber] int sourceLineNumber = -1) : base(sourceFilePath, sourceLineNumber) + [CallerLineNumber] int sourceLineNumber = -1, + bool deferVersionCheck = false) : base(sourceFilePath, sourceLineNumber) { Core = new(comparison, targetVersion); - Skip = Core.Skip; + if (!deferVersionCheck) Skip = Core.Skip; } }