Skip to content

Commit

Permalink
Added support for halfvec type to EF Core
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 17, 2024
1 parent f4282b3 commit 7f6d0e2
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/Pgvector.EntityFrameworkCore/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## 0.2.1 (unreleased)

- Added support for `halfvec` type
- Added support for compiled models

## 0.2.0 (2023-11-24)
Expand Down
19 changes: 19 additions & 0 deletions src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using Microsoft.EntityFrameworkCore.Storage;
using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping;
using NpgsqlTypes;

namespace Pgvector.EntityFrameworkCore;

public class HalfvecTypeMapping : RelationalTypeMapping
{
public static HalfvecTypeMapping Default { get; } = new();

public HalfvecTypeMapping() : base("halfvec", typeof(HalfVector)) { }

public HalfvecTypeMapping(string storeType) : base(storeType, typeof(HalfVector)) { }

protected HalfvecTypeMapping(RelationalTypeMappingParameters parameters) : base(parameters) { }

protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters parameters)
=> new HalfvecTypeMapping(parameters);
}
11 changes: 11 additions & 0 deletions src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using Microsoft.EntityFrameworkCore.Storage;

namespace Pgvector.EntityFrameworkCore;

public class HalfvecTypeMappingSourcePlugin : IRelationalTypeMappingSourcePlugin
{
public RelationalTypeMapping? FindMapping(in RelationalTypeMappingInfo mappingInfo)
=> mappingInfo.ClrType == typeof(HalfVector)
? new HalfvecTypeMapping(mappingInfo.StoreTypeName ?? "halfvec")
: null;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public void ApplyServices(IServiceCollection services)
.TryAdd<IMethodCallTranslatorPlugin, VectorDbFunctionsTranslatorPlugin>();

services.AddSingleton<IRelationalTypeMappingSourcePlugin, VectorTypeMappingSourcePlugin>();
services.AddSingleton<IRelationalTypeMappingSourcePlugin, HalfvecTypeMappingSourcePlugin>();
}

public void Validate(IDbContextOptions options) { }
Expand Down
10 changes: 7 additions & 3 deletions tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ public class Item

[Column("embedding", TypeName = "vector(3)")]
public Vector? Embedding { get; set; }

[Column("half_embedding", TypeName = "halfvec(3)")]
public HalfVector? HalfEmbedding { get; set; }
}

public class EntityFrameworkCoreTests
Expand All @@ -49,15 +52,16 @@ public async Task Main()
var databaseCreator = ctx.GetService<IRelationalDatabaseCreator>();
databaseCreator.CreateTables();

ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 1 }) });
ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 2, 2, 2 }) });
ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 2 }) });
ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 1 }), HalfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 }) });
ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 2, 2, 2 }), HalfEmbedding = new HalfVector(new Half[] { (Half)2, (Half)2, (Half)2 }) });
ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 2 }), HalfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)2 }) });
ctx.SaveChanges();

var embedding = new Vector(new float[] { 1, 1, 1 });
var items = await ctx.Items.FromSql($"SELECT * FROM efcore_items ORDER BY embedding <-> {embedding} LIMIT 5").ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray());
Assert.Equal(new Half[] { (Half)1, (Half)1, (Half)1 }, items[0].HalfEmbedding!.ToArray());

items = await ctx.Items.OrderBy(x => x.Embedding!.L2Distance(embedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
Expand Down

0 comments on commit 7f6d0e2

Please sign in to comment.