diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 3457e7f..6f8bf55 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -5,6 +5,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
+ - uses: actions/setup-dotnet@v3
- uses: ankane/setup-postgres@v1
with:
database: pgvector_dotnet_test
diff --git a/global.json b/global.json
new file mode 100644
index 0000000..817f0c3
--- /dev/null
+++ b/global.json
@@ -0,0 +1,7 @@
+{
+ "sdk": {
+ "version": "8.0.100-rc.2.23502.2",
+ "rollForward": "latestMajor",
+ "allowPrerelease": "true"
+ }
+}
diff --git a/src/Pgvector.EntityFrameworkCore/Pgvector.EntityFrameworkCore.csproj b/src/Pgvector.EntityFrameworkCore/Pgvector.EntityFrameworkCore.csproj
index f8ee1ec..7e6f0c5 100644
--- a/src/Pgvector.EntityFrameworkCore/Pgvector.EntityFrameworkCore.csproj
+++ b/src/Pgvector.EntityFrameworkCore/Pgvector.EntityFrameworkCore.csproj
@@ -9,7 +9,7 @@
https://github.com/pgvector/pgvector-dotnet
README.md
- net6.0
+ net8.0
enable
enable
latest
@@ -19,7 +19,7 @@
-
+
diff --git a/src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs b/src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs
index 1621f59..6ca564f 100644
--- a/src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs
+++ b/src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs
@@ -4,11 +4,11 @@
namespace Pgvector.EntityFrameworkCore;
-public class VectorTypeMapping : NpgsqlTypeMapping
+public class VectorTypeMapping : RelationalTypeMapping
{
- public VectorTypeMapping(string storeType) : base(storeType, typeof(Vector), NpgsqlDbType.Unknown) { }
+ public VectorTypeMapping(string storeType) : base(storeType, typeof(Vector)) { }
- protected VectorTypeMapping(RelationalTypeMappingParameters parameters) : base(parameters, NpgsqlDbType.Unknown) { }
+ protected VectorTypeMapping(RelationalTypeMappingParameters parameters) : base(parameters) { }
protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters parameters)
=> new VectorTypeMapping(parameters);
diff --git a/src/Pgvector/Npgsql/VectorConverter.cs b/src/Pgvector/Npgsql/VectorConverter.cs
new file mode 100644
index 0000000..883bb7f
--- /dev/null
+++ b/src/Pgvector/Npgsql/VectorConverter.cs
@@ -0,0 +1,91 @@
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using Npgsql.Internal;
+
+namespace Pgvector.Npgsql;
+
+public class VectorConverter : PgStreamingConverter
+{
+ public override Vector Read(PgReader reader)
+ {
+ if (reader.ShouldBuffer(2 * sizeof(ushort)))
+ reader.Buffer(2 * sizeof(ushort));
+
+ var dim = reader.ReadUInt16();
+ var unused = reader.ReadUInt16();
+ if (unused != 0)
+ throw new InvalidCastException("expected unused to be 0");
+
+ var vec = new float[dim];
+ for (var i = 0; i < dim; i++)
+ {
+ if (reader.ShouldBuffer(sizeof(float)))
+ reader.Buffer(sizeof(float));
+ vec[i] = reader.ReadFloat();
+ }
+
+ return new Vector(vec);
+ }
+
+ public override async ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default)
+ {
+ if (reader.ShouldBuffer(2 * sizeof(ushort)))
+ await reader.BufferAsync(2 * sizeof(ushort), cancellationToken).ConfigureAwait(false);
+
+ var dim = reader.ReadUInt16();
+ var unused = reader.ReadUInt16();
+ if (unused != 0)
+ throw new InvalidCastException("expected unused to be 0");
+
+ var vec = new float[dim];
+ for (var i = 0; i < dim; i++)
+ {
+ if (reader.ShouldBuffer(sizeof(float)))
+ await reader.BufferAsync(sizeof(float), cancellationToken).ConfigureAwait(false);
+ vec[i] = reader.ReadFloat();
+ }
+
+ return new Vector(vec);
+ }
+
+ public override Size GetSize(SizeContext context, Vector value, ref object? writeState)
+ => sizeof(ushort) * 2 + sizeof(float) * value.ToArray().Length;
+
+ public override void Write(PgWriter writer, Vector value)
+ {
+ if (writer.ShouldFlush(sizeof(ushort) * 2))
+ writer.Flush();
+
+ var vec = value.ToArray();
+ var dim = vec.Length;
+ writer.WriteUInt16(Convert.ToUInt16(dim));
+ writer.WriteUInt16(0);
+
+ for (int i = 0; i < dim; i++)
+ {
+ if (writer.ShouldFlush(sizeof(float)))
+ writer.Flush();
+ writer.WriteFloat(vec[i]);
+ }
+ }
+
+ public override async ValueTask WriteAsync(
+ PgWriter writer, Vector value, CancellationToken cancellationToken = default)
+ {
+ if (writer.ShouldFlush(sizeof(ushort) * 2))
+ await writer.FlushAsync(cancellationToken);
+
+ var vec = value.ToArray();
+ var dim = vec.Length;
+ writer.WriteUInt16(Convert.ToUInt16(dim));
+ writer.WriteUInt16(0);
+
+ for (int i = 0; i < dim; i++)
+ {
+ if (writer.ShouldFlush(sizeof(float)))
+ await writer.FlushAsync(cancellationToken);
+ writer.WriteFloat(vec[i]);
+ }
+ }
+}
diff --git a/src/Pgvector/Npgsql/VectorExtensions.cs b/src/Pgvector/Npgsql/VectorExtensions.cs
index c06925d..ccfbe21 100644
--- a/src/Pgvector/Npgsql/VectorExtensions.cs
+++ b/src/Pgvector/Npgsql/VectorExtensions.cs
@@ -6,7 +6,7 @@ public static class VectorExtensions
{
public static INpgsqlTypeMapper UseVector(this INpgsqlTypeMapper mapper)
{
- mapper.AddTypeResolverFactory(new VectorTypeHandlerResolverFactory());
+ mapper.AddTypeInfoResolver(new VectorTypeInfoResolver());
return mapper;
}
}
diff --git a/src/Pgvector/Npgsql/VectorHandler.cs b/src/Pgvector/Npgsql/VectorHandler.cs
deleted file mode 100644
index fc29021..0000000
--- a/src/Pgvector/Npgsql/VectorHandler.cs
+++ /dev/null
@@ -1,77 +0,0 @@
-using Npgsql;
-using Npgsql.BackendMessages;
-using Npgsql.Internal;
-using Npgsql.Internal.TypeHandling;
-using Npgsql.PostgresTypes;
-using System;
-using System.Threading;
-using System.Threading.Tasks;
-
-namespace Pgvector.Npgsql;
-
-public class VectorHandler : NpgsqlTypeHandler
-{
- public VectorHandler(PostgresType pgType) : base(pgType) { }
-
- public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null)
- {
- await buf.Ensure(2 * sizeof(ushort), async);
- var dim = buf.ReadUInt16();
- var unused = buf.ReadUInt16();
- if (unused != 0)
- throw new InvalidCastException("expected unused to be 0");
-
- var vec = new float[dim];
- for (var i = 0; i < dim; i++)
- {
- await buf.Ensure(sizeof(float), async);
- vec[i] = buf.ReadSingle();
- }
-
- return new Vector(vec);
- }
-
- public override int ValidateAndGetLength(Vector value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter)
- => sizeof(ushort) * 2 + sizeof(float) * value.ToArray().Length;
-
- public override async Task Write(
- Vector value,
- NpgsqlWriteBuffer buf,
- NpgsqlLengthCache? lengthCache,
- NpgsqlParameter? parameter,
- bool async,
- CancellationToken cancellationToken = default)
- {
- if (buf.WriteSpaceLeft < sizeof(ushort) * 2)
- await buf.Flush(async, cancellationToken);
-
- var vec = value.ToArray();
- var dim = vec.Length;
- buf.WriteUInt16(Convert.ToUInt16(dim));
- buf.WriteUInt16(0);
-
- for (int i = 0; i < dim; i++)
- {
- if (buf.WriteSpaceLeft < sizeof(float))
- await buf.Flush(async, cancellationToken);
- buf.WriteSingle(vec[i]);
- }
- }
-
- public override int ValidateObjectAndGetLength(object? value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter)
- => value switch
- {
- Vector converted => ValidateAndGetLength(converted, ref lengthCache, parameter),
- DBNull or null => 0,
- _ => throw new InvalidCastException($"Can't write CLR type {value.GetType()} with handler type VectorHandler")
- };
-
- public override Task WriteObjectWithLength(object? value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default)
- => value switch
- {
- Vector converted => WriteWithLength(converted, buf, lengthCache, parameter, async, cancellationToken),
- DBNull or null => WriteWithLength(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken),
- _ => throw new InvalidCastException(
- $"Can't write CLR type {value.GetType()} with handler type VectorHandler")
- };
-}
diff --git a/src/Pgvector/Npgsql/VectorTypeHandlerResolver.cs b/src/Pgvector/Npgsql/VectorTypeHandlerResolver.cs
deleted file mode 100644
index 020b973..0000000
--- a/src/Pgvector/Npgsql/VectorTypeHandlerResolver.cs
+++ /dev/null
@@ -1,52 +0,0 @@
-using Npgsql.Internal;
-using Npgsql.Internal.TypeHandling;
-using Npgsql.PostgresTypes;
-using NpgsqlTypes;
-using System;
-
-namespace Pgvector.Npgsql;
-
-public class VectorTypeHandlerResolver : TypeHandlerResolver
-{
- readonly NpgsqlDatabaseInfo _databaseInfo;
- readonly VectorHandler? _vectorHandler;
-
- internal VectorTypeHandlerResolver(NpgsqlConnector connector)
- {
- _databaseInfo = connector.DatabaseInfo;
-
- var pgVectorType = PgType("vector");
- if (pgVectorType != null)
- {
- _vectorHandler = new VectorHandler(pgVectorType);
- }
- }
-
- public override NpgsqlTypeHandler? ResolveByDataTypeName(string typeName)
- => typeName == "vector" ? _vectorHandler : null;
-
- public override NpgsqlTypeHandler? ResolveByClrType(Type type)
- {
- var dataTypeName = ClrTypeToDataTypeName(type);
- if (dataTypeName != null)
- {
- var handler = ResolveByDataTypeName(dataTypeName);
- if (handler != null)
- return handler;
- }
-
- return null;
- }
-
- public override TypeMappingInfo? GetMappingByDataTypeName(string dataTypeName)
- => DoGetMappingByDataTypeName(dataTypeName);
-
- internal static string? ClrTypeToDataTypeName(Type type)
- => type == typeof(Vector) ? "vector" : null;
-
- internal static TypeMappingInfo? DoGetMappingByDataTypeName(string dataTypeName)
- => dataTypeName == "vector" ? new TypeMappingInfo(NpgsqlDbType.Unknown, "vector") : null;
-
- PostgresType? PgType(string pgTypeName)
- => _databaseInfo.TryGetPostgresTypeByName(pgTypeName, out var pgType) ? pgType : null;
-}
diff --git a/src/Pgvector/Npgsql/VectorTypeHandlerResolverFactory.cs b/src/Pgvector/Npgsql/VectorTypeHandlerResolverFactory.cs
deleted file mode 100644
index 273a899..0000000
--- a/src/Pgvector/Npgsql/VectorTypeHandlerResolverFactory.cs
+++ /dev/null
@@ -1,17 +0,0 @@
-using Npgsql.Internal;
-using Npgsql.Internal.TypeHandling;
-using System;
-
-namespace Pgvector.Npgsql;
-
-public class VectorTypeHandlerResolverFactory : TypeHandlerResolverFactory
-{
- public override TypeHandlerResolver Create(NpgsqlConnector connector)
- => new VectorTypeHandlerResolver(connector);
-
- public override string? GetDataTypeNameByClrType(Type type)
- => VectorTypeHandlerResolver.ClrTypeToDataTypeName(type);
-
- public override TypeMappingInfo? GetMappingByDataTypeName(string dataTypeName)
- => VectorTypeHandlerResolver.DoGetMappingByDataTypeName(dataTypeName);
-}
diff --git a/src/Pgvector/Npgsql/VectorTypeInfoResolver.cs b/src/Pgvector/Npgsql/VectorTypeInfoResolver.cs
new file mode 100644
index 0000000..c12ead2
--- /dev/null
+++ b/src/Pgvector/Npgsql/VectorTypeInfoResolver.cs
@@ -0,0 +1,28 @@
+using System;
+using Npgsql.Internal;
+using Npgsql.Internal.Postgres;
+
+namespace Pgvector.Npgsql;
+
+public class VectorTypeInfoResolver : IPgTypeInfoResolver
+{
+ TypeInfoMappingCollection Mappings { get; }
+
+ public VectorTypeInfoResolver()
+ {
+ Mappings = new TypeInfoMappingCollection();
+ AddInfos(Mappings);
+ // TODO: Opt-in only
+ AddArrayInfos(Mappings);
+ }
+
+ public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options)
+ => Mappings.Find(type, dataTypeName, options);
+
+ static void AddInfos(TypeInfoMappingCollection mappings)
+ => mappings.AddType("vector",
+ static (options, mapping, _) => mapping.CreateInfo(options, new VectorConverter()), isDefault: true);
+
+ static void AddArrayInfos(TypeInfoMappingCollection mappings)
+ => mappings.AddArrayType("vector");
+}
diff --git a/src/Pgvector/Pgvector.csproj b/src/Pgvector/Pgvector.csproj
index ce4660f..208cf96 100644
--- a/src/Pgvector/Pgvector.csproj
+++ b/src/Pgvector/Pgvector.csproj
@@ -17,7 +17,7 @@
-
+
diff --git a/tests/Pgvector.Tests/EntityFrameworkCoreTests.cs b/tests/Pgvector.Tests/EntityFrameworkCoreTests.cs
index f75627a..9a263f2 100644
--- a/tests/Pgvector.Tests/EntityFrameworkCoreTests.cs
+++ b/tests/Pgvector.Tests/EntityFrameworkCoreTests.cs
@@ -11,7 +11,7 @@ public class ItemContext : DbContext
protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
{
var connString = "Host=localhost;Database=pgvector_dotnet_test";
- optionsBuilder.UseNpgsql(connString, o => o.UseVector()).UseSnakeCaseNamingConvention();
+ optionsBuilder.UseNpgsql(connString, o => o.UseVector());
}
protected override void OnModelCreating(ModelBuilder modelBuilder)
@@ -30,7 +30,7 @@ public class Item
{
public int Id { get; set; }
- [Column(TypeName = "vector(3)")]
+ [Column("embedding", TypeName = "vector(3)")]
public Vector? Embedding { get; set; }
}
diff --git a/tests/Pgvector.Tests/Pgvector.Tests.csproj b/tests/Pgvector.Tests/Pgvector.Tests.csproj
index 5368732..a23e1d7 100644
--- a/tests/Pgvector.Tests/Pgvector.Tests.csproj
+++ b/tests/Pgvector.Tests/Pgvector.Tests.csproj
@@ -1,7 +1,7 @@
- net7.0
+ net8.0
enable
enable