001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018
019package org.apache.hadoop.security;
020
021import java.io.ByteArrayInputStream;
022import java.io.DataInput;
023import java.io.DataInputStream;
024import java.io.DataOutput;
025import java.io.IOException;
026import java.security.PrivilegedExceptionAction;
027import java.security.Security;
028import java.util.ArrayList;
029import java.util.Enumeration;
030import java.util.HashMap;
031import java.util.List;
032import java.util.Map;
033
034import javax.security.auth.callback.Callback;
035import javax.security.auth.callback.CallbackHandler;
036import javax.security.auth.callback.NameCallback;
037import javax.security.auth.callback.PasswordCallback;
038import javax.security.auth.callback.UnsupportedCallbackException;
039import javax.security.sasl.AuthorizeCallback;
040import javax.security.sasl.RealmCallback;
041import javax.security.sasl.Sasl;
042import javax.security.sasl.SaslException;
043import javax.security.sasl.SaslServer;
044import javax.security.sasl.SaslServerFactory;
045
046import org.apache.commons.codec.binary.Base64;
047import org.apache.commons.logging.Log;
048import org.apache.commons.logging.LogFactory;
049import org.apache.hadoop.classification.InterfaceAudience;
050import org.apache.hadoop.classification.InterfaceStability;
051import org.apache.hadoop.conf.Configuration;
052import org.apache.hadoop.ipc.RetriableException;
053import org.apache.hadoop.ipc.Server;
054import org.apache.hadoop.ipc.Server.Connection;
055import org.apache.hadoop.ipc.StandbyException;
056import org.apache.hadoop.security.token.SecretManager;
057import org.apache.hadoop.security.token.SecretManager.InvalidToken;
058import org.apache.hadoop.security.token.TokenIdentifier;
059
060/**
061 * A utility class for dealing with SASL on RPC server
062 */
063@InterfaceAudience.LimitedPrivate({"HDFS", "MapReduce"})
064@InterfaceStability.Evolving
065public class SaslRpcServer {
066  public static final Log LOG = LogFactory.getLog(SaslRpcServer.class);
067  public static final String SASL_DEFAULT_REALM = "default";
068  private static SaslServerFactory saslFactory;
069
070  public static enum QualityOfProtection {
071    AUTHENTICATION("auth"),
072    INTEGRITY("auth-int"),
073    PRIVACY("auth-conf");
074    
075    public final String saslQop;
076    
077    private QualityOfProtection(String saslQop) {
078      this.saslQop = saslQop;
079    }
080    
081    public String getSaslQop() {
082      return saslQop;
083    }
084  }
085
086  @InterfaceAudience.Private
087  @InterfaceStability.Unstable
088  public AuthMethod authMethod;
089  public String mechanism;
090  public String protocol;
091  public String serverId;
092  
093  @InterfaceAudience.Private
094  @InterfaceStability.Unstable
095  public SaslRpcServer(AuthMethod authMethod) throws IOException {
096    this.authMethod = authMethod;
097    mechanism = authMethod.getMechanismName();    
098    switch (authMethod) {
099      case SIMPLE: {
100        return; // no sasl for simple
101      }
102      case TOKEN: {
103        protocol = "";
104        serverId = SaslRpcServer.SASL_DEFAULT_REALM;
105        break;
106      }
107      case KERBEROS: {
108        String fullName = UserGroupInformation.getCurrentUser().getUserName();
109        if (LOG.isDebugEnabled())
110          LOG.debug("Kerberos principal name is " + fullName);
111        // don't use KerberosName because we don't want auth_to_local
112        String[] parts = fullName.split("[/@]", 3);
113        protocol = parts[0];
114        // should verify service host is present here rather than in create()
115        // but lazy tests are using a UGI that isn't a SPN...
116        serverId = (parts.length < 2) ? "" : parts[1];
117        break;
118      }
119      default:
120        // we should never be able to get here
121        throw new AccessControlException(
122            "Server does not support SASL " + authMethod);
123    }
124  }
125  
126  @InterfaceAudience.Private
127  @InterfaceStability.Unstable
128  public SaslServer create(final Connection connection,
129                           final Map<String,?> saslProperties,
130                           SecretManager<TokenIdentifier> secretManager
131      ) throws IOException, InterruptedException {
132    UserGroupInformation ugi = null;
133    final CallbackHandler callback;
134    switch (authMethod) {
135      case TOKEN: {
136        callback = new SaslDigestCallbackHandler(secretManager, connection);
137        break;
138      }
139      case KERBEROS: {
140        ugi = UserGroupInformation.getCurrentUser();
141        if (serverId.isEmpty()) {
142          throw new AccessControlException(
143              "Kerberos principal name does NOT have the expected "
144                  + "hostname part: " + ugi.getUserName());
145        }
146        callback = new SaslGssCallbackHandler();
147        break;
148      }
149      default:
150        // we should never be able to get here
151        throw new AccessControlException(
152            "Server does not support SASL " + authMethod);
153    }
154    
155    final SaslServer saslServer;
156    if (ugi != null) {
157      saslServer = ugi.doAs(
158        new PrivilegedExceptionAction<SaslServer>() {
159          @Override
160          public SaslServer run() throws SaslException  {
161            return saslFactory.createSaslServer(mechanism, protocol, serverId,
162                saslProperties, callback);
163          }
164        });
165    } else {
166      saslServer = saslFactory.createSaslServer(mechanism, protocol, serverId,
167          saslProperties, callback);
168    }
169    if (saslServer == null) {
170      throw new AccessControlException(
171          "Unable to find SASL server implementation for " + mechanism);
172    }
173    if (LOG.isDebugEnabled()) {
174      LOG.debug("Created SASL server with mechanism = " + mechanism);
175    }
176    return saslServer;
177  }
178
179  public static void init(Configuration conf) {
180    Security.addProvider(new SaslPlainServer.SecurityProvider());
181    // passing null so factory is populated with all possibilities.  the
182    // properties passed when instantiating a server are what really matter
183    saslFactory = new FastSaslServerFactory(null);
184  }
185  
186  static String encodeIdentifier(byte[] identifier) {
187    return new String(Base64.encodeBase64(identifier));
188  }
189
190  static byte[] decodeIdentifier(String identifier) {
191    return Base64.decodeBase64(identifier.getBytes());
192  }
193
194  public static <T extends TokenIdentifier> T getIdentifier(String id,
195      SecretManager<T> secretManager) throws InvalidToken {
196    byte[] tokenId = decodeIdentifier(id);
197    T tokenIdentifier = secretManager.createIdentifier();
198    try {
199      tokenIdentifier.readFields(new DataInputStream(new ByteArrayInputStream(
200          tokenId)));
201    } catch (IOException e) {
202      throw (InvalidToken) new InvalidToken(
203          "Can't de-serialize tokenIdentifier").initCause(e);
204    }
205    return tokenIdentifier;
206  }
207
208  static char[] encodePassword(byte[] password) {
209    return new String(Base64.encodeBase64(password)).toCharArray();
210  }
211
212  /** Splitting fully qualified Kerberos name into parts */
213  public static String[] splitKerberosName(String fullName) {
214    return fullName.split("[/@]");
215  }
216
217  /** Authentication method */
218  @InterfaceStability.Evolving
219  public static enum AuthMethod {
220    SIMPLE((byte) 80, ""),
221    KERBEROS((byte) 81, "GSSAPI"),
222    @Deprecated
223    DIGEST((byte) 82, "DIGEST-MD5"),
224    TOKEN((byte) 82, "DIGEST-MD5"),
225    PLAIN((byte) 83, "PLAIN");
226
227    /** The code for this method. */
228    public final byte code;
229    public final String mechanismName;
230
231    private AuthMethod(byte code, String mechanismName) { 
232      this.code = code;
233      this.mechanismName = mechanismName;
234    }
235
236    private static final int FIRST_CODE = values()[0].code;
237
238    /** Return the object represented by the code. */
239    private static AuthMethod valueOf(byte code) {
240      final int i = (code & 0xff) - FIRST_CODE;
241      return i < 0 || i >= values().length ? null : values()[i];
242    }
243
244    /** Return the SASL mechanism name */
245    public String getMechanismName() {
246      return mechanismName;
247    }
248
249    /** Read from in */
250    public static AuthMethod read(DataInput in) throws IOException {
251      return valueOf(in.readByte());
252    }
253
254    /** Write to out */
255    public void write(DataOutput out) throws IOException {
256      out.write(code);
257    }
258  };
259
260  /** CallbackHandler for SASL DIGEST-MD5 mechanism */
261  @InterfaceStability.Evolving
262  public static class SaslDigestCallbackHandler implements CallbackHandler {
263    private SecretManager<TokenIdentifier> secretManager;
264    private Server.Connection connection; 
265    
266    public SaslDigestCallbackHandler(
267        SecretManager<TokenIdentifier> secretManager,
268        Server.Connection connection) {
269      this.secretManager = secretManager;
270      this.connection = connection;
271    }
272
273    private char[] getPassword(TokenIdentifier tokenid) throws InvalidToken,
274        StandbyException, RetriableException, IOException {
275      return encodePassword(secretManager.retriableRetrievePassword(tokenid));
276    }
277
278    @Override
279    public void handle(Callback[] callbacks) throws InvalidToken,
280        UnsupportedCallbackException, StandbyException, RetriableException,
281        IOException {
282      NameCallback nc = null;
283      PasswordCallback pc = null;
284      AuthorizeCallback ac = null;
285      for (Callback callback : callbacks) {
286        if (callback instanceof AuthorizeCallback) {
287          ac = (AuthorizeCallback) callback;
288        } else if (callback instanceof NameCallback) {
289          nc = (NameCallback) callback;
290        } else if (callback instanceof PasswordCallback) {
291          pc = (PasswordCallback) callback;
292        } else if (callback instanceof RealmCallback) {
293          continue; // realm is ignored
294        } else {
295          throw new UnsupportedCallbackException(callback,
296              "Unrecognized SASL DIGEST-MD5 Callback");
297        }
298      }
299      if (pc != null) {
300        TokenIdentifier tokenIdentifier = getIdentifier(nc.getDefaultName(),
301            secretManager);
302        char[] password = getPassword(tokenIdentifier);
303        UserGroupInformation user = null;
304        user = tokenIdentifier.getUser(); // may throw exception
305        connection.attemptingUser = user;
306        
307        if (LOG.isDebugEnabled()) {
308          LOG.debug("SASL server DIGEST-MD5 callback: setting password "
309              + "for client: " + tokenIdentifier.getUser());
310        }
311        pc.setPassword(password);
312      }
313      if (ac != null) {
314        String authid = ac.getAuthenticationID();
315        String authzid = ac.getAuthorizationID();
316        if (authid.equals(authzid)) {
317          ac.setAuthorized(true);
318        } else {
319          ac.setAuthorized(false);
320        }
321        if (ac.isAuthorized()) {
322          if (LOG.isDebugEnabled()) {
323            String username =
324              getIdentifier(authzid, secretManager).getUser().getUserName();
325            LOG.debug("SASL server DIGEST-MD5 callback: setting "
326                + "canonicalized client ID: " + username);
327          }
328          ac.setAuthorizedID(authzid);
329        }
330      }
331    }
332  }
333
334  /** CallbackHandler for SASL GSSAPI Kerberos mechanism */
335  @InterfaceStability.Evolving
336  public static class SaslGssCallbackHandler implements CallbackHandler {
337
338    @Override
339    public void handle(Callback[] callbacks) throws
340        UnsupportedCallbackException {
341      AuthorizeCallback ac = null;
342      for (Callback callback : callbacks) {
343        if (callback instanceof AuthorizeCallback) {
344          ac = (AuthorizeCallback) callback;
345        } else {
346          throw new UnsupportedCallbackException(callback,
347              "Unrecognized SASL GSSAPI Callback");
348        }
349      }
350      if (ac != null) {
351        String authid = ac.getAuthenticationID();
352        String authzid = ac.getAuthorizationID();
353        if (authid.equals(authzid)) {
354          ac.setAuthorized(true);
355        } else {
356          ac.setAuthorized(false);
357        }
358        if (ac.isAuthorized()) {
359          if (LOG.isDebugEnabled())
360            LOG.debug("SASL server GSSAPI callback: setting "
361                + "canonicalized client ID: " + authzid);
362          ac.setAuthorizedID(authzid);
363        }
364      }
365    }
366  }
367  
368  // Sasl.createSaslServer is 100-200X slower than caching the factories!
369  private static class FastSaslServerFactory implements SaslServerFactory {
370    private final Map<String,List<SaslServerFactory>> factoryCache =
371        new HashMap<String,List<SaslServerFactory>>();
372
373    FastSaslServerFactory(Map<String,?> props) {
374      final Enumeration<SaslServerFactory> factories =
375          Sasl.getSaslServerFactories();
376      while (factories.hasMoreElements()) {
377        SaslServerFactory factory = factories.nextElement();
378        for (String mech : factory.getMechanismNames(props)) {
379          if (!factoryCache.containsKey(mech)) {
380            factoryCache.put(mech, new ArrayList<SaslServerFactory>());
381          }
382          factoryCache.get(mech).add(factory);
383        }
384      }
385    }
386
387    @Override
388    public SaslServer createSaslServer(String mechanism, String protocol,
389        String serverName, Map<String,?> props, CallbackHandler cbh)
390        throws SaslException {
391      SaslServer saslServer = null;
392      List<SaslServerFactory> factories = factoryCache.get(mechanism);
393      if (factories != null) {
394        for (SaslServerFactory factory : factories) {
395          saslServer = factory.createSaslServer(
396              mechanism, protocol, serverName, props, cbh);
397          if (saslServer != null) {
398            break;
399          }
400        }
401      }
402      return saslServer;
403    }
404
405    @Override
406    public String[] getMechanismNames(Map<String, ?> props) {
407      return factoryCache.keySet().toArray(new String[0]);
408    }
409  }
410}