diff --git a/src/IdentityServer4/Models/Contexts/ProfileDataRequestContext.cs b/src/IdentityServer4/Models/Contexts/ProfileDataRequestContext.cs index 893ba555ad..ec0c317ff6 100644 --- a/src/IdentityServer4/Models/Contexts/ProfileDataRequestContext.cs +++ b/src/IdentityServer4/Models/Contexts/ProfileDataRequestContext.cs @@ -29,7 +29,10 @@ public ProfileDataRequestContext() /// The requested claim types. public ProfileDataRequestContext(ClaimsPrincipal subject, Client client, string caller, IEnumerable requestedClaimTypes) { - if (requestedClaimTypes.IsNullOrEmpty()) throw new ArgumentException("No claim types requested", nameof(requestedClaimTypes)); + if (subject == null) throw new ArgumentNullException(nameof(subject)); + if (client == null) throw new ArgumentNullException(nameof(client)); + if (caller == null) throw new ArgumentNullException(nameof(caller)); + if (requestedClaimTypes == null) throw new ArgumentNullException(nameof(requestedClaimTypes)); Subject = subject; Client = client; diff --git a/src/IdentityServer4/Services/DefaultClaimsService.cs b/src/IdentityServer4/Services/DefaultClaimsService.cs index 4bd67f1310..a5075f2a5d 100644 --- a/src/IdentityServer4/Services/DefaultClaimsService.cs +++ b/src/IdentityServer4/Services/DefaultClaimsService.cs @@ -61,31 +61,31 @@ public virtual async Task> GetIdentityTokenClaimsAsync(Claims // fetch all identity claims that need to go into the id token if (includeAllIdentityClaims || client.AlwaysIncludeUserClaimsInIdToken) { - var additionalClaims = new List(); + var additionalClaimTypes = new List(); foreach (var identityResource in resources.IdentityResources) { foreach (var userClaim in identityResource.UserClaims) { - additionalClaims.Add(userClaim); + additionalClaimTypes.Add(userClaim); } } - if (additionalClaims.Count > 0) - { - var context = new ProfileDataRequestContext( - subject, - client, - IdentityServerConstants.ProfileDataCallers.ClaimsProviderIdentityToken, - additionalClaims); + // filter so we don't ask for claim types that we will eventually filter out + additionalClaimTypes = FilterRequestedClaimTypes(additionalClaimTypes).ToList(); - await Profile.GetProfileDataAsync(context); + var context = new ProfileDataRequestContext( + subject, + client, + IdentityServerConstants.ProfileDataCallers.ClaimsProviderIdentityToken, + additionalClaimTypes); - var claims = FilterProtocolClaims(context.IssuedClaims); - if (claims != null) - { - outputClaims.AddRange(claims); - } + await Profile.GetProfileDataAsync(context); + + var claims = FilterProtocolClaims(context.IssuedClaims); + if (claims != null) + { + outputClaims.AddRange(claims); } } @@ -155,7 +155,7 @@ public virtual async Task> GetAccessTokenClaimsAsync(ClaimsPr outputClaims.AddRange(GetOptionalClaims(subject)); // fetch all resource claims that need to go into the access token - var additionalClaims = new List(); + var additionalClaimTypes = new List(); foreach (var api in resources.ApiResources) { // add claims configured on api resource @@ -163,7 +163,7 @@ public virtual async Task> GetAccessTokenClaimsAsync(ClaimsPr { foreach (var claim in api.UserClaims) { - additionalClaims.Add(claim); + additionalClaimTypes.Add(claim); } } @@ -174,27 +174,27 @@ public virtual async Task> GetAccessTokenClaimsAsync(ClaimsPr { foreach (var claim in scope.UserClaims) { - additionalClaims.Add(claim); + additionalClaimTypes.Add(claim); } } } } - if (additionalClaims.Count > 0) - { - var context = new ProfileDataRequestContext( - subject, - client, - IdentityServerConstants.ProfileDataCallers.ClaimsProviderAccessToken, - additionalClaims.Distinct()); + // filter so we don't ask for claim types that we will eventually filter out + additionalClaimTypes = FilterRequestedClaimTypes(additionalClaimTypes).ToList(); - await Profile.GetProfileDataAsync(context); + var context = new ProfileDataRequestContext( + subject, + client, + IdentityServerConstants.ProfileDataCallers.ClaimsProviderAccessToken, + additionalClaimTypes.Distinct()); - var claims = FilterProtocolClaims(context.IssuedClaims); - if (claims != null) - { - outputClaims.AddRange(claims); - } + await Profile.GetProfileDataAsync(context); + + var claims = FilterProtocolClaims(context.IssuedClaims); + if (claims != null) + { + outputClaims.AddRange(claims); } } @@ -246,9 +246,19 @@ protected virtual IEnumerable FilterProtocolClaims(IEnumerable cla if (claimsToFilter.Any()) { var types = claimsToFilter.Select(x => x.Type); - _logger.LogInformation("Claim types from profile service that were filtered: {claimTypes}", types); + _logger.LogDebug("Claim types from profile service that were filtered: {claimTypes}", types); } return claims.Except(claimsToFilter); } + + /// + /// Filters out protocol claims like amr, nonce etc.. + /// + /// The claim types. + protected virtual IEnumerable FilterRequestedClaimTypes(IEnumerable claimTypes) + { + var claimTypesToFilter = claimTypes.Where(x => Constants.Filters.ClaimsServiceFilterClaimTypes.Contains(x)); + return claimTypes.Except(claimTypesToFilter); + } } } \ No newline at end of file diff --git a/src/IdentityServer4/Services/DefaultProfileService.cs b/src/IdentityServer4/Services/DefaultProfileService.cs index b946b000e4..bb3b5b5d56 100644 --- a/src/IdentityServer4/Services/DefaultProfileService.cs +++ b/src/IdentityServer4/Services/DefaultProfileService.cs @@ -4,6 +4,9 @@ using System.Threading.Tasks; using IdentityServer4.Models; +using IdentityServer4.Extensions; +using Microsoft.Extensions.Logging; +using System.Linq; namespace IdentityServer4.Services { @@ -13,6 +16,13 @@ namespace IdentityServer4.Services /// public class DefaultProfileService : IProfileService { + private readonly ILogger _logger; + + public DefaultProfileService(ILogger logger) + { + _logger = logger; + } + /// /// This method is called whenever claims about the user are requested (e.g. during token creation or via the userinfo endpoint) /// @@ -20,7 +30,17 @@ public class DefaultProfileService : IProfileService /// public Task GetProfileDataAsync(ProfileDataRequestContext context) { - context.AddFilteredClaims(context.Subject.Claims); + _logger.LogDebug("Get profile called for {subject} from {client} with {claimTypes} because {caller}", + context.Subject.GetSubjectId(), + context.Client.ClientName, + context.RequestedClaimTypes, + context.Caller); + + if (context.RequestedClaimTypes.Any()) + { + context.AddFilteredClaims(context.Subject.Claims); + } + return Task.FromResult(0); } diff --git a/src/IdentityServer4/Test/TestUserProfileService.cs b/src/IdentityServer4/Test/TestUserProfileService.cs index bcb190eb63..d090f1d558 100644 --- a/src/IdentityServer4/Test/TestUserProfileService.cs +++ b/src/IdentityServer4/Test/TestUserProfileService.cs @@ -5,24 +5,36 @@ using IdentityServer4.Extensions; using IdentityServer4.Models; using IdentityServer4.Services; +using Microsoft.Extensions.Logging; +using System.Linq; using System.Threading.Tasks; namespace IdentityServer4.Test { public class TestUserProfileService : IProfileService { + private readonly ILogger _logger; private readonly TestUserStore _users; - public TestUserProfileService(TestUserStore users) + public TestUserProfileService(TestUserStore users, ILogger logger) { _users = users; + _logger = logger; } public Task GetProfileDataAsync(ProfileDataRequestContext context) { - var user = _users.FindBySubjectId(context.Subject.GetSubjectId()); - - context.AddFilteredClaims(user.Claims); + _logger.LogDebug("Get profile called for {subject} from {client} with {claimTypes} because {caller}", + context.Subject.GetSubjectId(), + context.Client.ClientName, + context.RequestedClaimTypes, + context.Caller); + + if (context.RequestedClaimTypes.Any()) + { + var user = _users.FindBySubjectId(context.Subject.GetSubjectId()); + context.AddFilteredClaims(user.Claims); + } return Task.FromResult(0); }