How to use generic type with the database Context in EF6 Code First

后端 未结 3 1196
春和景丽
春和景丽 2021-02-06 10:25

For example, let say I have 4 different entity that each implement a Add() method that add the entity to the database :

public class Profile
{
    ...

    publi         


        
相关标签:
3条回答
  • 2021-02-06 10:44

    You could get around that be the following; you just have to make sure at runtime it really is a TEntity:

    public void Add()
    {
        object obj = this;
        this._dbContext.Set<TEntity>().Add((TEntity)obj);
        this._dbContext.SaveChanges();
    }
    

    Since the compiler loses track of what this is when you use an object type. If you get an error, it's because obj is not truly a TEntity. However, you may want to use a factory, repository, or other design pattern for working with the entity framework DBSet.

    0 讨论(0)
  • 2021-02-06 10:53

    The solution to your problem is to be more explicit with the definition of the generic constraint. Define the constraint as TEntity must be a sub-class of Entity<TEntity> i.e. use where TEntity : Entity<TEntity> instead of where TEntity : class

    public abstract class Entity<TEntity> where TEntity : Entity<TEntity>
    {
        protected DbContext _dbContext;
    
        protected Entity()
        {
            this._dbContext = new SMTDBContext();
        }
    
        public void Add()
        {
            this._dbContext.Set<TEntity>().Add((TEntity)this);
            this._dbContext.SaveChanges();
        }
    }
    
    0 讨论(0)
  • 2021-02-06 11:02

    Try a generic repository, at the end you will develop something similar. You need 3 interfaces:

    • IEntity
    • IEntityRepository
    • IEntityContext

    And the implementations to those interfaces:

    • EntityContext
    • EntityRepository

    Here the code:

    IEntity.cs

    public interface IEntity<TId> where TId : IComparable
    {
        TId Id { get; set; }
    }
    

    IEntityContext.cs

    public interface IEntityContext : IDisposable
    {
        void SetAsAdded<TEntity>(TEntity entity) where TEntity : class;
        void SetAsModified<TEntity>(TEntity entity) where TEntity : class;
        void SetAsDeleted<TEntity>(TEntity entity) where TEntity : class;
    
        IDbSet<TEntity> Set<TEntity>() where TEntity : class;
        int SaveChanges();
    }
    

    IEntityRepository.cs

    public interface IEntityRepository<TEntity, TId>
        : IDisposable
        where TEntity : class, IEntity<TId>
        where TId : IComparable
    {
        IQueryable<TEntity> GetAll(
            Expression<Func<TEntity, bool>> where = null,
            Expression<Func<TEntity, object>> orderBy = null);
        PaginatedList<TEntity> Paginate(int pageIndex, int pageSize);
    
        TEntity GetSingle(TId id);
    
        IQueryable<TEntity> GetAllIncluding(
            Expression<Func<TEntity, bool>> where,
            Expression<Func<TEntity, object>> orderBy,
            params Expression<Func<TEntity, object>>[] includeProperties);
    
        TEntity GetSingleIncluding(
            TId id, params Expression<Func<TEntity, object>>[] includeProperties);
    
        void Add(TEntity entity);
        void Attach(TEntity entity);
        void Edit(TEntity entity);
        void Delete(TEntity entity);
        int Save();
    }
    

    EntityRepository.cs

