diff --git a/N.EntityFrameworkCore.Extensions.Test/Data/Product.cs b/N.EntityFrameworkCore.Extensions.Test/Data/Product.cs index 0266890..7d1253c 100644 --- a/N.EntityFrameworkCore.Extensions.Test/Data/Product.cs +++ b/N.EntityFrameworkCore.Extensions.Test/Data/Product.cs @@ -19,8 +19,11 @@ public class Product [Column("Status")] [StringLength(25)] public string StatusString { get; set; } + public int? ProductCategoryId { get; set; } public ProductStatus? StatusEnum { get; set; } public DateTime? UpdatedDateTime { get; set; } + + public virtual ProductCategory ProductCategory { get; set; } public Product() { diff --git a/N.EntityFrameworkCore.Extensions.Test/Data/ProductCategory.cs b/N.EntityFrameworkCore.Extensions.Test/Data/ProductCategory.cs new file mode 100644 index 0000000..edb09b3 --- /dev/null +++ b/N.EntityFrameworkCore.Extensions.Test/Data/ProductCategory.cs @@ -0,0 +1,17 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using N.EntityFrameworkCore.Extensions.Test.Data.Enums; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace N.EntityFrameworkCore.Extensions.Test.Data +{ + public class ProductCategory + { + public int Id { get; set; } + public string Name { get; set; } + public bool Active { get; internal set; } + } +} diff --git a/N.EntityFrameworkCore.Extensions.Test/Data/SqlExpression.cs b/N.EntityFrameworkCore.Extensions.Test/Data/SqlExpression.cs deleted file mode 100644 index bcaa187..0000000 --- a/N.EntityFrameworkCore.Extensions.Test/Data/SqlExpression.cs +++ /dev/null @@ -1,80 +0,0 @@ -using Microsoft.EntityFrameworkCore.Metadata.Internal; -using Microsoft.EntityFrameworkCore.Storage.ValueConversion; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Linq.Expressions; -using System.Runtime.InteropServices.ObjectiveC; -using System.Text; -using System.Threading.Tasks; - -namespace N.EntityFrameworkCore.Extensions.Sql -{ - internal class SqlExpression - { - SqlExpressionType ExpressionType { get; } - List Items { get; set; } - string Sql => ToSql(); - string Alias { get; } - bool IsEmpty => Items.Count == 0; - - internal SqlExpression(SqlExpressionType expressionType, object item, string alias = null) - { - ExpressionType = expressionType; - Items = new List(); - if (item is IEnumerable values) - { - Items.AddRange(values.ToArray()); - } - else - { - Items.Add(item); - } - Alias = alias; - } - internal SqlExpression(SqlExpressionType expressionType, object[] items, string alias = null) - { - ExpressionType = expressionType; - Items = new List(); - Items.AddRange(items); - Alias = alias; - } - internal static SqlExpression Columns(IEnumerable columns) - { - return new SqlExpression(SqlExpressionType.Columns, columns); - } - - internal static SqlExpression String(string joinOnCondition) - { - return new SqlExpression(SqlExpressionType.String, joinOnCondition); - } - - internal static SqlExpression Table(string tableName, string alias = null) - { - return new SqlExpression(SqlExpressionType.Table, tableName, alias); - } - - private string ToSql() - { - var values = Items.Select(o => o.ToString()).ToArray(); - StringBuilder sbSql = new StringBuilder(); - if (ExpressionType == SqlExpressionType.Columns) - { - sbSql.Append(string.Join(",", values.Select(c => c.StartsWith("$") || c.StartsWith("[") ? c : $"[{c}]"))); - } - else - { - sbSql.Append(string.Join(",", Items.Select(o => o.ToString()))); - } - if (Alias != null) - { - sbSql.Append(" "); - sbSql.Append(SqlKeyword.As.ToString().ToUpper()); - sbSql.Append(" "); - sbSql.Append(Alias); - } - //var test = Items.Select(o => o.ToString()).ToArray(); - return sbSql.ToString(); - } - } -} diff --git a/N.EntityFrameworkCore.Extensions.Test/Data/TestDbContext.cs b/N.EntityFrameworkCore.Extensions.Test/Data/TestDbContext.cs index 01d2ea4..17546bb 100644 --- a/N.EntityFrameworkCore.Extensions.Test/Data/TestDbContext.cs +++ b/N.EntityFrameworkCore.Extensions.Test/Data/TestDbContext.cs @@ -13,6 +13,7 @@ public class TestDbContext : DbContext public virtual DbSet ProductsWithCustomSchema { get; set; } public virtual DbSet ProductsWithComplexKey { get; set; } public virtual DbSet Orders { get; set; } + public virtual DbSet ProductCategories { get; set; } public virtual DbSet TpcPeople { get; set; } public virtual DbSet TphPeople { get; set; } public virtual DbSet TphCustomers { get; set; } diff --git a/N.EntityFrameworkCore.Extensions.Test/DbContextExtensions/DbContextExtensionsBase.cs b/N.EntityFrameworkCore.Extensions.Test/DbContextExtensions/DbContextExtensionsBase.cs index af3c6a7..dc54d03 100644 --- a/N.EntityFrameworkCore.Extensions.Test/DbContextExtensions/DbContextExtensionsBase.cs +++ b/N.EntityFrameworkCore.Extensions.Test/DbContextExtensions/DbContextExtensionsBase.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Linq; namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions { @@ -29,6 +30,7 @@ protected TestDbContext SetupDbContext(bool populateData, PopulateDataMode mode TestDbContext dbContext = new TestDbContext(); dbContext.Orders.Truncate(); dbContext.Products.Truncate(); + dbContext.ProductCategories.Clear(); dbContext.ProductsWithCustomSchema.Truncate(); dbContext.Database.ClearTable("TpcCustomer"); dbContext.Database.ClearTable("TpcVendor"); @@ -81,11 +83,20 @@ protected TestDbContext SetupDbContext(bool populateData, PopulateDataMode mode Debug.WriteLine("Last Id for Order is {0}", id); dbContext.BulkInsert(orders, new BulkInsertOptions() { KeepIdentity = true }); + + var productCategories = new List() + { + new ProductCategory { Id=1, Name="Category-1", Active=true}, + new ProductCategory { Id=2, Name="Category-2", Active=true}, + new ProductCategory { Id=3, Name="Category-3", Active=true}, + new ProductCategory { Id=4, Name="Category-4", Active=false}, + }; + dbContext.BulkInsert(productCategories, o => { o.KeepIdentity = true; o.UsePermanentTable = true; }); var products = new List(); id = 1; for (int i = 0; i < 2050; i++) { - products.Add(new Product { Id = i.ToString(), Price = 1.25M, OutOfStock = false }); + products.Add(new Product { Id = i.ToString(), Price = 1.25M, OutOfStock = false, ProductCategoryId = 4 }); id++; } for (int i = 2050; i < 7000; i++) diff --git a/N.EntityFrameworkCore.Extensions.Test/DbContextExtensions/DeleteFromQuery.cs b/N.EntityFrameworkCore.Extensions.Test/DbContextExtensions/DeleteFromQuery.cs index 14d723c..4c6bf45 100644 --- a/N.EntityFrameworkCore.Extensions.Test/DbContextExtensions/DeleteFromQuery.cs +++ b/N.EntityFrameworkCore.Extensions.Test/DbContextExtensions/DeleteFromQuery.cs @@ -22,6 +22,19 @@ public void With_Boolean_Value() Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); } [TestMethod] + public void With_Child_Relationship() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(p => !p.ProductCategory.Active); + int oldTotal = products.Count(); + int rowsDeleted = products.DeleteFromQuery(); + int newTotal = products.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (ProductCategory.Active == false)"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows update must match the count of rows that match the condition (ProductCategory.Active == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] public void With_Decimal_Using_IQuerable() { var dbContext = SetupDbContext(true); diff --git a/N.EntityFrameworkCore.Extensions.Test/DbContextExtensions/DeleteFromQueryAsync.cs b/N.EntityFrameworkCore.Extensions.Test/DbContextExtensions/DeleteFromQueryAsync.cs index 116ae1f..8f09a1a 100644 --- a/N.EntityFrameworkCore.Extensions.Test/DbContextExtensions/DeleteFromQueryAsync.cs +++ b/N.EntityFrameworkCore.Extensions.Test/DbContextExtensions/DeleteFromQueryAsync.cs @@ -1,6 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading.Tasks; namespace N.EntityFrameworkCore.Extensions.Test.DbContextExtensions @@ -22,6 +23,19 @@ public async Task With_Boolean_Value() Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were updated"); } [TestMethod] + public async Task With_Child_Relationship() + { + var dbContext = SetupDbContext(true); + var products = dbContext.Products.Where(p => !p.ProductCategory.Active); + int oldTotal = products.Count(); + int rowsDeleted = await products.DeleteFromQueryAsync(); + int newTotal = products.Count(); + + Assert.IsTrue(oldTotal > 0, "There must be products in database that match this condition (ProductCategory.Active == false)"); + Assert.IsTrue(rowsDeleted == oldTotal, "The number of rows update must match the count of rows that match the condition (ProductCategory.Active == false)"); + Assert.IsTrue(newTotal == 0, "The new count must be 0 to indicate all records were deleted"); + } + [TestMethod] public async Task With_Decimal_Using_IQuerable() { var dbContext = SetupDbContext(true); @@ -90,11 +104,11 @@ public async Task With_Different_Values() Assert.IsTrue(oldTotal - newTotal == rowsDeleted, "The rows deleted must match the new count minues the old count"); } [TestMethod] - public void With_Empty_List() + public async Task With_Empty_List() { var dbContext = SetupDbContext(false); int oldTotal = dbContext.Orders.Count(); - int rowsDeleted = dbContext.Orders.DeleteFromQuery(); + int rowsDeleted = await dbContext.Orders.DeleteFromQueryAsync(); int newTotal = dbContext.Orders.Count(); Assert.IsTrue(oldTotal == 0, "There must be no orders in database that match this condition"); diff --git a/N.EntityFrameworkCore.Extensions/Data/BulkOperation.cs b/N.EntityFrameworkCore.Extensions/Data/BulkOperation.cs index 079af5d..e081024 100644 --- a/N.EntityFrameworkCore.Extensions/Data/BulkOperation.cs +++ b/N.EntityFrameworkCore.Extensions/Data/BulkOperation.cs @@ -60,7 +60,7 @@ internal BulkInsertResult BulkInsertStagingData(IEnumerable entities, bool return DbContextExtensions.BulkInsert(entities, Options, TableMapping, Connection, Transaction, StagingTableName, columnsToInsert, SqlBulkCopyOptions.KeepIdentity, useInternalId); } internal BulkMergeResult ExecuteMerge(Dictionary entityMap, Expression> mergeOnCondition, - bool autoMapOutput, bool insertIfNotExists, bool update = false, bool delete = false) + bool autoMapOutput, bool keepIdentity, bool insertIfNotExists, bool update = false, bool delete = false) { var rowsInserted = new Dictionary(); var rowsUpdated = new Dictionary(); @@ -76,14 +76,19 @@ internal BulkMergeResult ExecuteMerge(Dictionary entityMap, Expressi rowsAffected[entityType] = 0; var columnsToInsert = TableMapping.GetColumnNames(entityType).Intersect(GetColumnNames(entityType)); + if(keepIdentity) + { + columnsToInsert = columnsToInsert.Union(TableMapping.GetPrimaryKeyColumns()); + } var columnsToUpdate = update ? TableMapping.GetColumnNames(entityType).Intersect(GetColumnNames(entityType)) : new string[] { }; var autoGeneratedColumns = autoMapOutput ? TableMapping.GetAutoGeneratedColumns(entityType) : new string[] { }; - var columnsToOutput = GetMergeOutputColumns(autoGeneratedColumns, delete); + var columnsToOutput = autoMapOutput ? GetMergeOutputColumns(autoGeneratedColumns, delete) : new string[] { }; var deleteEntityType = TableMapping.EntityType == entityType & delete ? delete : false; string mergeOnConditionSql = insertIfNotExists ? CommonUtil.GetJoinConditionSql(mergeOnCondition, PrimaryKeyColumnNames, "t", "s") : "1=2"; + bool toggleIdentity = keepIdentity && TableMapping.HasIdentityColumn; var mergeStatement = SqlStatement.CreateMerge(StagingTableName, entityType.GetSchemaQualifiedTableName(), - mergeOnConditionSql, columnsToInsert, columnsToUpdate, columnsToOutput, deleteEntityType); + mergeOnConditionSql, columnsToInsert, columnsToUpdate, columnsToOutput, deleteEntityType, toggleIdentity); if (autoMapOutput) { diff --git a/N.EntityFrameworkCore.Extensions/Data/DbContextExtensions.cs b/N.EntityFrameworkCore.Extensions/Data/DbContextExtensions.cs index 6c18578..dd45878 100644 --- a/N.EntityFrameworkCore.Extensions/Data/DbContextExtensions.cs +++ b/N.EntityFrameworkCore.Extensions/Data/DbContextExtensions.cs @@ -253,7 +253,7 @@ public static int BulkInsert(this DbContext context, IEnumerable entities, { var bulkInsertResult = bulkOperation.BulkInsertStagingData(entities, options.KeepIdentity, true); var bulkMergeResult = bulkOperation.ExecuteMerge(bulkInsertResult.EntityMap, options.InsertOnCondition, - options.AutoMapOutput, options.InsertIfNotExists); + options.AutoMapOutput, options.KeepIdentity, options.InsertIfNotExists); rowsAffected = bulkMergeResult.RowsAffected; bulkOperation.DbTransactionContext.Commit(); } @@ -414,7 +414,7 @@ private static BulkMergeResult InternalBulkMerge(this DbContext context, I bulkOperation.ValidateBulkMerge(options.MergeOnCondition); var bulkInsertResult = bulkOperation.BulkInsertStagingData(entities, true, true); bulkMergeResult = bulkOperation.ExecuteMerge(bulkInsertResult.EntityMap, options.MergeOnCondition, options.AutoMapOutput, - true, true, options.DeleteIfNotMatched); + false, true, true, options.DeleteIfNotMatched); bulkOperation.DbTransactionContext.Commit(); } catch (Exception) diff --git a/N.EntityFrameworkCore.Extensions/N.EntityFrameworkCore.Extensions.csproj b/N.EntityFrameworkCore.Extensions/N.EntityFrameworkCore.Extensions.csproj index 16c047a..9ec5a89 100644 --- a/N.EntityFrameworkCore.Extensions/N.EntityFrameworkCore.Extensions.csproj +++ b/N.EntityFrameworkCore.Extensions/N.EntityFrameworkCore.Extensions.csproj @@ -2,7 +2,7 @@ net8.0 - 8.0.0.5 + 8.0.0.6 true https://github.com/NorthernLight1/N.EntityFrameworkCore.Extensions/ Northern25 diff --git a/N.EntityFrameworkCore.Extensions/Sql/SqlBuilder.cs b/N.EntityFrameworkCore.Extensions/Sql/SqlBuilder.cs index 21574f8..2b0789c 100644 --- a/N.EntityFrameworkCore.Extensions/Sql/SqlBuilder.cs +++ b/N.EntityFrameworkCore.Extensions/Sql/SqlBuilder.cs @@ -125,7 +125,9 @@ public void ChangeToDelete() if(sqlClause != null) { sqlClause.Name = "DELETE"; - sqlClause.InputText = sqlFromClause.InputText.Substring(sqlFromClause.InputText.LastIndexOf("AS ") + 3); + int aliasStartIndex = sqlFromClause.InputText.IndexOf("AS ") + 3; + int aliasLength = sqlFromClause.InputText.IndexOf("]", aliasStartIndex) - aliasStartIndex + 1; + sqlClause.InputText = sqlFromClause.InputText.Substring(aliasStartIndex, aliasLength); } } public void ChangeToUpdate(string updateExpression, string setExpression) diff --git a/N.EntityFrameworkCore.Extensions/Sql/SqlKeyword.cs b/N.EntityFrameworkCore.Extensions/Sql/SqlKeyword.cs index d870a02..bf461bb 100644 --- a/N.EntityFrameworkCore.Extensions/Sql/SqlKeyword.cs +++ b/N.EntityFrameworkCore.Extensions/Sql/SqlKeyword.cs @@ -28,6 +28,9 @@ public enum SqlKeyword As, By, Source, - Target + Target, + Off, + Identity_Insert, + Semicolon, } } diff --git a/N.EntityFrameworkCore.Extensions/Sql/SqlStatement.cs b/N.EntityFrameworkCore.Extensions/Sql/SqlStatement.cs index a1497e7..53d8910 100644 --- a/N.EntityFrameworkCore.Extensions/Sql/SqlStatement.cs +++ b/N.EntityFrameworkCore.Extensions/Sql/SqlStatement.cs @@ -35,15 +35,28 @@ internal void CreatePart(SqlKeyword keyword, SqlExpression expression = null) { SqlParts.Add(new SqlPart(keyword, expression)); } + internal void SetIdentityInsert(string tableName, bool enable) + { + this.CreatePart(SqlKeyword.Set); + this.CreatePart(SqlKeyword.Identity_Insert, SqlExpression.Table(tableName)); + if (enable) + this.CreatePart(SqlKeyword.On); + else + this.CreatePart(SqlKeyword.Off); + this.CreatePart(SqlKeyword.Semicolon); + } //internal static SqlStatement CreateMergeInsert(string sourceTableName, string targetTableName, string mergeOnCondition, // IEnumerable insertColumns, IEnumerable outputColumns, bool deleteIfNotMatched = false) //{ //} internal static SqlStatement CreateMerge(string sourceTableName, string targetTableName, string joinOnCondition, - IEnumerable insertColumns, IEnumerable updateColumns, IEnumerable outputColumns, bool deleteIfNotMatched=false) + IEnumerable insertColumns, IEnumerable updateColumns, IEnumerable outputColumns, + bool deleteIfNotMatched=false, bool hasIdentityColumn=false) { var statement = new SqlStatement(); + if (hasIdentityColumn) + statement.SetIdentityInsert(targetTableName, true); statement.CreatePart(SqlKeyword.Merge, SqlExpression.Table(targetTableName, "t")); statement.CreatePart(SqlKeyword.Using, SqlExpression.Table(sourceTableName, "s")); statement.CreatePart(SqlKeyword.On, SqlExpression.String(joinOnCondition)); @@ -71,7 +84,12 @@ internal static SqlStatement CreateMerge(string sourceTableName, string targetTa statement.CreatePart(SqlKeyword.Then); statement.CreatePart(SqlKeyword.Delete); } - statement.CreatePart(SqlKeyword.Output, SqlExpression.Columns(outputColumns)); + if(outputColumns.Any()) + statement.CreatePart(SqlKeyword.Output, SqlExpression.Columns(outputColumns)); + statement.CreatePart(SqlKeyword.Semicolon); + + if (hasIdentityColumn) + statement.SetIdentityInsert(targetTableName, false); return statement; } @@ -80,9 +98,23 @@ private string ToSql() StringBuilder sbSql = new StringBuilder(); foreach(var part in SqlParts) { - if (!part.IgnoreOutput) + if (part.Keyword == SqlKeyword.Semicolon) { - sbSql.Append(part.Keyword.ToString().ToUpper() + " "); + int lastIndex = sbSql.Length - 1; + if (lastIndex > -1 && sbSql[lastIndex] == ' ') + { + sbSql[lastIndex] = ';'; + sbSql.Append("\n"); + } + else + { + sbSql.Append(";\n"); + } + } + else if (!part.IgnoreOutput) + { + sbSql.Append(part.Keyword.ToString().ToUpper()); + sbSql.Append(" "); bool useParenthese = part.Keyword == SqlKeyword.Insert || part.Keyword == SqlKeyword.Values; string format = useParenthese ? "({0})" : "{0}"; @@ -94,10 +126,10 @@ private string ToSql() } } //Output a semicolon for certain SQL Statments - if(SqlParts.First().Keyword == SqlKeyword.Merge) - { - sbSql.Append(";"); - } + //if(SqlParts.First().Keyword == SqlKeyword.Merge) + //{ + // sbSql.Append(";"); + //} return sbSql.ToString(); } }