diff --git a/src/IdentityServer4/Models/Contexts/ProfileDataRequestContext.cs b/src/IdentityServer4/Models/Contexts/ProfileDataRequestContext.cs
index 893ba555ad..ec0c317ff6 100644
--- a/src/IdentityServer4/Models/Contexts/ProfileDataRequestContext.cs
+++ b/src/IdentityServer4/Models/Contexts/ProfileDataRequestContext.cs
@@ -29,7 +29,10 @@ public ProfileDataRequestContext()
/// The requested claim types.
public ProfileDataRequestContext(ClaimsPrincipal subject, Client client, string caller, IEnumerable requestedClaimTypes)
{
- if (requestedClaimTypes.IsNullOrEmpty()) throw new ArgumentException("No claim types requested", nameof(requestedClaimTypes));
+ if (subject == null) throw new ArgumentNullException(nameof(subject));
+ if (client == null) throw new ArgumentNullException(nameof(client));
+ if (caller == null) throw new ArgumentNullException(nameof(caller));
+ if (requestedClaimTypes == null) throw new ArgumentNullException(nameof(requestedClaimTypes));
Subject = subject;
Client = client;
diff --git a/src/IdentityServer4/Services/DefaultClaimsService.cs b/src/IdentityServer4/Services/DefaultClaimsService.cs
index 4bd67f1310..a5075f2a5d 100644
--- a/src/IdentityServer4/Services/DefaultClaimsService.cs
+++ b/src/IdentityServer4/Services/DefaultClaimsService.cs
@@ -61,31 +61,31 @@ public virtual async Task> GetIdentityTokenClaimsAsync(Claims
// fetch all identity claims that need to go into the id token
if (includeAllIdentityClaims || client.AlwaysIncludeUserClaimsInIdToken)
{
- var additionalClaims = new List();
+ var additionalClaimTypes = new List();
foreach (var identityResource in resources.IdentityResources)
{
foreach (var userClaim in identityResource.UserClaims)
{
- additionalClaims.Add(userClaim);
+ additionalClaimTypes.Add(userClaim);
}
}
- if (additionalClaims.Count > 0)
- {
- var context = new ProfileDataRequestContext(
- subject,
- client,
- IdentityServerConstants.ProfileDataCallers.ClaimsProviderIdentityToken,
- additionalClaims);
+ // filter so we don't ask for claim types that we will eventually filter out
+ additionalClaimTypes = FilterRequestedClaimTypes(additionalClaimTypes).ToList();
- await Profile.GetProfileDataAsync(context);
+ var context = new ProfileDataRequestContext(
+ subject,
+ client,
+ IdentityServerConstants.ProfileDataCallers.ClaimsProviderIdentityToken,
+ additionalClaimTypes);
- var claims = FilterProtocolClaims(context.IssuedClaims);
- if (claims != null)
- {
- outputClaims.AddRange(claims);
- }
+ await Profile.GetProfileDataAsync(context);
+
+ var claims = FilterProtocolClaims(context.IssuedClaims);
+ if (claims != null)
+ {
+ outputClaims.AddRange(claims);
}
}
@@ -155,7 +155,7 @@ public virtual async Task> GetAccessTokenClaimsAsync(ClaimsPr
outputClaims.AddRange(GetOptionalClaims(subject));
// fetch all resource claims that need to go into the access token
- var additionalClaims = new List();
+ var additionalClaimTypes = new List();
foreach (var api in resources.ApiResources)
{
// add claims configured on api resource
@@ -163,7 +163,7 @@ public virtual async Task> GetAccessTokenClaimsAsync(ClaimsPr
{
foreach (var claim in api.UserClaims)
{
- additionalClaims.Add(claim);
+ additionalClaimTypes.Add(claim);
}
}
@@ -174,27 +174,27 @@ public virtual async Task> GetAccessTokenClaimsAsync(ClaimsPr
{
foreach (var claim in scope.UserClaims)
{
- additionalClaims.Add(claim);
+ additionalClaimTypes.Add(claim);
}
}
}
}
- if (additionalClaims.Count > 0)
- {
- var context = new ProfileDataRequestContext(
- subject,
- client,
- IdentityServerConstants.ProfileDataCallers.ClaimsProviderAccessToken,
- additionalClaims.Distinct());
+ // filter so we don't ask for claim types that we will eventually filter out
+ additionalClaimTypes = FilterRequestedClaimTypes(additionalClaimTypes).ToList();
- await Profile.GetProfileDataAsync(context);
+ var context = new ProfileDataRequestContext(
+ subject,
+ client,
+ IdentityServerConstants.ProfileDataCallers.ClaimsProviderAccessToken,
+ additionalClaimTypes.Distinct());
- var claims = FilterProtocolClaims(context.IssuedClaims);
- if (claims != null)
- {
- outputClaims.AddRange(claims);
- }
+ await Profile.GetProfileDataAsync(context);
+
+ var claims = FilterProtocolClaims(context.IssuedClaims);
+ if (claims != null)
+ {
+ outputClaims.AddRange(claims);
}
}
@@ -246,9 +246,19 @@ protected virtual IEnumerable FilterProtocolClaims(IEnumerable cla
if (claimsToFilter.Any())
{
var types = claimsToFilter.Select(x => x.Type);
- _logger.LogInformation("Claim types from profile service that were filtered: {claimTypes}", types);
+ _logger.LogDebug("Claim types from profile service that were filtered: {claimTypes}", types);
}
return claims.Except(claimsToFilter);
}
+
+ ///
+ /// Filters out protocol claims like amr, nonce etc..
+ ///
+ /// The claim types.
+ protected virtual IEnumerable FilterRequestedClaimTypes(IEnumerable claimTypes)
+ {
+ var claimTypesToFilter = claimTypes.Where(x => Constants.Filters.ClaimsServiceFilterClaimTypes.Contains(x));
+ return claimTypes.Except(claimTypesToFilter);
+ }
}
}
\ No newline at end of file
diff --git a/src/IdentityServer4/Services/DefaultProfileService.cs b/src/IdentityServer4/Services/DefaultProfileService.cs
index b946b000e4..bb3b5b5d56 100644
--- a/src/IdentityServer4/Services/DefaultProfileService.cs
+++ b/src/IdentityServer4/Services/DefaultProfileService.cs
@@ -4,6 +4,9 @@
using System.Threading.Tasks;
using IdentityServer4.Models;
+using IdentityServer4.Extensions;
+using Microsoft.Extensions.Logging;
+using System.Linq;
namespace IdentityServer4.Services
{
@@ -13,6 +16,13 @@ namespace IdentityServer4.Services
///
public class DefaultProfileService : IProfileService
{
+ private readonly ILogger _logger;
+
+ public DefaultProfileService(ILogger logger)
+ {
+ _logger = logger;
+ }
+
///
/// This method is called whenever claims about the user are requested (e.g. during token creation or via the userinfo endpoint)
///
@@ -20,7 +30,17 @@ public class DefaultProfileService : IProfileService
///
public Task GetProfileDataAsync(ProfileDataRequestContext context)
{
- context.AddFilteredClaims(context.Subject.Claims);
+ _logger.LogDebug("Get profile called for {subject} from {client} with {claimTypes} because {caller}",
+ context.Subject.GetSubjectId(),
+ context.Client.ClientName,
+ context.RequestedClaimTypes,
+ context.Caller);
+
+ if (context.RequestedClaimTypes.Any())
+ {
+ context.AddFilteredClaims(context.Subject.Claims);
+ }
+
return Task.FromResult(0);
}
diff --git a/src/IdentityServer4/Test/TestUserProfileService.cs b/src/IdentityServer4/Test/TestUserProfileService.cs
index bcb190eb63..d090f1d558 100644
--- a/src/IdentityServer4/Test/TestUserProfileService.cs
+++ b/src/IdentityServer4/Test/TestUserProfileService.cs
@@ -5,24 +5,36 @@
using IdentityServer4.Extensions;
using IdentityServer4.Models;
using IdentityServer4.Services;
+using Microsoft.Extensions.Logging;
+using System.Linq;
using System.Threading.Tasks;
namespace IdentityServer4.Test
{
public class TestUserProfileService : IProfileService
{
+ private readonly ILogger _logger;
private readonly TestUserStore _users;
- public TestUserProfileService(TestUserStore users)
+ public TestUserProfileService(TestUserStore users, ILogger logger)
{
_users = users;
+ _logger = logger;
}
public Task GetProfileDataAsync(ProfileDataRequestContext context)
{
- var user = _users.FindBySubjectId(context.Subject.GetSubjectId());
-
- context.AddFilteredClaims(user.Claims);
+ _logger.LogDebug("Get profile called for {subject} from {client} with {claimTypes} because {caller}",
+ context.Subject.GetSubjectId(),
+ context.Client.ClientName,
+ context.RequestedClaimTypes,
+ context.Caller);
+
+ if (context.RequestedClaimTypes.Any())
+ {
+ var user = _users.FindBySubjectId(context.Subject.GetSubjectId());
+ context.AddFilteredClaims(user.Claims);
+ }
return Task.FromResult(0);
}