    public class EntityRepository<TEntity, TId>
        : IEntityRepository<TEntity, TId>
        where TEntity : class, IEntity<TId>
        where TId : IComparable
    {
    
        private readonly IEntityContext _dbContext;
    
        public EntityRepository(IEntityContext dbContext)
        {
            if (dbContext == null)
                throw new ArgumentNullException("dbContext");
    
            _dbContext = dbContext;
        }
    
        public IQueryable<TEntity> GetAllIncluding(
            Expression<Func<TEntity, bool>> where,
            Expression<Func<TEntity, object>> orderBy,
            params Expression<Func<TEntity, object>>[] includeProperties)
        {
            try
            {
                IQueryable<TEntity> queryable = GetAll(where, orderBy);
                foreach (Expression<Func<TEntity, object>> includeProperty in includeProperties)
                {
                    queryable =
                        queryable.Include<TEntity, object>(includeProperty);
                }
                return queryable;
            }
            catch (Exception)
            {
                throw;
            }
        }
    
        public TEntity GetSingleIncluding(
            TId id,
            params Expression<Func<TEntity, object>>[] includeProperties)
        {
            try
            {
                IQueryable<TEntity> entities =
                        GetAllIncluding(null, null, includeProperties);
                TEntity entity =
                    Filter<TId>(entities, x => x.Id, id).FirstOrDefault();
                return entity;
            }
            catch (Exception)
            {
                throw;
            }
        }
    
        public void Add(TEntity entity)
        {
            try
            {
                _dbContext.Set<TEntity>().Add(entity);
                if (this.EntityAdded != null)
                    this.EntityAdded(this, new EntityAddedEventArgs<TEntity, TId>(entity));
            }
            catch (Exception)
            {
                throw;
            }
        }
    
        public void Attach(TEntity entity)
        {
            try
            {
                _dbContext.SetAsAdded(entity);
                if (this.EntityAttach != null)
                    this.EntityAttach(this, new EntityAddedEventArgs<TEntity, TId>(entity));
            }
            catch (Exception)
            {
                throw;
            }
        }
    
        public void Edit(TEntity entity)
        {
            try
            {
                _dbContext.SetAsModified(entity);
                if (this.EntityModified != null)
                    this.EntityModified(this, new EntityModifiedEventArgs<TEntity, TId>(entity));
            }
            catch (Exception)
            {
                throw;
            }
        }
    
        public void Delete(TEntity entity)
        {
            try
            {
                _dbContext.SetAsDeleted(entity);
                if (this.EntityDeleted != null)
                    this.EntityDeleted(this, new EntityDeletedEventArgs<TEntity, TId>(entity));
            }
            catch (Exception)
            {
                throw;
            }
        }
    
        public int Save()
        {
            try
            {
                return _dbContext.SaveChanges();
            }
            catch (Exception)
            {
                throw;
            }
        }
    
        public IQueryable<TEntity> GetAll(
            Expression<Func<TEntity, bool>> where = null,
            Expression<Func<TEntity, object>> orderBy = null)
        {
            try
            {
                IQueryable<TEntity> queryable =
                    (where != null) ? _dbContext.Set<TEntity>().Where(where)
                    : _dbContext.Set<TEntity>();
    
                return (orderBy != null) ? queryable.OrderBy(orderBy)
                    : queryable;
            }
            catch (Exception)
            {
                throw;
            }
        }
    
        public TEntity GetSingle(TId id)
        {
            try
            {
                IQueryable<TEntity> entities = GetAll();
                TEntity entity =
                    Filter<TId>(entities, x => x.Id, id).FirstOrDefault();
                return entity;
            }
            catch (Exception)
            {
                throw;
            }
        }
    
        public void Dispose()
        {
            _dbContext.Dispose();
        }
    
        #region Private
    
        private IQueryable<TEntity> Filter<TProperty>(
            IQueryable<TEntity> dbSet,
            Expression<Func<TEntity, TProperty>> property, TProperty value)
            where TProperty : IComparable
        {
            try
            {
                var memberExpression = property.Body as MemberExpression;
    
                if (memberExpression == null ||
                    !(memberExpression.Member is PropertyInfo))
                    throw new ArgumentException
                        ("Property expected", "property");
    
                Expression left = property.Body;
                Expression right =
                    Expression.Constant(value, typeof(TProperty));
                Expression searchExpression = Expression.Equal(left, right);
    
                Expression<Func<TEntity, bool>> lambda =
                    Expression.Lambda<Func<TEntity, bool>>(
                        searchExpression,
                        new ParameterExpression[] { property.Parameters.Single() });
    
                return dbSet.Where(lambda);
            }
            catch (Exception)
            {
                throw;
            }
        }
    
        private enum OrderByType
        {
            Ascending,
            Descending
        }
        #endregion
    }
    

    EntityContext.cs

    public abstract class EntityContext : DbContext, IEntityContext
    {
        /// <summary>
        /// Constructs a new context instance using conventions to create the name of
        /// the database to which a connection will be made. The by-convention name is
        /// the full name (namespace + class name) of the derived context class.  See
        /// the class remarks for how this is used to create a connection. 
        /// </summary>
        protected EntityContext() : base() { }
    
        /// <summary>
        /// Constructs a new context instance using conventions to create the name of
        /// the database to which a connection will be made, and initializes it from
        /// the given model.  The by-convention name is the full name (namespace + class
        /// name) of the derived context class.  See the class remarks for how this is
        /// used to create a connection.
        /// </summary>
        /// <param name="model">The model that will back this context.</param>
        protected EntityContext(DbCompiledModel model) : base(model) { }
    
        /// <summary>
        /// Constructs a new context instance using the given string as the name or connection
        /// string for the database to which a connection will be made.  See the class
        /// remarks for how this is used to create a connection.
        /// </summary>
        /// <param name="nameOrConnectionString">Either the database name or a connection string.</param>
        public EntityContext(string nameOrConnectionString)
            : base(nameOrConnectionString) { }
    
