Skip to content

Commit caea5e4

Browse files
committed
Ensure ID Token is updated after refresh token (Reactive)
Closes gh-17188 Signed-off-by: Evgeniy Cheban <[email protected]>
1 parent ffd6e3c commit caea5e4

File tree

6 files changed

+782
-9
lines changed

6 files changed

+782
-9
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
/*
2+
* Copyright 2002-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.oauth2.client;
18+
19+
import java.time.Duration;
20+
import java.util.Collection;
21+
import java.util.HashSet;
22+
import java.util.List;
23+
import java.util.Map;
24+
import java.util.Set;
25+
26+
import reactor.core.publisher.Mono;
27+
28+
import org.springframework.security.core.Authentication;
29+
import org.springframework.security.core.GrantedAuthority;
30+
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
31+
import org.springframework.security.core.context.SecurityContext;
32+
import org.springframework.security.core.context.SecurityContextImpl;
33+
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
34+
import org.springframework.security.oauth2.client.oidc.authentication.ReactiveOidcIdTokenDecoderFactory;
35+
import org.springframework.security.oauth2.client.oidc.userinfo.OidcReactiveOAuth2UserService;
36+
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
37+
import org.springframework.security.oauth2.client.registration.ClientRegistration;
38+
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
39+
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
40+
import org.springframework.security.oauth2.core.OAuth2Error;
41+
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
42+
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
43+
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
44+
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
45+
import org.springframework.security.oauth2.jwt.JwtException;
46+
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
47+
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
48+
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
49+
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
50+
import org.springframework.util.Assert;
51+
import org.springframework.util.StringUtils;
52+
import org.springframework.web.server.ServerWebExchange;
53+
54+
/**
55+
* A {@link ReactiveOAuth2AuthorizationSuccessHandler} that refreshes an {@link OidcUser}
56+
* in the {@link SecurityContext} if the refreshed {@link OidcIdToken} is valid according
57+
* to <a href=
58+
* "https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse">OpenID
59+
* Connect Core 1.0 - Section 12.2 Successful Refresh Response</a>
60+
*
61+
* @author Evgeniy Cheban
62+
* @since 7.0
63+
*/
64+
public final class RefreshTokenReactiveOAuth2AuthorizationSuccessHandler
65+
implements ReactiveOAuth2AuthorizationSuccessHandler {
66+
67+
private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token";
68+
69+
private static final String INVALID_NONCE_ERROR_CODE = "invalid_nonce";
70+
71+
private static final String REFRESH_TOKEN_RESPONSE_ERROR_URI = "https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse";
72+
73+
// @formatter:off
74+
private static final Mono<ServerWebExchange> currentServerWebExchangeMono = Mono.deferContextual(Mono::just)
75+
.filter((c) -> c.hasKey(ServerWebExchange.class))
76+
.map((c) -> c.get(ServerWebExchange.class));
77+
// @formatter:on
78+
79+
private ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
80+
81+
private ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory = new ReactiveOidcIdTokenDecoderFactory();
82+
83+
private ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = new OidcReactiveOAuth2UserService();
84+
85+
private GrantedAuthoritiesMapper authoritiesMapper = (authorities) -> authorities;
86+
87+
private Duration clockSkew = Duration.ofSeconds(60);
88+
89+
@Override
90+
public Mono<Void> onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient, Authentication principal,
91+
Map<String, Object> attributes) {
92+
if (!(principal instanceof OAuth2AuthenticationToken authenticationToken)
93+
|| authenticationToken.getClass() != OAuth2AuthenticationToken.class) {
94+
// If the application customizes the authentication result, then a custom
95+
// handler should be provided.
96+
return Mono.empty();
97+
}
98+
// The current principal must be an OidcUser.
99+
if (!(authenticationToken.getPrincipal() instanceof OidcUser existingOidcUser)) {
100+
return Mono.empty();
101+
}
102+
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
103+
// The registrationId must match the one used to log in.
104+
if (!authenticationToken.getAuthorizedClientRegistrationId().equals(clientRegistration.getRegistrationId())) {
105+
return Mono.empty();
106+
}
107+
// Create, validate OidcIdToken and refresh OidcUser in the SecurityContext.
108+
return Mono.zip(serverWebExchange(attributes), accessTokenResponse(attributes)).flatMap((t2) -> {
109+
ReactiveJwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);
110+
Map<String, Object> additionalParameters = t2.getT2().getAdditionalParameters();
111+
return jwtDecoder.decode((String) additionalParameters.get(OidcParameterNames.ID_TOKEN))
112+
.onErrorMap(JwtException.class, (ex) -> {
113+
OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(),
114+
null);
115+
return new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex);
116+
})
117+
.map((jwt) -> new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(),
118+
jwt.getClaims()))
119+
.doOnNext((idToken) -> validateIdToken(existingOidcUser, idToken))
120+
.flatMap((idToken) -> {
121+
OidcUserRequest userRequest = new OidcUserRequest(clientRegistration,
122+
authorizedClient.getAccessToken(), idToken);
123+
return this.userService.loadUser(userRequest);
124+
})
125+
.flatMap((oidcUser) -> refreshSecurityContext(t2.getT1(), clientRegistration, authenticationToken,
126+
oidcUser));
127+
});
128+
}
129+
130+
private Mono<ServerWebExchange> serverWebExchange(Map<String, Object> attributes) {
131+
if (attributes.get(ServerWebExchange.class.getName()) instanceof ServerWebExchange exchange) {
132+
return Mono.just(exchange);
133+
}
134+
return currentServerWebExchangeMono;
135+
}
136+
137+
private Mono<OAuth2AccessTokenResponse> accessTokenResponse(Map<String, Object> attributes) {
138+
if (attributes.get(OAuth2AccessTokenResponse.class.getName()) instanceof OAuth2AccessTokenResponse response) {
139+
return Mono.just(response);
140+
}
141+
return Mono.empty();
142+
}
143+
144+
private void validateIdToken(OidcUser existingOidcUser, OidcIdToken idToken) {
145+
// OpenID Connect Core 1.0 - Section 12.2 Successful Refresh Response
146+
// If an ID Token is returned as a result of a token refresh request, the
147+
// following requirements apply:
148+
// its iss Claim Value MUST be the same as in the ID Token issued when the
149+
// original authentication occurred,
150+
validateIssuer(existingOidcUser, idToken);
151+
// its sub Claim Value MUST be the same as in the ID Token issued when the
152+
// original authentication occurred,
153+
validateSubject(existingOidcUser, idToken);
154+
// its iat Claim MUST represent the time that the new ID Token is issued,
155+
validateIssuedAt(existingOidcUser, idToken);
156+
// its aud Claim Value MUST be the same as in the ID Token issued when the
157+
// original authentication occurred,
158+
validateAudience(existingOidcUser, idToken);
159+
// if the ID Token contains an auth_time Claim, its value MUST represent the time
160+
// of the original authentication - not the time that the new ID token is issued,
161+
validateAuthenticatedAt(existingOidcUser, idToken);
162+
// it SHOULD NOT have a nonce Claim, even when the ID Token issued at the time of
163+
// the original authentication contained nonce; however, if it is present, its
164+
// value MUST be the same as in the ID Token issued at the time of the original
165+
// authentication,
166+
validateNonce(existingOidcUser, idToken);
167+
}
168+
169+
private void validateIssuer(OidcUser existingOidcUser, OidcIdToken idToken) {
170+
if (!idToken.getIssuer().toString().equals(existingOidcUser.getIdToken().getIssuer().toString())) {
171+
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid issuer",
172+
REFRESH_TOKEN_RESPONSE_ERROR_URI);
173+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
174+
}
175+
}
176+
177+
private void validateSubject(OidcUser existingOidcUser, OidcIdToken idToken) {
178+
if (!idToken.getSubject().equals(existingOidcUser.getIdToken().getSubject())) {
179+
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid subject",
180+
REFRESH_TOKEN_RESPONSE_ERROR_URI);
181+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
182+
}
183+
}
184+
185+
private void validateIssuedAt(OidcUser existingOidcUser, OidcIdToken idToken) {
186+
if (!idToken.getIssuedAt().isAfter(existingOidcUser.getIdToken().getIssuedAt().minus(this.clockSkew))) {
187+
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid issued at time",
188+
REFRESH_TOKEN_RESPONSE_ERROR_URI);
189+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
190+
}
191+
}
192+
193+
private void validateAudience(OidcUser existingOidcUser, OidcIdToken idToken) {
194+
if (!isValidAudience(existingOidcUser, idToken)) {
195+
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid audience",
196+
REFRESH_TOKEN_RESPONSE_ERROR_URI);
197+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
198+
}
199+
}
200+
201+
private boolean isValidAudience(OidcUser existingOidcUser, OidcIdToken idToken) {
202+
List<String> idTokenAudiences = idToken.getAudience();
203+
Set<String> oidcUserAudiences = new HashSet<>(existingOidcUser.getIdToken().getAudience());
204+
if (idTokenAudiences.size() != oidcUserAudiences.size()) {
205+
return false;
206+
}
207+
for (String audience : idTokenAudiences) {
208+
if (!oidcUserAudiences.contains(audience)) {
209+
return false;
210+
}
211+
}
212+
return true;
213+
}
214+
215+
private void validateAuthenticatedAt(OidcUser existingOidcUser, OidcIdToken idToken) {
216+
if (idToken.getAuthenticatedAt() == null) {
217+
return;
218+
}
219+
if (!idToken.getAuthenticatedAt().equals(existingOidcUser.getIdToken().getAuthenticatedAt())) {
220+
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid authenticated at time",
221+
REFRESH_TOKEN_RESPONSE_ERROR_URI);
222+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
223+
}
224+
}
225+
226+
private void validateNonce(OidcUser existingOidcUser, OidcIdToken idToken) {
227+
if (!StringUtils.hasText(idToken.getNonce())) {
228+
return;
229+
}
230+
if (!idToken.getNonce().equals(existingOidcUser.getIdToken().getNonce())) {
231+
OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE, "Invalid nonce",
232+
REFRESH_TOKEN_RESPONSE_ERROR_URI);
233+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
234+
}
235+
}
236+
237+
private Mono<Void> refreshSecurityContext(ServerWebExchange exchange, ClientRegistration clientRegistration,
238+
OAuth2AuthenticationToken authenticationToken, OidcUser oidcUser) {
239+
Collection<? extends GrantedAuthority> mappedAuthorities = this.authoritiesMapper
240+
.mapAuthorities(oidcUser.getAuthorities());
241+
OAuth2AuthenticationToken authenticationResult = new OAuth2AuthenticationToken(oidcUser, mappedAuthorities,
242+
clientRegistration.getRegistrationId());
243+
authenticationResult.setDetails(authenticationToken.getDetails());
244+
SecurityContextImpl securityContext = new SecurityContextImpl(authenticationResult);
245+
return this.serverSecurityContextRepository.save(exchange, securityContext);
246+
}
247+
248+
/**
249+
* Sets a {@link ServerSecurityContextRepository} to use for refreshing a
250+
* {@link SecurityContext}, defaults to
251+
* {@link WebSessionServerSecurityContextRepository}.
252+
* @param serverSecurityContextRepository the {@link ServerSecurityContextRepository}
253+
* to use
254+
*/
255+
public void setServerSecurityContextRepository(ServerSecurityContextRepository serverSecurityContextRepository) {
256+
Assert.notNull(serverSecurityContextRepository, "serverSecurityContextRepository cannot be null");
257+
this.serverSecurityContextRepository = serverSecurityContextRepository;
258+
}
259+
260+
/**
261+
* Sets a {@link ReactiveJwtDecoderFactory} to use for decoding refreshed oidc
262+
* id-token, defaults to {@link ReactiveOidcIdTokenDecoderFactory}.
263+
* @param jwtDecoderFactory the {@link ReactiveJwtDecoderFactory} to use
264+
*/
265+
public void setJwtDecoderFactory(ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory) {
266+
Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null");
267+
this.jwtDecoderFactory = jwtDecoderFactory;
268+
}
269+
270+
/**
271+
* Sets a {@link GrantedAuthoritiesMapper} to use for mapping
272+
* {@link GrantedAuthority}s, defaults to no-op implementation.
273+
* @param authoritiesMapper the {@link GrantedAuthoritiesMapper} to use
274+
*/
275+
public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
276+
Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null");
277+
this.authoritiesMapper = authoritiesMapper;
278+
}
279+
280+
/**
281+
* Sets a {@link ReactiveOAuth2UserService} to use for loading an {@link OidcUser}
282+
* from refreshed oidc id-token, defaults to {@link OidcReactiveOAuth2UserService}.
283+
* @param userService the {@link ReactiveOAuth2UserService} to use
284+
*/
285+
public void setUserService(ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService) {
286+
Assert.notNull(userService, "userService cannot be null");
287+
this.userService = userService;
288+
}
289+
290+
/**
291+
* Sets the maximum acceptable clock skew, which is used when checking the
292+
* {@link OidcIdToken#getIssuedAt()} to match the existing
293+
* {@link OidcUser#getIdToken()}'s issuedAt time, defaults to 60 seconds.
294+
* @param clockSkew the maximum acceptable clock skew to use
295+
*/
296+
public void setClockSkew(Duration clockSkew) {
297+
Assert.notNull(clockSkew, "clockSkew cannot be null");
298+
Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0");
299+
this.clockSkew = clockSkew;
300+
}
301+
302+
}

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -21,7 +21,9 @@
2121
import java.time.Instant;
2222
import java.util.Arrays;
2323
import java.util.Collections;
24+
import java.util.HashMap;
2425
import java.util.HashSet;
26+
import java.util.Map;
2527
import java.util.Set;
2628

