001package com.nimbusds.oauth2.sdk.jose.jwk; 002 003 004import java.security.Key; 005import java.security.PrivateKey; 006import java.util.Collections; 007import java.util.LinkedList; 008import java.util.List; 009 010import com.nimbusds.jose.EncryptionMethod; 011import com.nimbusds.jose.JWEAlgorithm; 012import com.nimbusds.jose.JWEHeader; 013import com.nimbusds.jose.jwk.*; 014import com.nimbusds.jose.proc.JWEKeySelector; 015import com.nimbusds.jose.proc.SecurityContext; 016import com.nimbusds.oauth2.sdk.id.Identifier; 017import net.jcip.annotations.ThreadSafe; 018 019 020/** 021 * Key selector for decrypting JWE objects used in OpenID Connect. 022 * 023 * <p>Can be used to select RSA and EC key candidates for the decryption of: 024 * 025 * <ul> 026 * <li>Encrypted ID tokens 027 * <li>Encrypted JWT-encoded UserInfo responses 028 * <li>Encrypted OpenID request objects 029 * </ul> 030 */ 031@ThreadSafe 032public class JWEDecryptionKeySelector extends AbstractJWKSelectorWithSource implements JWEKeySelector { 033 034 035 /** 036 * The expected JWE algorithm. 037 */ 038 private final JWEAlgorithm jweAlg; 039 040 041 /** 042 * The expected JWE encryption method. 043 */ 044 private final EncryptionMethod jweEnc; 045 046 047 /** 048 * Ensures the specified JWE algorithm is RSA or EC based. 049 * 050 * @param jweAlg The JWE algorithm to check. 051 */ 052 private static void ensureAsymmetricEncryptionAlgorithm(final JWEAlgorithm jweAlg) { 053 054 if (! JWEAlgorithm.Family.RSA.contains(jweAlg) && ! JWEAlgorithm.Family.ECDH_ES.contains(jweAlg)) { 055 throw new IllegalArgumentException("The JWE algorithm must be RSA or EC based"); 056 } 057 } 058 059 060 /** 061 * Creates a new decryption key selector. 062 * 063 * @param id Identifier for the JWE recipient, typically an 064 * OAuth 2.0 server issuer ID, or client ID. Must not 065 * be {@code null}. 066 * @param jweAlg The expected JWE algorithm for the objects to be 067 * decrypted. Must not be {@code null}. 068 * @param jweEnc The expected JWE encryption method for the objects 069 * to be decrypted. Must be RSA or EC based. Must not 070 * be {@code null}. 071 * @param jwkSource The JWK source. Must include the private keys and 072 * must not be {@code null}. 073 */ 074 public JWEDecryptionKeySelector(final Identifier id, 075 final JWEAlgorithm jweAlg, 076 final EncryptionMethod jweEnc, 077 final JWKSource jwkSource) { 078 super(id, jwkSource); 079 if (jweAlg == null) { 080 throw new IllegalArgumentException("The JWE algorithm must not be null"); 081 } 082 ensureAsymmetricEncryptionAlgorithm(jweAlg); 083 this.jweAlg = jweAlg; 084 if (jweEnc == null) { 085 throw new IllegalArgumentException("The JWE encryption method must not be null"); 086 } 087 this.jweEnc = jweEnc; 088 } 089 090 091 /** 092 * Returns the expected JWE algorithm. 093 * 094 * @return The expected JWE algorithm. 095 */ 096 public JWEAlgorithm getExpectedJWEAlgorithm() { 097 return jweAlg; 098 } 099 100 101 /** 102 * The expected JWE encryption method. 103 * 104 * @return The expected JWE encryption method. 105 */ 106 public EncryptionMethod getExpectedJWEEncryptionMethod() { 107 return jweEnc; 108 } 109 110 111 /** 112 * Creates a JWK matcher for the expected JWE algorithms and the 113 * specified JWE header. 114 * 115 * @param jweHeader The JWE header. Must not be {@code null}. 116 * 117 * @return The JWK matcher, {@code null} if none could be created. 118 */ 119 protected JWKMatcher createJWKMatcher(final JWEHeader jweHeader) { 120 121 if (! getExpectedJWEAlgorithm().equals(jweHeader.getAlgorithm())) { 122 return null; 123 } 124 125 if (! getExpectedJWEEncryptionMethod().equals(jweHeader.getEncryptionMethod())) { 126 return null; 127 } 128 129 return new JWKMatcher.Builder() 130 .keyType(KeyType.forAlgorithm(getExpectedJWEAlgorithm())) 131 .keyID(jweHeader.getKeyID()) 132 .keyUses(KeyUse.ENCRYPTION, null) 133 .algorithms(getExpectedJWEAlgorithm(), null) 134 .build(); 135 } 136 137 138 @Override 139 public List<Key> selectJWEKeys(final JWEHeader jweHeader, final SecurityContext context) { 140 141 if (! jweAlg.equals(jweHeader.getAlgorithm()) || ! jweEnc.equals(jweHeader.getEncryptionMethod())) { 142 // Unexpected JWE alg or enc 143 return Collections.emptyList(); 144 } 145 146 JWKMatcher jwkMatcher = createJWKMatcher(jweHeader); 147 List<JWK> jwkMatches = getJWKSource().get(getIdentifier(), new JWKSelector(jwkMatcher)); 148 149 List<Key> sanitizedKeyList = new LinkedList<>(); 150 151 for (Key key: KeyConverter.toJavaKeys(jwkMatches)) { 152 if (key instanceof PrivateKey) { 153 sanitizedKeyList.add(key); 154 } // skip public keys 155 } 156 157 return sanitizedKeyList; 158 } 159}