        /// <summary>
        /// Constructs a new context instance using the existing connection to connect
        /// to a database.  The connection will not be disposed when the context is disposed.
        /// </summary>
        /// <param name="existingConnection">An existing connection to use for the new context.</param>
        /// <param name="contextOwnsConnection">
        /// If set to true the connection is disposed when the context is disposed, otherwise
        /// the caller must dispose the connection.
        /// </param>
        public EntityContext
            (DbConnection existingConnection, bool contextOwnsConnection)
            : base(existingConnection, contextOwnsConnection) { }
    
        /// <summary>
        /// Constructs a new context instance around an existing ObjectContext.  An existing
        /// ObjectContext to wrap with the new context.  If set to true the ObjectContext
        /// is disposed when the EntitiesContext is disposed, otherwise the caller must dispose
        /// the connection.
        /// </summary>
        /// <param name="objectContext">An existing ObjectContext to wrap with the new context.</param>
        /// <param name="EntitiesContextOwnsObjectContext">
        /// If set to true the ObjectContext is disposed when the EntitiesContext is disposed,
        /// otherwise the caller must dispose the connection.
        /// </param>
        public EntityContext(
            ObjectContext objectContext,
            bool EntityContextOwnsObjectContext)
            : base(objectContext, EntityContextOwnsObjectContext)
        { }
    
        /// <summary>
        /// Constructs a new context instance using the given string as the name or connection
        /// string for the database to which a connection will be made, and initializes
        /// it from the given model.  See the class remarks for how this is used to create
        /// a connection.
        /// </summary>
        /// <param name="nameOrConnectionString">Either the database name or a connection string.</param>
        /// <param name="model">The model that will back this context.</param>
        public EntityContext(
            string nameOrConnectionString,
            DbCompiledModel model)
            : base(nameOrConnectionString, model)
        { }
    
        /// <summary>
        /// Constructs a new context instance using the existing connection to connect
        /// to a database, and initializes it from the given model.  The connection will
        /// not be disposed when the context is disposed.  An existing connection to
        /// use for the new context.  The model that will back this context.  If set
        /// to true the connection is disposed when the context is disposed, otherwise
        /// the caller must dispose the connection.
        /// </summary>
        /// <param name="existingConnection">An existing connection to use for the new context.</param>
        /// <param name="model">The model that will back this context.</param>
        /// <param name="contextOwnsConnection">
        /// If set to true the connection is disposed when the context is disposed, otherwise
        /// the caller must dispose the connection.
        /// </param>
        public EntityContext(
            DbConnection existingConnection,
            DbCompiledModel model, bool contextOwnsConnection)
            : base(existingConnection, model, contextOwnsConnection)
        { }
    
        public new IDbSet<TEntity> Set<TEntity>() where TEntity : class
        {
            try
            {
                return base.Set<TEntity>();
            }
            catch (Exception)
            {
                throw;
            }
        }
    
        public void SetAsAdded<TEntity>(TEntity entity) where TEntity : class
        {
            try
            {
                DbEntityEntry dbEntityEntry = GetDbEntityEntrySafely(entity);
                dbEntityEntry.State = EntityState.Added;
            }
            catch (Exception)
            {
                throw;
            }
        }
    
        public void SetAsModified<TEntity>(TEntity entity) where TEntity : class
        {
            try
            {
                DbEntityEntry dbEntityEntry = GetDbEntityEntrySafely(entity);
                dbEntityEntry.State = EntityState.Modified;
            }
            catch (Exception)
            {
                throw;
            }
        }
    
        public void SetAsDeleted<TEntity>(TEntity entity) where TEntity : class
        {
            try
            {
                DbEntityEntry dbEntityEntry = GetDbEntityEntrySafely(entity);
                dbEntityEntry.State = EntityState.Deleted;
            }
            catch (Exception)
            {
                throw;
            }
        }
    
        public override int SaveChanges()
        {
            try
            {
                return base.SaveChanges();
            }
            catch (Exception)
            {
                throw;
            }
        }
    
        public new void Dispose()
        {
            try
            {
                base.Dispose();
            }
            catch (Exception)
            {
                throw;
            }
        }
    
        #region Private
        private DbEntityEntry GetDbEntityEntrySafely<TEntity>(
            TEntity entity) where TEntity : class
        {
            try
            {
                DbEntityEntry dbEntityEntry = base.Entry<TEntity>(entity);
                if (dbEntityEntry.State == EntityState.Detached)
                    Set<TEntity>().Attach(entity);
    
                return dbEntityEntry;
            }
            catch (Exception)
            {
                throw;
            }
        }
        #endregion
    }
    

    Long Answer but worth it... Have a nice day :) Its part of a personal huge project :D

    0 讨论(0)
提交回复
热议问题