diff --git a/src/Pgvector.EntityFrameworkCore/CHANGELOG.md b/src/Pgvector.EntityFrameworkCore/CHANGELOG.md index f8713b0..a6a8f70 100644 --- a/src/Pgvector.EntityFrameworkCore/CHANGELOG.md +++ b/src/Pgvector.EntityFrameworkCore/CHANGELOG.md @@ -1,6 +1,7 @@ ## 0.2.1 (unreleased) - Added support for compiled models +- Added `L1Distance` function ## 0.2.0 (2023-11-24) diff --git a/src/Pgvector.EntityFrameworkCore/VectorDbFunctionsExtensions.cs b/src/Pgvector.EntityFrameworkCore/VectorDbFunctionsExtensions.cs index 9cce058..131df01 100644 --- a/src/Pgvector.EntityFrameworkCore/VectorDbFunctionsExtensions.cs +++ b/src/Pgvector.EntityFrameworkCore/VectorDbFunctionsExtensions.cs @@ -12,4 +12,7 @@ public static double MaxInnerProduct(this Vector a, Vector b) public static double CosineDistance(this Vector a, Vector b) => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(CosineDistance))); + + public static double L1Distance(this Vector a, Vector b) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(L1Distance))); } diff --git a/src/Pgvector.EntityFrameworkCore/VectorDbFunctionsTranslatorPlugin.cs b/src/Pgvector.EntityFrameworkCore/VectorDbFunctionsTranslatorPlugin.cs index 02d1bf4..b19d00a 100644 --- a/src/Pgvector.EntityFrameworkCore/VectorDbFunctionsTranslatorPlugin.cs +++ b/src/Pgvector.EntityFrameworkCore/VectorDbFunctionsTranslatorPlugin.cs @@ -49,6 +49,13 @@ private class VectorDbFunctionsTranslator : IMethodCallTranslator typeof(Vector), })!; + private static readonly MethodInfo _methodL1Distance = typeof(VectorDbFunctionsExtensions) + .GetRuntimeMethod(nameof(VectorDbFunctionsExtensions.L1Distance), new[] + { + typeof(Vector), + typeof(Vector), + })!; + public VectorDbFunctionsTranslator( ISqlExpressionFactory sqlExpressionFactory, IRelationalTypeMappingSource typeMappingSource @@ -71,6 +78,7 @@ IRelationalTypeMappingSource typeMappingSource _ when ReferenceEquals(method, _methodL2Distance) => "<->", _ when ReferenceEquals(method, _methodMaxInnerProduct) => "<#>", _ when ReferenceEquals(method, _methodCosineDistance) => "<=>", + _ when ReferenceEquals(method, _methodL1Distance) => "<+>", _ => null }; diff --git a/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs b/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs index 0523e26..2de7ec8 100644 --- a/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs +++ b/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs @@ -69,6 +69,9 @@ public async Task Main() items = await ctx.Items.OrderBy(x => x.Embedding!.CosineDistance(embedding)).Take(5).ToListAsync(); Assert.Equal(3, items[2].Id); + items = await ctx.Items.OrderBy(x => x.Embedding!.L1Distance(embedding)).Take(5).ToListAsync(); + Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); + items = await ctx.Items .OrderBy(x => x.Id) .Where(x => x.Embedding!.L2Distance(embedding) < 1.5)