diff --git a/samples/IdentitySample.Mvc/Startup.cs b/samples/IdentitySample.Mvc/Startup.cs index d54b6eaac..92097c656 100644 --- a/samples/IdentitySample.Mvc/Startup.cs +++ b/samples/IdentitySample.Mvc/Startup.cs @@ -53,7 +53,7 @@ public void ConfigureServices(IServiceCollection services) options.Cookies.ApplicationCookie.DataProtectionProvider = DataProtectionProvider.Create(new DirectoryInfo("C:\\Github\\Identity\\artifacts")); }) - .AddLinqToDBStores(new DefaultConnectionFactory()) + .AddLinqToDBStores(new DefaultConnectionFactory()) .AddDefaultTokenProviders(); services.AddMvc(); diff --git a/samples/IdentitySample.Mvc/Views/Home/Index.cshtml b/samples/IdentitySample.Mvc/Views/Home/Index.cshtml index b07d5cd07..221bc0040 100644 --- a/samples/IdentitySample.Mvc/Views/Home/Index.cshtml +++ b/samples/IdentitySample.Mvc/Views/Home/Index.cshtml @@ -16,8 +16,7 @@
Initialize ASP.NET Identity
- You can initialize ASP.NET Identity when the application starts. Since ASP.NET Identity is Entity Framework based in this sample, - you can create DatabaseInitializer which is configured to get called each time the app starts. + You can initialize ASP.NET Identity when the application starts. Please look in App_Start\IdentityConfig.cs This code shows the following
    diff --git a/src/LinqToDB.Identity/DefaultConnectionFactory.cs b/src/LinqToDB.Identity/DefaultConnectionFactory.cs index a0e3e6184..40d37a297 100644 --- a/src/LinqToDB.Identity/DefaultConnectionFactory.cs +++ b/src/LinqToDB.Identity/DefaultConnectionFactory.cs @@ -3,13 +3,9 @@ namespace LinqToDB.Identity { /// - /// Represents default + /// Represents default /// - /// The type of the data getContext class used to access the store. - /// The type repewsenting database getConnection - public class DefaultConnectionFactory : IConnectionFactory - where TContext : class, IDataContext, new() - where TConnection : DataConnection, new() + public class DefaultConnectionFactory : IConnectionFactory { /// /// Creates with default parameters @@ -17,9 +13,9 @@ public class DefaultConnectionFactory : IConnectionFactor /// /// /// - public TConnection GetConnection() + public DataConnection GetConnection() { - return new TConnection(); + return new DataConnection(); } /// @@ -28,9 +24,9 @@ public TConnection GetConnection() /// /// /// - public TContext GetContext() + public IDataContext GetContext() { - return new TContext(); + return new DataContext(); } } } \ No newline at end of file diff --git a/src/LinqToDB.Identity/IConnectionFactory.cs b/src/LinqToDB.Identity/IConnectionFactory.cs index f0cc9e697..cb53c3fd7 100644 --- a/src/LinqToDB.Identity/IConnectionFactory.cs +++ b/src/LinqToDB.Identity/IConnectionFactory.cs @@ -5,15 +5,7 @@ namespace LinqToDB.Identity /// /// Represents connection factory /// - /// - /// - /// - /// - /// - /// - public interface IConnectionFactory - where TContext : IDataContext - where TConnection : DataConnection + public interface IConnectionFactory { /// /// Gets new instance of @@ -21,7 +13,7 @@ public interface IConnectionFactory /// /// /// - TContext GetContext(); + IDataContext GetContext(); /// /// Gets new instance of @@ -29,6 +21,6 @@ public interface IConnectionFactory /// /// /// - TConnection GetConnection(); + DataConnection GetConnection(); } } \ No newline at end of file diff --git a/src/LinqToDB.Identity/IdentityLinqToDbBuilderExtensions.cs b/src/LinqToDB.Identity/IdentityLinqToDbBuilderExtensions.cs index 02286630e..856e7a115 100644 --- a/src/LinqToDB.Identity/IdentityLinqToDbBuilderExtensions.cs +++ b/src/LinqToDB.Identity/IdentityLinqToDbBuilderExtensions.cs @@ -13,77 +13,130 @@ namespace Microsoft.Extensions.DependencyInjection { /// - /// Contains extension methods to for adding entity framework stores. + /// Contains extension methods to for adding linq2db stores. /// public static class IdentityLinqToDbBuilderExtensions { /// - /// Adds an Entity Framework implementation of identity information stores. + /// Adds an linq2db plementation of identity information stores. /// - /// - /// The type of the class for , - /// - /// - /// - /// The type of the class for , - /// - /// /// The instance this method extends. /// - /// + /// /// /// The instance this method extends. // ReSharper disable once InconsistentNaming - public static IdentityBuilder AddLinqToDBStores(this IdentityBuilder builder, - IConnectionFactory factory) - where TContext : IDataContext - where TConnection : DataConnection + public static IdentityBuilder AddLinqToDBStores(this IdentityBuilder builder, IConnectionFactory factory) { - builder.Services.AddSingleton(factory); + return AddLinqToDBStores(builder, factory, + typeof(string), + typeof(IdentityUserClaim), + typeof(IdentityUserRole), + typeof(IdentityUserLogin), + typeof(IdentityUserToken), + typeof(IdentityRoleClaim)); + } - builder.Services.TryAdd( - GetDefaultServices(builder.UserType, builder.RoleType, typeof(TContext), typeof(TConnection))); - return builder; + /// + /// Adds an linq2db implementation of identity information stores. + /// + /// The type of the primary key used for the users and roles. + /// The instance this method extends. + /// + /// + /// + /// The instance this method extends. + // ReSharper disable once InconsistentNaming + public static IdentityBuilder AddLinqToDBStores(this IdentityBuilder builder, IConnectionFactory factory) + where TKey : IEquatable + { + return AddLinqToDBStores(builder, factory, + typeof(TKey), + typeof(IdentityUserClaim), + typeof(IdentityUserRole), + typeof(IdentityUserLogin), + typeof(IdentityUserToken), + typeof(IdentityRoleClaim)); } /// - /// Adds an Entity Framework implementation of identity information stores. + /// Adds an linq2db implementation of identity information stores. /// - /// - /// The type of the class for , - /// - /// - /// - /// The type of the class for , - /// - /// /// The type of the primary key used for the users and roles. + /// The type representing a claim. + /// The type representing a user role. + /// The type representing a user external login. + /// The type representing a user token. + /// The type of the class representing a role claim. /// The instance this method extends. /// - /// + /// /// /// The instance this method extends. // ReSharper disable once InconsistentNaming - public static IdentityBuilder AddLinqToDBStores(this IdentityBuilder builder, - IConnectionFactory factory) - where TContext : IDataContext - where TConnection : DataConnection + public static IdentityBuilder AddLinqToDBStores< + TKey, + TUserClaim, + TUserRole, + TUserLogin, + TUserToken, + TRoleClaim>(this IdentityBuilder builder, IConnectionFactory factory) + where TUserClaim : class, IIdentityUserClaim + where TUserRole : class, IIdentityUserRole + where TUserLogin : class, IIdentityUserLogin + where TUserToken : class, IIdentityUserToken where TKey : IEquatable + where TRoleClaim : class, IIdentityRoleClaim + { + + return AddLinqToDBStores(builder, factory, + typeof(TKey), + typeof(TUserClaim), + typeof(TUserRole), + typeof(TUserLogin), + typeof(TUserToken), + typeof(TRoleClaim)); + } + + /// + /// Adds an linq2db implementation of identity information stores. + /// + /// The instance this method extends. + /// + /// + /// + /// The type of the primary key used for the users and roles. + /// The type representing a claim. + /// The type representing a user role. + /// The type representing a user external login. + /// The type representing a user token. + /// The type of the class representing a role claim. + /// The instance this method extends. + // ReSharper disable once InconsistentNaming + public static IdentityBuilder AddLinqToDBStores(this IdentityBuilder builder, IConnectionFactory factory, + Type keyType, Type userClaimType, Type userRoleType, Type userLoginType, Type userTokenType, Type roleClaimType) { builder.Services.AddSingleton(factory); - builder.Services.TryAdd(GetDefaultServices(builder.UserType, builder.RoleType, typeof(TContext), typeof(TConnection), typeof(TKey))); + builder.Services.TryAdd(GetDefaultServices( + keyType, + builder.UserType, + userClaimType, + userRoleType, + userLoginType, + userTokenType, + builder.RoleType, + roleClaimType)); + return builder; } - private static IServiceCollection GetDefaultServices(Type userType, Type roleType, Type contextType, - Type connectionType, Type keyType = null) + private static IServiceCollection GetDefaultServices(Type keyType, Type userType, Type userClaimType, Type userRoleType, Type userLoginType, Type userTokenType, Type roleType, Type roleClaimType) { - Type userStoreType; - Type roleStoreType; - keyType = keyType ?? typeof(string); - userStoreType = typeof(UserStore<,,,,>).MakeGenericType(contextType, connectionType, userType, roleType, keyType); - roleStoreType = typeof(RoleStore<,,,>).MakeGenericType(contextType, connectionType, roleType, keyType); + //UserStore + var userStoreType = typeof(UserStore<,,,,,,>).MakeGenericType(keyType, userType, roleType, userClaimType, userRoleType, userLoginType, userTokenType); + // RoleStore + var roleStoreType = typeof(RoleStore<,,>).MakeGenericType(keyType, roleType, roleClaimType); var services = new ServiceCollection(); services.AddScoped( diff --git a/src/LinqToDB.Identity/LinqToDB.Identity.csproj b/src/LinqToDB.Identity/LinqToDB.Identity.csproj index eb6779336..3a970c25c 100644 --- a/src/LinqToDB.Identity/LinqToDB.Identity.csproj +++ b/src/LinqToDB.Identity/LinqToDB.Identity.csproj @@ -14,15 +14,16 @@ linq2db.Identity aspnetcore;linq2db;identity;membership;LinqToDB http://www.gravatar.com/avatar/fc2e509b6ed116b9aa29a7988fdb8990?s=320 - https://github.com/ili/LinqToDB.Identity + https://github.com/linq2db/LinqToDB.Identity https://opensource.org/licenses/MIT git - git://github.com/ili/LinqToDB.Identity + git://github.com/linq2db/LinqToDB.Identity 1.6.1 false false false false + 1.2.0 @@ -31,7 +32,7 @@ - + diff --git a/src/LinqToDB.Identity/Properties/AssemblyInfo.cs b/src/LinqToDB.Identity/Properties/AssemblyInfo.cs index 432823c9b..bed11dd72 100644 --- a/src/LinqToDB.Identity/Properties/AssemblyInfo.cs +++ b/src/LinqToDB.Identity/Properties/AssemblyInfo.cs @@ -7,5 +7,5 @@ [assembly: AssemblyMetadata("Serviceable", "True")] [assembly: NeutralResourcesLanguage("en-us")] [assembly: AssemblyCompany("blog.linq2db.com")] -[assembly: AssemblyCopyright("\xA9 2011-2016 blog.linq2db.com")] +[assembly: AssemblyCopyright("© 2011-2017 blog.linq2db.com")] [assembly: AssemblyProduct("Linq to DB")] \ No newline at end of file diff --git a/src/LinqToDB.Identity/RoleStore.cs b/src/LinqToDB.Identity/RoleStore.cs index cc017facd..e6c47a103 100644 --- a/src/LinqToDB.Identity/RoleStore.cs +++ b/src/LinqToDB.Identity/RoleStore.cs @@ -17,27 +17,17 @@ namespace LinqToDB.Identity /// Creates a new instance of a persistence store for roles. /// /// The type of the class representing a role. - /// - /// The type of the class for , - /// - /// - /// - /// The type of the class for , - /// - /// - public class RoleStore : RoleStore + public class RoleStore : RoleStore where TRole : IdentityRole - where TContext : IDataContext - where TConnection : DataConnection { /// - /// Constructs a new instance of . + /// Constructs a new instance of . /// /// - /// + /// /// /// The . - public RoleStore(IConnectionFactory factory, IdentityErrorDescriber describer = null) + public RoleStore(IConnectionFactory factory, IdentityErrorDescriber describer = null) : base(factory, describer) { } @@ -48,47 +38,23 @@ public RoleStore(IConnectionFactory factory, IdentityErro /// /// The type of the class representing a role. /// The type of the primary key for a role. - /// - /// The type of the class for , - /// - /// - /// - /// The type of the class for , - /// - /// - public class RoleStore : - RoleStore, IdentityRoleClaim>, - IQueryableRoleStore, - IRoleClaimStore + public class RoleStore : + RoleStore> where TRole : IdentityRole where TKey : IEquatable - where TContext : IDataContext - where TConnection : DataConnection { /// - /// Constructs a new instance of . + /// Constructs a new instance of . /// /// - /// + /// /// /// The . - public RoleStore(IConnectionFactory factory, IdentityErrorDescriber describer = null) + public RoleStore(IConnectionFactory factory, IdentityErrorDescriber describer = null) : base(factory, describer) { } - /// - /// Creates a entity representing a role claim. - /// - /// The associated role. - /// The associated claim. - /// The role claim entity. - protected override IdentityRoleClaim CreateRoleClaim(TRole role, Claim claim) - { - var roleClaim = new IdentityRoleClaim {RoleId = role.Id}; - roleClaim.InitializeFromClaim(claim); - return roleClaim; - } } /// @@ -96,38 +62,37 @@ protected override IdentityRoleClaim CreateRoleClaim(TRole role, Claim cla /// /// The type of the class representing a role. /// The type of the primary key for a role. - /// The type of the class representing a user role. /// The type of the class representing a role claim. - /// - /// The type of the class for , - /// - /// - /// - /// The type of the class for , - /// - /// - public abstract class RoleStore : + public class RoleStore : IQueryableRoleStore, IRoleClaimStore where TRole : class, IIdentityRole where TKey : IEquatable - where TUserRole : class, IIdentityUserRole - where TRoleClaim : class, IIdentityRoleClaim - where TContext : IDataContext - where TConnection : DataConnection + where TRoleClaim : class, IIdentityRoleClaim, new() { - private readonly IConnectionFactory _factory; + private readonly IConnectionFactory _factory; + + /// + /// Gets from supplied + /// + /// + protected DataConnection GetConnection() => _factory.GetConnection(); + /// + /// Gets from supplied + /// + /// + protected IDataContext GetContext() => _factory.GetContext(); private bool _disposed; /// - /// Constructs a new instance of . + /// Constructs a new instance of . /// /// - /// + /// /// /// The . - public RoleStore(IConnectionFactory factory, IdentityErrorDescriber describer = null) + public RoleStore(IConnectionFactory factory, IdentityErrorDescriber describer = null) { if (factory == null) throw new ArgumentNullException(nameof(factory)); @@ -138,11 +103,6 @@ public RoleStore(IConnectionFactory factory, IdentityErro } - /// - /// Gets the database context for this store. - /// - private IDataContext Context => _factory.GetContext(); - /// /// Gets or sets the for any error that occurred with the current operation. /// @@ -158,7 +118,7 @@ public RoleStore(IConnectionFactory factory, IdentityErro /// should be canceled. /// /// A that represents the of the asynchronous query. - public virtual async Task CreateAsync(TRole role, + public async Task CreateAsync(TRole role, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); @@ -166,11 +126,20 @@ public virtual async Task CreateAsync(TRole role, if (role == null) throw new ArgumentNullException(nameof(role)); + using (var db = GetConnection()) + return await CreateAsync(db, role, cancellationToken); + } - await Task.Run(() => Context.TryInsertAndSetIdentity(role), cancellationToken); + /// + protected virtual async Task CreateAsync(DataConnection db, TRole role, CancellationToken cancellationToken) + { + await Task.Run(() => db.TryInsertAndSetIdentity(role), cancellationToken); return IdentityResult.Success; + } + + /// /// Updates a role in a store as an asynchronous operation. /// @@ -180,7 +149,7 @@ public virtual async Task CreateAsync(TRole role, /// should be canceled. /// /// A that represents the of the asynchronous query. - public virtual async Task UpdateAsync(TRole role, + public async Task UpdateAsync(TRole role, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); @@ -188,7 +157,14 @@ public virtual async Task UpdateAsync(TRole role, if (role == null) throw new ArgumentNullException(nameof(role)); - var result = await Task.Run(() => _factory.GetContext().UpdateConcurrent(role), cancellationToken); + using (var db = GetConnection()) + return await UpdateAsync(db, role, cancellationToken); + } + + /// + protected virtual async Task UpdateAsync(DataConnection db, TRole role, CancellationToken cancellationToken) + { + var result = await Task.Run(() => db.UpdateConcurrent(role), cancellationToken); return result == 1 ? IdentityResult.Success : IdentityResult.Failed(ErrorDescriber.ConcurrencyFailure()); } @@ -201,7 +177,7 @@ public virtual async Task UpdateAsync(TRole role, /// should be canceled. /// /// A that represents the of the asynchronous query. - public virtual async Task DeleteAsync(TRole role, + public async Task DeleteAsync(TRole role, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); @@ -209,10 +185,15 @@ public virtual async Task DeleteAsync(TRole role, if (role == null) throw new ArgumentNullException(nameof(role)); + using (var db = GetConnection()) + return await DeleteAsync(db, role, cancellationToken); + } + + /// + private async Task DeleteAsync(DataConnection db, TRole role, CancellationToken cancellationToken) + { var result = await Task.Run(() => - _factory - .GetContext() - .GetTable() + db.GetTable() .Where(_ => _.Id.Equals(role.Id) && _.ConcurrencyStamp == role.ConcurrencyStamp) .Delete(), cancellationToken); @@ -285,12 +266,20 @@ public Task SetRoleNameAsync(TRole role, string roleName, /// should be canceled. /// /// A that result of the look up. - public virtual Task FindByIdAsync(string id, CancellationToken cancellationToken = default(CancellationToken)) + public async Task FindByIdAsync(string id, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); var roleId = ConvertIdFromString(id); - return Roles.FirstOrDefaultAsync(u => u.Id.Equals(roleId), cancellationToken); + + using (var db = GetConnection()) + return await FindByIdAsync(db, roleId, cancellationToken); + } + + /// + protected virtual async Task FindByIdAsync(DataConnection db, TKey roleId, CancellationToken cancellationToken) + { + return await db.GetTable().FirstOrDefaultAsync(u => u.Id.Equals(roleId), cancellationToken); } /// @@ -302,12 +291,22 @@ public Task SetRoleNameAsync(TRole role, string roleName, /// should be canceled. /// /// A that result of the look up. - public virtual Task FindByNameAsync(string normalizedName, + public async Task FindByNameAsync(string normalizedName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); - return Roles.FirstOrDefaultAsync(r => r.NormalizedName == normalizedName, cancellationToken); + using (var db = GetConnection()) + { + return await FindByNameAsync(db, normalizedName, cancellationToken); + } + + } + + /// + protected virtual async Task FindByNameAsync(DataConnection db, string normalizedName, CancellationToken cancellationToken) + { + return await db.GetTable().FirstOrDefaultAsync(r => r.NormalizedName == normalizedName, cancellationToken); } /// @@ -361,7 +360,7 @@ public void Dispose() /// /// A navigation property for the roles the store contains. /// - public virtual IQueryable Roles => Context.GetTable(); + public virtual IQueryable Roles => GetContext().GetTable(); /// /// Get the claims associated with the specified as an asynchronous operation. @@ -379,7 +378,16 @@ public async Task> GetClaimsAsync(TRole role, if (role == null) throw new ArgumentNullException(nameof(role)); - return await Context.GetTable() + using (var db = GetConnection()) + { + return await GetClaimsAsync(db, role, cancellationToken); + } + } + + /// + protected virtual async Task> GetClaimsAsync(DataConnection db, TRole role, CancellationToken cancellationToken) + { + return await db.GetTable() .Where(rc => rc.RoleId.Equals(role.Id)) .Select(c => c.ToClaim()) .ToListAsync(cancellationToken); @@ -395,8 +403,7 @@ public async Task> GetClaimsAsync(TRole role, /// should be canceled. /// /// The that represents the asynchronous operation. - public virtual Task AddClaimAsync(TRole role, Claim claim, - CancellationToken cancellationToken = default(CancellationToken)) + public async Task AddClaimAsync(TRole role, Claim claim, CancellationToken cancellationToken = default(CancellationToken)) { ThrowIfDisposed(); if (role == null) @@ -404,9 +411,15 @@ public virtual Task AddClaimAsync(TRole role, Claim claim, if (claim == null) throw new ArgumentNullException(nameof(claim)); - Context.TryInsertAndSetIdentity(CreateRoleClaim(role, claim)); + using (var db = GetConnection()) + await AddClaimAsync(db, role, claim, cancellationToken); + } - return Task.FromResult(false); + /// + protected virtual async Task AddClaimAsync(DataConnection db, TRole role, Claim claim, CancellationToken cancellationToken) + { + await Task.Run(() => db.TryInsertAndSetIdentity(CreateRoleClaim(role, claim)), + cancellationToken); } /// @@ -428,8 +441,17 @@ public async Task RemoveClaimAsync(TRole role, Claim claim, if (claim == null) throw new ArgumentNullException(nameof(claim)); + using (var db = GetConnection()) + { + await RemoveClaimAsync(db, role, claim, cancellationToken); + } + } + + /// + protected virtual async Task RemoveClaimAsync(DataConnection db, TRole role, Claim claim, CancellationToken cancellationToken) + { await Task.Run(() => - Context.GetTable() + db.GetTable() .Where(rc => rc.RoleId.Equals(role.Id) && rc.ClaimValue == claim.Value && rc.ClaimType == claim.Type) .Delete(), cancellationToken); @@ -475,6 +497,11 @@ protected void ThrowIfDisposed() /// The associated role. /// The associated claim. /// The role claim entity. - protected abstract TRoleClaim CreateRoleClaim(TRole role, Claim claim); + protected virtual TRoleClaim CreateRoleClaim(TRole role, Claim claim) + { + var roleClaim = new TRoleClaim(){ RoleId = role.Id }; + roleClaim.InitializeFromClaim(claim); + return roleClaim; + } } } \ No newline at end of file diff --git a/src/LinqToDB.Identity/UserStore.cs b/src/LinqToDB.Identity/UserStore.cs index 7a139f838..8a45705b8 100644 --- a/src/LinqToDB.Identity/UserStore.cs +++ b/src/LinqToDB.Identity/UserStore.cs @@ -20,21 +20,17 @@ namespace LinqToDB.Identity /// Creates a new instance of a persistence store for the specified user type. /// /// The type representing a user. - /// The type of the data getContext class used to access the store. - /// The type repewsenting database getConnection - public class UserStore : UserStore + public class UserStore : UserStore where TUser : IdentityUser, new() - where TContext : IDataContext - where TConnection : DataConnection { /// /// Constructs a new instance of . /// /// - /// + /// /// /// The . - public UserStore(IConnectionFactory factory, IdentityErrorDescriber describer = null) + public UserStore(IConnectionFactory factory, IdentityErrorDescriber describer = null) : base(factory, describer) { } @@ -45,22 +41,18 @@ public UserStore(IConnectionFactory factory, IdentityErro /// /// The type representing a user. /// The type representing a role. - /// The type of the data getContext class used to access the store. - /// The type repewsenting database getConnection - public class UserStore : UserStore + public class UserStore : UserStore where TUser : IdentityUser where TRole : IdentityRole - where TContext : IDataContext - where TConnection : DataConnection { /// - /// Constructs a new instance of . + /// Constructs a new instance of . /// /// - /// + /// /// /// The . - public UserStore(IConnectionFactory factory, IdentityErrorDescriber describer = null) + public UserStore(IConnectionFactory factory, IdentityErrorDescriber describer = null) : base(factory, describer) { } @@ -71,27 +63,22 @@ public UserStore(IConnectionFactory factory, IdentityErro /// /// The type representing a user. /// The type representing a role. - /// The type of the data getContext class used to access the store. /// The type of the primary key for a role. - /// The type repewsenting database getConnection - public class UserStore : - UserStore - , IdentityUserRole, IdentityUserLogin, + public class UserStore : + UserStore, IdentityUserRole, IdentityUserLogin, IdentityUserToken> where TUser : IdentityUser where TRole : IdentityRole - where TContext : IDataContext where TKey : IEquatable - where TConnection : DataConnection { /// - /// Constructs a new instance of . + /// Constructs a new instance of . /// /// - /// + /// /// /// The . - public UserStore(IConnectionFactory factory, IdentityErrorDescriber describer = null) + public UserStore(IConnectionFactory factory, IdentityErrorDescriber describer = null) : base(factory, describer) { } @@ -162,62 +149,18 @@ protected override IdentityUserToken CreateUserToken(TUser user, string lo } } - /// - /// Represents a new instance of a persistence store for the specified user and role types. - /// - /// The type representing a user. - /// The type representing a role. - /// The type of the data getContext class used to access the store. - /// The type of the primary key for a role. - /// The type representing a claim. - /// The type representing a user role. - /// The type representing a user external login. - /// The type representing a user token. - /// The type repewsenting database getConnection - public abstract class UserStore : - UserStore - > - where TUser : IdentityUser - where TRole : IdentityRole> - where TContext : IDataContext - where TKey : IEquatable - where TUserClaim : IdentityUserClaim - where TUserRole : IdentityUserRole - where TUserLogin : IdentityUserLogin - where TUserToken : IdentityUserToken - where TConnection : DataConnection - { - /// - /// Creates a new instance of - /// . - /// - /// - /// - /// - /// The used to describe store errors. - public UserStore(IConnectionFactory factory, IdentityErrorDescriber describer = null) - : base(factory, describer) - { - } - } - /// /// Represents a new instance of a persistence store for the specified user and role types. /// /// The type representing a user. /// The type representing a role. - /// The type of the data getContext class used to access the store. /// The type of the primary key for a role. /// The type representing a claim. /// The type representing a user role. /// The type representing a user external login. /// The type representing a user token. - /// The type representing a role claim. - /// The type repewsenting database getConnection - public abstract class UserStore : + public class UserStore : IUserLoginStore, IUserRoleStore, IUserClaimStore, @@ -231,30 +174,38 @@ public abstract class UserStore where TUser : class, IIdentityUser where TRole : class, IIdentityRole - where TUserClaim : class, IIdentityUserClaim - where TUserRole : class, IIdentityUserRole - where TUserLogin : class, IIdentityUserLogin - where TUserToken : class, IIdentityUserToken - where TRoleClaim : class, IIdentityRoleClaim - where TContext : IDataContext - where TConnection : DataConnection + where TUserClaim : class, IIdentityUserClaim, new() + where TUserRole : class, IIdentityUserRole, new() + where TUserLogin : class, IIdentityUserLogin, new() + where TUserToken : class, IIdentityUserToken, new () where TKey : IEquatable { - private readonly IConnectionFactory _factory; + private readonly IConnectionFactory _factory; + + /// + /// Gets from supplied + /// + /// + protected DataConnection GetConnection() => _factory.GetConnection(); + /// + /// Gets from supplied + /// + /// + protected IDataContext GetContext() => _factory.GetContext(); private bool _disposed; /// /// Creates a new instance of /// + /// cref="LinqToDB.Identity.UserStore{TUser,TRole,TKey,TUserClaim,TUserRole,TUserLogin,TUserToken}" /> /// . /// /// - /// + /// /// /// The used to describe store errors. - public UserStore(IConnectionFactory factory, IdentityErrorDescriber describer = null) + public UserStore(IConnectionFactory factory, IdentityErrorDescriber describer = null) { if (factory == null) throw new ArgumentNullException(nameof(factory)); @@ -272,7 +223,7 @@ public UserStore(IConnectionFactory factory, IdentityErro /// /// A navigation property for the users the store contains. /// - public virtual IQueryable Users => _factory.GetContext().GetTable(); + public virtual IQueryable Users => GetContext().GetTable(); /// /// Sets the token value for a particular user. @@ -286,7 +237,7 @@ public UserStore(IConnectionFactory factory, IdentityErro /// should be canceled. /// /// The that represents the asynchronous operation. - public virtual async Task SetTokenAsync(TUser user, string loginProvider, string name, string value, + public async Task SetTokenAsync(TUser user, string loginProvider, string name, string value, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -295,25 +246,32 @@ public virtual async Task SetTokenAsync(TUser user, string loginProvider, string if (user == null) throw new ArgumentNullException(nameof(user)); + using (var db = GetConnection()) + { + await SetTokenAsync(db, user, loginProvider, name, value, cancellationToken); + } + } + + /// + protected virtual async Task SetTokenAsync(DataConnection db, TUser user, string loginProvider, string name, string value, + CancellationToken cancellationToken) + { await Task.Run(() => { - using (var dc = _factory.GetConnection()) + var q = db.GetTable() + .Where(_ => _.UserId.Equals(user.Id) && _.LoginProvider == loginProvider && _.Name == name); + + var token = q.FirstOrDefault(); + + if (token == null) { - var q = dc.GetTable() - .Where(_ => _.UserId.Equals(user.Id) && _.LoginProvider == loginProvider && _.Name == name); - - var token = q.FirstOrDefault(); - - if (token == null) - { - dc.Insert(CreateUserToken(user, loginProvider, name, value)); - } - else - { - token.Value = value; - q.Set(_ => _.Value, value) - .Update(); - } + db.Insert(CreateUserToken(user, loginProvider, name, value)); + } + else + { + token.Value = value; + q.Set(_ => _.Value, value) + .Update(); } }, cancellationToken); } @@ -337,9 +295,19 @@ public async Task RemoveTokenAsync(TUser user, string loginProvider, string name if (user == null) throw new ArgumentNullException(nameof(user)); + + using (var db = GetConnection()) + { + await RemoveTokenAsync(db, user, loginProvider, name, cancellationToken); + } + } + + /// + protected virtual async Task RemoveTokenAsync(DataConnection db, TUser user, string loginProvider, string name, + CancellationToken cancellationToken) + { await Task.Run(() => - _factory.GetContext() - .GetTable() + db.GetTable() .Where(_ => _.UserId.Equals(user.Id) && _.LoginProvider == loginProvider && _.Name == name) .Delete(), cancellationToken); @@ -365,7 +333,17 @@ public async Task GetTokenAsync(TUser user, string loginProvider, string if (user == null) throw new ArgumentNullException(nameof(user)); - var entry = await _factory.GetContext() + using (var db = GetConnection()) + { + return await GetTokenAsync(db, user, loginProvider, name, cancellationToken); + } + } + + /// + protected virtual async Task GetTokenAsync(DataConnection db, TUser user, string loginProvider, string name, + CancellationToken cancellationToken) + { + var entry = await db .GetTable() .Where(_ => _.UserId.Equals(user.Id) && _.LoginProvider == loginProvider && _.Name == name) .FirstOrDefaultAsync(cancellationToken); @@ -382,15 +360,25 @@ public async Task GetTokenAsync(TUser user, string loginProvider, string /// should be canceled. /// /// A that contains the claims granted to a user. - public virtual async Task> GetClaimsAsync(TUser user, + public async Task> GetClaimsAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { ThrowIfDisposed(); if (user == null) throw new ArgumentNullException(nameof(user)); + + using (var db = GetConnection()) + { + return await GetClaimsAsync(db, user, cancellationToken); + } + } + + /// + protected virtual async Task> GetClaimsAsync(DataConnection db, TUser user, CancellationToken cancellationToken) + { return await - _factory.GetContext() + db .GetTable() .Where(uc => uc.UserId.Equals(user.Id)) .Select(c => c.ToClaim()) @@ -407,7 +395,7 @@ public virtual async Task> GetClaimsAsync(TUser user, /// should be canceled. /// /// The that represents the asynchronous operation. - public virtual Task AddClaimsAsync(TUser user, IEnumerable claims, + public Task AddClaimsAsync(TUser user, IEnumerable claims, CancellationToken cancellationToken = default(CancellationToken)) { ThrowIfDisposed(); @@ -417,12 +405,17 @@ public virtual Task AddClaimsAsync(TUser user, IEnumerable claims, throw new ArgumentNullException(nameof(claims)); var data = claims.Select(_ => CreateUserClaim(user, _)); - using (var dc = _factory.GetConnection()) + using (var dc = GetConnection()) { - dc.BulkCopy(data); + return AddClaimsAsync(dc, data, cancellationToken); } + } - return Task.FromResult(false); + /// + protected virtual Task AddClaimsAsync(DataConnection dc, IEnumerable data, CancellationToken cancellationToken) + { + dc.BulkCopy(data); + return Task.FromResult(true); } /// @@ -437,7 +430,7 @@ public virtual Task AddClaimsAsync(TUser user, IEnumerable claims, /// should be canceled. /// /// The that represents the asynchronous operation. - public virtual async Task ReplaceClaimAsync(TUser user, Claim claim, Claim newClaim, + public async Task ReplaceClaimAsync(TUser user, Claim claim, Claim newClaim, CancellationToken cancellationToken = default(CancellationToken)) { ThrowIfDisposed(); @@ -448,9 +441,19 @@ public virtual async Task ReplaceClaimAsync(TUser user, Claim claim, Claim newCl if (newClaim == null) throw new ArgumentNullException(nameof(newClaim)); + using (var db = GetConnection()) + { + await ReplaceClaimAsync(user, claim, newClaim, cancellationToken, db); + } + } + + /// + protected virtual async Task ReplaceClaimAsync(TUser user, Claim claim, Claim newClaim, CancellationToken cancellationToken, + DataConnection db) + { await Task.Run(() => { - var q = _factory.GetContext() + var q = db .GetTable() .Where(uc => uc.UserId.Equals(user.Id) && uc.ClaimValue == claim.Value && uc.ClaimType == claim.Type); @@ -470,7 +473,7 @@ await Task.Run(() => /// should be canceled. /// /// The that represents the asynchronous operation. - public virtual async Task RemoveClaimsAsync(TUser user, IEnumerable claims, + public async Task RemoveClaimsAsync(TUser user, IEnumerable claims, CancellationToken cancellationToken = default(CancellationToken)) { ThrowIfDisposed(); @@ -479,16 +482,25 @@ public virtual async Task RemoveClaimsAsync(TUser user, IEnumerable claim if (claims == null) throw new ArgumentNullException(nameof(claims)); + using (var db = GetConnection()) + { + await RemoveClaimsAsync(db, user, claims, cancellationToken); + } + } + + /// + protected virtual async Task RemoveClaimsAsync(DataConnection db, TUser user, IEnumerable claims, + CancellationToken cancellationToken) + { await Task.Run(() => { - var dc = _factory.GetContext(); - var q = dc.GetTable(); + var q = db.GetTable(); var userId = Expression.PropertyOrField(Expression.Constant(user, typeof(TUser)), nameof(user.Id)); var equals = typeof(TKey).GetMethod(nameof(IEquatable.Equals), new[] {typeof(TKey)}); var uc = Expression.Parameter(typeof(TUserClaim)); Expression body = null; var ucUserId = Expression.PropertyOrField(uc, nameof(IIdentityUserClaim.UserId)); - var userIdEquals = Expression.Call(ucUserId, equals, userId); + var userIdEquals = Expression.Call(ucUserId, @equals, userId); foreach (var claim in claims) { @@ -527,18 +539,24 @@ await Task.Run(() => /// /// The contains a list of users, if any, that contain the specified claim. /// - public virtual async Task> GetUsersForClaimAsync(Claim claim, - CancellationToken cancellationToken = default(CancellationToken)) + public async Task> GetUsersForClaimAsync(Claim claim, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (claim == null) throw new ArgumentNullException(nameof(claim)); - var dc = _factory.GetContext(); + using (var db = GetConnection()) + { + return await UsersForClaimAsync(db, claim, cancellationToken); + } + } - var query = from userclaims in dc.GetTable() - join user in Users on userclaims.UserId equals user.Id + /// + protected virtual async Task> UsersForClaimAsync(DataConnection db, Claim claim, CancellationToken cancellationToken) + { + var query = from userclaims in db.GetTable() + join user in db.GetTable() on userclaims.UserId equals user.Id where userclaims.ClaimValue == claim.Value && userclaims.ClaimType == claim.Type select user; @@ -693,12 +711,22 @@ public virtual Task SetNormalizedEmailAsync(TUser user, string normalizedEmail, /// The task object containing the results of the asynchronous lookup operation, the user if any associated with the /// specified normalized email address. /// - public virtual Task FindByEmailAsync(string normalizedEmail, + public async Task FindByEmailAsync(string normalizedEmail, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); - return Users.FirstOrDefaultAsync(u => u.NormalizedEmail == normalizedEmail, cancellationToken); + + using (var db = GetConnection()) + { + return await FindByEmailAsync(db, normalizedEmail, cancellationToken); + } + } + + /// + protected virtual async Task FindByEmailAsync(DataConnection db, string normalizedEmail, CancellationToken cancellationToken) + { + return await db.GetTable().FirstOrDefaultAsync(u => u.NormalizedEmail == normalizedEmail, cancellationToken); } /// @@ -976,14 +1004,23 @@ public virtual Task SetNormalizedUserNameAsync(TUser user, string normalizedName /// The that represents the asynchronous operation, containing the /// of the creation operation. /// - public virtual async Task CreateAsync(TUser user, - CancellationToken cancellationToken = default(CancellationToken)) + public async Task CreateAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) throw new ArgumentNullException(nameof(user)); - await Task.Run(() => _factory.GetContext().TryInsertAndSetIdentity(user), cancellationToken); + + using (var db = GetConnection()) + { + return await CreateAsync(db, user, cancellationToken); + } + } + + /// + protected virtual async Task CreateAsync(DataConnection db, TUser user, CancellationToken cancellationToken) + { + await Task.Run(() => db.TryInsertAndSetIdentity(user), cancellationToken); return IdentityResult.Success; } @@ -999,15 +1036,23 @@ public virtual async Task CreateAsync(TUser user, /// The that represents the asynchronous operation, containing the /// of the update operation. /// - public virtual async Task UpdateAsync(TUser user, + public async Task UpdateAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) throw new ArgumentNullException(nameof(user)); + using (var db = GetConnection()) + { + return await UpdateAsync(db, user, cancellationToken); + } + } - var result = await Task.Run(() => _factory.GetContext().UpdateConcurrent(user), cancellationToken); + /// + protected virtual async Task UpdateAsync(DataConnection db, TUser user, CancellationToken cancellationToken) + { + var result = await Task.Run(() => db.UpdateConcurrent(user), cancellationToken); return result == 1 ? IdentityResult.Success : IdentityResult.Failed(ErrorDescriber.ConcurrencyFailure()); } @@ -1023,7 +1068,7 @@ public virtual async Task UpdateAsync(TUser user, /// The that represents the asynchronous operation, containing the /// of the update operation. /// - public virtual async Task DeleteAsync(TUser user, + public async Task DeleteAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); @@ -1031,10 +1076,17 @@ public virtual async Task DeleteAsync(TUser user, if (user == null) throw new ArgumentNullException(nameof(user)); + using (var db = GetConnection()) + { + return await DeleteAsync(db, user, cancellationToken); + } + } + + /// + protected virtual async Task DeleteAsync(DataConnection db, TUser user, CancellationToken cancellationToken) + { var result = await Task.Run(() => - _factory - .GetContext() - .GetTable() + db.GetTable() .Where(_ => _.Id.Equals(user.Id) && _.ConcurrencyStamp == user.ConcurrencyStamp) .Delete(), cancellationToken); return result == 1 ? IdentityResult.Success : IdentityResult.Failed(ErrorDescriber.ConcurrencyFailure()); @@ -1052,13 +1104,23 @@ public virtual async Task DeleteAsync(TUser user, /// The that represents the asynchronous operation, containing the user matching the specified /// if it exists. /// - public virtual Task FindByIdAsync(string userId, + public async Task FindByIdAsync(string userId, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); var id = ConvertIdFromString(userId); - return _factory.GetContext().GetTable().FirstOrDefaultAsync(_ => _.Id.Equals(id), cancellationToken); + + using (var db = GetConnection()) + { + return await FindByIdAsync(db, id, cancellationToken); + } + } + + /// + protected virtual async Task FindByIdAsync(DataConnection db, TKey id, CancellationToken cancellationToken) + { + return await db.GetTable().FirstOrDefaultAsync(_ => _.Id.Equals(id), cancellationToken); } /// @@ -1073,12 +1135,23 @@ public virtual Task FindByIdAsync(string userId, /// The that represents the asynchronous operation, containing the user matching the specified /// if it exists. /// - public virtual Task FindByNameAsync(string normalizedUserName, + public async Task FindByNameAsync(string normalizedUserName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); - return Users.FirstOrDefaultAsync(u => u.NormalizedUserName == normalizedUserName, cancellationToken); + using (var db = GetConnection()) + { + return await FindByNameAsync(db, normalizedUserName, cancellationToken); + } + } + + /// + protected virtual async Task FindByNameAsync(DataConnection db, string normalizedUserName, + CancellationToken cancellationToken) + { + return await db.GetTable() + .FirstOrDefaultAsync(u => u.NormalizedUserName == normalizedUserName, cancellationToken); } /// @@ -1099,7 +1172,7 @@ public void Dispose() /// should be canceled. /// /// The that represents the asynchronous operation. - public virtual Task AddLoginAsync(TUser user, UserLoginInfo login, + public async Task AddLoginAsync(TUser user, UserLoginInfo login, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); @@ -1109,9 +1182,17 @@ public virtual Task AddLoginAsync(TUser user, UserLoginInfo login, if (login == null) throw new ArgumentNullException(nameof(login)); - _factory.GetContext().Insert(CreateUserLogin(user, login)); + using (var db = GetConnection()) + { + await AddLoginAsync(db, user, login, cancellationToken); + } + } - return Task.FromResult(false); + /// + protected virtual async Task AddLoginAsync(DataConnection db, TUser user, UserLoginInfo login, + CancellationToken cancellationToken) + { + await Task.Run(() => db.Insert(CreateUserLogin(user, login)), cancellationToken); } /// @@ -1125,15 +1206,26 @@ public virtual Task AddLoginAsync(TUser user, UserLoginInfo login, /// should be canceled. /// /// The that represents the asynchronous operation. - public virtual async Task RemoveLoginAsync(TUser user, string loginProvider, string providerKey, + public async Task RemoveLoginAsync(TUser user, string loginProvider, string providerKey, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) throw new ArgumentNullException(nameof(user)); + + using (var db = GetConnection()) + { + await RemoveLoginAsync(db, user, loginProvider, providerKey, cancellationToken); + } + } + + /// + protected virtual async Task RemoveLoginAsync(DataConnection db, TUser user, string loginProvider, string providerKey, + CancellationToken cancellationToken) + { await Task.Run(() => - _factory.GetContext() + db .GetTable() .Delete( userLogin => @@ -1156,15 +1248,25 @@ await Task.Run(() => /// The for the asynchronous operation, containing a list of for the /// specified , if any. /// - public virtual async Task> GetLoginsAsync(TUser user, + public async Task> GetLoginsAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) throw new ArgumentNullException(nameof(user)); + + using (var db = GetConnection()) + { + return await GetLoginsAsync(db, user, cancellationToken); + } + } + + /// + protected virtual async Task> GetLoginsAsync(DataConnection db, TUser user, CancellationToken cancellationToken) + { var userId = user.Id; - return await _factory.GetContext() + return await db .GetTable() .Where(l => l.UserId.Equals(userId)) .Select(l => new UserLoginInfo(l.LoginProvider, l.ProviderKey, l.ProviderDisplayName)) @@ -1184,15 +1286,24 @@ public virtual async Task> GetLoginsAsync(TUser user, /// The for the asynchronous operation, containing the user, if any which matched the specified /// login provider and key. /// - public virtual async Task FindByLoginAsync(string loginProvider, string providerKey, + public async Task FindByLoginAsync(string loginProvider, string providerKey, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); - var dc = _factory.GetContext(); - var q = from ul in dc.GetTable() - join u in dc.GetTable() on ul.UserId equals u.Id + using (var db = GetConnection()) + { + return await FindByLoginAsync(db, loginProvider, providerKey, cancellationToken); + } + } + + /// + protected virtual async Task FindByLoginAsync(DataConnection db, string loginProvider, string providerKey, + CancellationToken cancellationToken) + { + var q = from ul in db.GetTable() + join u in db.GetTable() on ul.UserId equals u.Id where ul.LoginProvider == loginProvider && ul.ProviderKey == providerKey select u; @@ -1355,7 +1466,7 @@ public virtual Task SetPhoneNumberConfirmedAsync(TUser user, bool confirmed, /// should be canceled. /// /// The that represents the asynchronous operation. - public virtual async Task AddToRoleAsync(TUser user, string normalizedRoleName, + public async Task AddToRoleAsync(TUser user, string normalizedRoleName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); @@ -1364,17 +1475,25 @@ public virtual async Task AddToRoleAsync(TUser user, string normalizedRoleName, throw new ArgumentNullException(nameof(user)); if (string.IsNullOrWhiteSpace(normalizedRoleName)) throw new ArgumentException(Resources.ValueCannotBeNullOrEmpty, nameof(normalizedRoleName)); + + using (var db = GetConnection()) + { + await AddToRoleAsync(db, user, normalizedRoleName, cancellationToken); + } + } + + /// + protected virtual async Task AddToRoleAsync(DataConnection db, TUser user, string normalizedRoleName, + CancellationToken cancellationToken) + { await Task.Run(() => { - using (var dc = _factory.GetConnection()) - { - var roleEntity = dc.GetTable() - .SingleOrDefault(r => r.NormalizedName == normalizedRoleName); - if (roleEntity == null) - throw new InvalidOperationException(string.Format(CultureInfo.CurrentCulture, Resources.RoleNotFound, - normalizedRoleName)); - dc.TryInsertAndSetIdentity(CreateUserRole(user, roleEntity)); - } + var roleEntity = db.GetTable() + .SingleOrDefault(r => r.NormalizedName == normalizedRoleName); + if (roleEntity == null) + throw new InvalidOperationException(string.Format(CultureInfo.CurrentCulture, Resources.RoleNotFound, + normalizedRoleName)); + db.TryInsertAndSetIdentity(CreateUserRole(user, roleEntity)); }, cancellationToken); } @@ -1388,7 +1507,7 @@ await Task.Run(() => /// should be canceled. /// /// The that represents the asynchronous operation. - public virtual async Task RemoveFromRoleAsync(TUser user, string normalizedRoleName, + public async Task RemoveFromRoleAsync(TUser user, string normalizedRoleName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); @@ -1398,13 +1517,21 @@ public virtual async Task RemoveFromRoleAsync(TUser user, string normalizedRoleN if (string.IsNullOrWhiteSpace(normalizedRoleName)) throw new ArgumentException(Resources.ValueCannotBeNullOrEmpty, nameof(normalizedRoleName)); + using (var db = GetConnection()) + { + await RemoveFromRoleAsync(db, user, normalizedRoleName, cancellationToken); + } + } + + /// + protected virtual async Task RemoveFromRoleAsync(DataConnection db, TUser user, string normalizedRoleName, + CancellationToken cancellationToken) + { await Task.Run(() => { - var dc = _factory.GetContext(); - var q = - from ur in dc.GetTable() - join r in dc.GetTable() on ur.RoleId equals r.Id + from ur in db.GetTable() + join r in db.GetTable() on ur.RoleId equals r.Id where r.NormalizedName == normalizedRoleName && ur.UserId.Equals(user.Id) select ur; @@ -1422,17 +1549,26 @@ join r in dc.GetTable() on ur.RoleId equals r.Id /// should be canceled. /// /// A that contains the roles the user is a member of. - public virtual async Task> GetRolesAsync(TUser user, + public async Task> GetRolesAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) throw new ArgumentNullException(nameof(user)); + + using (var db = GetConnection()) + { + return await GetRolesAsync(db, user, cancellationToken); + } + } + + /// + protected virtual async Task> GetRolesAsync(DataConnection db, TUser user, CancellationToken cancellationToken) + { var userId = user.Id; - var dc = _factory.GetContext(); - var query = from userRole in dc.GetTable() - join role in dc.GetTable() on userRole.RoleId equals role.Id + var query = from userRole in db.GetTable() + join role in db.GetTable() on userRole.RoleId equals role.Id where userRole.UserId.Equals(userId) select role.Name; @@ -1453,7 +1589,7 @@ where userRole.UserId.Equals(userId) /// If the /// user is a member of the group the returned value with be true, otherwise it will be false. /// - public virtual async Task IsInRoleAsync(TUser user, string normalizedRoleName, + public async Task IsInRoleAsync(TUser user, string normalizedRoleName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); @@ -1463,10 +1599,18 @@ public virtual async Task IsInRoleAsync(TUser user, string normalizedRoleN if (string.IsNullOrWhiteSpace(normalizedRoleName)) throw new ArgumentException(Resources.ValueCannotBeNullOrEmpty, nameof(normalizedRoleName)); - var dc = _factory.GetContext(); + using (var db = GetConnection()) + { + return await IsInRoleAsync(db, user, normalizedRoleName, cancellationToken); + } + } - var q = from ur in dc.GetTable() - join r in dc.GetTable() on ur.RoleId equals r.Id + /// + protected virtual async Task IsInRoleAsync(DataConnection db, TUser user, string normalizedRoleName, + CancellationToken cancellationToken) + { + var q = from ur in db.GetTable() + join r in db.GetTable() on ur.RoleId equals r.Id where r.NormalizedName == normalizedRoleName && ur.UserId.Equals(user.Id) select ur; @@ -1484,7 +1628,7 @@ join r in dc.GetTable() on ur.RoleId equals r.Id /// /// The contains a list of users, if any, that are in the specified role. /// - public virtual async Task> GetUsersInRoleAsync(string normalizedRoleName, + public async Task> GetUsersInRoleAsync(string normalizedRoleName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); @@ -1492,11 +1636,19 @@ public virtual async Task> GetUsersInRoleAsync(string normalizedRol if (string.IsNullOrEmpty(normalizedRoleName)) throw new ArgumentNullException(nameof(normalizedRoleName)); - var dc = _factory.GetContext(); + using (var db = GetConnection()) + { + return await GetUsersInRoleAsync(db, normalizedRoleName, cancellationToken); + } + } - var query = from userrole in dc.GetTable() - join user in Users on userrole.UserId equals user.Id - join role in dc.GetTable() on userrole.RoleId equals role.Id + /// + protected virtual async Task> GetUsersInRoleAsync(DataConnection db, string normalizedRoleName, + CancellationToken cancellationToken) + { + var query = from userrole in db.GetTable() + join user in db.GetTable() on userrole.UserId equals user.Id + join role in db.GetTable() on userrole.RoleId equals role.Id where role.NormalizedName == normalizedRoleName select user; @@ -1602,7 +1754,14 @@ public virtual Task GetTwoFactorEnabledAsync(TUser user, /// /// /// - protected abstract TUserRole CreateUserRole(TUser user, TRole role); + protected virtual TUserRole CreateUserRole(TUser user, TRole role) + { + return new TUserRole() + { + UserId = user.Id, + RoleId = role.Id + }; + } /// /// Create a new entity representing a user claim. @@ -1610,7 +1769,12 @@ public virtual Task GetTwoFactorEnabledAsync(TUser user, /// /// /// - protected abstract TUserClaim CreateUserClaim(TUser user, Claim claim); + protected virtual TUserClaim CreateUserClaim(TUser user, Claim claim) + { + var res = new TUserClaim() {UserId = user.Id}; + res.InitializeFromClaim(claim); + return res; + } /// /// Create a new entity representing a user login. @@ -1618,7 +1782,16 @@ public virtual Task GetTwoFactorEnabledAsync(TUser user, /// /// /// - protected abstract TUserLogin CreateUserLogin(TUser user, UserLoginInfo login); + protected virtual TUserLogin CreateUserLogin(TUser user, UserLoginInfo login) + { + return new TUserLogin() + { + UserId = user.Id, + LoginProvider = login.LoginProvider, + ProviderDisplayName = login.ProviderDisplayName, + ProviderKey = login.ProviderKey + }; + } /// /// Create a new entity representing a user token. @@ -1628,7 +1801,16 @@ public virtual Task GetTwoFactorEnabledAsync(TUser user, /// /// /// - protected abstract TUserToken CreateUserToken(TUser user, string loginProvider, string name, string value); + protected virtual TUserToken CreateUserToken(TUser user, string loginProvider, string name, string value) + { + return new TUserToken() + { + UserId = user.Id, + LoginProvider = loginProvider, + Name = name, + Value = value + }; + } /// /// Converts the provided to a strongly typed key object. diff --git a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/InMemoryEFUserStoreTest.cs b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/InMemoryEFUserStoreTest.cs index c8ce509e9..b153c92f5 100644 --- a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/InMemoryEFUserStoreTest.cs +++ b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/InMemoryEFUserStoreTest.cs @@ -64,12 +64,12 @@ protected override TestConnectionFactory CreateTestContext() protected override void AddUserStore(IServiceCollection services, TestConnectionFactory context = null) { services.AddSingleton>( - new UserStore(context ?? CreateTestContext())); + new UserStore(context ?? CreateTestContext())); } protected override void AddRoleStore(IServiceCollection services, TestConnectionFactory context = null) { - var store = new RoleStore(context ?? CreateTestContext()); + var store = new RoleStore(context ?? CreateTestContext()); services.AddSingleton>(store); } diff --git a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/InMemoryStoreWithGenericsTest.cs b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/InMemoryStoreWithGenericsTest.cs index 6ded8ead0..f66937644 100644 --- a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/InMemoryStoreWithGenericsTest.cs +++ b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/InMemoryStoreWithGenericsTest.cs @@ -220,8 +220,8 @@ public IdentityUserWithGenerics() } } - public class UserStoreWithGenerics : UserStore { public UserStoreWithGenerics(TestConnectionFactory factory, string loginContext) @@ -283,8 +283,7 @@ protected override IdentityUserTokenWithStuff CreateUserToken(IdentityUserWithGe } } - public class RoleStoreWithGenerics : RoleStore> + public class RoleStoreWithGenerics : RoleStore> { private string _loginContext; diff --git a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test.csproj b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test.csproj index e50d6c756..52af409a5 100644 --- a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test.csproj +++ b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test.csproj @@ -27,7 +27,7 @@ - + diff --git a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/RoleStoreTest.cs b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/RoleStoreTest.cs index ab233efbc..c85077d11 100644 --- a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/RoleStoreTest.cs +++ b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/RoleStoreTest.cs @@ -22,7 +22,7 @@ public RoleStoreTest(InMemoryStorage storage) private readonly InMemoryStorage _storage; - private IConnectionFactory GetConnectionFactory() + private IConnectionFactory GetConnectionFactory() { var connectionString = _storage.ConnectionString; @@ -37,7 +37,7 @@ public async Task CanCreateRoleWithSingletonManager() var services = TestIdentityFactory.CreateTestServices(); //services.AddEntityFrameworkInMemoryDatabase(); services.AddSingleton(GetConnectionFactory()); - services.AddTransient, RoleStore>(); + services.AddTransient, RoleStore>(); services.AddSingleton>(); var provider = services.BuildServiceProvider(); var manager = provider.GetRequiredService>(); @@ -69,7 +69,7 @@ public async Task CanUpdateRoleName() [Fact] public async Task RoleStoreMethodsThrowWhenDisposedTest() { - var store = new RoleStore(GetConnectionFactory()); + var store = new RoleStore(GetConnectionFactory()); store.Dispose(); await Assert.ThrowsAsync(async () => await store.FindByIdAsync(null)); await Assert.ThrowsAsync(async () => await store.FindByNameAsync(null)); @@ -85,8 +85,8 @@ public async Task RoleStoreMethodsThrowWhenDisposedTest() public async Task RoleStorePublicNullCheckTest() { Assert.Throws("factory", - () => new RoleStore(null)); - var store = new RoleStore(GetConnectionFactory()); + () => new RoleStore(null)); + var store = new RoleStore(GetConnectionFactory()); await Assert.ThrowsAsync("role", async () => await store.GetRoleIdAsync(null)); await Assert.ThrowsAsync("role", async () => await store.GetRoleNameAsync(null)); await Assert.ThrowsAsync("role", async () => await store.SetRoleNameAsync(null, null)); diff --git a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/TestIdentityFactory.cs b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/TestIdentityFactory.cs index 4e803dfe8..497989664 100644 --- a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/TestIdentityFactory.cs +++ b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.InMemory.Test/TestIdentityFactory.cs @@ -33,10 +33,10 @@ public static IServiceCollection CreateTestServices() return services; } - public static RoleManager CreateRoleManager(IConnectionFactory factory) + public static RoleManager CreateRoleManager(IConnectionFactory factory) { var services = CreateTestServices(); - services.AddSingleton>(new RoleStore(factory)); + services.AddSingleton>(new RoleStore(factory)); return services.BuildServiceProvider().GetRequiredService>(); } //{ diff --git a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/DefaultPocoTest.cs b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/DefaultPocoTest.cs index bc2513c08..137a5826b 100644 --- a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/DefaultPocoTest.cs +++ b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/DefaultPocoTest.cs @@ -26,11 +26,13 @@ static DefaultPocoTest() { MappingSchema.Default .GetFluentMappingBuilder() + .Entity() .HasPrimaryKey(_ => _.Id) .Property(_ => _.Id) .HasLength(255) .IsNullable(false) + .Entity() .HasPrimaryKey(_ => _.Id) .Property(_ => _.Id) diff --git a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test.csproj b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test.csproj index 2c2cc2774..b736bb97c 100644 --- a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test.csproj +++ b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test.csproj @@ -35,7 +35,7 @@ - + diff --git a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/ProvideConfigutayionTest.cs b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/ProvideConfigutayionTest.cs new file mode 100644 index 000000000..21dd328d3 --- /dev/null +++ b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/ProvideConfigutayionTest.cs @@ -0,0 +1,59 @@ +using LinqToDB.Identity; +using Microsoft.Extensions.DependencyInjection; +using Xunit; +// ReSharper disable InconsistentNaming + +namespace Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test +{ + public class ProvideConfigutayionTest + { + [Fact] + public void AddLinqToDBStores0() + { + var services = new ServiceCollection(); + services + .AddIdentity() + .AddLinqToDBStores(new DefaultConnectionFactory()); + + var sp = services.BuildServiceProvider(); + + Assert.NotNull(sp.GetService>()); + Assert.NotNull(sp.GetService>()); + } + + [Fact] + public void AddLinqToDBStores1() + { + var services = new ServiceCollection(); + services + .AddIdentity, IdentityRole>() + .AddLinqToDBStores(new DefaultConnectionFactory()); + + var sp = services.BuildServiceProvider(); + + Assert.NotNull(sp.GetService>>()); + Assert.NotNull(sp.GetService>>()); + } + + [Fact] + public void AddLinqToDBStores6() + { + var services = new ServiceCollection(); + services + .AddIdentity, IdentityRole>() + .AddLinqToDBStores< + decimal, + IdentityUserClaim, + IdentityUserRole, + IdentityUserLogin, + IdentityUserToken, + IdentityRoleClaim>(new DefaultConnectionFactory()); + + var sp = services.BuildServiceProvider(); + + Assert.NotNull(sp.GetService>>()); + Assert.NotNull(sp.GetService>>()); + } + + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/SqlStoreTestBase.cs b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/SqlStoreTestBase.cs index e49dbd1ab..099640519 100644 --- a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/SqlStoreTestBase.cs +++ b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/SqlStoreTestBase.cs @@ -111,13 +111,13 @@ protected override TestConnectionFactory CreateTestContext() protected override void AddUserStore(IServiceCollection services, TestConnectionFactory context = null) { services.AddSingleton>( - new UserStore(CreateTestContext())); + new UserStore(CreateTestContext())); } protected override void AddRoleStore(IServiceCollection services, TestConnectionFactory context = null) { services.AddSingleton>( - new RoleStore(CreateTestContext())); + new RoleStore(CreateTestContext())); } protected override void SetUserPasswordHash(TUser user, string hashedPassword) diff --git a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/UserStoreGuidKeyTest.cs b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/UserStoreGuidKeyTest.cs index 693144b40..c25143a18 100644 --- a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/UserStoreGuidKeyTest.cs +++ b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/UserStoreGuidKeyTest.cs @@ -45,16 +45,16 @@ protected override void AddRoleStore(IServiceCollection services, TestConnection services.AddSingleton>(new ApplicationRoleStore(context ?? CreateTestContext())); } - public class ApplicationUserStore : UserStore + public class ApplicationUserStore : UserStore { - public ApplicationUserStore(IConnectionFactory factory) : base(factory) + public ApplicationUserStore(IConnectionFactory factory) : base(factory) { } } - public class ApplicationRoleStore : RoleStore + public class ApplicationRoleStore : RoleStore { - public ApplicationRoleStore(IConnectionFactory factory) : base(factory) + public ApplicationRoleStore(IConnectionFactory factory) : base(factory) { } } diff --git a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/UserStoreTest.cs b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/UserStoreTest.cs index d1fdeb0fe..c90fcf9e0 100644 --- a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/UserStoreTest.cs +++ b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/UserStoreTest.cs @@ -84,13 +84,13 @@ public void EnsureDatabase() protected override void AddUserStore(IServiceCollection services, TestConnectionFactory context = null) { services.AddSingleton>( - new UserStore(context ?? CreateTestContext())); + new UserStore(context ?? CreateTestContext())); } protected override void AddRoleStore(IServiceCollection services, TestConnectionFactory context = null) { services.AddSingleton>( - new RoleStore(context ?? CreateTestContext())); + new RoleStore(context ?? CreateTestContext())); } [ConditionalFact] @@ -327,7 +327,7 @@ protected override Expression> UserNameStartsWithPredic [Fact] public async Task SqlUserStoreMethodsThrowWhenDisposedTest() { - var store = new UserStore(CreateTestContext()); + var store = new UserStore(CreateTestContext()); store.Dispose(); await Assert.ThrowsAsync(async () => await store.AddClaimsAsync(null, null)); await Assert.ThrowsAsync(async () => await store.AddLoginAsync(null, null)); @@ -361,8 +361,8 @@ await Assert.ThrowsAsync( public async Task UserStorePublicNullCheckTest() { Assert.Throws("factory", - () => new UserStore(null)); - var store = new UserStore(CreateTestContext()); + () => new UserStore(null)); + var store = new UserStore(CreateTestContext()); await Assert.ThrowsAsync("user", async () => await store.GetUserIdAsync(null)); await Assert.ThrowsAsync("user", async () => await store.GetUserNameAsync(null)); await Assert.ThrowsAsync("user", async () => await store.SetUserNameAsync(null, null)); diff --git a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/UserStoreWithGenericsTest.cs b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/UserStoreWithGenericsTest.cs index 59f10b5ec..c128b3a8b 100644 --- a/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/UserStoreWithGenericsTest.cs +++ b/test/Microsoft.AspNetCore.Identity.EntityFrameworkCore.Test/UserStoreWithGenericsTest.cs @@ -233,19 +233,16 @@ public IdentityUserWithGenerics() public class UserStoreWithGenerics : UserStore< - DataContext, - DataConnection, + string, IdentityUserWithGenerics, MyIdentityRole, - string, IdentityUserClaimWithIssuer, IdentityUserRoleWithDate, IdentityUserLoginWithContext, - IdentityUserTokenWithStuff, - IdentityRoleClaimWithIssuer> + IdentityUserTokenWithStuff> { - public UserStoreWithGenerics(IConnectionFactory fasctory, - string loginContext) : base(fasctory) + public UserStoreWithGenerics(IConnectionFactory factory, + string loginContext) : base(factory) { LoginContext = loginContext; } @@ -299,12 +296,11 @@ protected override IdentityUserTokenWithStuff CreateUserToken(IdentityUserWithGe } } - public class RoleStoreWithGenerics : RoleStore + public class RoleStoreWithGenerics : RoleStore { private string _loginContext; - public RoleStoreWithGenerics(IConnectionFactory factory, + public RoleStoreWithGenerics(IConnectionFactory factory, string loginContext) : base(factory) { _loginContext = loginContext; diff --git a/test/Shared/UserManagerTestBase.cs b/test/Shared/UserManagerTestBase.cs index 85045e334..18ea8a665 100644 --- a/test/Shared/UserManagerTestBase.cs +++ b/test/Shared/UserManagerTestBase.cs @@ -23,7 +23,7 @@ namespace Microsoft.AspNetCore.Identity.Test { - public class TestConnectionFactory : IConnectionFactory + public class TestConnectionFactory : IConnectionFactory { private static readonly Dictionary> _tables = new Dictionary>(); private readonly string _configuration; @@ -41,7 +41,7 @@ public TestConnectionFactory(IDataProvider provider, string configuration, strin _key = _configuration + "$$" + _connectionString; } - public DataContext GetContext() + public IDataContext GetContext() { return new DataContext(_provider, _connectionString); } @@ -182,7 +182,7 @@ protected virtual UserManager CreateManager(TestConnectionFactory context if (context == null) context = CreateTestContext(); - services.AddSingleton>(context); + services.AddSingleton(context); SetupIdentityServices(services, context); if (configureServices != null)