001package com.nimbusds.oauth2.sdk.jose.jwk;
002
003
004import java.security.Key;
005import java.security.PrivateKey;
006import java.util.Collections;
007import java.util.LinkedList;
008import java.util.List;
009
010import com.nimbusds.jose.EncryptionMethod;
011import com.nimbusds.jose.JWEAlgorithm;
012import com.nimbusds.jose.JWEHeader;
013import com.nimbusds.jose.jwk.*;
014import com.nimbusds.jose.proc.JWEKeySelector;
015import com.nimbusds.jose.proc.SecurityContext;
016import com.nimbusds.oauth2.sdk.id.Identifier;
017import net.jcip.annotations.ThreadSafe;
018
019
020/**
021 * Key selector for decrypting JWE objects used in OpenID Connect.
022 *
023 * <p>Can be used to select RSA and EC key candidates for the decryption of:
024 *
025 * <ul>
026 *     <li>Encrypted ID tokens
027 *     <li>Encrypted JWT-encoded UserInfo responses
028 *     <li>Encrypted OpenID request objects
029 * </ul>
030 */
031@ThreadSafe
032public class JWEDecryptionKeySelector extends AbstractJWKSelectorWithSource implements JWEKeySelector {
033
034
035        /**
036         * The expected JWE algorithm.
037         */
038        private final JWEAlgorithm jweAlg;
039
040
041        /**
042         * The expected JWE encryption method.
043         */
044        private final EncryptionMethod jweEnc;
045
046
047        /**
048         * Ensures the specified JWE algorithm is RSA or EC based.
049         *
050         * @param jweAlg The JWE algorithm to check.
051         */
052        private static void ensureAsymmetricEncryptionAlgorithm(final JWEAlgorithm jweAlg) {
053
054                if (! JWEAlgorithm.Family.RSA.contains(jweAlg) && ! JWEAlgorithm.Family.ECDH_ES.contains(jweAlg)) {
055                        throw new IllegalArgumentException("The JWE algorithm must be RSA or EC based");
056                }
057        }
058
059
060        /**
061         * Creates a new decryption key selector.
062         *
063         * @param id        Identifier for the JWE recipient, typically an
064         *                  OAuth 2.0 server issuer ID, or client ID. Must not
065         *                  be {@code null}.
066         * @param jweAlg    The expected JWE algorithm for the objects to be
067         *                  decrypted. Must not be {@code null}.
068         * @param jweEnc    The expected JWE encryption method for the objects
069         *                  to be decrypted. Must be RSA or EC based. Must not
070         *                  be {@code null}.
071         * @param jwkSource The JWK source. Must include the private keys and
072         *                  must not be {@code null}.
073         */
074        public JWEDecryptionKeySelector(final Identifier id,
075                                        final JWEAlgorithm jweAlg,
076                                        final EncryptionMethod jweEnc,
077                                        final JWKSource jwkSource) {
078                super(id, jwkSource);
079                if (jweAlg == null) {
080                        throw new IllegalArgumentException("The JWE algorithm must not be null");
081                }
082                ensureAsymmetricEncryptionAlgorithm(jweAlg);
083                this.jweAlg = jweAlg;
084                if (jweEnc == null) {
085                        throw new IllegalArgumentException("The JWE encryption method must not be null");
086                }
087                this.jweEnc = jweEnc;
088        }
089
090
091        /**
092         * Returns the expected JWE algorithm.
093         *
094         * @return The expected JWE algorithm.
095         */
096        public JWEAlgorithm getExpectedJWEAlgorithm() {
097                return jweAlg;
098        }
099
100
101        /**
102         * The expected JWE encryption method.
103         *
104         * @return The expected JWE encryption method.
105         */
106        public EncryptionMethod getExpectedJWEEncryptionMethod() {
107                return jweEnc;
108        }
109
110
111        /**
112         * Creates a JWK matcher for the expected JWE algorithms and the
113         * specified JWE header.
114         *
115         * @param jweHeader The JWE header. Must not be {@code null}.
116         *
117         * @return The JWK matcher, {@code null} if none could be created.
118         */
119        protected JWKMatcher createJWKMatcher(final JWEHeader jweHeader) {
120
121                if (! getExpectedJWEAlgorithm().equals(jweHeader.getAlgorithm())) {
122                        return null;
123                }
124
125                if (! getExpectedJWEEncryptionMethod().equals(jweHeader.getEncryptionMethod())) {
126                        return null;
127                }
128
129                return new JWKMatcher.Builder()
130                        .keyType(KeyType.forAlgorithm(getExpectedJWEAlgorithm()))
131                        .keyID(jweHeader.getKeyID())
132                        .keyUses(KeyUse.ENCRYPTION, null)
133                        .algorithms(getExpectedJWEAlgorithm(), null)
134                        .build();
135        }
136
137
138        @Override
139        public List<Key> selectJWEKeys(final JWEHeader jweHeader, final SecurityContext context) {
140
141                if (! jweAlg.equals(jweHeader.getAlgorithm()) || ! jweEnc.equals(jweHeader.getEncryptionMethod())) {
142                        // Unexpected JWE alg or enc
143                        return Collections.emptyList();
144                }
145
146                JWKMatcher jwkMatcher = createJWKMatcher(jweHeader);
147                List<JWK> jwkMatches = getJWKSource().get(getIdentifier(), new JWKSelector(jwkMatcher));
148
149                List<Key> sanitizedKeyList = new LinkedList<>();
150
151                for (Key key: KeyConverter.toJavaKeys(jwkMatches)) {
152                        if (key instanceof PrivateKey) {
153                                sanitizedKeyList.add(key);
154                        } // skip public keys
155                }
156
157                return sanitizedKeyList;
158        }
159}