Update parent and child collections on generic repository with EF Core

后端 未结 3 555
醉话见心
醉话见心 2020-12-29 14:14

Say I have a Sale class:

public class Sale : BaseEntity //BaseEntity only has an Id  
{        
    public ICollection Items { get;          


        
相关标签:
3条回答
  • 2020-12-29 14:55

    Apparently the question is for applying modifications of disconnected entity (otherwise you won't need to do anything else than calling SaveChanges) containing collection navigation properties which need to reflect the added/removed/update items from the passed object.

    EF Core does not provide such out of the box capability. It supports simple upsert (insert or update) through Update method for entities with auto-generated keys, but it doesn't detect and delete the removed items.

    So you need to do that detection yourself. Loading the existing items is a step in the right direction. The problem with your code is that it doesn't account the new items, but instead is doing some useless state manipulation of the existing items retrieved from the database.

    Following is the correct implementation of the same idea. It uses some EF Core internals (IClrCollectionAccessor returned by the GetCollectionAccessor() method - both require using Microsoft.EntityFrameworkCore.Metadata.Internal;) to manipulate the collection, but your code already is using the internal GetPropertyAccess() method, so I guess that shouldn't be a problem - in case something is changed in some future EF Core version, the code should be updated accordingly. The collection accessor is needed because while IEnumerable<BaseEntity> can be used for generically accessing the collections due to covariance, the same cannot be said about ICollection<BaseEntity> because it's invariant, and we need a way to access Add / Remove methods. The internal accessor provides that capability as well as a way to generically retrieve the property value from the passed entity.

    Update: Starting from EF Core 3.0, GetCollectionAccessor and IClrCollectionAccessor are part of the public API.

    Here is the code:

    public async Task<int> UpdateAsync<T>(T entity, params Expression<Func<T, object>>[] navigations) where T : BaseEntity
    {
        var dbEntity = await _dbContext.FindAsync<T>(entity.Id);
    
        var dbEntry = _dbContext.Entry(dbEntity);
        dbEntry.CurrentValues.SetValues(entity);
    
        foreach (var property in navigations)
        {
            var propertyName = property.GetPropertyAccess().Name;
            var dbItemsEntry = dbEntry.Collection(propertyName);
            var accessor = dbItemsEntry.Metadata.GetCollectionAccessor();
    
            await dbItemsEntry.LoadAsync();
            var dbItemsMap = ((IEnumerable<BaseEntity>)dbItemsEntry.CurrentValue)
                .ToDictionary(e => e.Id);
    
            var items = (IEnumerable<BaseEntity>)accessor.GetOrCreate(entity);
    
            foreach (var item in items)
            {
                if (!dbItemsMap.TryGetValue(item.Id, out var oldItem))
                    accessor.Add(dbEntity, item);
                else
                {
                    _dbContext.Entry(oldItem).CurrentValues.SetValues(item);
                    dbItemsMap.Remove(item.Id);
                }
            }
    
            foreach (var oldItem in dbItemsMap.Values)
                accessor.Remove(dbEntity, oldItem);
        }
    
        return await _dbContext.SaveChangesAsync();
    }
    

    The algorithm is pretty standard. After loading the collection from the database, we create a dictionary containing the existing items keyed by Id (for fast lookup). Then we do a single pass over the new items. We use the dictionary to find the corresponding existing item. If no match is found, the item is considered new and is simply added to the target (tracked) collection. Otherwise the found item is updated from the source, and removed from the dictionary. This way, after finishing the loop, the dictionary contains the items that needs to be deleted, so all we need is remove them from the target (tracked) collection.

    And that's all. The rest of the work will be done by the EF Core change tracker - the added items to the target collection will be marked as Added, the updated - either Unchanged or Modified, and the removed items, depending of the delete cascade behavior will be either be marked for deletion or update (disassociate from parent). If you want to force deletion, simply replace

    accessor.Remove(dbEntity, oldItem);
    

    with

    _dbContext.Remove(oldItem);
    
    0 讨论(0)
  • 2020-12-29 14:59

    The easiest would be to just get all Deleted entities, cast them to BaseEntity and check their IDs to the current IDs in the entity's relationship collection.

    Something along the lines of:

    foreach (var property in navigations)
    {
        var propertyName = property.GetPropertyAccess().Name;
    
        await dbEntry.Collection(propertyName).LoadAsync();
    
        // this line specifically might need some changes
        // as it may give you ICollection<SomeType>
        var currentCollectionType = property.GetPropertyAccess().PropertyType;
    
        var deletedEntities = _dbContext.ChangeTracker
            .Entries
            .Where(x => x.EntityState == EntityState.Deleted && x.GetType() == currentCollectionType)
            .Select(x => (BaseEntity)x.Id)
            .ToArray();
    
        List<BaseEntity> dbChilds = dbEntry.Collection(propertyName).CurrentValue.Cast<BaseEntity>().ToList();
    
        foreach (BaseEntity child in dbChilds)
        {
            if (child.Id == 0)
            {
                _dbContext.Entry(child).State = EntityState.Added;
            }
    
            if (deletedEntities.Contains(child.Id))
            {
                _dbContext.Entry(child).State = EntityState.Deleted;
            }
            else
            {
                _dbContext.Entry(child).State = EntityState.Modified;
            }
        }
    }
    
    0 讨论(0)
  • 2020-12-29 15:09

    @craigmoliver Here's my solution. It is not the best, I know - if you find a more elegant way, please share.

    Repository:

    public async Task<TEntity> UpdateAsync<TEntity, TId>(TEntity entity, bool save = true, params Expression<Func<TEntity, object>>[] navigations)
                where TEntity : class, IIdEntity<TId>
            {
                TEntity dbEntity = await _context.FindAsync<TEntity>(entity.Id);
    
            EntityEntry<TEntity> dbEntry = _context.Entry(dbEntity);
            dbEntry.CurrentValues.SetValues(entity);
    
            foreach (Expression<Func<TEntity, object>> property in navigations)
            {
                var propertyName = property.GetPropertyAccess().Name;
                CollectionEntry dbItemsEntry = dbEntry.Collection(propertyName);
                IClrCollectionAccessor accessor = dbItemsEntry.Metadata.GetCollectionAccessor();
    
                await dbItemsEntry.LoadAsync();
                var dbItemsMap = ((IEnumerable<object>)dbItemsEntry.CurrentValue)
                    .ToDictionary(e => string.Join('|', _context.FindPrimaryKeyValues(e)));
    
                foreach (var item in (IEnumerable)accessor.GetOrCreate(entity))
                {
                    if (!dbItemsMap.TryGetValue(string.Join('|', _context.FindPrimaryKeyValues(item)), out object oldItem))
                    {
                        accessor.Add(dbEntity, item);
                    }
                    else
                    {
                        _context.Entry(oldItem).CurrentValues.SetValues(item);
                        dbItemsMap.Remove(string.Join('|', _context.FindPrimaryKeyValues(item)));
                    }
                }
    
                foreach (var oldItem in dbItemsMap.Values)
                {
                    accessor.Remove(dbEntity, oldItem);
                    await DeleteAsync(oldItem as IEntity, false);
    
                }
            }
    
            if (save)
            {
                await SaveChangesAsync();
            }
    
            return entity;
        }
    

    Context:

     public IReadOnlyList<IProperty> FindPrimaryKeyProperties<T>(T entity)
            {
                return Model.FindEntityType(entity.GetType()).FindPrimaryKey().Properties;
            }
    
            public IEnumerable<object> FindPrimaryKeyValues<TEntity>(TEntity entity) where TEntity : class
            {
                return from p in FindPrimaryKeyProperties(entity)
                       select entity.GetPropertyValue(p.Name);
            }
    
    0 讨论(0)
提交回复
热议问题