001/*
002 * oauth2-oidc-sdk
003 *
004 * Copyright 2012-2021, Connect2id Ltd and contributors.
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use
007 * this file except in compliance with the License. You may obtain a copy of the
008 * License at
009 *
010 *    http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software distributed
013 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
014 * CONDITIONS OF ANY KIND, either express or implied. See the License for the
015 * specific language governing permissions and limitations under the License.
016 */
017
018package com.nimbusds.openid.connect.sdk.federation.utils;
019
020
021import java.security.PublicKey;
022import java.util.List;
023
024import com.nimbusds.jose.*;
025import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
026import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory;
027import com.nimbusds.jose.jwk.*;
028import com.nimbusds.jose.proc.BadJOSEException;
029import com.nimbusds.jose.proc.JWSVerifierFactory;
030import com.nimbusds.jose.util.Base64URL;
031import com.nimbusds.jwt.JWTClaimsSet;
032import com.nimbusds.jwt.SignedJWT;
033import com.nimbusds.jwt.proc.JWTClaimsSetVerifier;
034import com.nimbusds.oauth2.sdk.ParseException;
035
036
037/**
038 * Federation JWT utilities.
039 */
040public class JWTUtils {
041        
042        
043        /**
044         * Resolves the signing JWS algorithm for the specified JWK.
045         *
046         * @param jwk The JWK. Must not be {@code null}.
047         *
048         * @return The JWS algorithm.
049         *
050         * @throws JOSEException If the resolution failed.
051         */
052        public static JWSAlgorithm resolveSigningAlgorithm(final JWK jwk)
053                throws JOSEException {
054                
055                KeyType jwkType = jwk.getKeyType();
056                
057                if (KeyType.RSA.equals(jwkType)) {
058                        if (jwk.getAlgorithm() != null) {
059                                return new JWSAlgorithm(jwk.getAlgorithm().getName());
060                        } else {
061                                return JWSAlgorithm.RS256; // assume RS256 as default
062                        }
063                } else if (KeyType.EC.equals(jwkType)) {
064                        ECKey ecJWK = jwk.toECKey();
065                        if (jwk.getAlgorithm() != null) {
066                                return new JWSAlgorithm(ecJWK.getAlgorithm().getName());
067                        } else {
068                                if (Curve.P_256.equals(ecJWK.getCurve())) {
069                                        return JWSAlgorithm.ES256;
070                                } else if (Curve.P_384.equals(ecJWK.getCurve())) {
071                                        return JWSAlgorithm.ES384;
072                                } else if (Curve.P_521.equals(ecJWK.getCurve())) {
073                                        return JWSAlgorithm.ES512;
074                                } else if (Curve.SECP256K1.equals(ecJWK.getCurve())) {
075                                        return JWSAlgorithm.ES256K;
076                                } else {
077                                        throw new JOSEException("Unsupported ECDSA curve: " + ecJWK.getCurve());
078                                }
079                        }
080                } else if (KeyType.OKP.equals(jwkType)){
081                        OctetKeyPair okp = jwk.toOctetKeyPair();
082                        if (Curve.Ed25519.equals(okp.getCurve())) {
083                                return JWSAlgorithm.EdDSA;
084                        } else {
085                                throw new JOSEException("Unsupported EdDSA curve: " + okp.getCurve());
086                        }
087                } else {
088                        throw new JOSEException("Unsupported JWK type: " + jwkType);
089                }
090        }
091        
092        
093        /**
094         * Signs the specified JWT claims set.
095         *
096         * @param signingJWK The signing JWK. Must not be {@code null}.
097         * @param alg        The JWS algorithm. Must not be {@code null}.
098         * @param type       The JOSE object type, {@code null} if not
099         *                   specified,
100         * @param claimsSet  The JWT claims set.
101         *
102         * @return The signed JWT.
103         *
104         * @throws JOSEException If signing failed.
105         */
106        public static SignedJWT sign(final JWK signingJWK,
107                                     final JWSAlgorithm alg,
108                                     final JOSEObjectType type,
109                                     final JWTClaimsSet claimsSet)
110                throws JOSEException{
111                
112                JWSSigner jwsSigner = new DefaultJWSSignerFactory().createJWSSigner(signingJWK, alg);
113                
114                JWSHeader jwsHeader = new JWSHeader.Builder(alg)
115                        .type(type)
116                        .keyID(signingJWK.getKeyID())
117                        .build();
118                
119                SignedJWT jwt = new SignedJWT(jwsHeader, claimsSet);
120                jwt.sign(jwsSigner);
121                return jwt;
122        }
123        
124        
125        /**
126         * Verifies the signature of the specified JWT.
127         *
128         * @param jwt            The signed JWT. Must not be {@code null}.
129         * @param type           The expected JOSE object type. Must not be
130         *                       {@code null}.
131         * @param claimsVerifier The JWT claims verifier. Must not be
132         *                       {@code null}.
133         * @param jwkSet         The public JWK set. Must not be {@code null}.
134         *
135         * @return The thumbprint of the JWK used to successfully verify the
136         *         signature.
137         *
138         * @throws BadJOSEException If the JWT is invalid.
139         * @throws JOSEException    If the signature verification failed.
140         */
141        public static Base64URL verifySignature(final SignedJWT jwt,
142                                                final JOSEObjectType type,
143                                                final JWTClaimsSetVerifier<?> claimsVerifier,
144                                                final JWKSet jwkSet)
145                throws BadJOSEException, JOSEException {
146                
147                if (! type.equals(jwt.getHeader().getType())) {
148                        throw new BadJOSEException("JWT rejected: Invalid or missing JWT typ (type) header");
149                }
150                
151                // Check claims with JWT framework
152                
153                try {
154                        claimsVerifier.verify(jwt.getJWTClaimsSet(), null);
155                } catch (java.text.ParseException e) {
156                        throw new BadJOSEException(e.getMessage(), e);
157                }
158                
159                List<JWK> jwkMatches = new JWKSelector(JWKMatcher.forJWSHeader(jwt.getHeader())).select(jwkSet);
160                
161                if (jwkMatches.isEmpty()) {
162                        throw new BadJOSEException("JWT rejected: Another JOSE algorithm expected, or no matching key(s) found");
163                }
164                
165                JWSVerifierFactory verifierFactory = new DefaultJWSVerifierFactory();
166                
167                for (JWK candidateJWK: jwkMatches) {
168                        
169                        if (candidateJWK instanceof AsymmetricJWK) {
170                                PublicKey publicKey = ((AsymmetricJWK)candidateJWK).toPublicKey();
171                                JWSVerifier jwsVerifier = verifierFactory.createJWSVerifier(jwt.getHeader(), publicKey);
172                                if (jwt.verify(jwsVerifier)) {
173                                        // success
174                                        return candidateJWK.computeThumbprint();
175                                }
176                        }
177                }
178                
179                throw new BadJOSEException("JWT rejected: Invalid signature");
180        }
181        
182        
183        /**
184         * Parses the claims of the specified signed JWT.
185         *
186         * @param jwt The signed JWT. Must not be {@code null}.
187         *
188         * @return The JWT claims set.
189         *
190         * @throws ParseException If parsing failed.
191         */
192        public static JWTClaimsSet parseSignedJWTClaimsSet(final SignedJWT jwt)
193                throws ParseException {
194                
195                if (JWSObject.State.UNSIGNED.equals(jwt.getState())) {
196                        throw new ParseException("The JWT is not signed");
197                }
198                
199                try {
200                        return jwt.getJWTClaimsSet();
201                } catch (java.text.ParseException e) {
202                        throw new ParseException(e.getMessage(), e);
203                }
204        }
205        
206        
207        private JWTUtils() {}
208}