diff --git a/samples/UnitOfWork.Host/Controllers/ValuesController.cs b/samples/UnitOfWork.Host/Controllers/ValuesController.cs index 5fcfe36..83d0052 100644 --- a/samples/UnitOfWork.Host/Controllers/ValuesController.cs +++ b/samples/UnitOfWork.Host/Controllers/ValuesController.cs @@ -13,7 +13,7 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Host.Controllers public class ValuesController : Controller { private readonly IUnitOfWork _unitOfWork; - private ILogger _logger; + private readonly ILogger _logger; // 1. IRepositoryFactory used for readonly scenario; // 2. IUnitOfWork used for read/write scenario; @@ -27,135 +27,137 @@ public ValuesController(IUnitOfWork unitOfWork, ILogger logger var repo = _unitOfWork.GetRepository(hasCustomRepository: true); if (repo.Count() == 0) { - repo.Insert(new Blog - { - Id = 1, - Url = "/a/" + 1, - Title = $"a{1}", - Posts = new List{ - new Post - { - Id = 1, - Title = "A", - Content = "A's content", - Comments = new List - { - new Comment - { - Id = 1, - Title = "A", - Content = "A's content", - }, - new Comment - { - Id = 2, - Title = "b", - Content = "b's content", - }, - new Comment - { - Id = 3, - Title = "c", - Content = "c's content", - } - }, - }, - new Post - { - Id = 2, - Title = "B", - Content = "B's content", - Comments = new List - { - new Comment - { - Id = 4, - Title = "A", - Content = "A's content", - }, - new Comment - { - Id = 5, - Title = "b", - Content = "b's content", - }, - new Comment - { - Id = 6, - Title = "c", - Content = "c's content", - } - }, - }, - new Post - { - Id = 3, - Title = "C", - Content = "C's content", - Comments = new List - { - new Comment - { - Id = 7, - Title = "A", - Content = "A's content", - }, - new Comment - { - Id = 8, - Title = "b", - Content = "b's content", - }, - new Comment - { - Id = 9, - Title = "c", - Content = "c's content", - } - }, - }, - new Post - { - Id = 4, - Title = "D", - Content = "D's content", - Comments = new List - { - new Comment - { - Id = 10, - Title = "A", - Content = "A's content", - }, - new Comment - { - Id = 11, - Title = "b", - Content = "b's content", - }, - new Comment - { - Id = 12, - Title = "c", - Content = "c's content", - } - }, - } - }, - }); + SeedInitialEntities(repo); _unitOfWork.SaveChanges(); } } + private static void SeedInitialEntities(IRepository repo) + => repo.Insert(new Blog + { + Id = 1, + Url = "/a/" + 1, + Title = $"a{1}", + Posts = new List{ + new Post + { + Id = 1, + Title = "A", + Content = "A's content", + Comments = new List + { + new Comment + { + Id = 1, + Title = "A", + Content = "A's content", + }, + new Comment + { + Id = 2, + Title = "b", + Content = "b's content", + }, + new Comment + { + Id = 3, + Title = "c", + Content = "c's content", + } + }, + }, + new Post + { + Id = 2, + Title = "B", + Content = "B's content", + Comments = new List + { + new Comment + { + Id = 4, + Title = "A", + Content = "A's content", + }, + new Comment + { + Id = 5, + Title = "b", + Content = "b's content", + }, + new Comment + { + Id = 6, + Title = "c", + Content = "c's content", + } + }, + }, + new Post + { + Id = 3, + Title = "C", + Content = "C's content", + Comments = new List + { + new Comment + { + Id = 7, + Title = "A", + Content = "A's content", + }, + new Comment + { + Id = 8, + Title = "b", + Content = "b's content", + }, + new Comment + { + Id = 9, + Title = "c", + Content = "c's content", + } + }, + }, + new Post + { + Id = 4, + Title = "D", + Content = "D's content", + Comments = new List + { + new Comment + { + Id = 10, + Title = "A", + Content = "A's content", + }, + new Comment + { + Id = 11, + Title = "b", + Content = "b's content", + }, + new Comment + { + Id = 12, + Title = "c", + Content = "c's content", + } + }, + } + }, + }); + // GET api/values [HttpGet] - public async Task> Get() - { - return await _unitOfWork.GetRepository().GetAllAsync(include: source => source.Include(blog => blog.Posts).ThenInclude(post => post.Comments)); - } + public async Task> Get() + => await _unitOfWork.GetRepository() + .GetAllAsync(include: source => source.Include(blog => blog.Posts).ThenInclude(post => post.Comments)); // GET api/values/Page/5/10 - [HttpGet("Page/{pageIndex}/{pageSize}")] + [HttpGet("Page/{pageIndex:int}/{pageSize:int}")] public async Task> Get(int pageIndex, int pageSize) { // projection @@ -170,11 +172,11 @@ public async Task> Get(string term) { _logger.LogInformation("demo about first or default with include"); - var item = _unitOfWork.GetRepository().GetFirstOrDefault(predicate: x => x.Title.Contains(term), include: source => source.Include(blog => blog.Posts).ThenInclude(post => post.Comments)); + var item = await _unitOfWork.GetRepository().GetFirstOrDefaultAsync(predicate: x => x.Title.Contains(term), include: source => source.Include(blog => blog.Posts).ThenInclude(post => post.Comments)); _logger.LogInformation("demo about first or default without include"); - item = _unitOfWork.GetRepository().GetFirstOrDefault(predicate: x => x.Title.Contains(term), orderBy: source => source.OrderByDescending(b => b.Id)); + item = await _unitOfWork.GetRepository().GetFirstOrDefaultAsync(predicate: x => x.Title.Contains(term), orderBy: source => source.OrderByDescending(b => b.Id)); _logger.LogInformation("demo about first or default with projection"); @@ -184,11 +186,9 @@ public async Task> Get(string term) } // GET api/values/4 - [HttpGet("{id}")] - public async Task Get(int id) - { - return await _unitOfWork.GetRepository().FindAsync(id); - } + [HttpGet("{id:int}")] + public async Task Get(int id) + => await _unitOfWork.GetRepository().FindAsync(id); // POST api/values [HttpPost] diff --git a/samples/UnitOfWork.Host/Models/BlogggingContext.cs b/samples/UnitOfWork.Host/Models/BlogggingContext.cs index 1099928..e5fc31d 100644 --- a/samples/UnitOfWork.Host/Models/BlogggingContext.cs +++ b/samples/UnitOfWork.Host/Models/BlogggingContext.cs @@ -12,10 +12,7 @@ public BloggingContext(DbContextOptions options) public DbSet Blogs { get; set; } public DbSet Posts { get; set; } - protected override void OnModelCreating(ModelBuilder modelBuilder) - { - modelBuilder.EnableAutoHistory(null); - } + protected override void OnModelCreating(ModelBuilder modelBuilder) => modelBuilder.EnableAutoHistory(null); } public class Blog diff --git a/samples/UnitOfWork.Host/Models/CustomBlogRepository.cs b/samples/UnitOfWork.Host/Models/CustomBlogRepository.cs index c706628..571954d 100644 --- a/samples/UnitOfWork.Host/Models/CustomBlogRepository.cs +++ b/samples/UnitOfWork.Host/Models/CustomBlogRepository.cs @@ -1,10 +1,9 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Host.Models { - public class CustomBlogRepository : Repository, IRepository + public class CustomBlogRepository : Repository { public CustomBlogRepository(BloggingContext dbContext) : base(dbContext) { - } } } diff --git a/samples/UnitOfWork.Host/Startup.cs b/samples/UnitOfWork.Host/Startup.cs index b706c5d..a29ae19 100644 --- a/samples/UnitOfWork.Host/Startup.cs +++ b/samples/UnitOfWork.Host/Startup.cs @@ -33,8 +33,8 @@ public void ConfigureServices(IServiceCollection services) { // use in memory for testing. services - .AddDbContext(opt => opt.UseMySql("Server=localhost;database=uow;uid=root;pwd=root1234;")) - //.AddDbContext(opt => opt.UseInMemoryDatabase("UnitOfWork")) + //.AddDbContext(opt => opt.UseMySql("Server=localhost;database=uow;uid=root;pwd=root1234;")) + .AddDbContext(opt => opt.UseInMemoryDatabase("UnitOfWork")) .AddUnitOfWork() .AddCustomRepository(); diff --git a/samples/UnitOfWork.Host/UnitOfWork.Host.csproj b/samples/UnitOfWork.Host/UnitOfWork.Host.csproj index 3d001d5..03e55c7 100644 --- a/samples/UnitOfWork.Host/UnitOfWork.Host.csproj +++ b/samples/UnitOfWork.Host/UnitOfWork.Host.csproj @@ -1,23 +1,26 @@  + - netcoreapp3.1 + net6.0 + Arch.EntityFrameworkCore.UnitOfWork.Host true Exe + - - - - - - - - - + + + + + + + + + diff --git a/src/UnitOfWork/Collections/IEnumerablePagedListExtensions.cs b/src/UnitOfWork/Collections/EnumerablePagedListExtensions.cs similarity index 97% rename from src/UnitOfWork/Collections/IEnumerablePagedListExtensions.cs rename to src/UnitOfWork/Collections/EnumerablePagedListExtensions.cs index 01c8ef5..42697f7 100644 --- a/src/UnitOfWork/Collections/IEnumerablePagedListExtensions.cs +++ b/src/UnitOfWork/Collections/EnumerablePagedListExtensions.cs @@ -8,7 +8,7 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Collections /// /// Provides some extension methods for to provide paging capability. /// - public static class IEnumerablePagedListExtensions + public static class EnumerablePagedListExtensions { /// /// Converts the specified source to by the specified and . diff --git a/src/UnitOfWork/Collections/PagedList.cs b/src/UnitOfWork/Collections/PagedList.cs index 7dca0db..13d9c40 100644 --- a/src/UnitOfWork/Collections/PagedList.cs +++ b/src/UnitOfWork/Collections/PagedList.cs @@ -70,15 +70,15 @@ internal PagedList(IEnumerable source, int pageIndex, int pageSize, int index throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex}, must indexFrom <= pageIndex"); } - if (source is IQueryable querable) + if (source is IQueryable queryable) { PageIndex = pageIndex; PageSize = pageSize; IndexFrom = indexFrom; - TotalCount = querable.Count(); + TotalCount = queryable.Count(); TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize); - Items = querable.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToList(); + Items = queryable.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToList(); } else { @@ -95,7 +95,7 @@ internal PagedList(IEnumerable source, int pageIndex, int pageSize, int index /// /// Initializes a new instance of the class. /// - internal PagedList() => Items = new T[0]; + internal PagedList() => Items = Array.Empty(); } @@ -165,15 +165,15 @@ public PagedList(IEnumerable source, Func, IEnumer throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex}, must indexFrom <= pageIndex"); } - if (source is IQueryable querable) + if (source is IQueryable queryable) { PageIndex = pageIndex; PageSize = pageSize; IndexFrom = indexFrom; - TotalCount = querable.Count(); + TotalCount = queryable.Count(); TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize); - var items = querable.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToArray(); + var items = queryable.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToArray(); Items = new List(converter(items)); } diff --git a/src/UnitOfWork/Collections/IQueryablePageListExtensions.cs b/src/UnitOfWork/Collections/QueryablePageListExtensions.cs similarity index 84% rename from src/UnitOfWork/Collections/IQueryablePageListExtensions.cs rename to src/UnitOfWork/Collections/QueryablePageListExtensions.cs index 9578aae..3d198cf 100644 --- a/src/UnitOfWork/Collections/IQueryablePageListExtensions.cs +++ b/src/UnitOfWork/Collections/QueryablePageListExtensions.cs @@ -6,7 +6,7 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Collections { - public static class IQueryablePageListExtensions + public static class QueryablePageListExtensions { /// /// Converts the specified source to by the specified and . @@ -20,7 +20,7 @@ public static class IQueryablePageListExtensions /// /// The start index value. /// An instance of the inherited from interface. - public static async Task> ToPagedListAsync(this IQueryable source, int pageIndex, int pageSize, int indexFrom = 0, CancellationToken cancellationToken = default(CancellationToken)) + public static async Task> ToPagedListAsync(this IQueryable source, int pageIndex, int pageSize, int indexFrom = 0, CancellationToken cancellationToken = default) { if (indexFrom > pageIndex) { @@ -28,8 +28,11 @@ public static class IQueryablePageListExtensions } var count = await source.CountAsync(cancellationToken).ConfigureAwait(false); - var items = await source.Skip((pageIndex - indexFrom) * pageSize) - .Take(pageSize).ToListAsync(cancellationToken).ConfigureAwait(false); + var items = await source + .Skip((pageIndex - indexFrom) * pageSize) + .Take(pageSize) + .ToListAsync(cancellationToken) + .ConfigureAwait(false); var pagedList = new PagedList() { diff --git a/src/UnitOfWork/IRepository.cs b/src/UnitOfWork/IRepository.cs index a31fef2..6f96596 100644 --- a/src/UnitOfWork/IRepository.cs +++ b/src/UnitOfWork/IRepository.cs @@ -33,7 +33,7 @@ public interface IRepository where TEntity : class void ChangeTable(string table); /// - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// A function to test each element for a condition. /// A function to order elements. @@ -44,16 +44,17 @@ public interface IRepository where TEntity : class /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - IPagedList GetPagedList(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - int pageIndex = 0, - int pageSize = 20, - bool disableTracking = true, - bool ignoreQueryFilters = false); + IPagedList GetPagedList( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + int pageIndex = 0, + int pageSize = 20, + bool disableTracking = true, + bool ignoreQueryFilters = false); /// - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// A function to test each element for a condition. /// A function to order elements. @@ -67,17 +68,18 @@ IPagedList GetPagedList(Expression> predicate = nul /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - Task> GetPagedListAsync(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - int pageIndex = 0, - int pageSize = 20, - bool disableTracking = true, - CancellationToken cancellationToken = default(CancellationToken), - bool ignoreQueryFilters = false); + Task> GetPagedListAsync( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + int pageIndex = 0, + int pageSize = 20, + bool disableTracking = true, + CancellationToken cancellationToken = default, + bool ignoreQueryFilters = false); /// - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// The selector for projection. /// A function to test each element for a condition. @@ -89,17 +91,18 @@ Task> GetPagedListAsync(Expression> pred /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - IPagedList GetPagedList(Expression> selector, - Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - int pageIndex = 0, - int pageSize = 20, - bool disableTracking = true, - bool ignoreQueryFilters = false) where TResult : class; + IPagedList GetPagedList( + Expression> selector, + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + int pageIndex = 0, + int pageSize = 20, + bool disableTracking = true, + bool ignoreQueryFilters = false) where TResult : class; /// - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// The selector for projection. /// A function to test each element for a condition. @@ -114,18 +117,19 @@ IPagedList GetPagedList(Expression> sel /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - Task> GetPagedListAsync(Expression> selector, - Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - int pageIndex = 0, - int pageSize = 20, - bool disableTracking = true, - CancellationToken cancellationToken = default(CancellationToken), - bool ignoreQueryFilters = false) where TResult : class; + Task> GetPagedListAsync( + Expression> selector, + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + int pageIndex = 0, + int pageSize = 20, + bool disableTracking = true, + CancellationToken cancellationToken = default, + bool ignoreQueryFilters = false) where TResult : class; /// - /// Gets the first or default entity based on a predicate, orderby delegate and include delegate. This method defaults to a read-only, no-tracking query. + /// Gets the first or default entity based on a predicate, orderBy delegate and include delegate. This method defaults to a read-only, no-tracking query. /// /// A function to test each element for a condition. /// A function to order elements. @@ -134,14 +138,15 @@ Task> GetPagedListAsync(ExpressionIgnore query filters /// An that contains elements that satisfy the condition specified by . /// This method defaults to a read-only, no-tracking query. - TEntity GetFirstOrDefault(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - bool disableTracking = true, - bool ignoreQueryFilters = false); + TEntity GetFirstOrDefault( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false); /// - /// Gets the first or default entity based on a predicate, orderby delegate and include delegate. This method defaults to a read-only, no-tracking query. + /// Gets the first or default entity based on a predicate, orderBy delegate and include delegate. This method defaults to a read-only, no-tracking query. /// /// The selector for projection. /// A function to test each element for a condition. @@ -151,15 +156,16 @@ TEntity GetFirstOrDefault(Expression> predicate = null, /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// This method defaults to a read-only, no-tracking query. - TResult GetFirstOrDefault(Expression> selector, - Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - bool disableTracking = true, - bool ignoreQueryFilters = false); + TResult GetFirstOrDefault( + Expression> selector, + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false); /// - /// Gets the first or default entity based on a predicate, orderby delegate and include delegate. This method defaults to a read-only, no-tracking query. + /// Gets the first or default entity based on a predicate, orderBy delegate and include delegate. This method defaults to a read-only, no-tracking query. /// /// The selector for projection. /// A function to test each element for a condition. @@ -177,7 +183,7 @@ Task GetFirstOrDefaultAsync(Expression> bool ignoreQueryFilters = false); /// - /// Gets the first or default entity based on a predicate, orderby delegate and include delegate. This method defaults to a read-only, no-tracking query. + /// Gets the first or default entity based on a predicate, orderBy delegate and include delegate. This method defaults to a read-only, no-tracking query. /// /// A function to test each element for a condition. /// A function to order elements. @@ -238,11 +244,12 @@ Task GetFirstOrDefaultAsync(Expression> predicate = /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// Ex: This method defaults to a read-only, no-tracking query. - IQueryable GetAll(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - bool disableTracking = true, - bool ignoreQueryFilters = false); + IQueryable GetAll( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false); /// /// Gets all entities. This method is not recommended @@ -278,11 +285,12 @@ IQueryable GetAll(Expression> selector, /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// Ex: This method defaults to a read-only, no-tracking query. - Task> GetAllAsync(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - bool disableTracking = true, - bool ignoreQueryFilters = false); + Task> GetAllAsync( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false); /// /// Gets all entities. This method is not recommended @@ -333,66 +341,74 @@ Task> GetAllAsync(Expression> sel /// /// Gets the max based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - T Max(Expression> predicate = null, Expression> selector = null); + T Max(Expression> selector, Expression> predicate = null); /// /// Gets the async max based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - Task MaxAsync(Expression> predicate = null, Expression> selector = null); + Task MaxAsync(Expression> selector, Expression> predicate = null); /// /// Gets the min based on a predicate. /// - /// /// + /// /// decimal - T Min(Expression> predicate = null, Expression> selector = null); + T Min(Expression> selector, Expression> predicate = null); /// /// Gets the async min based on a predicate. /// - /// /// + /// /// decimal - Task MinAsync(Expression> predicate = null, Expression> selector = null); + Task MinAsync(Expression> selector, Expression> predicate = null); /// /// Gets the average based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - decimal Average (Expression> predicate = null, Expression> selector = null); + decimal Average(Expression> selector, Expression> predicate = null); /// - /// Gets the async average based on a predicate. - /// - /// - /// /// - /// decimal - Task AverageAsync(Expression> predicate = null, Expression> selector = null); + /// Gets the async average based on a predicate. + /// + /// + /// + /// /// + /// decimal + Task AverageAsync(Expression> selector, + Expression> predicate = null); /// /// Gets the sum based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - decimal Sum (Expression> predicate = null, Expression> selector = null); + decimal Sum(Expression> selector, Expression> predicate = null); /// /// Gets the async sum based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - Task SumAsync (Expression> predicate = null, Expression> selector = null); + Task SumAsync(Expression> selector, + Expression> predicate = null); /// /// Gets the Exists record based on a predicate. @@ -431,7 +447,7 @@ Task> GetAllAsync(Expression> sel /// The entity to insert. /// A to observe while waiting for the task to complete. /// A that represents the asynchronous insert operation. - ValueTask> InsertAsync(TEntity entity, CancellationToken cancellationToken = default(CancellationToken)); + ValueTask> InsertAsync(TEntity entity, CancellationToken cancellationToken = default); /// /// Inserts a range of entities asynchronously. @@ -446,7 +462,7 @@ Task> GetAllAsync(Expression> sel /// The entities to insert. /// A to observe while waiting for the task to complete. /// A that represents the asynchronous insert operation. - Task InsertAsync(IEnumerable entities, CancellationToken cancellationToken = default(CancellationToken)); + Task InsertAsync(IEnumerable entities, CancellationToken cancellationToken = default); /// /// Updates the specified entity. @@ -466,6 +482,25 @@ Task> GetAllAsync(Expression> sel /// The entities. void Update(IEnumerable entities); + /// + /// Check Given Entity is exists + /// + /// The entity. + bool Exists(TEntity entity); + + /// + /// Inserts or Updates the specified entities. + /// + /// The entity. + void InsertOrUpdate(TEntity entity); + + /// + /// Inserts or Updates the specified entities. + /// + /// The entities. + void InsertOrUpdate(IEnumerable entities); + + /// /// Deletes the entity by the specified primary key. /// @@ -490,6 +525,103 @@ Task> GetAllAsync(Expression> sel /// The entities. void Delete(IEnumerable entities); + + /// + /// Gets the based on a predicate, orderBy delegate. This method default no-tracking query. + /// + /// A function to test each element for a condition. + /// A function to order elements. + /// A function to include navigation properties + /// True to disable changing tracking; otherwise, false. Default to true. + /// Ignore query filters + /// + /// A to observe while waiting for the task to complete. + /// + /// An that contains elements that satisfy the condition specified by . + /// This method default no-tracking query. + Task> GetListAsync( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false, + CancellationToken cancellationToken = default); + + /// + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. + /// + /// A function to test each element for a condition. + /// A function to order elements. + /// A function to include navigation properties + /// True to disable changing tracking; otherwise, false. Default to true. + /// Ignore query filters + /// An that contains elements that satisfy the condition specified by . + /// This method default no-tracking query. + List GetList( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false); + + + /// + /// Finds next entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned. + /// + /// The values of the primary key for the entity to be found. + /// The found entity or null. + TEntity GetNextById(params object[] keyValues); + + + /// + /// Finds next entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned. + /// + /// The values of the primary key for the entity to be found. + /// The found entity or null. + Task GetNextByIdAsync(params object[] keyValues); + + + /// + /// Finds previous entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned. + /// + /// The values of the primary key for the entity to be found. + /// The found entity or null. + TEntity GetPreviousById(params object[] keyValues); + + /// + /// Finds previous entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned. + /// + /// The values of the primary key for the entity to be found. + /// The found entity or null. + Task GetPreviousByIdAsync(params object[] keyValues); + + /// + /// Finds the first entity with order by primary key. If found, is attached to the context and returned. If no entity is found, then null is returned. + /// + /// The found entity or null. + TEntity GetFirst(); + + + /// + /// Finds the first entity with order by primary key. If found, is attached to the context and returned. If no entity is found, then null is returned. + /// + /// The found entity or null. + Task GetFirstAsync(); + + /// + /// Finds the Last entity with order by primary key. If found, is attached to the context and returned. If no entity is found, then null is returned. + /// + /// The found entity or null. + TEntity GetLast(); + + + /// + /// Finds the last entity with order by primary key. If found, is attached to the context and returned. If no entity is found, then null is returned. + /// + /// The found entity or null. + Task GetLastAsync(); + + /// /// Change entity state for patch method on web api. /// diff --git a/src/UnitOfWork/IRepositoryFactory.cs b/src/UnitOfWork/IRepositoryFactory.cs index c5f017f..885ef48 100644 --- a/src/UnitOfWork/IRepositoryFactory.cs +++ b/src/UnitOfWork/IRepositoryFactory.cs @@ -14,7 +14,7 @@ public interface IRepositoryFactory /// /// Gets the specified repository for the . /// - /// True if providing custom repositry + /// True if providing custom repository /// The type of the entity. /// An instance of type inherited from interface. IRepository GetRepository(bool hasCustomRepository = false) where TEntity : class; diff --git a/src/UnitOfWork/IUnitOfWork.cs b/src/UnitOfWork/IUnitOfWork.cs index c78ca54..06fb276 100644 --- a/src/UnitOfWork/IUnitOfWork.cs +++ b/src/UnitOfWork/IUnitOfWork.cs @@ -7,14 +7,16 @@ namespace Arch.EntityFrameworkCore.UnitOfWork { using System; + using System.Data; using System.Linq; using System.Threading.Tasks; using Microsoft.EntityFrameworkCore.ChangeTracking; + using Microsoft.EntityFrameworkCore.Storage; /// /// Defines the interface(s) for unit of work. /// - public interface IUnitOfWork : IDisposable + public interface IUnitOfWork : IDisposable, IAsyncDisposable { /// /// Changes the database name. This require the databases in the same machine. NOTE: This only work for MySQL right now. @@ -28,7 +30,7 @@ public interface IUnitOfWork : IDisposable /// /// Gets the specified repository for the . /// - /// True if providing custom repositry + /// True if providing custom repository /// The type of the entity. /// An instance of type inherited from interface. IRepository GetRepository(bool hasCustomRepository = false) where TEntity : class; @@ -36,7 +38,7 @@ public interface IUnitOfWork : IDisposable /// /// Saves all changes made in this context to the database. /// - /// True if sayve changes ensure auto record the change history. + /// True if save changes ensure auto record the change history. /// The number of state entries written to the database. int SaveChanges(bool ensureAutoHistory = false); @@ -55,6 +57,15 @@ public interface IUnitOfWork : IDisposable /// The number of state entities written to database. int ExecuteSqlCommand(string sql, params object[] parameters); + + /// + /// Executes the specified raw SQL command. + /// + /// The raw SQL. + /// The parameters. + /// The number of state entities written to database. + DataTable ExecuteDtSqlCommand(string sql, params object[] parameters); + /// /// Uses raw SQL queries to fetch the specified data. /// @@ -68,7 +79,14 @@ public interface IUnitOfWork : IDisposable /// Uses TrakGrap Api to attach disconnected entities /// /// Root entity - /// Delegate to convert Object's State properities to Entities entry state. + /// Delegate to convert Object's State properties to Entities entry state. void TrackGraph(object rootEntity, Action callback); + + /// + /// Starts DatabaseLevel Transaction + /// + /// The IsolationLevel + /// Transaction Context + IDbContextTransaction BeginTransaction(IsolationLevel isolation = IsolationLevel.ReadCommitted); } } \ No newline at end of file diff --git a/src/UnitOfWork/IUnitOfWorkOfT.cs b/src/UnitOfWork/IUnitOfWorkOfT.cs index 6ea0c7e..cadb711 100644 --- a/src/UnitOfWork/IUnitOfWorkOfT.cs +++ b/src/UnitOfWork/IUnitOfWorkOfT.cs @@ -2,13 +2,15 @@ using System.Threading.Tasks; using Microsoft.EntityFrameworkCore; +using System.Data; +using System.Transactions; namespace Arch.EntityFrameworkCore.UnitOfWork { /// /// Defines the interface(s) for generic unit of work. /// - public interface IUnitOfWork : IUnitOfWork where TContext : DbContext { + public interface IUnitOfWork : IUnitOfWork where TContext : DbContext { /// /// Gets the db context. /// @@ -22,5 +24,15 @@ public interface IUnitOfWork : IUnitOfWork where TContext : DbContext /// An optional array. /// A that represents the asynchronous save operation. The task result contains the number of state entities written to database. Task SaveChangesAsync(bool ensureAutoHistory = false, params IUnitOfWork[] unitOfWorks); + + /// + /// Saves all changes made in this context to the database with distributed transaction. + /// + /// The transaction to use + /// True if save changes ensure auto record the change history. + /// An optional array. + /// A that represents the asynchronous save operation. The task result contains the number of state entities written to database. + Task SaveChangesAsync(Transaction transaction, bool ensureAutoHistory = false, params IUnitOfWork[] unitOfWorks); + } } diff --git a/src/UnitOfWork/Microsoft - Backup.EntityFrameworkCore.UnitOfWork.csproj b/src/UnitOfWork/Microsoft - Backup.EntityFrameworkCore.UnitOfWork.csproj new file mode 100644 index 0000000..7916b7b --- /dev/null +++ b/src/UnitOfWork/Microsoft - Backup.EntityFrameworkCore.UnitOfWork.csproj @@ -0,0 +1,28 @@ + + + A plugin for Microsoft.EntityFrameworkCore to support repository, unit of work patterns, and multiple database with distributed transaction supported. + 2.1.0 + rigofunc;rigofunc@outlook.com; + netstandard2.0 + $(NoWarn);CS1591 + true + true + Microsoft.EntityFrameworkCore.UnitOfWork + Microsoft.EntityFrameworkCore.UnitOfWork + Entity Framework Core;entity-framework-core;EF;Data;O/RM;unitofwork;Unit Of Work;unit-of-work + https://github.com/arch/UnitOfWork + https://github.com/arch/UnitOfWork/blob/master/LICENSE + git + https://github.com/arch/UnitOfWork.git + + + + + + + + + C:\Program Files\dotnet\sdk\NuGetFallbackFolder\system.data.sqlclient\4.4.0\ref\netstandard2.0\System.Data.SqlClient.dll + + + diff --git a/src/UnitOfWork/Repository.cs b/src/UnitOfWork/Repository.cs index c3fff96..2e35e3a 100644 --- a/src/UnitOfWork/Repository.cs +++ b/src/UnitOfWork/Repository.cs @@ -8,6 +8,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.EntityFrameworkCore; +using EFCore.BulkExtensions; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Query; using Arch.EntityFrameworkCore.UnitOfWork.Collections; @@ -21,8 +22,8 @@ namespace Arch.EntityFrameworkCore.UnitOfWork /// The type of the entity. public class Repository : IRepository where TEntity : class { - protected readonly DbContext _dbContext; - protected readonly DbSet _dbSet; + protected readonly DbContext DbContext; + protected readonly DbSet DbSet; /// /// Initializes a new instance of the class. @@ -30,8 +31,8 @@ public class Repository : IRepository where TEntity : class /// The database context. public Repository(DbContext dbContext) { - _dbContext = dbContext ?? throw new ArgumentNullException(nameof(dbContext)); - _dbSet = _dbContext.Set(); + DbContext = dbContext ?? throw new ArgumentNullException(nameof(dbContext)); + DbSet = DbContext.Set(); } /// @@ -43,7 +44,7 @@ public Repository(DbContext dbContext) /// public virtual void ChangeTable(string table) { - if (_dbContext.Model.FindEntityType(typeof(TEntity)) is IConventionEntityType relational) + if (DbContext.Model.FindEntityType(typeof(TEntity)) is IConventionEntityType relational) { relational.SetTableName(table); } @@ -53,10 +54,7 @@ public virtual void ChangeTable(string table) /// Gets all entities. This method is not recommended /// /// The . - public IQueryable GetAll() - { - return _dbSet; - } + public IQueryable GetAll() => DbSet; /// /// Gets all entities. This method is not recommended @@ -73,7 +71,7 @@ public IQueryable GetAll( Func, IOrderedQueryable> orderBy = null, Func, IIncludableQueryable> include = null, bool disableTracking = true, bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -99,10 +97,8 @@ public IQueryable GetAll( { return orderBy(query); } - else - { - return query; - } + + return query; } /// @@ -116,12 +112,13 @@ public IQueryable GetAll( /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// Ex: This method defaults to a read-only, no-tracking query. - public IQueryable GetAll(Expression> selector, + public IQueryable GetAll( + Expression> selector, Expression> predicate = null, Func, IOrderedQueryable> orderBy = null, Func, IIncludableQueryable> include = null, bool disableTracking = true, bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -147,14 +144,12 @@ public IQueryable GetAll(Expression> sel { return orderBy(query).Select(selector); } - else - { - return query.Select(selector); - } + + return query.Select(selector); } /// - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// A function to test each element for a condition. /// A function to order elements. @@ -165,15 +160,16 @@ public IQueryable GetAll(Expression> sel /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - public virtual IPagedList GetPagedList(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - int pageIndex = 0, - int pageSize = 20, - bool disableTracking = true, - bool ignoreQueryFilters = false) + public virtual IPagedList GetPagedList( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + int pageIndex = 0, + int pageSize = 20, + bool disableTracking = true, + bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -199,14 +195,12 @@ public virtual IPagedList GetPagedList(Expression> { return orderBy(query).ToPagedList(pageIndex, pageSize); } - else - { - return query.ToPagedList(pageIndex, pageSize); - } + + return query.ToPagedList(pageIndex, pageSize); } /// - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// A function to test each element for a condition. /// A function to order elements. @@ -220,16 +214,17 @@ public virtual IPagedList GetPagedList(Expression> /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - public virtual Task> GetPagedListAsync(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - int pageIndex = 0, - int pageSize = 20, - bool disableTracking = true, - CancellationToken cancellationToken = default(CancellationToken), - bool ignoreQueryFilters = false) + public virtual Task> GetPagedListAsync( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + int pageIndex = 0, + int pageSize = 20, + bool disableTracking = true, + CancellationToken cancellationToken = default, + bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -255,14 +250,12 @@ public virtual Task> GetPagedListAsync(Expression - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// The selector for projection. /// A function to test each element for a condition. @@ -274,17 +267,18 @@ public virtual Task> GetPagedListAsync(ExpressionIgnore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - public virtual IPagedList GetPagedList(Expression> selector, - Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - int pageIndex = 0, - int pageSize = 20, - bool disableTracking = true, - bool ignoreQueryFilters = false) + public virtual IPagedList GetPagedList( + Expression> selector, + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + int pageIndex = 0, + int pageSize = 20, + bool disableTracking = true, + bool ignoreQueryFilters = false) where TResult : class { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -310,14 +304,12 @@ public virtual IPagedList GetPagedList(Expression - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// The selector for projection. /// A function to test each element for a condition. @@ -332,18 +324,19 @@ public virtual IPagedList GetPagedList(ExpressionIgnore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - public virtual Task> GetPagedListAsync(Expression> selector, - Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - int pageIndex = 0, - int pageSize = 20, - bool disableTracking = true, - CancellationToken cancellationToken = default(CancellationToken), - bool ignoreQueryFilters = false) + public virtual Task> GetPagedListAsync( + Expression> selector, + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + int pageIndex = 0, + int pageSize = 20, + bool disableTracking = true, + CancellationToken cancellationToken = default, + bool ignoreQueryFilters = false) where TResult : class { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -369,14 +362,12 @@ public virtual Task> GetPagedListAsync(Expression - /// Gets the first or default entity based on a predicate, orderby delegate and include delegate. This method default no-tracking query. + /// Gets the first or default entity based on a predicate, orderBy delegate and include delegate. This method default no-tracking query. /// /// A function to test each element for a condition. /// A function to order elements. @@ -385,13 +376,14 @@ public virtual Task> GetPagedListAsync(ExpressionIgnore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - public virtual TEntity GetFirstOrDefault(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - bool disableTracking = true, - bool ignoreQueryFilters = false) + public virtual TEntity GetFirstOrDefault( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -417,21 +409,20 @@ public virtual TEntity GetFirstOrDefault(Expression> predica { return orderBy(query).FirstOrDefault(); } - else - { - return query.FirstOrDefault(); - } + + return query.FirstOrDefault(); } /// - public virtual async Task GetFirstOrDefaultAsync(Expression> predicate = null, + public virtual async Task GetFirstOrDefaultAsync( + Expression> predicate = null, Func, IOrderedQueryable> orderBy = null, Func, IIncludableQueryable> include = null, bool disableTracking = true, bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -457,14 +448,12 @@ public virtual async Task GetFirstOrDefaultAsync(Expression - /// Gets the first or default entity based on a predicate, orderby delegate and include delegate. This method default no-tracking query. + /// Gets the first or default entity based on a predicate, orderBy delegate and include delegate. This method default no-tracking query. /// /// The selector for projection. /// A function to test each element for a condition. @@ -474,14 +463,15 @@ public virtual async Task GetFirstOrDefaultAsync(ExpressionIgnore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - public virtual TResult GetFirstOrDefault(Expression> selector, - Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - bool disableTracking = true, - bool ignoreQueryFilters = false) + public virtual TResult GetFirstOrDefault( + Expression> selector, + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -507,20 +497,19 @@ public virtual TResult GetFirstOrDefault(Expression - public virtual async Task GetFirstOrDefaultAsync(Expression> selector, - Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - bool disableTracking = true, bool ignoreQueryFilters = false) + public virtual async Task GetFirstOrDefaultAsync( + Expression> selector, + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -546,10 +535,8 @@ public virtual async Task GetFirstOrDefaultAsync(Expression @@ -558,21 +545,24 @@ public virtual async Task GetFirstOrDefaultAsync(ExpressionThe raw SQL. /// The parameters. /// An that contains elements that satisfy the condition specified by raw SQL. - public virtual IQueryable FromSql(string sql, params object[] parameters) => _dbSet.FromSqlRaw(sql, parameters); + public virtual IQueryable FromSql(string sql, params object[] parameters) + => DbSet.FromSqlRaw(sql, parameters); /// /// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned. /// /// The values of the primary key for the entity to be found. /// The found entity or null. - public virtual TEntity Find(params object[] keyValues) => _dbSet.Find(keyValues); + public virtual TEntity Find(params object[] keyValues) + => DbSet.Find(keyValues); /// /// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned. /// /// The values of the primary key for the entity to be found. /// A that represents the asynchronous insert operation. - public virtual ValueTask FindAsync(params object[] keyValues) => _dbSet.FindAsync(keyValues); + public virtual ValueTask FindAsync(params object[] keyValues) + => DbSet.FindAsync(keyValues); /// /// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned. @@ -580,240 +570,162 @@ public virtual async Task GetFirstOrDefaultAsync(ExpressionThe values of the primary key for the entity to be found. /// A to observe while waiting for the task to complete. /// A that represents the asynchronous find operation. The task result contains the found entity or null. - public virtual ValueTask FindAsync(object[] keyValues, CancellationToken cancellationToken) => _dbSet.FindAsync(keyValues, cancellationToken); + public virtual ValueTask FindAsync(object[] keyValues, CancellationToken cancellationToken) => DbSet.FindAsync(keyValues, cancellationToken); /// /// Gets the count based on a predicate. /// /// /// - public virtual int Count(Expression> predicate = null) - { - if (predicate == null) - { - return _dbSet.Count(); - } - else - { - return _dbSet.Count(predicate); - } - } + public virtual int Count(Expression> predicate = null) + => predicate == null ? DbSet.Count() : DbSet.Count(predicate); /// /// Gets async the count based on a predicate. /// /// /// - public virtual async Task CountAsync(Expression> predicate = null) - { - if (predicate == null) - { - return await _dbSet.CountAsync(); - } - else - { - return await _dbSet.CountAsync(predicate); - } - } + public virtual async Task CountAsync(Expression> predicate = null) + => predicate == null ? await DbSet.CountAsync() : await DbSet.CountAsync(predicate); /// /// Gets the long count based on a predicate. /// /// /// - public virtual long LongCount(Expression> predicate = null) - { - if (predicate == null) - { - return _dbSet.LongCount(); - } - else - { - return _dbSet.LongCount(predicate); - } - } + public virtual long LongCount(Expression> predicate = null) + => predicate == null ? DbSet.LongCount() : DbSet.LongCount(predicate); /// /// Gets async the long count based on a predicate. /// /// /// - public virtual async Task LongCountAsync(Expression> predicate = null) - { - if (predicate == null) - { - return await _dbSet.LongCountAsync(); - } - else - { - return await _dbSet.LongCountAsync(predicate); - } - } + public virtual async Task LongCountAsync(Expression> predicate = null) + => predicate == null ? await DbSet.LongCountAsync() : await DbSet.LongCountAsync(predicate); /// /// Gets the max based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - public virtual T Max(Expression> predicate = null, Expression> selector = null) - { - if (predicate == null) - return _dbSet.Max(selector); - else - return _dbSet.Where(predicate).Max(selector); - } + public virtual T Max(Expression> selector, Expression> predicate = null) + => predicate == null ? DbSet.Max(selector) : DbSet.Where(predicate).Max(selector); /// /// Gets the async max based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - public virtual async Task MaxAsync(Expression> predicate = null, Expression> selector = null) - { - if (predicate == null) - return await _dbSet.MaxAsync(selector); - else - return await _dbSet.Where(predicate).MaxAsync(selector); - } + public virtual async Task MaxAsync(Expression> selector, + Expression> predicate = null) + => predicate == null ? await DbSet.MaxAsync(selector) : await DbSet.Where(predicate).MaxAsync(selector); /// /// Gets the min based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - public virtual T Min(Expression> predicate = null, Expression> selector = null) - { - if (predicate == null) - return _dbSet.Min(selector); - else - return _dbSet.Where(predicate).Min(selector); - } + public virtual T Min(Expression> selector, Expression> predicate = null) + => predicate == null ? DbSet.Min(selector) : DbSet.Where(predicate).Min(selector); /// /// Gets the async min based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - public virtual async Task MinAsync(Expression> predicate = null, Expression> selector = null) - { - if (predicate == null) - return await _dbSet.MinAsync(selector); - else - return await _dbSet.Where(predicate).MinAsync(selector); - } + public virtual async Task MinAsync(Expression> selector, + Expression> predicate = null) + => predicate == null ? await DbSet.MinAsync(selector) : await DbSet.Where(predicate).MinAsync(selector); /// /// Gets the average based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - public virtual decimal Average(Expression> predicate = null, Expression> selector = null) - { - if (predicate == null) - return _dbSet.Average(selector); - else - return _dbSet.Where(predicate).Average(selector); - } + public virtual decimal Average(Expression> selector, + Expression> predicate = null) + => predicate == null ? DbSet.Average(selector) : DbSet.Where(predicate).Average(selector); /// /// Gets the async average based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - public virtual async Task AverageAsync(Expression> predicate = null, Expression> selector = null) - { - if (predicate == null) - return await _dbSet.AverageAsync(selector); - else - return await _dbSet.Where(predicate).AverageAsync(selector); - } + public virtual async Task AverageAsync(Expression> selector, + Expression> predicate = null) + => predicate == null ? await DbSet.AverageAsync(selector) : await DbSet.Where(predicate).AverageAsync(selector); /// /// Gets the sum based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - public virtual decimal Sum(Expression> predicate = null, Expression> selector = null) - { - if (predicate == null) - return _dbSet.Sum(selector); - else - return _dbSet.Where(predicate).Sum(selector); - } + public virtual decimal Sum(Expression> selector, + Expression> predicate = null) + => predicate == null ? DbSet.Sum(selector) : DbSet.Where(predicate).Sum(selector); /// /// Gets the async sum based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - public virtual async Task SumAsync(Expression> predicate = null, Expression> selector = null) - { - if (predicate == null) - return await _dbSet.SumAsync(selector); - else - return await _dbSet.Where(predicate).SumAsync(selector); - } + public virtual async Task SumAsync(Expression> selector, + Expression> predicate = null) + => predicate == null ? await DbSet.SumAsync(selector) : await DbSet.Where(predicate).SumAsync(selector); /// /// Gets the exists based on a predicate. /// /// /// - public bool Exists(Expression> selector = null) - { - if (selector == null) - { - return _dbSet.Any(); - } - else - { - return _dbSet.Any(selector); - } - } + public bool Exists(Expression> selector = null) + => selector == null ? DbSet.Any() : DbSet.Any(selector); + /// /// Gets the async exists based on a predicate. /// /// /// - public async Task ExistsAsync(Expression> selector = null) - { - if (selector == null) - { - return await _dbSet.AnyAsync(); - } - else - { - return await _dbSet.AnyAsync(selector); - } - } + public async Task ExistsAsync(Expression> selector = null) + => selector == null ? await DbSet.AnyAsync() : await DbSet.AnyAsync(selector); + /// /// Inserts a new entity synchronously. /// /// The entity to insert. - public virtual TEntity Insert(TEntity entity) - { - return _dbSet.Add(entity).Entity; - } + public virtual TEntity Insert(TEntity entity) + => DbSet.Add(entity).Entity; /// /// Inserts a range of entities synchronously. /// /// The entities to insert. - public virtual void Insert(params TEntity[] entities) => _dbSet.AddRange(entities); + public virtual void Insert(params TEntity[] entities) + => DbSet.AddRange(entities); /// /// Inserts a range of entities synchronously. /// /// The entities to insert. - public virtual void Insert(IEnumerable entities) => _dbSet.AddRange(entities); + public virtual void Insert(IEnumerable entities) + => DbSet.AddRange(entities); /// /// Inserts a new entity asynchronously. @@ -821,23 +733,21 @@ public virtual TEntity Insert(TEntity entity) /// The entity to insert. /// A to observe while waiting for the task to complete. /// A that represents the asynchronous insert operation. - public virtual ValueTask> InsertAsync(TEntity entity, CancellationToken cancellationToken = default(CancellationToken)) - { - return _dbSet.AddAsync(entity, cancellationToken); - - // Shadow properties? - //var property = _dbContext.Entry(entity).Property("Created"); - //if (property != null) { - //property.CurrentValue = DateTime.Now; - //} - } - + public virtual ValueTask> InsertAsync(TEntity entity, CancellationToken cancellationToken = default) + => DbSet.AddAsync(entity, cancellationToken); + + // Shadow properties? + //var property = _dbContext.Entry(entity).Property("Created"); + //if (property != null) { + //property.CurrentValue = DateTime.Now; + //} /// /// Inserts a range of entities asynchronously. /// /// The entities to insert. /// A that represents the asynchronous insert operation. - public virtual Task InsertAsync(params TEntity[] entities) => _dbSet.AddRangeAsync(entities); + public virtual Task InsertAsync(params TEntity[] entities) + => DbSet.AddRangeAsync(entities); /// /// Inserts a range of entities asynchronously. @@ -845,44 +755,43 @@ public virtual TEntity Insert(TEntity entity) /// The entities to insert. /// A to observe while waiting for the task to complete. /// A that represents the asynchronous insert operation. - public virtual Task InsertAsync(IEnumerable entities, CancellationToken cancellationToken = default(CancellationToken)) => _dbSet.AddRangeAsync(entities, cancellationToken); + public virtual Task InsertAsync(IEnumerable entities, CancellationToken cancellationToken = default) + => DbSet.AddRangeAsync(entities, cancellationToken); /// /// Updates the specified entity. /// /// The entity. - public virtual void Update(TEntity entity) - { - _dbSet.Update(entity); - } + public virtual void Update(TEntity entity) + => DbSet.Update(entity); /// /// Updates the specified entity. /// /// The entity. - public virtual void UpdateAsync(TEntity entity) - { - _dbSet.Update(entity); - - } + public virtual void UpdateAsync(TEntity entity) + => DbSet.Update(entity); /// /// Updates the specified entities. /// /// The entities. - public virtual void Update(params TEntity[] entities) => _dbSet.UpdateRange(entities); + public virtual void Update(params TEntity[] entities) + => DbSet.UpdateRange(entities); /// /// Updates the specified entities. /// /// The entities. - public virtual void Update(IEnumerable entities) => _dbSet.UpdateRange(entities); + public virtual void Update(IEnumerable entities) + => DbSet.UpdateRange(entities); /// /// Deletes the specified entity. /// /// The entity to delete. - public virtual void Delete(TEntity entity) => _dbSet.Remove(entity); + public virtual void Delete(TEntity entity) + => DbSet.Remove(entity); /// /// Deletes the entity by the specified primary key. @@ -892,17 +801,17 @@ public virtual void Delete(object id) { // using a stub entity to mark for deletion var typeInfo = typeof(TEntity).GetTypeInfo(); - var key = _dbContext.Model.FindEntityType(typeInfo).FindPrimaryKey().Properties.FirstOrDefault(); + var key = DbContext.Model.FindEntityType(typeInfo).FindPrimaryKey().Properties.FirstOrDefault(); var property = typeInfo.GetProperty(key?.Name); if (property != null) { var entity = Activator.CreateInstance(); property.SetValue(entity, id); - _dbContext.Entry(entity).State = EntityState.Deleted; + DbContext.Entry(entity).State = EntityState.Deleted; } else { - var entity = _dbSet.Find(id); + var entity = DbSet.Find(id); if (entity != null) { Delete(entity); @@ -914,22 +823,22 @@ public virtual void Delete(object id) /// Deletes the specified entities. /// /// The entities. - public virtual void Delete(params TEntity[] entities) => _dbSet.RemoveRange(entities); + public virtual void Delete(params TEntity[] entities) + => DbSet.RemoveRange(entities); /// /// Deletes the specified entities. /// /// The entities. - public virtual void Delete(IEnumerable entities) => _dbSet.RemoveRange(entities); + public virtual void Delete(IEnumerable entities) + => DbSet.RemoveRange(entities); /// /// Gets all entities. This method is not recommended /// /// The . - public async Task> GetAllAsync() - { - return await _dbSet.ToListAsync(); - } + public async Task> GetAllAsync() + => await DbSet.ToListAsync(); /// /// Gets all entities. This method is not recommended @@ -941,12 +850,13 @@ public async Task> GetAllAsync() /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// Ex: This method defaults to a read-only, no-tracking query. - public async Task> GetAllAsync(Expression> predicate = null, + public async Task> GetAllAsync( + Expression> predicate = null, Func, IOrderedQueryable> orderBy = null, Func, IIncludableQueryable> include = null, bool disableTracking = true, bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -977,6 +887,714 @@ public async Task> GetAllAsync(Expression> pr return await query.ToListAsync(); } } + + private bool ExistsUpdateTimestamp(TEntity entity, out TEntity entityForUpdate) + { + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IKey key = entityType.FindPrimaryKey(); + object[] objArr = key.Properties.Select(q => entity.GetType().GetProperty(q.Name).GetValue(entity, null)).ToArray(); + TEntity obj = DbSet.Find(objArr); + if (obj != null && obj.GetType().GetProperty("Timestamp") != null) + { + entity.GetType().GetProperty("Timestamp").SetValue(entity, obj.GetType().GetProperty("Timestamp").GetValue(obj, null)); + DbContext.Entry(obj).State = EntityState.Detached; + } + entityForUpdate = entity; + + return obj != null; + + } + public virtual bool Exists(TEntity entity) + { + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IKey key = entityType.FindPrimaryKey(); + object[] objArr = key.Properties.Select(q => entity.GetType().GetProperty(q.Name).GetValue(entity, null)).ToArray(); + TEntity obj = DbSet.Find(objArr); + if (obj != null) DbContext.Entry(obj).State = EntityState.Detached; + return obj != null; + } + + public virtual void InsertOrUpdate(TEntity entity) + { + if (ExistsUpdateTimestamp(entity, out var entityForUpdate)) { + Update(entityForUpdate); + } + else { + Insert(entity); + } + } + + public virtual void InsertOrUpdate(IEnumerable entities) + => DbContext.BulkInsertOrUpdate(entities.ToList()); + + /// + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. + /// + /// A function to test each element for a condition. + /// A function to order elements. + /// A function to include navigation properties + /// True to disable changing tracking; otherwise, false. Default to true. + /// + /// An that contains elements that satisfy the condition specified by . + /// This method default no-tracking query. + public virtual List GetList( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false) + { + IQueryable query = DbSet; + if (disableTracking) + { + query = query.AsNoTracking(); + } + + if (include != null) + { + query = include(query); + } + + if (predicate != null) + { + query = query.Where(predicate); + } + + if (ignoreQueryFilters) + { + query = query.IgnoreQueryFilters(); + } + + if (orderBy != null) + { + return orderBy(query).ToList(); + } + + return query.ToList(); + } + + /// + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. + /// + /// A function to test each element for a condition. + /// A function to order elements. + /// A function to include navigation properties + /// True to disable changing tracking; otherwise, false. Default to true. + /// Ignore query filters + /// + /// A to observe while waiting for the task to complete. + /// + /// An that contains elements that satisfy the condition specified by . + /// This method default no-tracking query. + public virtual Task> GetListAsync( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false, + CancellationToken cancellationToken = default) + { + IQueryable query = DbSet; + if (disableTracking) + { + query = query.AsNoTracking(); + } + + if (include != null) + { + query = include(query); + } + + if (predicate != null) + { + query = query.Where(predicate); + } + + if (ignoreQueryFilters) + { + query = query.IgnoreQueryFilters(); + } + + if (orderBy != null) + { + return orderBy(query).ToListAsync(cancellationToken); + } + + return query.ToListAsync(cancellationToken); + } + + + /// + /// Finds next entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned. + /// + /// The values of the primary key for the entity to be found. + /// The found entity or null. + public virtual TEntity GetNextById(params object[] keyValues) + { + TEntity res = DbSet.Find(IncrementKey(keyValues)); + if (res != null) + { + return res; + } + + //No Result Found. So Order the Entity with key column and select next Entity + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IKey key = entityType.FindPrimaryKey(); + List keyColums = key.Properties.Select(q => q.Name).ToList(); + //var ordByExp = GetOrderBy(keyColums[0],"asc"); + var ordByExp = GetOrderByExpression(keyColums); + + List lstObjs = GetList(null, ordByExp.Compile()); + + if (lstObjs != null && lstObjs.Count > 0) + { + //Form Where Condition + Expression> expr = GetWhereConditionExpression(key, DecrementKey(keyValues)); + Func func = expr.Compile(); + Predicate pred = func.Invoke; + TEntity currObj = lstObjs.Find(pred); + + int curobj = lstObjs.IndexOf(currObj); + if (curobj != -1) + { + int nxt = curobj + 1; + return lstObjs.ElementAtOrDefault(nxt); + } + else + { + return null; + } + } + else + { + return null; + } + } + + /// + /// Finds next entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned. + /// + /// The values of the primary key for the entity to be found. + /// The found entity or null. + public virtual Task GetNextByIdAsync(params object[] keyValues) + { + TEntity res = DbSet.Find(IncrementKey(keyValues)); + if (res != null) + { + return Task.Factory.StartNew(() => res); + } + + //No Result Found. So Order the Entity with key column and select next Entity + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IKey key = entityType.FindPrimaryKey(); + List keyColums = key.Properties.Select(q => q.Name).ToList(); + //var ordByExp = GetOrderBy(keyColums[0],"asc"); + var ordByExp = GetOrderByExpression(keyColums); + + List lstObjs = GetList(null, ordByExp.Compile()); + + if (lstObjs != null && lstObjs.Count > 0) + { + //Form Where Condition + Expression> expr = GetWhereConditionExpression(key, DecrementKey(keyValues)); + Func func = expr.Compile(); + Predicate pred = func.Invoke; + TEntity currObj = lstObjs.Find(pred); + + int curobj = lstObjs.IndexOf(currObj); + if(curobj != -1) + { + int nxt = curobj + 1; + return Task.Factory.StartNew(() => lstObjs.ElementAtOrDefault(nxt)); + }else + { + return Task.FromResult(null); + } + }else + { + return Task.FromResult(null); + } + } + + /// + /// Finds previous entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned. + /// + /// The values of the primary key for the entity to be found. + /// The found entity or null. + public virtual TEntity GetPreviousById(params object[] keyValues) + { + TEntity res = DbSet.Find(DecrementKey(keyValues)); + if (res != null) + { + return res; + } + + //No Result Found. So Order the Entity with key column and select next Entity + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IKey key = entityType.FindPrimaryKey(); + List keyColums = key.Properties.Select(q => q.Name).ToList(); + //var ordByExp = GetOrderBy(keyColums[0],"asc"); + var ordByExp = GetOrderByExpression(keyColums); + + List lstObjs = GetList(null, ordByExp.Compile()); + + if (lstObjs != null && lstObjs.Count > 0) + { + //Form Where Condition + Expression> expr = GetWhereConditionExpression(key, IncrementKey(keyValues)); + Func func = expr.Compile(); + Predicate pred = func.Invoke; + TEntity currObj = lstObjs.Find(pred); + + int curobj = lstObjs.IndexOf(currObj); + if (curobj != -1) + { + int prev = curobj - 1; + return lstObjs.ElementAtOrDefault(prev); + } + else + { + return null; + } + } + else + { + return null; + } + } + + /// + /// Finds previous entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned. + /// + /// The values of the primary key for the entity to be found. + /// The found entity or null. + public virtual Task GetPreviousByIdAsync(params object[] keyValues) + { + TEntity res = DbSet.Find(DecrementKey(keyValues)); + if (res != null) + { + return Task.Factory.StartNew(() => res); + } + + //No Result Found. So Order the Entity with key column and select next Entity + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IKey key = entityType.FindPrimaryKey(); + List keyColums = key.Properties.Select(q => q.Name).ToList(); + //var ordByExp = GetOrderBy(keyColums[0],"asc"); + var ordByExp = GetOrderByExpression(keyColums); + + List lstObjs = GetList(null, ordByExp.Compile()); + + if (lstObjs != null && lstObjs.Count > 0) + { + //Form Where Condition + Expression> expr = GetWhereConditionExpression(key, IncrementKey(keyValues)); + Func func = expr.Compile(); + Predicate pred = func.Invoke; + TEntity currObj = lstObjs.Find(pred); + + int curobj = lstObjs.IndexOf(currObj); + if (curobj != -1) + { + int prev = curobj - 1; + return Task.Factory.StartNew(() => lstObjs.ElementAtOrDefault(prev)); + } + else + { + return Task.FromResult(null); + } + } + else + { + return Task.FromResult(null); + } + } + + /// + /// Finds the first entity with order by primary key. If found, is attached to the context and returned. If no entity is found, then null is returned. + /// + /// The found entity or null. + public virtual TEntity GetFirst() + { + //No Result Found. So Order the Entity with key column and select next Entity + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IKey key = entityType.FindPrimaryKey(); + List keyColums = key.Properties.Select(q => q.Name).ToList(); + var ordByExp = GetOrderByExpression(keyColums); + + List lstObjs = GetList(null, ordByExp.Compile()); + + if (lstObjs != null && lstObjs.Count > 0) + { + return lstObjs.FirstOrDefault(); + } + else + { + return null; + } + } + + /// + /// Finds the first entity with order by primary key. If found, is attached to the context and returned. If no entity is found, then null is returned. + /// + /// The found entity or null. + public virtual Task GetFirstAsync() + { + //No Result Found. So Order the Entity with key column and select next Entity + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IKey key = entityType.FindPrimaryKey(); + List keyColums = key.Properties.Select(q => q.Name).ToList(); + var ordByExp = GetOrderByExpression(keyColums); + + List lstObjs = GetList(null, ordByExp.Compile()); + + if (lstObjs != null && lstObjs.Count > 0) + { + return Task.Factory.StartNew(() => lstObjs.FirstOrDefault()); + } + else + { + return Task.FromResult(null); + } + } + + /// + /// Finds the Last entity with order by primary key. If found, is attached to the context and returned. If no entity is found, then null is returned. + /// + /// The found entity or null. + public virtual TEntity GetLast() + { + //No Result Found. So Order the Entity with key column and select next Entity + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IKey key = entityType.FindPrimaryKey(); + List keyColums = key.Properties.Select(q => q.Name).ToList(); + var ordByExp = GetOrderByExpression(keyColums,true); + + List lstObjs = GetList(null, ordByExp.Compile()); + + if (lstObjs != null && lstObjs.Count > 0) + { + return lstObjs.FirstOrDefault(); + } + else + { + return null; + } + } + + /// + /// Finds the Last entity with order by primary key. If found, is attached to the context and returned. If no entity is found, then null is returned. + /// + /// The found entity or null. + public virtual Task GetLastAsync() + { + //No Result Found. So Order the Entity with key column and select next Entity + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IKey key = entityType.FindPrimaryKey(); + List keyColums = key.Properties.Select(q => q.Name).ToList(); + var ordByExp = GetOrderByExpression(keyColums,true); + + List lstObjs = GetList(null, ordByExp.Compile()); + + if (lstObjs != null && lstObjs.Count > 0) + { + return Task.Factory.StartNew(() => lstObjs.FirstOrDefault()); + } + else + { + return Task.FromResult(null); + } + } + + #region Next, Previous Support Methos + + private object[] IncrementKey(object[] id) + { + int idx = id.Length -1; + var val = id[idx]; + + if(val.GetType() == typeof(int)) + { + int iVal = (int)val; + id[idx] = ++iVal; + }else if(val.GetType() == typeof(long)) + { + long lVal = (long)val; + id[idx] = ++lVal; + } + + return id; + } + + private object[] DecrementKey(object[] id) + { + int idx = id.Length - 1; + var val = id[idx]; + + if (val.GetType() == typeof(int)) + { + int iVal = (int)val; + id[idx] = --iVal; + } + else if (val.GetType() == typeof(long)) + { + long lVal = (long)val; + id[idx] = --lVal; + } + return id; + } + + private static MemberExpression GetMemberExpression(Expression param, string propertyName) + { + if (propertyName.Contains(".")) + { + int index = propertyName.IndexOf("."); + var subParam = Expression.Property(param, propertyName.Substring(0, index)); + return GetMemberExpression(subParam, propertyName.Substring(index + 1)); + } + return Expression.Property(param, propertyName); + } + + public static Func, IOrderedQueryable> GetOrderBy(string orderColumn, string orderType) + { + Type typeQueryable = typeof(IQueryable); + ParameterExpression argQueryable = Expression.Parameter(typeQueryable, "p"); + var outerExpression = Expression.Lambda(argQueryable, argQueryable); + string[] props = orderColumn.Split('.'); + IQueryable query = new List().AsQueryable(); + Type type = typeof(T); + ParameterExpression arg = Expression.Parameter(type, "x"); + + Expression expr = arg; + foreach (string prop in props) + { + PropertyInfo pi = type.GetProperty(prop, BindingFlags.IgnoreCase | BindingFlags.Public | BindingFlags.Instance); + expr = Expression.Property(expr, pi); + type = pi.PropertyType; + } + LambdaExpression lambda = Expression.Lambda(expr, arg); + string methodName = orderType == "asc" ? "OrderBy" : "OrderByDescending"; + + MethodCallExpression resultExp = + Expression.Call(typeof(Queryable), methodName, new Type[] { typeof(T), type }, outerExpression.Body, Expression.Quote(lambda)); + var finalLambda = Expression.Lambda(resultExp, argQueryable); + return (Func, IOrderedQueryable>)finalLambda.Compile(); + } + + public static Expression, IOrderedQueryable>> GetOrderByExpression(IEnumerable lstSelection, bool isDescending = false) + { + bool isThenBy = false; + ParameterExpression inParameter = Expression.Parameter(typeof(T), "s"); + foreach (string propName in lstSelection) + { + MemberExpression prop = GetMemberExpression(inParameter, propName); //s.mfrId + var propertyInfo = (PropertyInfo)prop.Member; + var lambda = Expression.Lambda(prop, inParameter); // s => s.mfrId + Type pType = propertyInfo.PropertyType; + Type[] argumentTypes = new[] { typeof(T),pType }; + if (isThenBy) + { + var thenByMethod = typeof(Queryable).GetMethods() + .First(method => method.Name == "ThenBy" + && method.GetParameters().Count() == 2) + .MakeGenericMethod(argumentTypes); + + var ThenByDescending = typeof(Queryable).GetMethods() + .First(method => method.Name == "ThenByDescending" + && method.GetParameters().Count() == 2) + .MakeGenericMethod(argumentTypes); + if (isDescending) + { + return query => (IOrderedQueryable) + ThenByDescending.Invoke(null, new object[] { query, lambda }); + } + else + { + return query => (IOrderedQueryable) + thenByMethod.Invoke(null, new object[] { query, lambda }); + } + } + else + { + isThenBy = true; + var orderByMethod = typeof(Queryable).GetMethods() + .First(method => method.Name == "OrderBy" + && method.GetParameters().Count() == 2) + .MakeGenericMethod(argumentTypes); + + var orderByDescMethod = typeof(Queryable).GetMethods() + .First(method => method.Name == "OrderByDescending" + && method.GetParameters().Count() == 2) + .MakeGenericMethod(argumentTypes); + + if (isDescending) { + return query => (IOrderedQueryable) + orderByDescMethod.Invoke(null, new object[] { query, lambda }); + } + else { + return query => (IOrderedQueryable) + orderByMethod.Invoke(null, new object[] { query, lambda }); + } + } + } + return null; + } + + static MethodInfo LikeMethod = typeof(DbFunctionsExtensions).GetMethod("Like", new Type[] { typeof(DbFunctions), typeof(string), typeof(string) }); + + static MethodInfo StartsWithMethod = typeof(String).GetMethod("StartsWith", new Type[] { typeof(String) }); + + static MethodInfo ContainsMethod = typeof(String).GetMethod("Contains", new Type[] { typeof(String) }); + + static MethodInfo EndsWithMethod = typeof(String).GetMethod("EndsWith", new Type[] { typeof(String) }); + + public static Expression> GetWhereConditionExpression(IKey key, params object[] keyValues) + { + if (key == null || keyValues == null) return null; + //Create the expression parameters + ParameterExpression inParameter = Expression.Parameter(typeof(T)); + + Expression whereExp = null; + int idx = 0; + foreach (IProperty p in key.Properties) + { + string dataPropertyName = p.Name; + Type dataType = p.PropertyInfo.PropertyType; + Object propVale = keyValues[idx++]; + + if (propVale != null) + { + Expression LHS = GetMemberExpression(inParameter, dataPropertyName); //Expression.Property(inParameter, dataPropertyName); + Expression RHS = Expression.Convert(Expression.Constant(propVale), dataType); + + if (dataType == typeof(DateTime)) + { + //MethodInfo truncTimeMethod = typeof(EF).GetProperty("Functions").GetType().GetMethod("TruncateTime", new Type[] { typeof(DateTime?) }); + //MethodInfo conMethod = typeof(System.Data.Entity.DbFunctions).GetMethod("TruncateTime", new Type[] { typeof(DateTime?) }); + //LHS = Expression.Call(conMethod, Expression.Convert(Expression.Property(inParameter, dataPropertyName), typeof(DateTime?))); + RHS = Expression.Convert(Expression.Constant(((DateTime)propVale).Date), typeof(DateTime?)); + } + string conOperator = "eq"; + Expression expr = null; + switch (conOperator) + { + case "<": + case "lt": + expr = Expression.LessThan(LHS, RHS); + break; + + case ">": + case "gt": + expr = Expression.GreaterThan(LHS, RHS); + break; + + case "<=": + case "le": + expr = Expression.LessThanOrEqual(LHS, RHS); + break; + + case ">=": + case "ge": + expr = Expression.GreaterThanOrEqual(LHS, RHS); + break; + + case "!=": + case "<>": + case "ne": + expr = Expression.NotEqual(LHS, RHS); + break; + + case "IsNull": + expr = Expression.Equal(LHS, Expression.Constant(null, dataType)); + break; + + case "IsNotNull": + expr = Expression.NotEqual(LHS, Expression.Constant(null, dataType)); + break; + + case "Like": + if (LHS.Type != typeof(string)) + { + LHS = Expression.Convert(Expression.Convert(LHS, typeof(object)), typeof(string)); + } + RHS = Expression.Convert(Expression.Constant(propVale.ToString().Replace(" ", "%") + "%"), dataType); + expr = Expression.Call(LikeMethod, Expression.Convert(Expression.Constant(EF.Functions), typeof(DbFunctions)), LHS, RHS); + break; + + case "Contains": + if (LHS.Type != typeof(string)) + { + LHS = Expression.Convert(Expression.Convert(LHS, typeof(object)), typeof(string)); + RHS = Expression.Convert(Expression.Constant("%" + propVale + "%"), dataType); + expr = Expression.Call(LikeMethod, Expression.Convert(Expression.Constant(EF.Functions), typeof(DbFunctions)), LHS, RHS); + } + else + { + expr = Expression.Call(LHS, ContainsMethod, Expression.Constant(propVale.ToString())); + } + break; + + case "StartsWith": + if (LHS.Type != typeof(string)) + { + LHS = Expression.Convert(Expression.Convert(LHS, typeof(object)), typeof(string)); + RHS = Expression.Convert(Expression.Constant(propVale + "%"), dataType); + expr = Expression.Call(LikeMethod, Expression.Convert(Expression.Constant(EF.Functions), typeof(DbFunctions)), LHS, RHS); + } + else + { + expr = Expression.Call(LHS, StartsWithMethod, Expression.Constant(propVale.ToString())); + } + break; + + case "EndsWith": + if (LHS.Type != typeof(string)) + { + LHS = Expression.Convert(Expression.Convert(LHS, typeof(object)), typeof(string)); + RHS = Expression.Convert(Expression.Constant("%" + propVale), dataType); + expr = Expression.Call(LikeMethod, Expression.Convert(Expression.Constant(EF.Functions), typeof(DbFunctions)), LHS, RHS); + } + else + { + expr = Expression.Call(LHS, EndsWithMethod, Expression.Constant(propVale.ToString())); + } + break; + + case "=": + case "==": + case "eq": + default: + expr = Expression.Equal(LHS, RHS); + break; + } + + if (whereExp == null) + { + whereExp = expr; + } + else + { + String condi = "AND"; + if (condi != null && condi.Equals("AND", StringComparison.OrdinalIgnoreCase)) + { + whereExp = Expression.AndAlso(whereExp, expr); + } + else + { + whereExp = Expression.OrElse(whereExp, expr); + } + } + } + + } + + if (whereExp == null) + { + whereExp = Expression.Constant(true); + } + return Expression.Lambda>(whereExp, inParameter); + } + #endregion /// /// Gets all entities. This method is not recommended @@ -995,7 +1613,7 @@ public async Task> GetAllAsync(Expression, IIncludableQueryable> include = null, bool disableTracking = true, bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -1028,14 +1646,15 @@ public async Task> GetAllAsync(Expression /// Change entity state for patch method on web api. /// /// The entity. /// /// The entity state. - public void ChangeEntityState(TEntity entity, EntityState state) - { - _dbContext.Entry(entity).State = state; - } + public void ChangeEntityState(TEntity entity, EntityState state) => DbContext.Entry(entity).State = state; + + ValueTask IRepository.FindAsync(params object[] keyValues) => DbSet.FindAsync(keyValues); + ValueTask IRepository.FindAsync(object[] keyValues, CancellationToken cancellationToken) => DbSet.FindAsync(keyValues, cancellationToken); } } diff --git a/src/UnitOfWork/UnitOfWork.cs b/src/UnitOfWork/UnitOfWork.cs index 4a1758d..87da947 100644 --- a/src/UnitOfWork/UnitOfWork.cs +++ b/src/UnitOfWork/UnitOfWork.cs @@ -7,10 +7,12 @@ using System.Text.RegularExpressions; using System.Threading.Tasks; using System.Transactions; +using Microsoft.Data.SqlClient; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.ChangeTracking; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Storage; namespace Arch.EntityFrameworkCore.UnitOfWork { @@ -18,26 +20,22 @@ namespace Arch.EntityFrameworkCore.UnitOfWork /// Represents the default implementation of the and interface. /// /// The type of the db context. - public class UnitOfWork : IRepositoryFactory, IUnitOfWork, IUnitOfWork where TContext : DbContext + public class UnitOfWork : IRepositoryFactory, IUnitOfWork where TContext : DbContext { - private readonly TContext _context; - private bool disposed = false; - private Dictionary repositories; + private bool _disposed; + private Dictionary _repositories; /// /// Initializes a new instance of the class. /// /// The context. - public UnitOfWork(TContext context) - { - _context = context ?? throw new ArgumentNullException(nameof(context)); - } + public UnitOfWork(TContext context) => DbContext = context ?? throw new ArgumentNullException(nameof(context)); /// /// Gets the db context. /// /// The instance of type . - public TContext DbContext => _context; + public TContext DbContext { get; } /// /// Changes the database name. This require the databases in the same machine. NOTE: This only work for MySQL right now. @@ -48,7 +46,7 @@ public UnitOfWork(TContext context) /// public void ChangeDatabase(string database) { - var connection = _context.Database.GetDbConnection(); + var connection = DbContext.Database.GetDbConnection(); if (connection.State.HasFlag(ConnectionState.Open)) { connection.ChangeDatabase(database); @@ -60,7 +58,7 @@ public void ChangeDatabase(string database) } // Following code only working for mysql. - var items = _context.Model.GetEntityTypes(); + var items = DbContext.Model.GetEntityTypes(); foreach (var item in items) { if (item is IConventionEntityType entityType) @@ -73,20 +71,20 @@ public void ChangeDatabase(string database) /// /// Gets the specified repository for the . /// - /// True if providing custom repositry + /// True if providing custom repository /// The type of the entity. /// An instance of type inherited from interface. public IRepository GetRepository(bool hasCustomRepository = false) where TEntity : class { - if (repositories == null) + if (_repositories == null) { - repositories = new Dictionary(); + _repositories = new Dictionary(); } - // what's the best way to support custom reposity? + // what's the best way to support custom repository? if (hasCustomRepository) { - var customRepo = _context.GetService>(); + var customRepo = DbContext.GetService>(); if (customRepo != null) { return customRepo; @@ -94,12 +92,12 @@ public IRepository GetRepository(bool hasCustomRepository = fa } var type = typeof(TEntity); - if (!repositories.ContainsKey(type)) + if (!_repositories.ContainsKey(type)) { - repositories[type] = new Repository(_context); + _repositories[type] = new Repository(DbContext); } - return (IRepository)repositories[type]; + return (IRepository)_repositories[type]; } /// @@ -108,7 +106,41 @@ public IRepository GetRepository(bool hasCustomRepository = fa /// The raw SQL. /// The parameters. /// The number of state entities written to database. - public int ExecuteSqlCommand(string sql, params object[] parameters) => _context.Database.ExecuteSqlRaw(sql, parameters); + public int ExecuteSqlCommand(string sql, params object[] parameters) => DbContext.Database.ExecuteSqlRaw(sql, parameters); + + /// + /// Executes the specified raw SQL command. + /// + /// The raw SQL. + /// The parameters. + /// The DataTable. + public DataTable ExecuteDtSqlCommand(string sql, params object[] parameters) + { + SqlConnection conn = (SqlConnection) DbContext.Database.GetDbConnection(); + SqlCommand cmd = new SqlCommand(sql, conn); + cmd.CommandTimeout = 0; + + if(parameters != null && parameters.Count() > 0) + { + foreach(object obj in parameters) + { + cmd.Parameters.Add(obj); + } + } + + conn.Open(); + // create data adapter + SqlDataAdapter da = new SqlDataAdapter(cmd); + // this will query your database and return the result to your datatable + DataTable dataTable = new DataTable(); + DataSet ds = new DataSet(); + da.Fill(ds); + da.Fill(dataTable); + da.Dispose(); + conn.Close(); + return dataTable; + } + /// /// Uses raw SQL queries to fetch the specified data. @@ -117,7 +149,14 @@ public IRepository GetRepository(bool hasCustomRepository = fa /// The raw SQL. /// The parameters. /// An that contains elements that satisfy the condition specified by raw SQL. - public IQueryable FromSql(string sql, params object[] parameters) where TEntity : class => _context.Set().FromSqlRaw(sql, parameters); + public IQueryable FromSql(string sql, params object[] parameters) where TEntity : class => DbContext.Set().FromSqlRaw(sql, parameters); + + /// + /// Starts Databaselevel Transaction + /// + /// The IsolationLevel + /// Transaction Context + public IDbContextTransaction BeginTransaction(System.Data.IsolationLevel isolation = System.Data.IsolationLevel.ReadCommitted) => DbContext.Database.BeginTransaction(isolation); /// /// Saves all changes made in this context to the database. @@ -128,10 +167,10 @@ public int SaveChanges(bool ensureAutoHistory = false) { if (ensureAutoHistory) { - _context.EnsureAutoHistory(); + DbContext.EnsureAutoHistory(); } - return _context.SaveChanges(); + return DbContext.SaveChanges(); } /// @@ -143,10 +182,10 @@ public async Task SaveChangesAsync(bool ensureAutoHistory = false) { if (ensureAutoHistory) { - _context.EnsureAutoHistory(); + DbContext.EnsureAutoHistory(); } - return await _context.SaveChangesAsync(); + return await DbContext.SaveChangesAsync(); } /// @@ -157,21 +196,48 @@ public async Task SaveChangesAsync(bool ensureAutoHistory = false) /// A that represents the asynchronous save operation. The task result contains the number of state entities written to database. public async Task SaveChangesAsync(bool ensureAutoHistory = false, params IUnitOfWork[] unitOfWorks) { - using (var ts = new TransactionScope(TransactionScopeAsyncFlowOption.Enabled)) + using var ts = new TransactionScope(TransactionScopeAsyncFlowOption.Enabled); + var count = 0; + foreach (var unitOfWork in unitOfWorks) { - var count = 0; - foreach (var unitOfWork in unitOfWorks) - { - count += await unitOfWork.SaveChangesAsync(ensureAutoHistory).ConfigureAwait(false); - } + count += await unitOfWork.SaveChangesAsync(ensureAutoHistory).ConfigureAwait(false); + } - count += await SaveChangesAsync(ensureAutoHistory); + count += await SaveChangesAsync(ensureAutoHistory); - ts.Complete(); + ts.Complete(); - return count; + return count; + } + + + /// + /// Saves all changes made in this context to the database with distributed transaction. + /// + /// The transaction to use + /// True if save changes ensure auto record the change history. + /// An optional array. + /// A that represents the asynchronous save operation. The task result contains the number of state entities written to database. + public async Task SaveChangesAsync(Transaction transaction, bool ensureAutoHistory = false, params IUnitOfWork[] unitOfWorks) + { + using var ts = new TransactionScope(transaction); + var count = 0; + foreach (var unitOfWork in unitOfWorks) + { + count += await unitOfWork.SaveChangesAsync(ensureAutoHistory); } + + count += await SaveChangesAsync(ensureAutoHistory); + + ts.Complete(); + + return count; } + + public void TrackGraph(object rootEntity, Action callback) => DbContext.ChangeTracker.TrackGraph(rootEntity, callback); + + IDbContextTransaction IUnitOfWork.BeginTransaction(System.Data.IsolationLevel isolation) => DbContext.Database.BeginTransaction(isolation); + /// /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. @@ -189,27 +255,50 @@ public void Dispose() /// The disposing. protected virtual void Dispose(bool disposing) { - if (!disposed) + if (!_disposed) { if (disposing) { // clear repositories - if (repositories != null) - { - repositories.Clear(); - } + _repositories?.Clear(); // dispose the db context. - _context.Dispose(); + DbContext.Dispose(); } } - disposed = true; + _disposed = true; } + + /// + /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. + /// + public async ValueTask DisposeAsync() + { + await DisposeAsync(true); - public void TrackGraph(object rootEntity, Action callback) + GC.SuppressFinalize(this); + } + + /// + /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. + /// + /// The disposing. + protected virtual async ValueTask DisposeAsync(bool disposing) { - _context.ChangeTracker.TrackGraph(rootEntity, callback); + if (!_disposed) + { + if (disposing) + { + // clear repositories + _repositories?.Clear(); + + // dispose the db context. + await DbContext.DisposeAsync(); + } + } + + _disposed = true; } } } diff --git a/src/UnitOfWork/UnitOfWork.csproj b/src/UnitOfWork/UnitOfWork.csproj index fdf335f..d5880ec 100644 --- a/src/UnitOfWork/UnitOfWork.csproj +++ b/src/UnitOfWork/UnitOfWork.csproj @@ -3,10 +3,11 @@ A plugin for Microsoft.EntityFrameworkCore to support repository, unit of work patterns, and multiple database with distributed transaction supported. 3.1.0 rigofunc;rigofunc@outlook.com; - netstandard2.0 + net6.0 $(NoWarn);CS1591 true true + Arch.EntityFrameworkCore.UnitOfWork Microsoft.EntityFrameworkCore.UnitOfWork Microsoft.EntityFrameworkCore.UnitOfWork Entity Framework Core;entity-framework-core;EF;Data;O/RM;unitofwork;Unit Of Work;unit-of-work @@ -19,7 +20,14 @@ snupkg - - + + + + + + + + C:\Program Files\dotnet\sdk\NuGetFallbackFolder\system.data.sqlclient\4.4.0\ref\netstandard2.0\System.Data.SqlClient.dll + diff --git a/src/UnitOfWork/UnitOfWorkServiceCollectionExtensions.cs b/src/UnitOfWork/UnitOfWorkServiceCollectionExtensions.cs index 433ea00..6335145 100644 --- a/src/UnitOfWork/UnitOfWorkServiceCollectionExtensions.cs +++ b/src/UnitOfWork/UnitOfWorkServiceCollectionExtensions.cs @@ -22,7 +22,7 @@ public static class UnitOfWorkServiceCollectionExtensions public static IServiceCollection AddUnitOfWork(this IServiceCollection services) where TContext : DbContext { services.AddScoped>(); - // Following has a issue: IUnitOfWork cannot support multiple dbcontext/database, + // Following has a issue: IUnitOfWork cannot support multiple DbContext/Database, // that means cannot call AddUnitOfWork multiple times. // Solution: check IUnitOfWork whether or null services.AddScoped>(); @@ -104,7 +104,7 @@ public static IServiceCollection AddUnitOfWork. /// /// The type of the entity. - /// The type of the custom repositry. + /// The type of the custom repository. /// The to add services to. /// The same service collection so that multiple calls can be chained. public static IServiceCollection AddCustomRepository(this IServiceCollection services) diff --git a/test/UnitOfWork.Tests/Entities/City.cs b/test/UnitOfWork.Tests/Entities/City.cs index f632058..10539d2 100644 --- a/test/UnitOfWork.Tests/Entities/City.cs +++ b/test/UnitOfWork.Tests/Entities/City.cs @@ -3,7 +3,7 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Tests.Entities { - public class City + public record City { public int Id { get; set; } public string Name { get; set; } diff --git a/test/UnitOfWork.Tests/Entities/Country.cs b/test/UnitOfWork.Tests/Entities/Country.cs index d4d03d8..f23d2b0 100644 --- a/test/UnitOfWork.Tests/Entities/Country.cs +++ b/test/UnitOfWork.Tests/Entities/Country.cs @@ -2,7 +2,7 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Tests.Entities { - public class Country + public record Country { public int Id { get; set; } public string Name { get; set; } diff --git a/test/UnitOfWork.Tests/Entities/Customer.cs b/test/UnitOfWork.Tests/Entities/Customer.cs index d2188d5..094881a 100644 --- a/test/UnitOfWork.Tests/Entities/Customer.cs +++ b/test/UnitOfWork.Tests/Entities/Customer.cs @@ -1,6 +1,6 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Tests.Entities { - public class Customer + public record Customer { public int Id { get; set; } public string Name { get; set; } diff --git a/test/UnitOfWork.Tests/Entities/Town.cs b/test/UnitOfWork.Tests/Entities/Town.cs index aca070a..16ee910 100644 --- a/test/UnitOfWork.Tests/Entities/Town.cs +++ b/test/UnitOfWork.Tests/Entities/Town.cs @@ -1,6 +1,6 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Tests.Entities { - public class Town + public record Town { public int Id { get; set; } public string Name { get; set; } diff --git a/test/UnitOfWork.Tests/IQueryablePageListExtensionsTests.cs b/test/UnitOfWork.Tests/IQueryablePageListExtensionsTests.cs index 5554b51..ecebdec 100644 --- a/test/UnitOfWork.Tests/IQueryablePageListExtensionsTests.cs +++ b/test/UnitOfWork.Tests/IQueryablePageListExtensionsTests.cs @@ -13,40 +13,35 @@ public class IQueryablePageListExtensionsTests [Fact] public async Task ToPagedListAsyncTest() { - using (var db = new InMemoryContext()) - { - var testItems = TestItems(); - await db.AddRangeAsync(testItems); - db.SaveChanges(); + await using var db = new InMemoryContext(); + var testItems = TestItems(); + await db.AddRangeAsync(testItems); + await db.SaveChangesAsync(); - var items = db.Customers.Where(t => t.Age > 1); + var items = db.Customers.Where(t => t.Age > 1); - var page = await items.ToPagedListAsync(1, 2); - Assert.NotNull(page); + var page = await items.ToPagedListAsync(1, 2); + Assert.NotNull(page); - Assert.Equal(4, page.TotalCount); - Assert.Equal(2, page.Items.Count); - Assert.Equal("E", page.Items[0].Name); + Assert.Equal(4, page.TotalCount); + Assert.Equal(2, page.Items.Count); + Assert.Equal("E", page.Items[0].Name); - page = await items.ToPagedListAsync(0, 2); - Assert.NotNull(page); - Assert.Equal(4, page.TotalCount); - Assert.Equal(2, page.Items.Count); - Assert.Equal("C", page.Items[0].Name); - } + page = await items.ToPagedListAsync(0, 2); + Assert.NotNull(page); + Assert.Equal(4, page.TotalCount); + Assert.Equal(2, page.Items.Count); + Assert.Equal("C", page.Items[0].Name); } - public List TestItems() - { - return new List() + private static IEnumerable TestItems() => new List() { - new Customer(){Name="A", Age=1}, - new Customer(){Name="B", Age=1}, - new Customer(){Name="C", Age=2}, - new Customer(){Name="D", Age=3}, - new Customer(){Name="E", Age=4}, - new Customer(){Name="F", Age=5}, + new(){Name="A", Age=1}, + new(){Name="B", Age=1}, + new(){Name="C", Age=2}, + new(){Name="D", Age=3}, + new(){Name="E", Age=4}, + new(){Name="F", Age=5}, }; - } } } diff --git a/test/UnitOfWork.Tests/IRepositoryGetPagedListTest.cs b/test/UnitOfWork.Tests/IRepositoryGetPagedListTest.cs index 504a3d2..e95ee93 100644 --- a/test/UnitOfWork.Tests/IRepositoryGetPagedListTest.cs +++ b/test/UnitOfWork.Tests/IRepositoryGetPagedListTest.cs @@ -8,23 +8,23 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Tests { public class IRepositoryGetPagedListTest { - private static readonly InMemoryContext db; + private static readonly InMemoryContext Db; static IRepositoryGetPagedListTest() { - db = new InMemoryContext(); + Db = new InMemoryContext(); - db.AddRange(TestCountries); - db.AddRange(TestCities); - db.AddRange(TestTowns); + Db.AddRange(TestCountries); + Db.AddRange(TestCities); + Db.AddRange(TestTowns); - db.SaveChanges(); + Db.SaveChanges(); } [Fact] public void GetPagedList() { - var repository = new Repository(db); + var repository = new Repository(Db); var page = repository.GetPagedList(predicate: t => t.Name == "C", include: source => source.Include(t => t.Country), pageSize: 1); @@ -39,7 +39,7 @@ public void GetPagedList() [Fact] public async Task GetPagedListAsync() { - var repository = new Repository(db); + var repository = new Repository(Db); var page = await repository.GetPagedListAsync(predicate: t => t.Name == "C", include: source => source.Include(t => t.Country), pageSize: 1); @@ -54,7 +54,7 @@ public async Task GetPagedListAsync() [Fact] public async Task GetPagedListWithIncludingMultipleLevelsAsync() { - var repository = new Repository(db); + var repository = new Repository(Db); var page = await repository.GetPagedListAsync(predicate: t => t.Name == "A", include: country => country.Include(c => c.Cities).ThenInclude(city => city.Towns), pageSize: 1); @@ -67,7 +67,7 @@ public async Task GetPagedListWithIncludingMultipleLevelsAsync() [Fact] public void GetPagedListWithoutInclude() { - var repository = new Repository(db); + var repository = new Repository(Db); var page = repository.GetPagedList(pageIndex: 0, pageSize: 1); @@ -75,30 +75,30 @@ public void GetPagedListWithoutInclude() Assert.Null(page.Items[0].Country); } - protected static List TestCountries => new List + private static IEnumerable TestCountries => new List { - new Country {Id = 1, Name = "A"}, - new Country {Id = 2, Name = "B"} + new() {Id = 1, Name = "A"}, + new() {Id = 2, Name = "B"} }; - public static List TestCities => new List + private static IEnumerable TestCities => new List { - new City { Id = 1, Name = "A", CountryId = 1}, - new City { Id = 2, Name = "B", CountryId = 2}, - new City { Id = 3, Name = "C", CountryId = 1}, - new City { Id = 4, Name = "D", CountryId = 2}, - new City { Id = 5, Name = "E", CountryId = 1}, - new City { Id = 6, Name = "F", CountryId = 2}, + new() { Id = 1, Name = "A", CountryId = 1}, + new() { Id = 2, Name = "B", CountryId = 2}, + new() { Id = 3, Name = "C", CountryId = 1}, + new() { Id = 4, Name = "D", CountryId = 2}, + new() { Id = 5, Name = "E", CountryId = 1}, + new() { Id = 6, Name = "F", CountryId = 2}, }; - public static List TestTowns => new List + private static IEnumerable TestTowns => new List { - new Town { Id = 1, Name="A", CityId = 1 }, - new Town { Id = 2, Name="B", CityId = 2 }, - new Town { Id = 3, Name="C", CityId = 3 }, - new Town { Id = 4, Name="D", CityId = 4 }, - new Town { Id = 5, Name="E", CityId = 5 }, - new Town { Id = 6, Name="F", CityId = 6 }, + new() { Id = 1, Name="A", CityId = 1 }, + new() { Id = 2, Name="B", CityId = 2 }, + new() { Id = 3, Name="C", CityId = 3 }, + new() { Id = 4, Name="D", CityId = 4 }, + new() { Id = 5, Name="E", CityId = 5 }, + new() { Id = 6, Name="F", CityId = 6 }, }; } } diff --git a/test/UnitOfWork.Tests/InMemoryContext.cs b/test/UnitOfWork.Tests/InMemoryContext.cs index cf60438..e0c4099 100644 --- a/test/UnitOfWork.Tests/InMemoryContext.cs +++ b/test/UnitOfWork.Tests/InMemoryContext.cs @@ -5,12 +5,9 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Tests { public class InMemoryContext : DbContext { - public DbSet Countries { get; set; } - public DbSet Customers { get; set; } + public DbSet Countries => Set(); + public DbSet Customers => Set(); - protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) - { - optionsBuilder.UseInMemoryDatabase("test"); - } + protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) => optionsBuilder.UseInMemoryDatabase("test"); } } diff --git a/test/UnitOfWork.Tests/TestGetFirstOrDefaultAsync.cs b/test/UnitOfWork.Tests/TestGetFirstOrDefaultAsync.cs index ded0b57..a178593 100644 --- a/test/UnitOfWork.Tests/TestGetFirstOrDefaultAsync.cs +++ b/test/UnitOfWork.Tests/TestGetFirstOrDefaultAsync.cs @@ -52,30 +52,30 @@ public async void TestGetFirstOrDefaultAsyncCanInclude() } - protected static List TestCountries => new List + private static IEnumerable TestCountries => new List { - new Country {Id = 1, Name = "A"}, - new Country {Id = 2, Name = "B"} + new() {Id = 1, Name = "A"}, + new() {Id = 2, Name = "B"} }; - public static List TestCities => new List + private static IEnumerable TestCities => new List { - new City { Id = 1, Name = "A", CountryId = 1}, - new City { Id = 2, Name = "B", CountryId = 2}, - new City { Id = 3, Name = "C", CountryId = 1}, - new City { Id = 4, Name = "D", CountryId = 2}, - new City { Id = 5, Name = "E", CountryId = 1}, - new City { Id = 6, Name = "F", CountryId = 2}, + new() { Id = 1, Name = "A", CountryId = 1}, + new() { Id = 2, Name = "B", CountryId = 2}, + new() { Id = 3, Name = "C", CountryId = 1}, + new() { Id = 4, Name = "D", CountryId = 2}, + new() { Id = 5, Name = "E", CountryId = 1}, + new() { Id = 6, Name = "F", CountryId = 2}, }; - public static List TestTowns => new List + private static IEnumerable TestTowns => new List { - new Town { Id = 1, Name="TownA", CityId = 1 }, - new Town { Id = 2, Name="TownB", CityId = 2 }, - new Town { Id = 3, Name="TownC", CityId = 3 }, - new Town { Id = 4, Name="TownD", CityId = 4 }, - new Town { Id = 5, Name="TownE", CityId = 5 }, - new Town { Id = 6, Name="TownF", CityId = 6 }, + new() { Id = 1, Name="TownA", CityId = 1 }, + new() { Id = 2, Name="TownB", CityId = 2 }, + new() { Id = 3, Name="TownC", CityId = 3 }, + new() { Id = 4, Name="TownD", CityId = 4 }, + new() { Id = 5, Name="TownE", CityId = 5 }, + new() { Id = 6, Name="TownF", CityId = 6 }, }; } } diff --git a/test/UnitOfWork.Tests/UnitOfWork.Tests.csproj b/test/UnitOfWork.Tests/UnitOfWork.Tests.csproj index 1597ff6..72d5c91 100644 --- a/test/UnitOfWork.Tests/UnitOfWork.Tests.csproj +++ b/test/UnitOfWork.Tests/UnitOfWork.Tests.csproj @@ -1,14 +1,16 @@  - netcoreapp3.1 + net6.0 + + Arch.EntityFrameworkCore.UnitOfWork.Tests - - + + - + all runtime; build; native; contentfiles; analyzers