001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.activemq.transport.auto;
018
019import java.io.IOException;
020import java.io.InputStream;
021import java.net.Socket;
022import java.net.URI;
023import java.net.URISyntaxException;
024import java.nio.ByteBuffer;
025import java.util.Map;
026import java.util.Set;
027import java.util.concurrent.ConcurrentHashMap;
028import java.util.concurrent.ConcurrentMap;
029import java.util.concurrent.ExecutorService;
030import java.util.concurrent.Executors;
031import java.util.concurrent.Future;
032import java.util.concurrent.LinkedBlockingQueue;
033import java.util.concurrent.ThreadPoolExecutor;
034import java.util.concurrent.TimeUnit;
035import java.util.concurrent.TimeoutException;
036import java.util.concurrent.atomic.AtomicInteger;
037
038import javax.net.ServerSocketFactory;
039
040import org.apache.activemq.broker.BrokerService;
041import org.apache.activemq.broker.BrokerServiceAware;
042import org.apache.activemq.openwire.OpenWireFormatFactory;
043import org.apache.activemq.transport.InactivityIOException;
044import org.apache.activemq.transport.Transport;
045import org.apache.activemq.transport.TransportFactory;
046import org.apache.activemq.transport.TransportServer;
047import org.apache.activemq.transport.protocol.AmqpProtocolVerifier;
048import org.apache.activemq.transport.protocol.MqttProtocolVerifier;
049import org.apache.activemq.transport.protocol.OpenWireProtocolVerifier;
050import org.apache.activemq.transport.protocol.ProtocolVerifier;
051import org.apache.activemq.transport.protocol.StompProtocolVerifier;
052import org.apache.activemq.transport.tcp.TcpTransport;
053import org.apache.activemq.transport.tcp.TcpTransport.InitBuffer;
054import org.apache.activemq.transport.tcp.TcpTransportFactory;
055import org.apache.activemq.transport.tcp.TcpTransportServer;
056import org.apache.activemq.util.FactoryFinder;
057import org.apache.activemq.util.IOExceptionSupport;
058import org.apache.activemq.util.IntrospectionSupport;
059import org.apache.activemq.util.ServiceStopper;
060import org.apache.activemq.wireformat.WireFormat;
061import org.apache.activemq.wireformat.WireFormatFactory;
062import org.slf4j.Logger;
063import org.slf4j.LoggerFactory;
064
065/**
066 * A TCP based implementation of {@link TransportServer}
067 */
068public class AutoTcpTransportServer extends TcpTransportServer {
069
070    private static final Logger LOG = LoggerFactory.getLogger(AutoTcpTransportServer.class);
071
072    protected Map<String, Map<String, Object>> wireFormatOptions;
073    protected Map<String, Object> autoTransportOptions;
074    protected Set<String> enabledProtocols;
075    protected final Map<String, ProtocolVerifier> protocolVerifiers = new ConcurrentHashMap<String, ProtocolVerifier>();
076
077    protected BrokerService brokerService;
078
079    protected int maxConnectionThreadPoolSize = Integer.MAX_VALUE;
080    protected int protocolDetectionTimeOut = 30000;
081
082    private static final FactoryFinder TRANSPORT_FACTORY_FINDER = new FactoryFinder("META-INF/services/org/apache/activemq/transport/");
083    private final ConcurrentMap<String, TransportFactory> transportFactories = new ConcurrentHashMap<String, TransportFactory>();
084
085    private static final FactoryFinder WIREFORMAT_FACTORY_FINDER = new FactoryFinder("META-INF/services/org/apache/activemq/wireformat/");
086
087    public WireFormatFactory findWireFormatFactory(String scheme, Map<String, Map<String, Object>> options) throws IOException {
088        WireFormatFactory wff = null;
089        try {
090            wff = (WireFormatFactory)WIREFORMAT_FACTORY_FINDER.newInstance(scheme);
091            if (options != null) {
092                IntrospectionSupport.setProperties(wff, options.get(AutoTransportUtils.ALL));
093                IntrospectionSupport.setProperties(wff, options.get(scheme));
094            }
095            if (wff instanceof OpenWireFormatFactory) {
096                protocolVerifiers.put(AutoTransportUtils.OPENWIRE, new OpenWireProtocolVerifier((OpenWireFormatFactory) wff));
097            }
098            return wff;
099        } catch (Throwable e) {
100           throw IOExceptionSupport.create("Could not create wire format factory for: " + scheme + ", reason: " + e, e);
101        }
102    }
103
104    public TransportFactory findTransportFactory(String scheme, Map<String, ?> options) throws IOException {
105        scheme = append(scheme, "nio");
106        scheme = append(scheme, "ssl");
107
108        if (scheme.isEmpty()) {
109            scheme = "tcp";
110        }
111
112        TransportFactory tf = transportFactories.get(scheme);
113        if (tf == null) {
114            // Try to load if from a META-INF property.
115            try {
116                tf = (TransportFactory)TRANSPORT_FACTORY_FINDER.newInstance(scheme);
117                if (options != null) {
118                    IntrospectionSupport.setProperties(tf, options);
119                }
120                transportFactories.put(scheme, tf);
121            } catch (Throwable e) {
122                throw IOExceptionSupport.create("Transport scheme NOT recognized: [" + scheme + "]", e);
123            }
124        }
125        return tf;
126    }
127
128    protected String append(String currentScheme, String scheme) {
129        if (this.getBindLocation().getScheme().contains(scheme)) {
130            if (!currentScheme.isEmpty()) {
131                currentScheme += "+";
132            }
133            currentScheme += scheme;
134        }
135        return currentScheme;
136    }
137
138    /**
139     * @param transportFactory
140     * @param location
141     * @param serverSocketFactory
142     * @throws IOException
143     * @throws URISyntaxException
144     */
145    public AutoTcpTransportServer(TcpTransportFactory transportFactory,
146            URI location, ServerSocketFactory serverSocketFactory, BrokerService brokerService,
147            Set<String> enabledProtocols)
148            throws IOException, URISyntaxException {
149        super(transportFactory, location, serverSocketFactory);
150
151        //Use an executor service here to handle new connections.  Setting the max number
152        //of threads to the maximum number of connections the thread count isn't unbounded
153        service = new ThreadPoolExecutor(maxConnectionThreadPoolSize,
154                maxConnectionThreadPoolSize,
155                30L, TimeUnit.SECONDS,
156                new LinkedBlockingQueue<Runnable>());
157        //allow the thread pool to shrink if the max number of threads isn't needed
158        service.allowCoreThreadTimeOut(true);
159
160        this.brokerService = brokerService;
161        this.enabledProtocols = enabledProtocols;
162        initProtocolVerifiers();
163    }
164
165    public int getMaxConnectionThreadPoolSize() {
166        return maxConnectionThreadPoolSize;
167    }
168
169    public void setMaxConnectionThreadPoolSize(int maxConnectionThreadPoolSize) {
170        this.maxConnectionThreadPoolSize = maxConnectionThreadPoolSize;
171        service.setCorePoolSize(maxConnectionThreadPoolSize);
172        service.setMaximumPoolSize(maxConnectionThreadPoolSize);
173    }
174
175    public void setProtocolDetectionTimeOut(int protocolDetectionTimeOut) {
176        this.protocolDetectionTimeOut = protocolDetectionTimeOut;
177    }
178
179    @Override
180    public void setWireFormatFactory(WireFormatFactory factory) {
181        super.setWireFormatFactory(factory);
182        initOpenWireProtocolVerifier();
183    }
184
185    protected void initProtocolVerifiers() {
186        initOpenWireProtocolVerifier();
187
188        if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.AMQP)) {
189            protocolVerifiers.put(AutoTransportUtils.AMQP, new AmqpProtocolVerifier());
190        }
191        if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.STOMP)) {
192            protocolVerifiers.put(AutoTransportUtils.STOMP, new StompProtocolVerifier());
193        }
194        if (isAllProtocols()|| enabledProtocols.contains(AutoTransportUtils.MQTT)) {
195            protocolVerifiers.put(AutoTransportUtils.MQTT, new MqttProtocolVerifier());
196        }
197    }
198
199    protected void initOpenWireProtocolVerifier() {
200        if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.OPENWIRE)) {
201            OpenWireProtocolVerifier owpv;
202            if (wireFormatFactory instanceof OpenWireFormatFactory) {
203                owpv = new OpenWireProtocolVerifier((OpenWireFormatFactory) wireFormatFactory);
204            } else {
205                owpv = new OpenWireProtocolVerifier(new OpenWireFormatFactory());
206            }
207            protocolVerifiers.put(AutoTransportUtils.OPENWIRE, owpv);
208        }
209    }
210
211    protected boolean isAllProtocols() {
212        return enabledProtocols == null || enabledProtocols.isEmpty();
213    }
214
215
216    protected final ThreadPoolExecutor service;
217
218
219    /**
220     * This holds the initial buffer that has been read to detect the protocol.
221     */
222    public InitBuffer initBuffer;
223
224    @Override
225    protected void handleSocket(final Socket socket) {
226        final AutoTcpTransportServer server = this;
227        //This needs to be done in a new thread because
228        //the socket might be waiting on the client to send bytes
229        //doHandleSocket can't complete until the protocol can be detected
230        service.submit(new Runnable() {
231            @Override
232            public void run() {
233                server.doHandleSocket(socket);
234            }
235        });
236    }
237
238    @Override
239    protected TransportInfo configureTransport(final TcpTransportServer server, final Socket socket) throws Exception {
240        final InputStream is = socket.getInputStream();
241        ExecutorService executor = Executors.newSingleThreadExecutor();
242
243        final AtomicInteger readBytes = new AtomicInteger(0);
244        final ByteBuffer data = ByteBuffer.allocate(8);
245        // We need to peak at the first 8 bytes of the buffer to detect the protocol
246        Future<?> future = executor.submit(new Runnable() {
247            @Override
248            public void run() {
249                try {
250                    do {
251                        int read = is.read();
252                        if (read == -1) {
253                            throw new IOException("Connection failed, stream is closed.");
254                        }
255                        data.put((byte) read);
256                        readBytes.incrementAndGet();
257                    } while (readBytes.get() < 8);
258                } catch (Exception e) {
259                    throw new IllegalStateException(e);
260                }
261            }
262        });
263
264        waitForProtocolDetectionFinish(future, readBytes);
265        data.flip();
266        ProtocolInfo protocolInfo = detectProtocol(data.array());
267
268        initBuffer = new InitBuffer(readBytes.get(), ByteBuffer.allocate(readBytes.get()));
269        initBuffer.buffer.put(data.array());
270
271        if (protocolInfo.detectedTransportFactory instanceof BrokerServiceAware) {
272            ((BrokerServiceAware) protocolInfo.detectedTransportFactory).setBrokerService(brokerService);
273        }
274
275        WireFormat format = protocolInfo.detectedWireFormatFactory.createWireFormat();
276        Transport transport = createTransport(socket, format,protocolInfo.detectedTransportFactory);
277
278        return new TransportInfo(format, transport, protocolInfo.detectedTransportFactory);
279    }
280
281    protected void waitForProtocolDetectionFinish(final Future<?> future, final AtomicInteger readBytes) throws Exception {
282        try {
283            //Wait for protocolDetectionTimeOut if defined
284            if (protocolDetectionTimeOut > 0) {
285                future.get(protocolDetectionTimeOut, TimeUnit.MILLISECONDS);
286            } else {
287                future.get();
288            }
289        } catch (TimeoutException e) {
290            throw new InactivityIOException("Client timed out before wire format could be detected. " +
291                    " 8 bytes are required to detect the protocol but only: " + readBytes.get() + " byte(s) were sent.");
292        }
293    }
294
295    @Override
296    protected TcpTransport createTransport(Socket socket, WireFormat format) throws IOException {
297        return new TcpTransport(format, socket, this.initBuffer);
298    }
299
300    /**
301     * @param socket
302     * @param format
303     * @param detectedTransportFactory
304     * @return
305     */
306    protected TcpTransport createTransport(Socket socket, WireFormat format,
307            TcpTransportFactory detectedTransportFactory) throws IOException {
308        return createTransport(socket, format);
309    }
310
311    public void setWireFormatOptions(Map<String, Map<String, Object>> wireFormatOptions) {
312        this.wireFormatOptions = wireFormatOptions;
313    }
314
315    public void setEnabledProtocols(Set<String> enabledProtocols) {
316        this.enabledProtocols = enabledProtocols;
317    }
318
319    public void setAutoTransportOptions(Map<String, Object> autoTransportOptions) {
320        this.autoTransportOptions = autoTransportOptions;
321        if (autoTransportOptions.get("protocols") != null) {
322            this.enabledProtocols = AutoTransportUtils.parseProtocols((String) autoTransportOptions.get("protocols"));
323        }
324    }
325    @Override
326    protected void doStop(ServiceStopper stopper) throws Exception {
327        if (service != null) {
328            service.shutdown();
329        }
330        super.doStop(stopper);
331    }
332
333    protected ProtocolInfo detectProtocol(byte[] buffer) throws IOException {
334        TcpTransportFactory detectedTransportFactory = transportFactory;
335        WireFormatFactory detectedWireFormatFactory = wireFormatFactory;
336
337        boolean found = false;
338        for (String scheme : protocolVerifiers.keySet()) {
339            if (protocolVerifiers.get(scheme).isProtocol(buffer)) {
340                LOG.debug("Detected protocol " + scheme);
341                detectedWireFormatFactory = findWireFormatFactory(scheme, wireFormatOptions);
342
343                if (scheme.equals("default")) {
344                    scheme = "";
345                }
346
347                detectedTransportFactory = (TcpTransportFactory) findTransportFactory(scheme, transportOptions);
348                found = true;
349                break;
350            }
351        }
352
353        if (!found) {
354            throw new IllegalStateException("Could not detect the wire format");
355        }
356
357        return new ProtocolInfo(detectedTransportFactory, detectedWireFormatFactory);
358
359    }
360
361    protected class ProtocolInfo {
362        public final TcpTransportFactory detectedTransportFactory;
363        public final WireFormatFactory detectedWireFormatFactory;
364
365        public ProtocolInfo(TcpTransportFactory detectedTransportFactory,
366                WireFormatFactory detectedWireFormatFactory) {
367            super();
368            this.detectedTransportFactory = detectedTransportFactory;
369            this.detectedWireFormatFactory = detectedWireFormatFactory;
370        }
371    }
372
373}