001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.activemq.transport.nio; 019 020import java.io.DataInputStream; 021import java.io.DataOutputStream; 022import java.io.EOFException; 023import java.io.IOException; 024import java.net.Socket; 025import java.net.SocketTimeoutException; 026import java.net.URI; 027import java.net.UnknownHostException; 028import java.nio.ByteBuffer; 029import java.nio.channels.SelectionKey; 030import java.nio.channels.Selector; 031import java.security.cert.X509Certificate; 032 033import javax.net.SocketFactory; 034import javax.net.ssl.SSLContext; 035import javax.net.ssl.SSLEngine; 036import javax.net.ssl.SSLEngineResult; 037import javax.net.ssl.SSLEngineResult.HandshakeStatus; 038import javax.net.ssl.SSLPeerUnverifiedException; 039import javax.net.ssl.SSLSession; 040 041import org.apache.activemq.command.ConnectionInfo; 042import org.apache.activemq.openwire.OpenWireFormat; 043import org.apache.activemq.thread.TaskRunnerFactory; 044import org.apache.activemq.util.IOExceptionSupport; 045import org.apache.activemq.util.ServiceStopper; 046import org.apache.activemq.wireformat.WireFormat; 047import org.slf4j.Logger; 048import org.slf4j.LoggerFactory; 049 050public class NIOSSLTransport extends NIOTransport { 051 052 private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class); 053 054 protected boolean needClientAuth; 055 protected boolean wantClientAuth; 056 protected String[] enabledCipherSuites; 057 protected String[] enabledProtocols; 058 059 protected SSLContext sslContext; 060 protected SSLEngine sslEngine; 061 protected SSLSession sslSession; 062 063 protected volatile boolean handshakeInProgress = false; 064 protected SSLEngineResult.Status status = null; 065 protected SSLEngineResult.HandshakeStatus handshakeStatus = null; 066 protected TaskRunnerFactory taskRunnerFactory; 067 068 public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException { 069 super(wireFormat, socketFactory, remoteLocation, localLocation); 070 } 071 072 public NIOSSLTransport(WireFormat wireFormat, Socket socket, SSLEngine engine, InitBuffer initBuffer, 073 ByteBuffer inputBuffer) throws IOException { 074 super(wireFormat, socket, initBuffer); 075 this.sslEngine = engine; 076 if (engine != null) { 077 this.sslSession = engine.getSession(); 078 } 079 this.inputBuffer = inputBuffer; 080 } 081 082 public void setSslContext(SSLContext sslContext) { 083 this.sslContext = sslContext; 084 } 085 086 volatile boolean hasSslEngine = false; 087 088 @Override 089 protected void initializeStreams() throws IOException { 090 if (sslEngine != null) { 091 hasSslEngine = true; 092 } 093 NIOOutputStream outputStream = null; 094 try { 095 channel = socket.getChannel(); 096 channel.configureBlocking(false); 097 098 if (sslContext == null) { 099 sslContext = SSLContext.getDefault(); 100 } 101 102 String remoteHost = null; 103 int remotePort = -1; 104 105 try { 106 URI remoteAddress = new URI(this.getRemoteAddress()); 107 remoteHost = remoteAddress.getHost(); 108 remotePort = remoteAddress.getPort(); 109 } catch (Exception e) { 110 } 111 112 // initialize engine, the initial sslSession we get will need to be 113 // updated once the ssl handshake process is completed. 114 if (!hasSslEngine) { 115 if (remoteHost != null && remotePort != -1) { 116 sslEngine = sslContext.createSSLEngine(remoteHost, remotePort); 117 } else { 118 sslEngine = sslContext.createSSLEngine(); 119 } 120 121 sslEngine.setUseClientMode(false); 122 if (enabledCipherSuites != null) { 123 sslEngine.setEnabledCipherSuites(enabledCipherSuites); 124 } 125 126 if (enabledProtocols != null) { 127 sslEngine.setEnabledProtocols(enabledProtocols); 128 } 129 130 if (wantClientAuth) { 131 sslEngine.setWantClientAuth(wantClientAuth); 132 } 133 134 if (needClientAuth) { 135 sslEngine.setNeedClientAuth(needClientAuth); 136 } 137 138 sslSession = sslEngine.getSession(); 139 140 inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize()); 141 inputBuffer.clear(); 142 } 143 144 outputStream = new NIOOutputStream(channel); 145 outputStream.setEngine(sslEngine); 146 this.dataOut = new DataOutputStream(outputStream); 147 this.buffOut = outputStream; 148 149 //If the sslEngine was not passed in, then handshake 150 if (!hasSslEngine) { 151 sslEngine.beginHandshake(); 152 } 153 handshakeStatus = sslEngine.getHandshakeStatus(); 154 if (!hasSslEngine) { 155 doHandshake(); 156 } 157 158 // if (hasSslEngine) { 159 selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() { 160 @Override 161 public void onSelect(SelectorSelection selection) { 162 serviceRead(); 163 } 164 165 @Override 166 public void onError(SelectorSelection selection, Throwable error) { 167 if (error instanceof IOException) { 168 onException((IOException) error); 169 } else { 170 onException(IOExceptionSupport.create(error)); 171 } 172 } 173 }); 174 doInit(); 175 176 } catch (Exception e) { 177 try { 178 if(outputStream != null) { 179 outputStream.close(); 180 } 181 super.closeStreams(); 182 } catch (Exception ex) {} 183 throw new IOException(e); 184 } 185 } 186 187 protected void doInit() throws Exception { 188 189 } 190 191 protected void doOpenWireInit() throws Exception { 192 //Do this later to let wire format negotiation happen 193 if (initBuffer != null && this.wireFormat instanceof OpenWireFormat) { 194 initBuffer.buffer.flip(); 195 if (initBuffer.buffer.hasRemaining()) { 196 nextFrameSize = -1; 197 receiveCounter += initBuffer.readSize; 198 processCommand(initBuffer.buffer); 199 processCommand(initBuffer.buffer); 200 initBuffer.buffer.clear(); 201 } 202 } 203 } 204 205 protected void finishHandshake() throws Exception { 206 if (handshakeInProgress) { 207 handshakeInProgress = false; 208 nextFrameSize = -1; 209 210 // Once handshake completes we need to ask for the now real sslSession 211 // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the 212 // cipher suite. 213 sslSession = sslEngine.getSession(); 214 215 // listen for events telling us when the socket is readable. 216 selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() { 217 @Override 218 public void onSelect(SelectorSelection selection) { 219 serviceRead(); 220 } 221 222 @Override 223 public void onError(SelectorSelection selection, Throwable error) { 224 if (error instanceof IOException) { 225 onException((IOException) error); 226 } else { 227 onException(IOExceptionSupport.create(error)); 228 } 229 } 230 }); 231 } 232 } 233 234 @Override 235 public void serviceRead() { 236 try { 237 if (handshakeInProgress) { 238 doHandshake(); 239 } 240 241 doOpenWireInit(); 242 243 ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize()); 244 plain.position(plain.limit()); 245 246 while (true) { 247 if (!plain.hasRemaining()) { 248 249 int readCount = secureRead(plain); 250 251 if (readCount == 0) { 252 break; 253 } 254 255 // channel is closed, cleanup 256 if (readCount == -1) { 257 onException(new EOFException()); 258 selection.close(); 259 break; 260 } 261 262 receiveCounter += readCount; 263 } 264 265 if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { 266 processCommand(plain); 267 } 268 } 269 } catch (IOException e) { 270 onException(e); 271 } catch (Throwable e) { 272 onException(IOExceptionSupport.create(e)); 273 } 274 } 275 276 protected void processCommand(ByteBuffer plain) throws Exception { 277 278 // Are we waiting for the next Command or are we building on the current one 279 if (nextFrameSize == -1) { 280 281 // We can get small packets that don't give us enough for the frame size 282 // so allocate enough for the initial size value and 283 if (plain.remaining() < Integer.SIZE) { 284 if (currentBuffer == null) { 285 currentBuffer = ByteBuffer.allocate(4); 286 } 287 288 // Go until we fill the integer sized current buffer. 289 while (currentBuffer.hasRemaining() && plain.hasRemaining()) { 290 currentBuffer.put(plain.get()); 291 } 292 293 // Didn't we get enough yet to figure out next frame size. 294 if (currentBuffer.hasRemaining()) { 295 return; 296 } else { 297 currentBuffer.flip(); 298 nextFrameSize = currentBuffer.getInt(); 299 } 300 301 } else { 302 303 // Either we are completing a previous read of the next frame size or its 304 // fully contained in plain already. 305 if (currentBuffer != null) { 306 307 // Finish the frame size integer read and get from the current buffer. 308 while (currentBuffer.hasRemaining()) { 309 currentBuffer.put(plain.get()); 310 } 311 312 currentBuffer.flip(); 313 nextFrameSize = currentBuffer.getInt(); 314 315 } else { 316 nextFrameSize = plain.getInt(); 317 } 318 } 319 320 if (wireFormat instanceof OpenWireFormat) { 321 long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize(); 322 if (nextFrameSize > maxFrameSize) { 323 throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) + 324 " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB"); 325 } 326 } 327 328 // now we got the data, lets reallocate and store the size for the marshaler. 329 // if there's more data in plain, then the next call will start processing it. 330 currentBuffer = ByteBuffer.allocate(nextFrameSize + 4); 331 currentBuffer.putInt(nextFrameSize); 332 333 } else { 334 // If its all in one read then we can just take it all, otherwise take only 335 // the current frame size and the next iteration starts a new command. 336 if (currentBuffer != null) { 337 if (currentBuffer.remaining() >= plain.remaining()) { 338 currentBuffer.put(plain); 339 } else { 340 byte[] fill = new byte[currentBuffer.remaining()]; 341 plain.get(fill); 342 currentBuffer.put(fill); 343 } 344 345 // Either we have enough data for a new command or we have to wait for some more. 346 if (currentBuffer.hasRemaining()) { 347 return; 348 } else { 349 currentBuffer.flip(); 350 Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer))); 351 doConsume(command); 352 nextFrameSize = -1; 353 currentBuffer = null; 354 } 355 } 356 } 357 } 358 359 protected int secureRead(ByteBuffer plain) throws Exception { 360 361 if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 362 int bytesRead = channel.read(inputBuffer); 363 364 if (bytesRead == 0 && !(sslEngine.getHandshakeStatus().equals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP))) { 365 return 0; 366 } 367 368 if (bytesRead == -1) { 369 sslEngine.closeInbound(); 370 if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 371 return -1; 372 } 373 } 374 } 375 376 plain.clear(); 377 378 inputBuffer.flip(); 379 SSLEngineResult res; 380 do { 381 res = sslEngine.unwrap(inputBuffer, plain); 382 } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP 383 && res.bytesProduced() == 0); 384 385 if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) { 386 finishHandshake(); 387 } 388 389 status = res.getStatus(); 390 handshakeStatus = res.getHandshakeStatus(); 391 392 // TODO deal with BUFFER_OVERFLOW 393 394 if (status == SSLEngineResult.Status.CLOSED) { 395 sslEngine.closeInbound(); 396 return -1; 397 } 398 399 inputBuffer.compact(); 400 plain.flip(); 401 402 return plain.remaining(); 403 } 404 405 protected void doHandshake() throws Exception { 406 handshakeInProgress = true; 407 Selector selector = null; 408 SelectionKey key = null; 409 boolean readable = true; 410 try { 411 while (true) { 412 HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus(); 413 switch (handshakeStatus) { 414 case NEED_UNWRAP: 415 if (readable) { 416 secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize())); 417 } 418 if (this.status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 419 long now = System.currentTimeMillis(); 420 if (selector == null) { 421 selector = Selector.open(); 422 key = channel.register(selector, SelectionKey.OP_READ); 423 } else { 424 key.interestOps(SelectionKey.OP_READ); 425 } 426 int keyCount = selector.select(this.getSoTimeout()); 427 if (keyCount == 0 && this.getSoTimeout() > 0 && ((System.currentTimeMillis() - now) >= this.getSoTimeout())) { 428 throw new SocketTimeoutException("Timeout during handshake"); 429 } 430 readable = key.isReadable(); 431 } 432 break; 433 case NEED_TASK: 434 Runnable task; 435 while ((task = sslEngine.getDelegatedTask()) != null) { 436 task.run(); 437 } 438 break; 439 case NEED_WRAP: 440 ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0)); 441 break; 442 case FINISHED: 443 case NOT_HANDSHAKING: 444 finishHandshake(); 445 return; 446 } 447 } 448 } finally { 449 if (key!=null) try {key.cancel();} catch (Exception ignore) {} 450 if (selector!=null) try {selector.close();} catch (Exception ignore) {} 451 } 452 } 453 454 @Override 455 protected void doStart() throws Exception { 456 taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task"); 457 // no need to init as we can delay that until demand (eg in doHandshake) 458 super.doStart(); 459 } 460 461 @Override 462 protected void doStop(ServiceStopper stopper) throws Exception { 463 if (taskRunnerFactory != null) { 464 taskRunnerFactory.shutdownNow(); 465 taskRunnerFactory = null; 466 } 467 if (channel != null) { 468 channel.close(); 469 channel = null; 470 } 471 super.doStop(stopper); 472 } 473 474 /** 475 * Overriding in order to add the client's certificates to ConnectionInfo Commands. 476 * 477 * @param command 478 * The Command coming in. 479 */ 480 @Override 481 public void doConsume(Object command) { 482 if (command instanceof ConnectionInfo) { 483 ConnectionInfo connectionInfo = (ConnectionInfo) command; 484 connectionInfo.setTransportContext(getPeerCertificates()); 485 } 486 super.doConsume(command); 487 } 488 489 /** 490 * @return peer certificate chain associated with the ssl socket 491 */ 492 public X509Certificate[] getPeerCertificates() { 493 494 X509Certificate[] clientCertChain = null; 495 try { 496 if (sslEngine.getSession() != null) { 497 clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates(); 498 } 499 } catch (SSLPeerUnverifiedException e) { 500 if (LOG.isTraceEnabled()) { 501 LOG.trace("Failed to get peer certificates.", e); 502 } 503 } 504 505 return clientCertChain; 506 } 507 508 public boolean isNeedClientAuth() { 509 return needClientAuth; 510 } 511 512 public void setNeedClientAuth(boolean needClientAuth) { 513 this.needClientAuth = needClientAuth; 514 } 515 516 public boolean isWantClientAuth() { 517 return wantClientAuth; 518 } 519 520 public void setWantClientAuth(boolean wantClientAuth) { 521 this.wantClientAuth = wantClientAuth; 522 } 523 524 public String[] getEnabledCipherSuites() { 525 return enabledCipherSuites; 526 } 527 528 public void setEnabledCipherSuites(String[] enabledCipherSuites) { 529 this.enabledCipherSuites = enabledCipherSuites; 530 } 531 532 public String[] getEnabledProtocols() { 533 return enabledProtocols; 534 } 535 536 public void setEnabledProtocols(String[] enabledProtocols) { 537 this.enabledProtocols = enabledProtocols; 538 } 539}