2729
import reactor.core.publisher.Mono;
@@ -33,13 +35,15 @@
3335
import org.springframework.security.oauth2.core.AuthorizationGrantType;
3436
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
3537
import org.springframework.security.oauth2.core.OAuth2Token;
38+
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
3639
import org.springframework.util.Assert;
3740

3841
/**
3942
* An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider} for the
4043
* {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant.
4144
*
4245
* @author Joe Grandja
46+
* @author Evgeniy Cheban
4347
* @since 5.2
4448
* @see ReactiveOAuth2AuthorizedClientProvider
4549
* @see WebClientReactiveRefreshTokenTokenResponseClient
@@ -49,6 +53,8 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider
4953

5054
private ReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient = new WebClientReactiveRefreshTokenTokenResponseClient();
5155

56+
private ReactiveOAuth2AuthorizationSuccessHandler refreshTokenSuccessHandler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler();
57+
5258
private Duration clockSkew = Duration.ofSeconds(60);
5359

5460
private Clock clock = Clock.systemUTC();
@@ -96,8 +102,16 @@ public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizationContext context
96102
.flatMap(this.accessTokenResponseClient::getTokenResponse)
97103
.onErrorMap(OAuth2AuthorizationException.class,
98104
(e) -> new ClientAuthorizationException(e.getError(), clientRegistration.getRegistrationId(), e))
99-
.map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(),
100-
tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()));
105+
.flatMap((tokenResponse) -> {
106+
OAuth2AuthorizedClient refreshedAuthorizedClient = new OAuth2AuthorizedClient(clientRegistration,
107+
context.getPrincipal().getName(), tokenResponse.getAccessToken(),
108+
tokenResponse.getRefreshToken());
109+
Map<String, Object> attributes = new HashMap<>(context.getAttributes());
110+
attributes.put(OAuth2AccessTokenResponse.class.getName(), tokenResponse);
111+
return this.refreshTokenSuccessHandler
112+
.onAuthorizationSuccess(refreshedAuthorizedClient, context.getPrincipal(), attributes)
113+
.then(Mono.just(refreshedAuthorizedClient));
114+
});
101115
}
102116

103117
private boolean hasTokenExpired(OAuth2Token token) {
@@ -116,6 +130,20 @@ public void setAccessTokenResponseClient(
116130
this.accessTokenResponseClient = accessTokenResponseClient;
117131
}
118132

133+
/**
134+
* Sets a {@link ReactiveOAuth2AuthorizationSuccessHandler} that is called after the
135+
* client is re-authorized, defaults to
136+
* {@link RefreshTokenReactiveOAuth2AuthorizationSuccessHandler}.
137+
* @param refreshTokenSuccessHandler the
138+
* {@link ReactiveOAuth2AuthorizationSuccessHandler} to use, defaults to
139+
* {@link RefreshTokenReactiveOAuth2AuthorizationSuccessHandler}
140+
* @since 7.0
141+
*/
142+
public void setRefreshTokenSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler refreshTokenSuccessHandler) {
143+
Assert.notNull(refreshTokenSuccessHandler, "refreshTokenSuccessHandler cannot be null");
144+
this.refreshTokenSuccessHandler = refreshTokenSuccessHandler;
145+
}
146+
119147
/**
120148
* Sets the maximum acceptable clock skew, which is used when checking the
121149
* {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is

0 commit comments

Comments
 (0)