diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3457e7f..6f8bf55 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -5,6 +5,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + - uses: actions/setup-dotnet@v3 - uses: ankane/setup-postgres@v1 with: database: pgvector_dotnet_test diff --git a/global.json b/global.json new file mode 100644 index 0000000..817f0c3 --- /dev/null +++ b/global.json @@ -0,0 +1,7 @@ +{ + "sdk": { + "version": "8.0.100-rc.2.23502.2", + "rollForward": "latestMajor", + "allowPrerelease": "true" + } +} diff --git a/src/Pgvector.EntityFrameworkCore/Pgvector.EntityFrameworkCore.csproj b/src/Pgvector.EntityFrameworkCore/Pgvector.EntityFrameworkCore.csproj index f8ee1ec..7e6f0c5 100644 --- a/src/Pgvector.EntityFrameworkCore/Pgvector.EntityFrameworkCore.csproj +++ b/src/Pgvector.EntityFrameworkCore/Pgvector.EntityFrameworkCore.csproj @@ -9,7 +9,7 @@ https://github.com/pgvector/pgvector-dotnet README.md - net6.0 + net8.0 enable enable latest @@ -19,7 +19,7 @@ - + diff --git a/src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs b/src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs index 1621f59..6ca564f 100644 --- a/src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs +++ b/src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs @@ -4,11 +4,11 @@ namespace Pgvector.EntityFrameworkCore; -public class VectorTypeMapping : NpgsqlTypeMapping +public class VectorTypeMapping : RelationalTypeMapping { - public VectorTypeMapping(string storeType) : base(storeType, typeof(Vector), NpgsqlDbType.Unknown) { } + public VectorTypeMapping(string storeType) : base(storeType, typeof(Vector)) { } - protected VectorTypeMapping(RelationalTypeMappingParameters parameters) : base(parameters, NpgsqlDbType.Unknown) { } + protected VectorTypeMapping(RelationalTypeMappingParameters parameters) : base(parameters) { } protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters parameters) => new VectorTypeMapping(parameters); diff --git a/src/Pgvector/Npgsql/VectorConverter.cs b/src/Pgvector/Npgsql/VectorConverter.cs new file mode 100644 index 0000000..883bb7f --- /dev/null +++ b/src/Pgvector/Npgsql/VectorConverter.cs @@ -0,0 +1,91 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal; + +namespace Pgvector.Npgsql; + +public class VectorConverter : PgStreamingConverter +{ + public override Vector Read(PgReader reader) + { + if (reader.ShouldBuffer(2 * sizeof(ushort))) + reader.Buffer(2 * sizeof(ushort)); + + var dim = reader.ReadUInt16(); + var unused = reader.ReadUInt16(); + if (unused != 0) + throw new InvalidCastException("expected unused to be 0"); + + var vec = new float[dim]; + for (var i = 0; i < dim; i++) + { + if (reader.ShouldBuffer(sizeof(float))) + reader.Buffer(sizeof(float)); + vec[i] = reader.ReadFloat(); + } + + return new Vector(vec); + } + + public override async ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + { + if (reader.ShouldBuffer(2 * sizeof(ushort))) + await reader.BufferAsync(2 * sizeof(ushort), cancellationToken).ConfigureAwait(false); + + var dim = reader.ReadUInt16(); + var unused = reader.ReadUInt16(); + if (unused != 0) + throw new InvalidCastException("expected unused to be 0"); + + var vec = new float[dim]; + for (var i = 0; i < dim; i++) + { + if (reader.ShouldBuffer(sizeof(float))) + await reader.BufferAsync(sizeof(float), cancellationToken).ConfigureAwait(false); + vec[i] = reader.ReadFloat(); + } + + return new Vector(vec); + } + + public override Size GetSize(SizeContext context, Vector value, ref object? writeState) + => sizeof(ushort) * 2 + sizeof(float) * value.ToArray().Length; + + public override void Write(PgWriter writer, Vector value) + { + if (writer.ShouldFlush(sizeof(ushort) * 2)) + writer.Flush(); + + var vec = value.ToArray(); + var dim = vec.Length; + writer.WriteUInt16(Convert.ToUInt16(dim)); + writer.WriteUInt16(0); + + for (int i = 0; i < dim; i++) + { + if (writer.ShouldFlush(sizeof(float))) + writer.Flush(); + writer.WriteFloat(vec[i]); + } + } + + public override async ValueTask WriteAsync( + PgWriter writer, Vector value, CancellationToken cancellationToken = default) + { + if (writer.ShouldFlush(sizeof(ushort) * 2)) + await writer.FlushAsync(cancellationToken); + + var vec = value.ToArray(); + var dim = vec.Length; + writer.WriteUInt16(Convert.ToUInt16(dim)); + writer.WriteUInt16(0); + + for (int i = 0; i < dim; i++) + { + if (writer.ShouldFlush(sizeof(float))) + await writer.FlushAsync(cancellationToken); + writer.WriteFloat(vec[i]); + } + } +} diff --git a/src/Pgvector/Npgsql/VectorExtensions.cs b/src/Pgvector/Npgsql/VectorExtensions.cs index c06925d..ccfbe21 100644 --- a/src/Pgvector/Npgsql/VectorExtensions.cs +++ b/src/Pgvector/Npgsql/VectorExtensions.cs @@ -6,7 +6,7 @@ public static class VectorExtensions { public static INpgsqlTypeMapper UseVector(this INpgsqlTypeMapper mapper) { - mapper.AddTypeResolverFactory(new VectorTypeHandlerResolverFactory()); + mapper.AddTypeInfoResolver(new VectorTypeInfoResolver()); return mapper; } } diff --git a/src/Pgvector/Npgsql/VectorHandler.cs b/src/Pgvector/Npgsql/VectorHandler.cs deleted file mode 100644 index fc29021..0000000 --- a/src/Pgvector/Npgsql/VectorHandler.cs +++ /dev/null @@ -1,77 +0,0 @@ -using Npgsql; -using Npgsql.BackendMessages; -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using System; -using System.Threading; -using System.Threading.Tasks; - -namespace Pgvector.Npgsql; - -public class VectorHandler : NpgsqlTypeHandler -{ - public VectorHandler(PostgresType pgType) : base(pgType) { } - - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(2 * sizeof(ushort), async); - var dim = buf.ReadUInt16(); - var unused = buf.ReadUInt16(); - if (unused != 0) - throw new InvalidCastException("expected unused to be 0"); - - var vec = new float[dim]; - for (var i = 0; i < dim; i++) - { - await buf.Ensure(sizeof(float), async); - vec[i] = buf.ReadSingle(); - } - - return new Vector(vec); - } - - public override int ValidateAndGetLength(Vector value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => sizeof(ushort) * 2 + sizeof(float) * value.ToArray().Length; - - public override async Task Write( - Vector value, - NpgsqlWriteBuffer buf, - NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, - bool async, - CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < sizeof(ushort) * 2) - await buf.Flush(async, cancellationToken); - - var vec = value.ToArray(); - var dim = vec.Length; - buf.WriteUInt16(Convert.ToUInt16(dim)); - buf.WriteUInt16(0); - - for (int i = 0; i < dim; i++) - { - if (buf.WriteSpaceLeft < sizeof(float)) - await buf.Flush(async, cancellationToken); - buf.WriteSingle(vec[i]); - } - } - - public override int ValidateObjectAndGetLength(object? value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value switch - { - Vector converted => ValidateAndGetLength(converted, ref lengthCache, parameter), - DBNull or null => 0, - _ => throw new InvalidCastException($"Can't write CLR type {value.GetType()} with handler type VectorHandler") - }; - - public override Task WriteObjectWithLength(object? value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value switch - { - Vector converted => WriteWithLength(converted, buf, lengthCache, parameter, async, cancellationToken), - DBNull or null => WriteWithLength(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken), - _ => throw new InvalidCastException( - $"Can't write CLR type {value.GetType()} with handler type VectorHandler") - }; -} diff --git a/src/Pgvector/Npgsql/VectorTypeHandlerResolver.cs b/src/Pgvector/Npgsql/VectorTypeHandlerResolver.cs deleted file mode 100644 index 020b973..0000000 --- a/src/Pgvector/Npgsql/VectorTypeHandlerResolver.cs +++ /dev/null @@ -1,52 +0,0 @@ -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using Npgsql.PostgresTypes; -using NpgsqlTypes; -using System; - -namespace Pgvector.Npgsql; - -public class VectorTypeHandlerResolver : TypeHandlerResolver -{ - readonly NpgsqlDatabaseInfo _databaseInfo; - readonly VectorHandler? _vectorHandler; - - internal VectorTypeHandlerResolver(NpgsqlConnector connector) - { - _databaseInfo = connector.DatabaseInfo; - - var pgVectorType = PgType("vector"); - if (pgVectorType != null) - { - _vectorHandler = new VectorHandler(pgVectorType); - } - } - - public override NpgsqlTypeHandler? ResolveByDataTypeName(string typeName) - => typeName == "vector" ? _vectorHandler : null; - - public override NpgsqlTypeHandler? ResolveByClrType(Type type) - { - var dataTypeName = ClrTypeToDataTypeName(type); - if (dataTypeName != null) - { - var handler = ResolveByDataTypeName(dataTypeName); - if (handler != null) - return handler; - } - - return null; - } - - public override TypeMappingInfo? GetMappingByDataTypeName(string dataTypeName) - => DoGetMappingByDataTypeName(dataTypeName); - - internal static string? ClrTypeToDataTypeName(Type type) - => type == typeof(Vector) ? "vector" : null; - - internal static TypeMappingInfo? DoGetMappingByDataTypeName(string dataTypeName) - => dataTypeName == "vector" ? new TypeMappingInfo(NpgsqlDbType.Unknown, "vector") : null; - - PostgresType? PgType(string pgTypeName) - => _databaseInfo.TryGetPostgresTypeByName(pgTypeName, out var pgType) ? pgType : null; -} diff --git a/src/Pgvector/Npgsql/VectorTypeHandlerResolverFactory.cs b/src/Pgvector/Npgsql/VectorTypeHandlerResolverFactory.cs deleted file mode 100644 index 273a899..0000000 --- a/src/Pgvector/Npgsql/VectorTypeHandlerResolverFactory.cs +++ /dev/null @@ -1,17 +0,0 @@ -using Npgsql.Internal; -using Npgsql.Internal.TypeHandling; -using System; - -namespace Pgvector.Npgsql; - -public class VectorTypeHandlerResolverFactory : TypeHandlerResolverFactory -{ - public override TypeHandlerResolver Create(NpgsqlConnector connector) - => new VectorTypeHandlerResolver(connector); - - public override string? GetDataTypeNameByClrType(Type type) - => VectorTypeHandlerResolver.ClrTypeToDataTypeName(type); - - public override TypeMappingInfo? GetMappingByDataTypeName(string dataTypeName) - => VectorTypeHandlerResolver.DoGetMappingByDataTypeName(dataTypeName); -} diff --git a/src/Pgvector/Npgsql/VectorTypeInfoResolver.cs b/src/Pgvector/Npgsql/VectorTypeInfoResolver.cs new file mode 100644 index 0000000..c12ead2 --- /dev/null +++ b/src/Pgvector/Npgsql/VectorTypeInfoResolver.cs @@ -0,0 +1,28 @@ +using System; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; + +namespace Pgvector.Npgsql; + +public class VectorTypeInfoResolver : IPgTypeInfoResolver +{ + TypeInfoMappingCollection Mappings { get; } + + public VectorTypeInfoResolver() + { + Mappings = new TypeInfoMappingCollection(); + AddInfos(Mappings); + // TODO: Opt-in only + AddArrayInfos(Mappings); + } + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static void AddInfos(TypeInfoMappingCollection mappings) + => mappings.AddType("vector", + static (options, mapping, _) => mapping.CreateInfo(options, new VectorConverter()), isDefault: true); + + static void AddArrayInfos(TypeInfoMappingCollection mappings) + => mappings.AddArrayType("vector"); +} diff --git a/src/Pgvector/Pgvector.csproj b/src/Pgvector/Pgvector.csproj index ce4660f..208cf96 100644 --- a/src/Pgvector/Pgvector.csproj +++ b/src/Pgvector/Pgvector.csproj @@ -17,7 +17,7 @@ - + diff --git a/tests/Pgvector.Tests/EntityFrameworkCoreTests.cs b/tests/Pgvector.Tests/EntityFrameworkCoreTests.cs index f75627a..9a263f2 100644 --- a/tests/Pgvector.Tests/EntityFrameworkCoreTests.cs +++ b/tests/Pgvector.Tests/EntityFrameworkCoreTests.cs @@ -11,7 +11,7 @@ public class ItemContext : DbContext protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) { var connString = "Host=localhost;Database=pgvector_dotnet_test"; - optionsBuilder.UseNpgsql(connString, o => o.UseVector()).UseSnakeCaseNamingConvention(); + optionsBuilder.UseNpgsql(connString, o => o.UseVector()); } protected override void OnModelCreating(ModelBuilder modelBuilder) @@ -30,7 +30,7 @@ public class Item { public int Id { get; set; } - [Column(TypeName = "vector(3)")] + [Column("embedding", TypeName = "vector(3)")] public Vector? Embedding { get; set; } } diff --git a/tests/Pgvector.Tests/Pgvector.Tests.csproj b/tests/Pgvector.Tests/Pgvector.Tests.csproj index 5368732..a23e1d7 100644 --- a/tests/Pgvector.Tests/Pgvector.Tests.csproj +++ b/tests/Pgvector.Tests/Pgvector.Tests.csproj @@ -1,7 +1,7 @@ - net7.0 + net8.0 enable enable