diff --git a/src/Pgvector.EntityFrameworkCore/CHANGELOG.md b/src/Pgvector.EntityFrameworkCore/CHANGELOG.md index f8713b0..1b69831 100644 --- a/src/Pgvector.EntityFrameworkCore/CHANGELOG.md +++ b/src/Pgvector.EntityFrameworkCore/CHANGELOG.md @@ -1,5 +1,6 @@ ## 0.2.1 (unreleased) +- Added support for `halfvec` type - Added support for compiled models ## 0.2.0 (2023-11-24) diff --git a/src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs b/src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs new file mode 100644 index 0000000..d1d4bfa --- /dev/null +++ b/src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs @@ -0,0 +1,19 @@ +using Microsoft.EntityFrameworkCore.Storage; +using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping; +using NpgsqlTypes; + +namespace Pgvector.EntityFrameworkCore; + +public class HalfvecTypeMapping : RelationalTypeMapping +{ + public static HalfvecTypeMapping Default { get; } = new(); + + public HalfvecTypeMapping() : base("halfvec", typeof(HalfVector)) { } + + public HalfvecTypeMapping(string storeType) : base(storeType, typeof(HalfVector)) { } + + protected HalfvecTypeMapping(RelationalTypeMappingParameters parameters) : base(parameters) { } + + protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters parameters) + => new HalfvecTypeMapping(parameters); +} diff --git a/src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs b/src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs new file mode 100644 index 0000000..663bf71 --- /dev/null +++ b/src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs @@ -0,0 +1,11 @@ +using Microsoft.EntityFrameworkCore.Storage; + +namespace Pgvector.EntityFrameworkCore; + +public class HalfvecTypeMappingSourcePlugin : IRelationalTypeMappingSourcePlugin +{ + public RelationalTypeMapping? FindMapping(in RelationalTypeMappingInfo mappingInfo) + => mappingInfo.ClrType == typeof(HalfVector) + ? new HalfvecTypeMapping(mappingInfo.StoreTypeName ?? "halfvec") + : null; +} diff --git a/src/Pgvector.EntityFrameworkCore/VectorDbContextOptionsExtension.cs b/src/Pgvector.EntityFrameworkCore/VectorDbContextOptionsExtension.cs index 570e4dc..96c4225 100644 --- a/src/Pgvector.EntityFrameworkCore/VectorDbContextOptionsExtension.cs +++ b/src/Pgvector.EntityFrameworkCore/VectorDbContextOptionsExtension.cs @@ -17,6 +17,7 @@ public void ApplyServices(IServiceCollection services) .TryAdd(); services.AddSingleton(); + services.AddSingleton(); } public void Validate(IDbContextOptions options) { } diff --git a/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs b/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs index 0523e26..1846057 100644 --- a/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs +++ b/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs @@ -36,6 +36,9 @@ public class Item [Column("embedding", TypeName = "vector(3)")] public Vector? Embedding { get; set; } + + [Column("half_embedding", TypeName = "halfvec(3)")] + public HalfVector? HalfEmbedding { get; set; } } public class EntityFrameworkCoreTests @@ -49,15 +52,16 @@ public async Task Main() var databaseCreator = ctx.GetService(); databaseCreator.CreateTables(); - ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 1 }) }); - ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 2, 2, 2 }) }); - ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 2 }) }); + ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 1 }), HalfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 }) }); + ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 2, 2, 2 }), HalfEmbedding = new HalfVector(new Half[] { (Half)2, (Half)2, (Half)2 }) }); + ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 2 }), HalfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)2 }) }); ctx.SaveChanges(); var embedding = new Vector(new float[] { 1, 1, 1 }); var items = await ctx.Items.FromSql($"SELECT * FROM efcore_items ORDER BY embedding <-> {embedding} LIMIT 5").ToListAsync(); Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray()); + Assert.Equal(new Half[] { (Half)1, (Half)1, (Half)1 }, items[0].HalfEmbedding!.ToArray()); items = await ctx.Items.OrderBy(x => x.Embedding!.L2Distance(embedding)).Take(5).ToListAsync(); Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());