diff --git a/Achilles.Entities.Sqlite/Linq/EntityCollectionOfT.cs b/Achilles.Entities.Sqlite/Linq/EntityCollectionOfT.cs index 886c644..8cf1182 100644 --- a/Achilles.Entities.Sqlite/Linq/EntityCollectionOfT.cs +++ b/Achilles.Entities.Sqlite/Linq/EntityCollectionOfT.cs @@ -14,19 +14,24 @@ using System.Collections; using System.Collections.Generic; using System.ComponentModel; +using System.Linq; +using System.Linq.Expressions; #endregion namespace Achilles.Entities.Linq { - public sealed class EntityCollection : IEntityCollection, IEntityCollection, ICollection, IListSource + public sealed class EntityCollection : IEntityCollection, IEntitySource, IEntityCollection, ICollection, IListSource where TEntity : class { #region Private Fields - private IEnumerable _source; - private HashSet _entities; - + private EntitySet _source; + private List _entities; + + private object _foreignKeyValue; + private string _referenceKey; + private bool _isLoaded; #endregion @@ -35,59 +40,79 @@ public sealed class EntityCollection : IEntityCollection, IEnt public EntityCollection() { - var test = 6; } #endregion - #region IEntityCollection Implementation + #region Internal IEntitySource Implementation + + bool IEntitySource.HasSource => (_source != null); + + void IEntitySource.SetSource( IEntitySet source, string referenceKey, object foreignKeyValue ) + { + _source = source as EntitySet; + _referenceKey = referenceKey; + _foreignKeyValue = foreignKeyValue; + } + + #endregion + + #region Private Properties - void IEntityCollection.AttachSource( IEnumerable source ) + private List Entities { - _source = source; + get + { + if ( !_isLoaded ) + { + Load(); + } + + return _entities; + } } - + #endregion #region ICollection Implementation - public int Count => ((ICollection)_entities).Count; + public int Count => ((ICollection)Entities).Count; - public bool IsReadOnly => ((ICollection)_entities).IsReadOnly; + public bool IsReadOnly => ((ICollection)Entities).IsReadOnly; public void Add( TEntity item ) { - ((ICollection)_entities).Add( item ); + ((ICollection)Entities).Add( item ); } public void Clear() { - ((ICollection)_entities).Clear(); + ((ICollection)Entities).Clear(); } public bool Contains( TEntity item ) { - return ((ICollection)_entities).Contains( item ); + return ((ICollection)Entities).Contains( item ); } public void CopyTo( TEntity[] array, int arrayIndex ) { - ((ICollection)_entities).CopyTo( array, arrayIndex ); + ((ICollection)Entities).CopyTo( array, arrayIndex ); } public IEnumerator GetEnumerator() { - return ((ICollection)_entities).GetEnumerator(); + return ((ICollection)Entities).GetEnumerator(); } public bool Remove( TEntity item ) { - return ((ICollection)_entities).Remove( item ); + return ((ICollection)Entities).Remove( item ); } IEnumerator IEnumerable.GetEnumerator() { - return ((ICollection)_entities).GetEnumerator(); + return ((ICollection)Entities).GetEnumerator(); } #endregion @@ -96,12 +121,35 @@ IEnumerator IEnumerable.GetEnumerator() bool IListSource.ContainsListCollection => throw new NotImplementedException(); + public bool IsLoaded => _isLoaded; + IList IListSource.GetList() { throw new NotImplementedException(); } - + #endregion + + #region Private Implementation + + private void Load() + { + if ( !_isLoaded && _source != null ) + { + _entities = _source.Where( FilterByForeignKeyPredicate( _foreignKeyValue ) ).ToList(); + _isLoaded = true; + } + } + + private Expression> FilterByForeignKeyPredicate( object foreignKey ) + { + var entity = Expression.Parameter( typeof( TEntity ), "e" ); + var referenceKey = Expression.Property( entity, _referenceKey ); + var value = Expression.Constant( foreignKey, referenceKey.Type ); + var body = Expression.Equal( referenceKey, value ); + + return Expression.Lambda>( body, entity ); + } #endregion } diff --git a/Achilles.Entities.Sqlite/Linq/EntityReferenceOfT.cs b/Achilles.Entities.Sqlite/Linq/EntityReferenceOfT.cs index fff7478..b49cfcc 100644 --- a/Achilles.Entities.Sqlite/Linq/EntityReferenceOfT.cs +++ b/Achilles.Entities.Sqlite/Linq/EntityReferenceOfT.cs @@ -19,7 +19,7 @@ namespace Achilles.Entities.Linq { - public sealed class EntityReference : IEntityReference, IEntityReference, IEntityReferenceSource + public sealed class EntityReference : IEntityReference, IEntityReference, IEntitySource where TEntity : class { #region Private Fields @@ -66,9 +66,9 @@ public TEntity Value #region Internal IEntityReferenceSource API - bool IEntityReferenceSource.HasSource => (_source != null); + bool IEntitySource.HasSource => (_source != null); - void IEntityReferenceSource.SetSource( IEntitySet source, string referenceKey, object foreignKeyValue ) + void IEntitySource.SetSource( IEntitySet source, string referenceKey, object foreignKeyValue ) { _source = source as EntitySet; _referenceKey = referenceKey; diff --git a/Achilles.Entities.Sqlite/Linq/IEntityCollectionOfT.cs b/Achilles.Entities.Sqlite/Linq/IEntityCollectionOfT.cs index 5d67251..9916695 100644 --- a/Achilles.Entities.Sqlite/Linq/IEntityCollectionOfT.cs +++ b/Achilles.Entities.Sqlite/Linq/IEntityCollectionOfT.cs @@ -16,8 +16,8 @@ namespace Achilles.Entities.Linq { - internal interface IEntityCollection + public interface IEntityCollection { - void AttachSource( IEnumerable source ); + bool IsLoaded { get; } } } diff --git a/Achilles.Entities.Sqlite/Linq/IEntityReferenceSource.cs b/Achilles.Entities.Sqlite/Linq/IEntitySource.cs similarity index 92% rename from Achilles.Entities.Sqlite/Linq/IEntityReferenceSource.cs rename to Achilles.Entities.Sqlite/Linq/IEntitySource.cs index 6216cef..b6214d8 100644 --- a/Achilles.Entities.Sqlite/Linq/IEntityReferenceSource.cs +++ b/Achilles.Entities.Sqlite/Linq/IEntitySource.cs @@ -16,7 +16,7 @@ namespace Achilles.Entities.Linq { - internal interface IEntityReferenceSource + internal interface IEntitySource { bool HasSource { get; } diff --git a/Achilles.Entities.Sqlite/Modelling/Mapping/Accessors/EntityCollectionAccessor.cs b/Achilles.Entities.Sqlite/Modelling/Mapping/Accessors/EntityCollectionAccessor.cs index 2e613ab..9a2bb34 100644 --- a/Achilles.Entities.Sqlite/Modelling/Mapping/Accessors/EntityCollectionAccessor.cs +++ b/Achilles.Entities.Sqlite/Modelling/Mapping/Accessors/EntityCollectionAccessor.cs @@ -8,23 +8,60 @@ #endregion +#region Namespaces + +using Achilles.Entities.Linq; using System; -using System.Collections.Generic; using System.Reflection; -using System.Text; + +#endregion namespace Achilles.Entities.Modelling.Mapping.Accessors { internal class EntityCollectionAccessor : MemberAccessor + where TEntity : class { MemberInfo _entityCollectionInfo; - IEntityMapping _entityMapping; + EntityMapping _entityMapping; - public EntityCollectionAccessor( IEntityMapping entityMapping, MemberInfo entityReferenceInfo ) - : base( entityReferenceInfo ) + public EntityCollectionAccessor( EntityMapping entityMapping, MemberInfo entityCollectionInfo ) + : base( entityCollectionInfo ) { _entityMapping = entityMapping; - _entityCollectionInfo = entityReferenceInfo; + _entityCollectionInfo = entityCollectionInfo; + } + + public Type EntityType => typeof( TEntity ); + + public override object GetValue( TEntity entity ) + { + return base.GetValue( entity ); + } + + public override void SetValue( TEntity entity, object value ) + { + // The base.GetValue gets the entityReference class + var entityCollection = base.GetValue( entity ) as IEntitySource; + + // The The foreign key mapping comes from the value passed to this method + IForeignKeyMapping foreignKeyMapping = (IForeignKeyMapping)value; + + var entityCollectionType = foreignKeyMapping.ForeignKeyProperty.DeclaringType; + + // TJT: Clean the EntitySet access up! + var entitySetSource = _entityMapping.Model.DataContext.EntitySets[ entityCollectionType ]; + + var keyName = foreignKeyMapping.ForeignKeyProperty.Name; + var keyValue = _entityMapping.GetColumn( entity, foreignKeyMapping.ReferenceKeyProperty.Name ); + + if ( !entityCollection.HasSource ) + { + entityCollection.SetSource( entitySetSource, keyName, keyValue ); + } + else + { + throw new InvalidOperationException( "Entity reference source already set." ); + } } } } diff --git a/Achilles.Entities.Sqlite/Modelling/Mapping/Accessors/EntityReferenceAccessor.cs b/Achilles.Entities.Sqlite/Modelling/Mapping/Accessors/EntityReferenceAccessor.cs index d93ee67..0130e5c 100644 --- a/Achilles.Entities.Sqlite/Modelling/Mapping/Accessors/EntityReferenceAccessor.cs +++ b/Achilles.Entities.Sqlite/Modelling/Mapping/Accessors/EntityReferenceAccessor.cs @@ -45,7 +45,7 @@ public override object GetValue( TEntity entity ) public override void SetValue( TEntity entity, object value ) { // The base.GetValue gets the entityReference class - var entityReference = base.GetValue( entity ) as IEntityReferenceSource; + var entityReference = base.GetValue( entity ) as IEntitySource; // The The foreign key mapping comes from the value passed to this method IForeignKeyMapping foreignKeyMapping = (IForeignKeyMapping)value; diff --git a/Achilles.Entities.Sqlite/Modelling/Mapping/EntityMapping.cs b/Achilles.Entities.Sqlite/Modelling/Mapping/EntityMapping.cs index 26821dc..a11f408 100644 --- a/Achilles.Entities.Sqlite/Modelling/Mapping/EntityMapping.cs +++ b/Achilles.Entities.Sqlite/Modelling/Mapping/EntityMapping.cs @@ -78,6 +78,9 @@ public object GetForeignKey( T entity, string propertyName ) where T : class public void SetEntityReference( T Entity, string propertyName, object source ) where T : class => EntityReferenceAccessors[ propertyName ].SetValue( Entity as TEntity, source ); + public void SetEntityCollection( T Entity, string propertyName, object source ) where T : class + => EntityCollectionAccessors[ propertyName ].SetValue( Entity as TEntity, source ); + public List ColumnMappings { get; set; } = new List(); public List IndexMappings { get; set; } = new List(); diff --git a/Achilles.Entities.Sqlite/Modelling/Mapping/IEntityMapping.cs b/Achilles.Entities.Sqlite/Modelling/Mapping/IEntityMapping.cs index 95a7b7c..f1e4cf3 100644 --- a/Achilles.Entities.Sqlite/Modelling/Mapping/IEntityMapping.cs +++ b/Achilles.Entities.Sqlite/Modelling/Mapping/IEntityMapping.cs @@ -44,6 +44,8 @@ public interface IEntityMapping void SetEntityReference( T entity, string propertyName, object value ) where T : class; + void SetEntityCollection( T entity, string propertyName, object value ) where T : class; + object GetForeignKey( T entity, string propertyName ) where T : class; void Compile(); diff --git a/Achilles.Entities.Sqlite/Querying/EntityMaterializer.cs b/Achilles.Entities.Sqlite/Querying/EntityMaterializer.cs index 9b7479a..282d310 100644 --- a/Achilles.Entities.Sqlite/Querying/EntityMaterializer.cs +++ b/Achilles.Entities.Sqlite/Querying/EntityMaterializer.cs @@ -259,10 +259,10 @@ private void SetDeferredLoading ( object entity ) if ( relationshipMapping.IsMany ) { - //entityMapping.SetEntityCollection( - // entity, - // relationshipMapping.RelationshipProperty.Name, - // relationshipMapping.ForeignKeyMapping ); + entityMapping.SetEntityCollection( + entity, + relationshipMapping.RelationshipProperty.Name, + relationshipMapping.ForeignKeyMapping ); } else { diff --git a/Entities.Sqlite.Tests/DbContext/Database/DatabaseQueryTest.cs b/Entities.Sqlite.Tests/DbContext/Database/DatabaseQueryTest.cs index 7f62e55..b4c5c3e 100644 --- a/Entities.Sqlite.Tests/DbContext/Database/DatabaseQueryTest.cs +++ b/Entities.Sqlite.Tests/DbContext/Database/DatabaseQueryTest.cs @@ -92,6 +92,39 @@ public void Database_ComplexQuery_CanReadListWithLazyEntityReference() } } + [Fact] + public void Database_ComplexQuery_CanReadListWithLazyEntityCollection() + { + const string connectionString = "Data Source=:memory:"; + var options = new DataContextOptionsBuilder().UseSqlite( connectionString ).Options; + + using ( var context = new TestDataContext( options ) ) + { + InitializeContext( context ); + + Assert.Equal( 2, context.Products.Count() ); + + var query = from p in context.Products + where p.Name == "Banana" + select p; + + var products = query.ToList(); + + Assert.False( products[ 0 ].Parts.IsLoaded ); + + var count = products[ 0 ].Parts.Count; + Assert.Equal( 3, count ); + + var parts = products[ 0 ].Parts.ToList(); + + Assert.Equal( "Bolt", parts[ 0 ].Name ); + Assert.Equal( "Wrench", parts[ 1 ].Name ); + Assert.Equal( "Hammer", parts[ 2 ].Name ); + + Assert.True( products[ 0 ].Parts.IsLoaded ); + } + } + [Fact] public void Database_ComplexQuery_CanReadList() {