diff --git a/src/SQLite.cs b/src/SQLite.cs index f255f5e4..cdaf384c 100644 --- a/src/SQLite.cs +++ b/src/SQLite.cs @@ -364,10 +364,17 @@ public int CreateTable(Type ty, CreateFlags createFlags = CreateFlags.None) _tables.Add (ty.FullName, map); } var query = "create table if not exists \"" + map.TableName + "\"(\n"; - - var decls = map.Columns.Select (p => Orm.SqlDecl (p, StoreDateTimeAsTicks)); + + var pkCols = map.Columns.Where(p => p.IsPK); + int numPkCols = pkCols.Count(); + + var decls = map.Columns.Select(p => Orm.SqlDecl(p, StoreDateTimeAsTicks, numPkCols == 1)); var decl = string.Join (",\n", decls.ToArray ()); query += decl; + + if (numPkCols > 1) + query += string.Format(",\nprimary key ({0})\n", string.Join(", ", pkCols.Select(p => "\"" + p.Name + "\""))); + query += ")"; var count = Execute (query); @@ -1234,35 +1241,39 @@ public int Insert (object obj, string extra, Type objType) var map = GetMapping (objType); #if NETFX_CORE - if (map.PK != null && map.PK.IsAutoGuid) - { - // no GetProperty so search our way up the inheritance chain till we find it - PropertyInfo prop; - while (objType != null) - { - var info = objType.GetTypeInfo(); - prop = info.GetDeclaredProperty(map.PK.PropertyName); - if (prop != null) - { - if (prop.GetValue(obj, null).Equals(Guid.Empty)) - { - prop.SetValue(obj, Guid.NewGuid(), null); - } - break; - } - - objType = info.BaseType; - } - } + foreach (var pk in map.PKs) + { + if (pk.IsAutoGuid) + { + // no GetProperty so search our way up the inheritance chain till we find it + PropertyInfo prop; + while (objType != null) + { + var info = objType.GetTypeInfo(); + prop = info.GetDeclaredProperty(pk.PropertyName); + if (prop != null) + { + if (prop.GetValue(obj, null).Equals(Guid.Empty)) + { + prop.SetValue(obj, Guid.NewGuid(), null); + } + break; + } + objType = info.BaseType; + } + } + } #else - if (map.PK != null && map.PK.IsAutoGuid) { - var prop = objType.GetProperty(map.PK.PropertyName); - if (prop != null) { - if (prop.GetValue(obj, null).Equals(Guid.Empty)) { - prop.SetValue(obj, Guid.NewGuid(), null); - } - } - } + foreach (var pk in map.PKs) { + if (pk.IsAutoGuid) { + var prop = objType.GetProperty(pk.PropertyName); + if (prop != null) { + if (prop.GetValue(obj, null).Equals(Guid.Empty)) { + prop.SetValue(obj, Guid.NewGuid(), null); + } + } + } + } #endif @@ -1342,21 +1353,22 @@ public int Update (object obj, Type objType) var map = GetMapping (objType); - var pk = map.PK; - - if (pk == null) { + var pks = map.PKs; + if (pks.Count == 0) { throw new NotSupportedException ("Cannot update " + map.TableName + ": it has no PK"); } var cols = from p in map.Columns - where p != pk + where !p.IsPK select p; var vals = from c in cols select c.GetValue (obj); - var ps = new List (vals); - ps.Add (pk.GetValue (obj)); - var q = string.Format ("update \"{0}\" set {1} where {2} = ? ", map.TableName, string.Join (",", (from c in cols - select "\"" + c.Name + "\" = ? ").ToArray ()), pk.Name); + + var q = string.Format ("update \"{0}\" set {1} {2} ", map.TableName, string.Join (",", (from c in cols + select "\"" + c.Name + "\" = ? ").ToArray ()), map.GetPrimaryKeyClause()); + + var ps = new List(vals); + ps.AddRange(pks.Select(pk => pk.GetValue(obj))); try { rowsAffected = Execute (q, ps.ToArray ()); @@ -1408,12 +1420,12 @@ public int UpdateAll (System.Collections.IEnumerable objects) public int Delete (object objectToDelete) { var map = GetMapping (objectToDelete.GetType ()); - var pk = map.PK; - if (pk == null) { + var pks = map.PKs; + if (pks.Count == 0) { throw new NotSupportedException ("Cannot delete " + map.TableName + ": it has no PK"); } - var q = string.Format ("delete from \"{0}\" where \"{1}\" = ?", map.TableName, pk.Name); - var count = Execute (q, pk.GetValue (objectToDelete)); + var q = string.Format ("delete from \"{0}\" {1}", map.TableName, map.GetPrimaryKeyClause()); + var count = Execute (q, pks.Select(pk => pk.GetValue (objectToDelete)).ToArray()); if (count > 0) OnTableChanged (map, NotifyTableChangedAction.Delete); return count; @@ -1434,11 +1446,14 @@ public int Delete (object objectToDelete) public int Delete (object primaryKey) { var map = GetMapping (typeof (T)); - var pk = map.PK; - if (pk == null) { + var pks = map.PKs; + if (pks.Count == 0) { throw new NotSupportedException ("Cannot delete " + map.TableName + ": it has no PK"); } - var q = string.Format ("delete from \"{0}\" where \"{1}\" = ?", map.TableName, pk.Name); + if (pks.Count > 1) { + throw new NotSupportedException("Cannot delete " + map.TableName + ": it has > 1 PK"); + } + var q = string.Format ("delete from \"{0}\" {1}", map.TableName, map.GetPrimaryKeyClause()); var count = Execute (q, primaryKey); if (count > 0) OnTableChanged (map, NotifyTableChangedAction.Delete); @@ -1658,7 +1673,7 @@ public class TableMapping public Column[] Columns { get; private set; } - public Column PK { get; private set; } + public List PKs { get; private set; } public string GetByPrimaryKeySql { get; private set; } @@ -1698,19 +1713,20 @@ public TableMapping(Type type, CreateFlags createFlags = CreateFlags.None) } } Columns = cols.ToArray (); + PKs = new List(); foreach (var c in Columns) { if (c.IsAutoInc && c.IsPK) { _autoPk = c; } if (c.IsPK) { - PK = c; + PKs.Add(c); } } HasAutoIncPK = _autoPk != null; - if (PK != null) { - GetByPrimaryKeySql = string.Format ("select * from \"{0}\" where \"{1}\" = ?", TableName, PK.Name); + if (PKs.Count > 0) { + GetByPrimaryKeySql = string.Format("select * from \"{0}\" {1}", TableName, GetPrimaryKeyClause()); ; } else { // People should not be calling Get/Find without a PK @@ -1718,6 +1734,22 @@ public TableMapping(Type type, CreateFlags createFlags = CreateFlags.None) } } + public string GetPrimaryKeyClause() + { + string clause = String.Empty; + bool first = true; + foreach (Column pk in PKs) { + if (first) { + clause += "where "; + first = false; + } else { + clause += " and "; + } + clause += string.Format("\"{0}\" = ?", pk.Name); + } + return clause; + } + public bool HasAutoIncPK { get; private set; } public void SetAutoIncPK (object obj, long id) @@ -1883,11 +1915,11 @@ public static class Orm public const string ImplicitPkName = "Id"; public const string ImplicitIndexSuffix = "Id"; - public static string SqlDecl (TableMapping.Column p, bool storeDateTimeAsTicks) + public static string SqlDecl (TableMapping.Column p, bool storeDateTimeAsTicks, bool trySetAsPrimaryKey = true) { string decl = "\"" + p.Name + "\" " + SqlType (p, storeDateTimeAsTicks) + " "; - if (p.IsPK) { + if (trySetAsPrimaryKey && p.IsPK) { decl += "primary key "; } if (p.IsAutoInc) { diff --git a/src/SQLiteAsync.cs b/src/SQLiteAsync.cs index 79b91cba..e5aa96b3 100644 --- a/src/SQLiteAsync.cs +++ b/src/SQLiteAsync.cs @@ -410,8 +410,10 @@ public Entry (SQLiteConnectionString connectionString, SQLiteOpenFlags openFlags public void OnApplicationSuspended () { - Connection.Dispose (); - Connection = null; + if (Connection != null) { + Connection.Dispose (); + Connection = null; + } } } diff --git a/tests/CreateTableImplicitTest.cs b/tests/CreateTableImplicitTest.cs index 416b42cf..41fd3d74 100644 --- a/tests/CreateTableImplicitTest.cs +++ b/tests/CreateTableImplicitTest.cs @@ -54,7 +54,7 @@ public void WithoutImplicitMapping () var mapping = db.GetMapping(); - Assert.IsNull (mapping.PK); + Assert.AreEqual(mapping.PKs.Count, 0); var column = mapping.Columns[2]; Assert.AreEqual("IndexedId", column.Name); @@ -72,10 +72,11 @@ public void ImplicitPK() var mapping = db.GetMapping(); - Assert.IsNotNull(mapping.PK); - Assert.AreEqual("Id", mapping.PK.Name); - Assert.IsTrue(mapping.PK.IsPK); - Assert.IsFalse(mapping.PK.IsAutoInc); + Assert.AreNotEqual(mapping.PKs.Count, 0); + var pk = mapping.PKs.First(); + Assert.AreEqual("Id", pk.Name); + Assert.IsTrue(pk.IsPK); + Assert.IsFalse(pk.IsAutoInc); CheckPK(db); } @@ -90,10 +91,11 @@ public void ImplicitAutoInc() var mapping = db.GetMapping(); - Assert.IsNotNull(mapping.PK); - Assert.AreEqual("Id", mapping.PK.Name); - Assert.IsTrue(mapping.PK.IsPK); - Assert.IsTrue(mapping.PK.IsAutoInc); + Assert.AreNotEqual(mapping.PKs.Count, 0); + var pk = mapping.PKs.First(); + Assert.AreEqual("Id", pk.Name); + Assert.IsTrue(pk.IsPK); + Assert.IsTrue(pk.IsAutoInc); } [Test] @@ -118,10 +120,11 @@ public void ImplicitPKAutoInc() var mapping = db.GetMapping(); - Assert.IsNotNull(mapping.PK); - Assert.AreEqual("Id", mapping.PK.Name); - Assert.IsTrue(mapping.PK.IsPK); - Assert.IsTrue(mapping.PK.IsAutoInc); + Assert.AreNotEqual(mapping.PKs.Count, 0); + var pk = mapping.PKs.First(); + Assert.AreEqual("Id", pk.Name); + Assert.IsTrue(pk.IsPK); + Assert.IsTrue(pk.IsAutoInc); } [Test] @@ -133,10 +136,11 @@ public void ImplicitAutoIncAsPassedInTypes() var mapping = db.GetMapping(); - Assert.IsNotNull(mapping.PK); - Assert.AreEqual("Id", mapping.PK.Name); - Assert.IsTrue(mapping.PK.IsPK); - Assert.IsTrue(mapping.PK.IsAutoInc); + Assert.AreNotEqual(mapping.PKs.Count, 0); + var pk = mapping.PKs.First(); + Assert.AreEqual("Id", pk.Name); + Assert.IsTrue(pk.IsPK); + Assert.IsTrue(pk.IsAutoInc); } [Test] @@ -148,10 +152,11 @@ public void ImplicitPkAsPassedInTypes() var mapping = db.GetMapping(); - Assert.IsNotNull(mapping.PK); - Assert.AreEqual("Id", mapping.PK.Name); - Assert.IsTrue(mapping.PK.IsPK); - Assert.IsFalse(mapping.PK.IsAutoInc); + Assert.AreNotEqual(mapping.PKs.Count, 0); + var pk = mapping.PKs.First(); + Assert.AreEqual("Id", pk.Name); + Assert.IsTrue(pk.IsPK); + Assert.IsFalse(pk.IsAutoInc); } [Test] @@ -163,10 +168,11 @@ public void ImplicitPKAutoIncAsPassedInTypes() var mapping = db.GetMapping(); - Assert.IsNotNull(mapping.PK); - Assert.AreEqual("Id", mapping.PK.Name); - Assert.IsTrue(mapping.PK.IsPK); - Assert.IsTrue(mapping.PK.IsAutoInc); + Assert.AreNotEqual(mapping.PKs.Count, 0); + var pk = mapping.PKs.First(); + Assert.AreEqual("Id", pk.Name); + Assert.IsTrue(pk.IsPK); + Assert.IsTrue(pk.IsAutoInc); } } } diff --git a/tests/InheritanceTest.cs b/tests/InheritanceTest.cs index 1705d77e..a520125a 100644 --- a/tests/InheritanceTest.cs +++ b/tests/InheritanceTest.cs @@ -38,7 +38,7 @@ public void InheritanceWorks () var mapping = db.GetMapping (); Assert.AreEqual (3, mapping.Columns.Length); - Assert.AreEqual ("Id", mapping.PK.Name); + Assert.AreEqual ("Id", mapping.PKs.First().Name); } } } diff --git a/tests/MultiplePrimaryKeyTest.cs b/tests/MultiplePrimaryKeyTest.cs new file mode 100644 index 00000000..fd0f17da --- /dev/null +++ b/tests/MultiplePrimaryKeyTest.cs @@ -0,0 +1,87 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +#if NETFX_CORE +using Microsoft.VisualStudio.TestPlatform.UnitTestFramework; +using SetUp = Microsoft.VisualStudio.TestPlatform.UnitTestFramework.TestInitializeAttribute; +using TestFixture = Microsoft.VisualStudio.TestPlatform.UnitTestFramework.TestClassAttribute; +using Test = Microsoft.VisualStudio.TestPlatform.UnitTestFramework.TestMethodAttribute; +#else +using NUnit.Framework; +#endif + + +namespace SQLite.Tests +{ + [Table("MultiplePrimaryKeysTable")] + public class MultiplePrimaryKeys + { + [PrimaryKey, Indexed] + public int Id { get; set; } + + [PrimaryKey, NotNull, Indexed] + public string Id2 { get; set; } + + [NotNull] + public string Value1 { get; set; } + + public bool Value2 { get; set; } + } + + [TestFixture] + public class MultiplePrimaryKeyTest + { + [Test] + public void CreateTableWithMultiplePrimaryKeys() + { + var db = new TestDb(); + db.CreateTable(); + + MultiplePrimaryKeys[] insertRows = new MultiplePrimaryKeys[] + { + new MultiplePrimaryKeys() { Id = 1, Id2 = "Foo", Value1 = "One", Value2 = true }, + new MultiplePrimaryKeys() { Id = 2, Id2 = "Bar", Value1 = "Two", Value2 = true }, + new MultiplePrimaryKeys() { Id = 3, Id2 = "Baz", Value1 = "Three", Value2 = true }, + }; + + db.InsertAll(insertRows); + + var queryRows = (from r in db.Table() select r).ToArray(); + Assert.AreEqual(insertRows.Length, queryRows.Length); + + for (int i = 0; i < insertRows.Length; i ++) + { + Assert.AreEqual(insertRows[i].Id, queryRows[i].Id); + Assert.AreEqual(insertRows[i].Id2, queryRows[i].Id2); + } + + // + // Test Updates Async + // + + TestAsyncDb asyncDb = new TestAsyncDb(db.DatabasePath); + var query = asyncDb.Table(); + + Task> rowsTask = (from row in query where row.Value2 select row).ToListAsync(); + rowsTask.Wait(); + + List rows = rowsTask.Result; + Assert.AreEqual(rows.Count, 3); + Assert.AreEqual(rows[0].Value1, "One"); + Assert.AreEqual(rows[1].Value1, "Two"); + Assert.AreEqual(rows[2].Value1, "Three"); + + rows[0].Value1 = "Three"; + rows[0].Value2 = false; + rows[1].Value1 = "Four"; + rows[1].Value2 = false; + rows[2].Value1 = "Five"; + rows[2].Value2 = false; + + Task insertTask = asyncDb.UpdateAllAsync(rows); + insertTask.Wait(); + } + } +} diff --git a/tests/SQLite.Tests.csproj b/tests/SQLite.Tests.csproj index aa911663..ceebb131 100644 --- a/tests/SQLite.Tests.csproj +++ b/tests/SQLite.Tests.csproj @@ -47,6 +47,7 @@ + diff --git a/tests/TestDb.cs b/tests/TestDb.cs index 609c59f2..f7762912 100644 --- a/tests/TestDb.cs +++ b/tests/TestDb.cs @@ -38,10 +38,10 @@ public class OrderLine { [AutoIncrement, PrimaryKey] public int Id { get; set; } - [Indexed("IX_OrderProduct", 1)] + [Indexed("IX_OrderProduct", 1)] public int OrderId { get; set; } - [Indexed("IX_OrderProduct", 2)] - public int ProductId { get; set; } + [Indexed("IX_OrderProduct", 2)] + public int ProductId { get; set; } public int Quantity { get; set; } public decimal UnitPrice { get; set; } public OrderLineStatus Status { get; set; } @@ -59,6 +59,18 @@ public TestDb (bool storeDateTimeAsTicks = false) : base (TestPath.GetTempFileNa } } + public class TestAsyncDb : SQLiteAsyncConnection + { + public TestAsyncDb(bool storeDateTimeAsTicks = false) : base(TestPath.GetTempFileName(), storeDateTimeAsTicks) + { + } + + public TestAsyncDb(string databasePath, bool storeDateTimeAsTicks = false) + : base(databasePath, storeDateTimeAsTicks) + { + } + } + public class TestPath { public static string GetTempFileName ()