diff options
author | Tim Brooks <tim@uncontended.net> | 2017-06-28 10:51:20 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-06-28 10:51:20 -0500 |
commit | 5f8be0e090f788d9132c2c371f3ecb9c4129bd8e (patch) | |
tree | e2858917b831b7406844184ef8de7c4862c98d3d /test/framework | |
parent | 9ce9c21b836834070c584c97a35d2e232d8478d0 (diff) |
Introduce NioTransport into framework for testing (#24262)
This commit introduces a nio based tcp transport into framework for
testing.
Currently Elasticsearch uses a simple blocking tcp transport for
testing purposes (MockTcpTransport). This diverges from production
where our current transport (netty) is non-blocking.
The point of this commit is to introduce a testing variant that more
closely matches the behavior of production instances.
Diffstat (limited to 'test/framework')
46 files changed, 5492 insertions, 0 deletions
diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptingSelector.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptingSelector.java new file mode 100644 index 0000000000..c2c9ac03a2 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptingSelector.java @@ -0,0 +1,115 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.elasticsearch.transport.nio.channel.NioServerSocketChannel; + +import java.io.IOException; +import java.nio.channels.CancelledKeyException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.ClosedSelectorException; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.util.Iterator; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; + +/** + * Selector implementation that handles {@link NioServerSocketChannel}. It's main piece of functionality is + * accepting new channels. + */ +public class AcceptingSelector extends ESSelector { + + private final AcceptorEventHandler eventHandler; + private final ConcurrentLinkedQueue<NioServerSocketChannel> newChannels = new ConcurrentLinkedQueue<>(); + + public AcceptingSelector(AcceptorEventHandler eventHandler) throws IOException { + super(eventHandler); + this.eventHandler = eventHandler; + } + + public AcceptingSelector(AcceptorEventHandler eventHandler, Selector selector) throws IOException { + super(eventHandler, selector); + this.eventHandler = eventHandler; + } + + @Override + void doSelect(int timeout) throws IOException, ClosedSelectorException { + setUpNewServerChannels(); + + int ready = selector.select(timeout); + if (ready > 0) { + Set<SelectionKey> selectionKeys = selector.selectedKeys(); + Iterator<SelectionKey> keyIterator = selectionKeys.iterator(); + while (keyIterator.hasNext()) { + SelectionKey sk = keyIterator.next(); + keyIterator.remove(); + acceptChannel(sk); + } + } + } + + @Override + void cleanup() { + channelsToClose.addAll(registeredChannels); + closePendingChannels(); + } + + /** + * Registers a NioServerSocketChannel to be handled by this selector. The channel will by queued and + * eventually registered next time through the event loop. + * @param serverSocketChannel the channel to register + */ + public void registerServerChannel(NioServerSocketChannel serverSocketChannel) { + newChannels.add(serverSocketChannel); + wakeup(); + } + + private void setUpNewServerChannels() throws ClosedChannelException { + NioServerSocketChannel newChannel; + while ((newChannel = this.newChannels.poll()) != null) { + if (newChannel.register(this)) { + SelectionKey selectionKey = newChannel.getSelectionKey(); + selectionKey.attach(newChannel); + registeredChannels.add(newChannel); + eventHandler.serverChannelRegistered(newChannel); + } + } + } + + private void acceptChannel(SelectionKey sk) { + NioServerSocketChannel serverChannel = (NioServerSocketChannel) sk.attachment(); + if (sk.isValid()) { + try { + if (sk.isAcceptable()) { + try { + eventHandler.acceptChannel(serverChannel); + } catch (IOException e) { + eventHandler.acceptException(serverChannel, e); + } + } + } catch (CancelledKeyException ex) { + eventHandler.genericServerChannelException(serverChannel, ex); + } + } else { + eventHandler.genericServerChannelException(serverChannel, new CancelledKeyException()); + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java new file mode 100644 index 0000000000..7ce3b93e17 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java @@ -0,0 +1,91 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.transport.nio.channel.ChannelFactory; +import org.elasticsearch.transport.nio.channel.NioServerSocketChannel; +import org.elasticsearch.transport.nio.channel.NioSocketChannel; +import org.elasticsearch.transport.nio.channel.SelectionKeyUtils; + +import java.io.IOException; +import java.util.function.Supplier; + +/** + * Event handler designed to handle events from server sockets + */ +public class AcceptorEventHandler extends EventHandler { + + private final Supplier<SocketSelector> selectorSupplier; + private final OpenChannels openChannels; + + public AcceptorEventHandler(Logger logger, OpenChannels openChannels, Supplier<SocketSelector> selectorSupplier) { + super(logger); + this.openChannels = openChannels; + this.selectorSupplier = selectorSupplier; + } + + /** + * This method is called when a NioServerSocketChannel is successfully registered. It should only be + * called once per channel. + * + * @param nioServerSocketChannel that was registered + */ + public void serverChannelRegistered(NioServerSocketChannel nioServerSocketChannel) { + SelectionKeyUtils.setAcceptInterested(nioServerSocketChannel); + openChannels.serverChannelOpened(nioServerSocketChannel); + } + + /** + * This method is called when a server channel signals it is ready to accept a connection. All of the + * accept logic should occur in this call. + * + * @param nioServerChannel that can accept a connection + */ + public void acceptChannel(NioServerSocketChannel nioServerChannel) throws IOException { + ChannelFactory channelFactory = nioServerChannel.getChannelFactory(); + NioSocketChannel nioSocketChannel = channelFactory.acceptNioChannel(nioServerChannel); + openChannels.acceptedChannelOpened(nioSocketChannel); + nioSocketChannel.getCloseFuture().setListener(openChannels::channelClosed); + selectorSupplier.get().registerSocketChannel(nioSocketChannel); + } + + /** + * This method is called when an attempt to accept a connection throws an exception. + * + * @param nioServerChannel that accepting a connection + * @param exception that occurred + */ + public void acceptException(NioServerSocketChannel nioServerChannel, Exception exception) { + logger.debug("exception while accepting new channel", exception); + } + + /** + * This method is called when handling an event from a channel fails due to an unexpected exception. + * An example would be if checking ready ops on a {@link java.nio.channels.SelectionKey} threw + * {@link java.nio.channels.CancelledKeyException}. + * + * @param channel that caused the exception + * @param exception that was thrown + */ + public void genericServerChannelException(NioServerSocketChannel channel, Exception exception) { + logger.debug("event handling exception", exception); + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/ESSelector.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/ESSelector.java new file mode 100644 index 0000000000..c5cf7e2593 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/ESSelector.java @@ -0,0 +1,196 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.transport.nio.channel.NioChannel; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.channels.ClosedSelectorException; +import java.nio.channels.Selector; +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.ReentrantLock; + +/** + * This is a basic selector abstraction used by {@link org.elasticsearch.transport.nio.NioTransport}. This + * selector wraps a raw nio {@link Selector}. When you call {@link #runLoop()}, the selector will run until + * {@link #close()} is called. This instance handles closing of channels. Users should call + * {@link #queueChannelClose(NioChannel)} to schedule a channel for close by this selector. + * <p> + * Children of this class should implement the specific {@link #doSelect(int)} and {@link #cleanup()} + * functionality. + */ +public abstract class ESSelector implements Closeable { + + final Selector selector; + final ConcurrentLinkedQueue<NioChannel> channelsToClose = new ConcurrentLinkedQueue<>(); + final Set<NioChannel> registeredChannels = Collections.newSetFromMap(new ConcurrentHashMap<NioChannel, Boolean>()); + + private final EventHandler eventHandler; + private final ReentrantLock runLock = new ReentrantLock(); + private final AtomicBoolean isClosed = new AtomicBoolean(false); + private final PlainActionFuture<Boolean> isRunningFuture = PlainActionFuture.newFuture(); + private volatile Thread thread; + + ESSelector(EventHandler eventHandler) throws IOException { + this(eventHandler, Selector.open()); + } + + ESSelector(EventHandler eventHandler, Selector selector) throws IOException { + this.eventHandler = eventHandler; + this.selector = selector; + } + + /** + * Starts this selector. The selector will run until {@link #close()} or {@link #close(boolean)} is + * called. + */ + public void runLoop() { + if (runLock.tryLock()) { + isRunningFuture.onResponse(true); + try { + setThread(); + while (isOpen()) { + singleLoop(); + } + } finally { + try { + cleanup(); + } finally { + runLock.unlock(); + } + } + } else { + throw new IllegalStateException("selector is already running"); + } + } + + void singleLoop() { + try { + closePendingChannels(); + doSelect(300); + } catch (ClosedSelectorException e) { + if (isOpen()) { + throw e; + } + } catch (IOException e) { + eventHandler.selectException(e); + } catch (Exception e) { + eventHandler.uncaughtException(e); + } + } + + /** + * Should implement the specific select logic. This will be called once per {@link #singleLoop()} + * + * @param timeout to pass to the raw select operation + * @throws IOException thrown by the raw select operation + * @throws ClosedSelectorException thrown if the raw selector is closed + */ + abstract void doSelect(int timeout) throws IOException, ClosedSelectorException; + + void setThread() { + thread = Thread.currentThread(); + } + + public boolean isOnCurrentThread() { + return Thread.currentThread() == thread; + } + + public void wakeup() { + // TODO: Do I need the wakeup optimizations that some other libraries use? + selector.wakeup(); + } + + public Set<NioChannel> getRegisteredChannels() { + return registeredChannels; + } + + @Override + public void close() throws IOException { + close(false); + } + + public void close(boolean shouldInterrupt) throws IOException { + if (isClosed.compareAndSet(false, true)) { + selector.close(); + if (shouldInterrupt && thread != null) { + thread.interrupt(); + } else { + wakeup(); + } + runLock.lock(); // wait for the shutdown to complete + } + } + + public void queueChannelClose(NioChannel channel) { + ensureOpen(); + channelsToClose.offer(channel); + wakeup(); + } + + void closePendingChannels() { + NioChannel channel; + while ((channel = channelsToClose.poll()) != null) { + closeChannel(channel); + } + } + + + /** + * Called once as the selector is being closed. + */ + abstract void cleanup(); + + public Selector rawSelector() { + return selector; + } + + public boolean isOpen() { + return isClosed.get() == false; + } + + public boolean isRunning() { + return runLock.isLocked(); + } + + public PlainActionFuture<Boolean> isRunningFuture() { + return isRunningFuture; + } + + private void closeChannel(NioChannel channel) { + try { + eventHandler.handleClose(channel); + } finally { + registeredChannels.remove(channel); + } + } + + private void ensureOpen() { + if (isClosed.get()) { + throw new IllegalStateException("selector is already closed"); + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/EventHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/EventHandler.java new file mode 100644 index 0000000000..6ecf36343f --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/EventHandler.java @@ -0,0 +1,71 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.transport.nio.channel.CloseFuture; +import org.elasticsearch.transport.nio.channel.NioChannel; +import org.elasticsearch.transport.nio.channel.NioSocketChannel; + +import java.io.IOException; +import java.nio.channels.Selector; + +public abstract class EventHandler { + + protected final Logger logger; + + public EventHandler(Logger logger) { + this.logger = logger; + } + + /** + * This method handles an IOException that was thrown during a call to {@link Selector#select(long)}. + * + * @param exception that was uncaught + */ + public void selectException(IOException exception) { + logger.warn("io exception during select", exception); + } + + /** + * This method handles an exception that was uncaught during a select loop. + * + * @param exception that was uncaught + */ + public void uncaughtException(Exception exception) { + Thread thread = Thread.currentThread(); + thread.getUncaughtExceptionHandler().uncaughtException(thread, exception); + } + + /** + * This method handles the closing of an NioChannel + * + * @param channel that should be closed + */ + public void handleClose(NioChannel channel) { + channel.closeFromSelector(); + CloseFuture closeFuture = channel.getCloseFuture(); + assert closeFuture.isDone() : "Should always be done as we are on the selector thread"; + IOException closeException = closeFuture.getCloseException(); + if (closeException != null) { + logger.trace("exception while closing channel", closeException); + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/NetworkBytesReference.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/NetworkBytesReference.java new file mode 100644 index 0000000000..cbccd7333d --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/NetworkBytesReference.java @@ -0,0 +1,157 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; + +import java.nio.ByteBuffer; +import java.util.Iterator; + +public class NetworkBytesReference extends BytesReference { + + private final BytesArray bytesArray; + private final ByteBuffer writeBuffer; + private final ByteBuffer readBuffer; + + private int writeIndex; + private int readIndex; + + public NetworkBytesReference(BytesArray bytesArray, int writeIndex, int readIndex) { + this.bytesArray = bytesArray; + this.writeIndex = writeIndex; + this.readIndex = readIndex; + this.writeBuffer = ByteBuffer.wrap(bytesArray.array()); + this.readBuffer = ByteBuffer.wrap(bytesArray.array()); + } + + public static NetworkBytesReference wrap(BytesArray bytesArray) { + return wrap(bytesArray, 0, 0); + } + + public static NetworkBytesReference wrap(BytesArray bytesArray, int writeIndex, int readIndex) { + if (readIndex > writeIndex) { + throw new IndexOutOfBoundsException("Read index [" + readIndex + "] was greater than write index [" + writeIndex + "]"); + } + return new NetworkBytesReference(bytesArray, writeIndex, readIndex); + } + + @Override + public byte get(int index) { + return bytesArray.get(index); + } + + @Override + public int length() { + return bytesArray.length(); + } + + @Override + public NetworkBytesReference slice(int from, int length) { + BytesReference ref = bytesArray.slice(from, length); + BytesArray newBytesArray; + if (ref instanceof BytesArray) { + newBytesArray = (BytesArray) ref; + } else { + newBytesArray = new BytesArray(ref.toBytesRef()); + } + + int newReadIndex = Math.min(Math.max(readIndex - from, 0), length); + int newWriteIndex = Math.min(Math.max(writeIndex - from, 0), length); + + return wrap(newBytesArray, newWriteIndex, newReadIndex); + } + + @Override + public BytesRef toBytesRef() { + return bytesArray.toBytesRef(); + } + + @Override + public long ramBytesUsed() { + return bytesArray.ramBytesUsed(); + } + + public int getWriteIndex() { + return writeIndex; + } + + public void incrementWrite(int delta) { + int newWriteIndex = writeIndex + delta; + if (newWriteIndex > bytesArray.length()) { + throw new IndexOutOfBoundsException("New write index [" + newWriteIndex + "] would be greater than length" + + " [" + bytesArray.length() + "]"); + } + + writeIndex = newWriteIndex; + } + + public int getWriteRemaining() { + return bytesArray.length() - writeIndex; + } + + public boolean hasWriteRemaining() { + return getWriteRemaining() > 0; + } + + public int getReadIndex() { + return readIndex; + } + + public void incrementRead(int delta) { + int newReadIndex = readIndex + delta; + if (newReadIndex > writeIndex) { + throw new IndexOutOfBoundsException("New read index [" + newReadIndex + "] would be greater than write" + + " index [" + writeIndex + "]"); + } + readIndex = newReadIndex; + } + + public int getReadRemaining() { + return writeIndex - readIndex; + } + + public boolean hasReadRemaining() { + return getReadRemaining() > 0; + } + + public ByteBuffer getWriteByteBuffer() { + writeBuffer.position(bytesArray.offset() + writeIndex); + writeBuffer.limit(bytesArray.offset() + bytesArray.length()); + return writeBuffer; + } + + public ByteBuffer getReadByteBuffer() { + readBuffer.position(bytesArray.offset() + readIndex); + readBuffer.limit(bytesArray.offset() + writeIndex); + return readBuffer; + } + + public static void vectorizedIncrementReadIndexes(Iterable<NetworkBytesReference> references, int delta) { + Iterator<NetworkBytesReference> refs = references.iterator(); + while (delta != 0) { + NetworkBytesReference ref = refs.next(); + int amountToInc = Math.min(ref.getReadRemaining(), delta); + ref.incrementRead(amountToInc); + delta -= amountToInc; + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java new file mode 100644 index 0000000000..bc06ad0bc3 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java @@ -0,0 +1,155 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.transport.ConnectTransportException; +import org.elasticsearch.transport.nio.channel.ChannelFactory; +import org.elasticsearch.transport.nio.channel.ConnectFuture; +import org.elasticsearch.transport.nio.channel.NioChannel; +import org.elasticsearch.transport.nio.channel.NioSocketChannel; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.LockSupport; +import java.util.function.Consumer; +import java.util.function.Supplier; + +public class NioClient { + + private static final int CLOSED = -1; + + private final Logger logger; + private final OpenChannels openChannels; + private final Supplier<SocketSelector> selectorSupplier; + private final TimeValue defaultConnectTimeout; + private final ChannelFactory channelFactory; + private final Semaphore semaphore = new Semaphore(Integer.MAX_VALUE); + + public NioClient(Logger logger, OpenChannels openChannels, Supplier<SocketSelector> selectorSupplier, TimeValue connectTimeout, + ChannelFactory channelFactory) { + this.logger = logger; + this.openChannels = openChannels; + this.selectorSupplier = selectorSupplier; + this.defaultConnectTimeout = connectTimeout; + this.channelFactory = channelFactory; + } + + public boolean connectToChannels(DiscoveryNode node, NioSocketChannel[] channels, TimeValue connectTimeout, + Consumer<NioChannel> closeListener) throws IOException { + boolean allowedToConnect = semaphore.tryAcquire(); + if (allowedToConnect == false) { + return false; + } + + final ArrayList<NioSocketChannel> connections = new ArrayList<>(channels.length); + connectTimeout = getConnectTimeout(connectTimeout); + final InetSocketAddress address = node.getAddress().address(); + try { + for (int i = 0; i < channels.length; i++) { + SocketSelector socketSelector = selectorSupplier.get(); + NioSocketChannel nioSocketChannel = channelFactory.openNioChannel(address); + openChannels.clientChannelOpened(nioSocketChannel); + nioSocketChannel.getCloseFuture().setListener(closeListener); + connections.add(nioSocketChannel); + socketSelector.registerSocketChannel(nioSocketChannel); + } + + Exception ex = null; + boolean allConnected = true; + for (NioSocketChannel socketChannel : connections) { + ConnectFuture connectFuture = socketChannel.getConnectFuture(); + boolean success = connectFuture.awaitConnectionComplete(connectTimeout.getMillis(), TimeUnit.MILLISECONDS); + if (success == false) { + allConnected = false; + Exception exception = connectFuture.getException(); + if (exception != null) { + ex = exception; + break; + } + } + } + + if (allConnected == false) { + if (ex == null) { + throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]"); + } else { + throw new ConnectTransportException(node, "connect_exception", ex); + } + } + addConnectionsToList(channels, connections); + return true; + + } catch (IOException | RuntimeException e) { + closeChannels(connections, e); + throw e; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + closeChannels(connections, e); + throw new ElasticsearchException(e); + } finally { + semaphore.release(); + } + } + + public void close() { + semaphore.acquireUninterruptibly(Integer.MAX_VALUE); + } + + private TimeValue getConnectTimeout(TimeValue connectTimeout) { + if (connectTimeout != null && connectTimeout.equals(defaultConnectTimeout) == false) { + return connectTimeout; + } else { + return defaultConnectTimeout; + } + } + + private static void addConnectionsToList(NioSocketChannel[] channels, ArrayList<NioSocketChannel> connections) { + final Iterator<NioSocketChannel> iterator = connections.iterator(); + for (int i = 0; i < channels.length; i++) { + assert iterator.hasNext(); + channels[i] = iterator.next(); + } + assert iterator.hasNext() == false : "not all created connection have been consumed"; + } + + private void closeChannels(ArrayList<NioSocketChannel> connections, Exception e) { + for (final NioSocketChannel socketChannel : connections) { + try { + socketChannel.closeAsync().awaitClose(); + } catch (InterruptedException inner) { + logger.trace("exception while closing channel", e); + e.addSuppressed(inner); + Thread.currentThread().interrupt(); + } catch (Exception inner) { + logger.trace("exception while closing channel", e); + e.addSuppressed(inner); + } + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioShutdown.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioShutdown.java new file mode 100644 index 0000000000..8dc87f80f8 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioShutdown.java @@ -0,0 +1,66 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.concurrent.CountDownLatch; + +public class NioShutdown { + + private final Logger logger; + + public NioShutdown(Logger logger) { + this.logger = logger; + } + + void orderlyShutdown(OpenChannels openChannels, NioClient client, ArrayList<AcceptingSelector> acceptors, + ArrayList<SocketSelector> socketSelectors) { + // Close the client. This ensures that no new send connections will be opened. Client could be null if exception was + // throw on start up + if (client != null) { + client.close(); + } + + // Start by closing the server channels. Once these are closed, we are guaranteed to no accept new connections + openChannels.closeServerChannels(); + + for (AcceptingSelector acceptor : acceptors) { + shutdownSelector(acceptor); + } + + openChannels.close(); + + for (SocketSelector selector : socketSelectors) { + shutdownSelector(selector); + } + } + + private void shutdownSelector(ESSelector selector) { + try { + selector.close(); + } catch (IOException | ElasticsearchException e) { + logger.warn("unexpected exception while stopping selector", e); + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java new file mode 100644 index 0000000000..05c818476a --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java @@ -0,0 +1,289 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.network.NetworkService; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.indices.breaker.CircuitBreakerService; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.ConnectionProfile; +import org.elasticsearch.transport.TcpTransport; +import org.elasticsearch.transport.TransportSettings; +import org.elasticsearch.transport.nio.channel.ChannelFactory; +import org.elasticsearch.transport.nio.channel.NioChannel; +import org.elasticsearch.transport.nio.channel.NioServerSocketChannel; +import org.elasticsearch.transport.nio.channel.NioSocketChannel; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ThreadFactory; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static org.elasticsearch.common.settings.Setting.intSetting; +import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap; +import static org.elasticsearch.common.util.concurrent.EsExecutors.daemonThreadFactory; + +public class NioTransport extends TcpTransport<NioChannel> { + + // TODO: Need to add to places where we check if transport thread + public static final String TRANSPORT_WORKER_THREAD_NAME_PREFIX = "transport_worker"; + public static final String TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX = "transport_acceptor"; + + public static final Setting<Integer> NIO_WORKER_COUNT = + new Setting<>("transport.nio.worker_count", + (s) -> Integer.toString(EsExecutors.numberOfProcessors(s) * 2), + (s) -> Setting.parseInt(s, 1, "transport.nio.worker_count"), Setting.Property.NodeScope); + + public static final Setting<Integer> NIO_ACCEPTOR_COUNT = + intSetting("transport.nio.acceptor_count", 1, 1, Setting.Property.NodeScope); + + private final TcpReadHandler tcpReadHandler = new TcpReadHandler(this); + private final BigArrays bigArrays; + private final ConcurrentMap<String, ChannelFactory> profileToChannelFactory = newConcurrentMap(); + private final OpenChannels openChannels = new OpenChannels(logger); + private final ArrayList<AcceptingSelector> acceptors = new ArrayList<>(); + private final ArrayList<SocketSelector> socketSelectors = new ArrayList<>(); + private NioClient client; + private int acceptorNumber; + + public NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, + NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService) { + super("nio", settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService); + this.bigArrays = bigArrays; + } + + @Override + public long getNumOpenServerConnections() { + return openChannels.serverChannelsCount(); + } + + @Override + protected InetSocketAddress getLocalAddress(NioChannel channel) { + return channel.getLocalAddress(); + } + + @Override + protected NioServerSocketChannel bind(String name, InetSocketAddress address) throws IOException { + ChannelFactory channelFactory = this.profileToChannelFactory.get(name); + NioServerSocketChannel serverSocketChannel = channelFactory.openNioServerSocketChannel(name, address); + acceptors.get(++acceptorNumber % NioTransport.NIO_ACCEPTOR_COUNT.get(settings)).registerServerChannel(serverSocketChannel); + return serverSocketChannel; + } + + @Override + protected void closeChannels(List<NioChannel> channels) throws IOException { + IOException closingExceptions = null; + for (final NioChannel channel : channels) { + if (channel != null && channel.isOpen()) { + try { + channel.closeAsync().awaitClose(); + } catch (Exception e) { + if (closingExceptions == null) { + closingExceptions = new IOException("failed to close channels"); + } + closingExceptions.addSuppressed(e.getCause()); + } + } + } + + if (closingExceptions != null) { + throw closingExceptions; + } + } + + @Override + protected void sendMessage(NioChannel channel, BytesReference reference, ActionListener<NioChannel> listener) { + if (channel instanceof NioSocketChannel) { + NioSocketChannel nioSocketChannel = (NioSocketChannel) channel; + nioSocketChannel.getWriteContext().sendMessage(reference, listener); + } else { + logger.error("cannot send message to channel of this type [{}]", channel.getClass()); + } + } + + @Override + protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile, Consumer<NioChannel> onChannelClose) + throws IOException { + NioSocketChannel[] channels = new NioSocketChannel[profile.getNumConnections()]; + ClientChannelCloseListener closeListener = new ClientChannelCloseListener(onChannelClose); + boolean connected = client.connectToChannels(node, channels, profile.getConnectTimeout(), closeListener); + if (connected == false) { + throw new ElasticsearchException("client is shutdown"); + } + return new NodeChannels(node, channels, profile); + } + + @Override + protected boolean isOpen(NioChannel channel) { + return channel.isOpen(); + } + + @Override + protected void doStart() { + boolean success = false; + try { + if (NetworkService.NETWORK_SERVER.get(settings)) { + int workerCount = NioTransport.NIO_WORKER_COUNT.get(settings); + for (int i = 0; i < workerCount; ++i) { + SocketSelector selector = new SocketSelector(getSocketEventHandler()); + socketSelectors.add(selector); + } + + int acceptorCount = NioTransport.NIO_ACCEPTOR_COUNT.get(settings); + for (int i = 0; i < acceptorCount; ++i) { + Supplier<SocketSelector> selectorSupplier = new RoundRobinSelectorSupplier(socketSelectors); + AcceptorEventHandler eventHandler = new AcceptorEventHandler(logger, openChannels, selectorSupplier); + AcceptingSelector acceptor = new AcceptingSelector(eventHandler); + acceptors.add(acceptor); + } + // loop through all profiles and start them up, special handling for default one + for (Map.Entry<String, Settings> entry : buildProfileSettings().entrySet()) { + // merge fallback settings with default settings with profile settings so we have complete settings with default values + final Settings settings = Settings.builder() + .put(createFallbackSettings()) + .put(entry.getValue()).build(); + profileToChannelFactory.putIfAbsent(entry.getKey(), new ChannelFactory(settings, tcpReadHandler)); + bindServer(entry.getKey(), settings); + } + } + client = createClient(); + + for (SocketSelector selector : socketSelectors) { + if (selector.isRunning() == false) { + ThreadFactory threadFactory = daemonThreadFactory(this.settings, TRANSPORT_WORKER_THREAD_NAME_PREFIX); + threadFactory.newThread(selector::runLoop).start(); + selector.isRunningFuture().actionGet(); + } + } + + for (AcceptingSelector acceptor : acceptors) { + if (acceptor.isRunning() == false) { + ThreadFactory threadFactory = daemonThreadFactory(this.settings, TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX); + threadFactory.newThread(acceptor::runLoop).start(); + acceptor.isRunningFuture().actionGet(); + } + } + + super.doStart(); + success = true; + } catch (IOException e) { + throw new ElasticsearchException(e); + } finally { + if (success == false) { + doStop(); + } + } + } + + @Override + protected void stopInternal() { + NioShutdown nioShutdown = new NioShutdown(logger); + nioShutdown.orderlyShutdown(openChannels, client, acceptors, socketSelectors); + + profileToChannelFactory.clear(); + socketSelectors.clear(); + } + + protected SocketEventHandler getSocketEventHandler() { + return new SocketEventHandler(logger, this::exceptionCaught); + } + + final void exceptionCaught(NioSocketChannel channel, Throwable cause) { + final Throwable unwrapped = ExceptionsHelper.unwrap(cause, ElasticsearchException.class); + final Throwable t = unwrapped != null ? unwrapped : cause; + onException(channel, t instanceof Exception ? (Exception) t : new ElasticsearchException(t)); + } + + private Settings createFallbackSettings() { + Settings.Builder fallbackSettingsBuilder = Settings.builder(); + + List<String> fallbackBindHost = TransportSettings.BIND_HOST.get(settings); + if (fallbackBindHost.isEmpty() == false) { + fallbackSettingsBuilder.putArray("bind_host", fallbackBindHost); + } + + List<String> fallbackPublishHost = TransportSettings.PUBLISH_HOST.get(settings); + if (fallbackPublishHost.isEmpty() == false) { + fallbackSettingsBuilder.putArray("publish_host", fallbackPublishHost); + } + + boolean fallbackTcpNoDelay = settings.getAsBoolean("transport.nio.tcp_no_delay", + NetworkService.TcpSettings.TCP_NO_DELAY.get(settings)); + fallbackSettingsBuilder.put("tcp_no_delay", fallbackTcpNoDelay); + + boolean fallbackTcpKeepAlive = settings.getAsBoolean("transport.nio.tcp_keep_alive", + NetworkService.TcpSettings.TCP_KEEP_ALIVE.get(settings)); + fallbackSettingsBuilder.put("tcp_keep_alive", fallbackTcpKeepAlive); + + boolean fallbackReuseAddress = settings.getAsBoolean("transport.nio.reuse_address", + NetworkService.TcpSettings.TCP_REUSE_ADDRESS.get(settings)); + fallbackSettingsBuilder.put("reuse_address", fallbackReuseAddress); + + ByteSizeValue fallbackTcpSendBufferSize = settings.getAsBytesSize("transport.nio.tcp_send_buffer_size", + TCP_SEND_BUFFER_SIZE.get(settings)); + if (fallbackTcpSendBufferSize.getBytes() >= 0) { + fallbackSettingsBuilder.put("tcp_send_buffer_size", fallbackTcpSendBufferSize); + } + + ByteSizeValue fallbackTcpBufferSize = settings.getAsBytesSize("transport.nio.tcp_receive_buffer_size", + TCP_RECEIVE_BUFFER_SIZE.get(settings)); + if (fallbackTcpBufferSize.getBytes() >= 0) { + fallbackSettingsBuilder.put("tcp_receive_buffer_size", fallbackTcpBufferSize); + } + + return fallbackSettingsBuilder.build(); + } + + private NioClient createClient() { + Supplier<SocketSelector> selectorSupplier = new RoundRobinSelectorSupplier(socketSelectors); + ChannelFactory channelFactory = new ChannelFactory(settings, tcpReadHandler); + return new NioClient(logger, openChannels, selectorSupplier, defaultConnectionProfile.getConnectTimeout(), channelFactory); + } + + class ClientChannelCloseListener implements Consumer<NioChannel> { + + private final Consumer<NioChannel> consumer; + + private ClientChannelCloseListener(Consumer<NioChannel> consumer) { + this.consumer = consumer; + } + + @Override + public void accept(final NioChannel channel) { + consumer.accept(channel); + openChannels.channelClosed(channel); + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java new file mode 100644 index 0000000000..eea353a6c1 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java @@ -0,0 +1,120 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.transport.nio.channel.NioChannel; +import org.elasticsearch.transport.nio.channel.NioServerSocketChannel; +import org.elasticsearch.transport.nio.channel.NioSocketChannel; + +import java.util.HashSet; +import java.util.Map; +import java.util.concurrent.ConcurrentMap; + +import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap; + +public class OpenChannels implements Releasable { + + // TODO: Maybe set concurrency levels? + private final ConcurrentMap<NioSocketChannel, Long> openClientChannels = newConcurrentMap(); + private final ConcurrentMap<NioSocketChannel, Long> openAcceptedChannels = newConcurrentMap(); + private final ConcurrentMap<NioServerSocketChannel, Long> openServerChannels = newConcurrentMap(); + + private final Logger logger; + + public OpenChannels(Logger logger) { + this.logger = logger; + } + + public void serverChannelOpened(NioServerSocketChannel channel) { + boolean added = openServerChannels.putIfAbsent(channel, System.nanoTime()) == null; + if (added && logger.isTraceEnabled()) { + logger.trace("server channel opened: {}", channel); + } + } + + public long serverChannelsCount() { + return openServerChannels.size(); + } + + public void acceptedChannelOpened(NioSocketChannel channel) { + boolean added = openAcceptedChannels.putIfAbsent(channel, System.nanoTime()) == null; + if (added && logger.isTraceEnabled()) { + logger.trace("accepted channel opened: {}", channel); + } + } + + public HashSet<NioSocketChannel> getAcceptedChannels() { + return new HashSet<>(openAcceptedChannels.keySet()); + } + + public void clientChannelOpened(NioSocketChannel channel) { + boolean added = openClientChannels.putIfAbsent(channel, System.nanoTime()) == null; + if (added && logger.isTraceEnabled()) { + logger.trace("client channel opened: {}", channel); + } + } + + public void channelClosed(NioChannel channel) { + boolean removed; + if (channel instanceof NioServerSocketChannel) { + removed = openServerChannels.remove(channel) != null; + } else { + NioSocketChannel socketChannel = (NioSocketChannel) channel; + removed = openClientChannels.remove(socketChannel) != null; + if (removed == false) { + removed = openAcceptedChannels.remove(socketChannel) != null; + } + } + if (removed && logger.isTraceEnabled()) { + logger.trace("channel closed: {}", channel); + } + } + + public void closeServerChannels() { + for (NioServerSocketChannel channel : openServerChannels.keySet()) { + ensureClosedInternal(channel); + } + + openServerChannels.clear(); + } + + @Override + public void close() { + for (NioSocketChannel channel : openClientChannels.keySet()) { + ensureClosedInternal(channel); + } + for (NioSocketChannel channel : openAcceptedChannels.keySet()) { + ensureClosedInternal(channel); + } + + openClientChannels.clear(); + openAcceptedChannels.clear(); + } + + private void ensureClosedInternal(NioChannel channel) { + try { + channel.closeAsync().get(); + } catch (Exception e) { + logger.trace("exception while closing channels", e); + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/RoundRobinSelectorSupplier.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/RoundRobinSelectorSupplier.java new file mode 100644 index 0000000000..108242b1e0 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/RoundRobinSelectorSupplier.java @@ -0,0 +1,40 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +public class RoundRobinSelectorSupplier implements Supplier<SocketSelector> { + + private final ArrayList<SocketSelector> selectors; + private final int count; + private AtomicInteger counter = new AtomicInteger(0); + + public RoundRobinSelectorSupplier(ArrayList<SocketSelector> selectors) { + this.count = selectors.size(); + this.selectors = selectors; + } + + public SocketSelector get() { + return selectors.get(counter.getAndIncrement() % count); + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketEventHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketEventHandler.java new file mode 100644 index 0000000000..6905f7957b --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketEventHandler.java @@ -0,0 +1,154 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.transport.nio.channel.NioSocketChannel; +import org.elasticsearch.transport.nio.channel.SelectionKeyUtils; +import org.elasticsearch.transport.nio.channel.WriteContext; + +import java.io.IOException; +import java.util.function.BiConsumer; + +/** + * Event handler designed to handle events from non-server sockets + */ +public class SocketEventHandler extends EventHandler { + + private final BiConsumer<NioSocketChannel, Throwable> exceptionHandler; + private final Logger logger; + + public SocketEventHandler(Logger logger, BiConsumer<NioSocketChannel, Throwable> exceptionHandler) { + super(logger); + this.exceptionHandler = exceptionHandler; + this.logger = logger; + } + + /** + * This method is called when a NioSocketChannel is successfully registered. It should only be called + * once per channel. + * + * @param channel that was registered + */ + public void handleRegistration(NioSocketChannel channel) { + SelectionKeyUtils.setConnectAndReadInterested(channel); + } + + /** + * This method is called when an attempt to register a channel throws an exception. + * + * @param channel that was registered + * @param exception that occurred + */ + public void registrationException(NioSocketChannel channel, Exception exception) { + logger.trace("failed to register channel", exception); + exceptionCaught(channel, exception); + } + + /** + * This method is called when a NioSocketChannel is successfully connected. It should only be called + * once per channel. + * + * @param channel that was registered + */ + public void handleConnect(NioSocketChannel channel) { + SelectionKeyUtils.removeConnectInterested(channel); + } + + /** + * This method is called when an attempt to connect a channel throws an exception. + * + * @param channel that was connecting + * @param exception that occurred + */ + public void connectException(NioSocketChannel channel, Exception exception) { + logger.trace("failed to connect to channel", exception); + exceptionCaught(channel, exception); + + } + + /** + * This method is called when a channel signals it is ready for be read. All of the read logic should + * occur in this call. + * + * @param channel that can be read + */ + public void handleRead(NioSocketChannel channel) throws IOException { + int bytesRead = channel.getReadContext().read(); + if (bytesRead == -1) { + handleClose(channel); + } + } + + /** + * This method is called when an attempt to read from a channel throws an exception. + * + * @param channel that was being read + * @param exception that occurred + */ + public void readException(NioSocketChannel channel, Exception exception) { + logger.trace("failed to read from channel", exception); + exceptionCaught(channel, exception); + } + + /** + * This method is called when a channel signals it is ready to receive writes. All of the write logic + * should occur in this call. + * + * @param channel that can be read + */ + public void handleWrite(NioSocketChannel channel) throws IOException { + WriteContext channelContext = channel.getWriteContext(); + channelContext.flushChannel(); + if (channelContext.hasQueuedWriteOps()) { + SelectionKeyUtils.setWriteInterested(channel); + } else { + SelectionKeyUtils.removeWriteInterested(channel); + } + } + + /** + * This method is called when an attempt to write to a channel throws an exception. + * + * @param channel that was being written to + * @param exception that occurred + */ + public void writeException(NioSocketChannel channel, Exception exception) { + logger.trace("failed to write to channel", exception); + exceptionCaught(channel, exception); + } + + /** + * This method is called when handling an event from a channel fails due to an unexpected exception. + * An example would be if checking ready ops on a {@link java.nio.channels.SelectionKey} threw + * {@link java.nio.channels.CancelledKeyException}. + * + * @param channel that caused the exception + * @param exception that was thrown + */ + public void genericChannelException(NioSocketChannel channel, Exception exception) { + logger.trace("event handling failed", exception); + exceptionCaught(channel, exception); + } + + private void exceptionCaught(NioSocketChannel channel, Exception e) { + exceptionHandler.accept(channel, e); + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketSelector.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketSelector.java new file mode 100644 index 0000000000..24f68504d8 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketSelector.java @@ -0,0 +1,216 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.elasticsearch.transport.nio.channel.NioSocketChannel; +import org.elasticsearch.transport.nio.channel.SelectionKeyUtils; +import org.elasticsearch.transport.nio.channel.WriteContext; + +import java.io.IOException; +import java.nio.channels.CancelledKeyException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.ClosedSelectorException; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.util.Iterator; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; + +/** + * Selector implementation that handles {@link NioSocketChannel}. It's main piece of functionality is + * handling connect, read, and write events. + */ +public class SocketSelector extends ESSelector { + + private final ConcurrentLinkedQueue<NioSocketChannel> newChannels = new ConcurrentLinkedQueue<>(); + private final ConcurrentLinkedQueue<WriteOperation> queuedWrites = new ConcurrentLinkedQueue<>(); + private final SocketEventHandler eventHandler; + + public SocketSelector(SocketEventHandler eventHandler) throws IOException { + super(eventHandler); + this.eventHandler = eventHandler; + } + + public SocketSelector(SocketEventHandler eventHandler, Selector selector) throws IOException { + super(eventHandler, selector); + this.eventHandler = eventHandler; + } + + @Override + void doSelect(int timeout) throws IOException, ClosedSelectorException { + setUpNewChannels(); + handleQueuedWrites(); + + int ready = selector.select(timeout); + if (ready > 0) { + Set<SelectionKey> selectionKeys = selector.selectedKeys(); + processKeys(selectionKeys); + } + + } + + @Override + void cleanup() { + WriteOperation op; + while ((op = queuedWrites.poll()) != null) { + op.getListener().onFailure(new ClosedSelectorException()); + } + channelsToClose.addAll(newChannels); + channelsToClose.addAll(registeredChannels); + closePendingChannels(); + } + + /** + * Registers a NioSocketChannel to be handled by this selector. The channel will by queued and eventually + * registered next time through the event loop. + * @param nioSocketChannel the channel to register + */ + public void registerSocketChannel(NioSocketChannel nioSocketChannel) { + newChannels.offer(nioSocketChannel); + wakeup(); + } + + + /** + * Queues a write operation to be handled by the event loop. This can be called by any thread and is the + * api available for non-selector threads to schedule writes. + * + * @param writeOperation to be queued + */ + public void queueWrite(WriteOperation writeOperation) { + queuedWrites.offer(writeOperation); + if (isOpen() == false) { + boolean wasRemoved = queuedWrites.remove(writeOperation); + if (wasRemoved) { + writeOperation.getListener().onFailure(new ClosedSelectorException()); + } + } else { + wakeup(); + } + } + + /** + * Queues a write operation directly in a channel's buffer. Channel buffers are only safe to be accessed + * by the selector thread. As a result, this method should only be called by the selector thread. + * + * @param writeOperation to be queued in a channel's buffer + */ + public void queueWriteInChannelBuffer(WriteOperation writeOperation) { + assert isOnCurrentThread() : "Must be on selector thread"; + NioSocketChannel channel = writeOperation.getChannel(); + WriteContext context = channel.getWriteContext(); + try { + SelectionKeyUtils.setWriteInterested(channel); + context.queueWriteOperations(writeOperation); + } catch (Exception e) { + writeOperation.getListener().onFailure(e); + } + } + + private void processKeys(Set<SelectionKey> selectionKeys) { + Iterator<SelectionKey> keyIterator = selectionKeys.iterator(); + while (keyIterator.hasNext()) { + SelectionKey sk = keyIterator.next(); + keyIterator.remove(); + NioSocketChannel nioSocketChannel = (NioSocketChannel) sk.attachment(); + if (sk.isValid()) { + try { + int ops = sk.readyOps(); + if ((ops & SelectionKey.OP_CONNECT) != 0) { + attemptConnect(nioSocketChannel); + } + + if (nioSocketChannel.isConnectComplete()) { + if ((ops & SelectionKey.OP_WRITE) != 0) { + handleWrite(nioSocketChannel); + } + + if ((ops & SelectionKey.OP_READ) != 0) { + handleRead(nioSocketChannel); + } + } + } catch (CancelledKeyException e) { + eventHandler.genericChannelException(nioSocketChannel, e); + } + } else { + eventHandler.genericChannelException(nioSocketChannel, new CancelledKeyException()); + } + } + } + + + private void handleWrite(NioSocketChannel nioSocketChannel) { + try { + eventHandler.handleWrite(nioSocketChannel); + } catch (Exception e) { + eventHandler.writeException(nioSocketChannel, e); + } + } + + private void handleRead(NioSocketChannel nioSocketChannel) { + try { + eventHandler.handleRead(nioSocketChannel); + } catch (Exception e) { + eventHandler.readException(nioSocketChannel, e); + } + } + + private void handleQueuedWrites() { + WriteOperation writeOperation; + while ((writeOperation = queuedWrites.poll()) != null) { + if (writeOperation.getChannel().isWritable()) { + queueWriteInChannelBuffer(writeOperation); + } else { + writeOperation.getListener().onFailure(new ClosedChannelException()); + } + } + } + + private void setUpNewChannels() { + NioSocketChannel newChannel; + while ((newChannel = this.newChannels.poll()) != null) { + setupChannel(newChannel); + } + } + + private void setupChannel(NioSocketChannel newChannel) { + try { + if (newChannel.register(this)) { + registeredChannels.add(newChannel); + SelectionKey key = newChannel.getSelectionKey(); + key.attach(newChannel); + eventHandler.handleRegistration(newChannel); + attemptConnect(newChannel); + } + } catch (Exception e) { + eventHandler.registrationException(newChannel, e); + } + } + + private void attemptConnect(NioSocketChannel newChannel) { + try { + if (newChannel.finishConnect()) { + eventHandler.handleConnect(newChannel); + } + } catch (Exception e) { + eventHandler.connectException(newChannel, e); + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/TcpReadHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/TcpReadHandler.java new file mode 100644 index 0000000000..b41d87a0c0 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/TcpReadHandler.java @@ -0,0 +1,47 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.transport.nio.channel.NioSocketChannel; + +import java.io.IOException; + +public class TcpReadHandler { + + private final NioTransport transport; + + public TcpReadHandler(NioTransport transport) { + this.transport = transport; + } + + public void handleMessage(BytesReference reference, NioSocketChannel channel, String profileName, + int messageBytesLength) { + try { + transport.messageReceived(reference, channel, profileName, channel.getRemoteAddress(), messageBytesLength); + } catch (IOException e) { + handleException(channel, e); + } + } + + public void handleException(NioSocketChannel channel, Exception e) { + transport.exceptionCaught(channel, e); + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/WriteOperation.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/WriteOperation.java new file mode 100644 index 0000000000..67ed2447f6 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/WriteOperation.java @@ -0,0 +1,81 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefIterator; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.transport.nio.channel.NioChannel; +import org.elasticsearch.transport.nio.channel.NioSocketChannel; + +import java.io.IOException; +import java.util.ArrayList; + +public class WriteOperation { + + private final NioSocketChannel channel; + private final ActionListener<NioChannel> listener; + private final NetworkBytesReference[] references; + + public WriteOperation(NioSocketChannel channel, BytesReference bytesReference, ActionListener<NioChannel> listener) { + this.channel = channel; + this.listener = listener; + this.references = toArray(bytesReference); + } + + public NetworkBytesReference[] getByteReferences() { + return references; + } + + public ActionListener<NioChannel> getListener() { + return listener; + } + + public NioSocketChannel getChannel() { + return channel; + } + + public boolean isFullyFlushed() { + return references[references.length - 1].hasReadRemaining() == false; + } + + public int flush() throws IOException { + return channel.write(references); + } + + private static NetworkBytesReference[] toArray(BytesReference reference) { + BytesRefIterator byteRefIterator = reference.iterator(); + BytesRef r; + try { + // Most network messages are composed of three buffers + ArrayList<NetworkBytesReference> references = new ArrayList<>(3); + while ((r = byteRefIterator.next()) != null) { + references.add(NetworkBytesReference.wrap(new BytesArray(r), r.length, 0)); + } + return references.toArray(new NetworkBytesReference[references.size()]); + + } catch (IOException e) { + // this is really an error since we don't do IO in our bytesreferences + throw new AssertionError("won't happen", e); + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java new file mode 100644 index 0000000000..be8dbe3f46 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java @@ -0,0 +1,205 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.transport.nio.ESSelector; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.NetworkChannel; +import java.nio.channels.SelectableChannel; +import java.nio.channels.SelectionKey; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * This is a basic channel abstraction used by the {@link org.elasticsearch.transport.nio.NioTransport}. + * <p> + * A channel is open once it is constructed. The channel remains open and {@link #isOpen()} will return + * true until the channel is explicitly closed. + * <p> + * A channel lifecycle has four stages: + * <ol> + * <li>UNREGISTERED - When a channel is created and prior to it being registered with a selector. + * <li>REGISTERED - When a channel has been registered with a selector. This is the state of a channel that + * can perform normal operations. + * <li>CLOSING - When a channel has been marked for closed, but is not yet closed. {@link #isOpen()} will + * still return true. Normal operations should be rejected. The most common scenario for a channel to be + * CLOSING is when channel that was REGISTERED has {@link #closeAsync()} called, but the selector thread + * has not yet closed the channel. + * <li>CLOSED - The channel has been closed. + * </ol> + * + * @param <S> the type of raw channel this AbstractNioChannel uses + */ +public abstract class AbstractNioChannel<S extends SelectableChannel & NetworkChannel> implements NioChannel { + + static final int UNREGISTERED = 0; + static final int REGISTERED = 1; + static final int CLOSING = 2; + static final int CLOSED = 3; + + final S socketChannel; + final AtomicInteger state = new AtomicInteger(UNREGISTERED); + + private final InetSocketAddress localAddress; + private final String profile; + private final CloseFuture closeFuture = new CloseFuture(); + private volatile ESSelector selector; + private SelectionKey selectionKey; + + public AbstractNioChannel(String profile, S socketChannel) throws IOException { + this.profile = profile; + this.socketChannel = socketChannel; + this.localAddress = (InetSocketAddress) socketChannel.getLocalAddress(); + } + + @Override + public boolean isOpen() { + return closeFuture.isClosed() == false; + } + + @Override + public InetSocketAddress getLocalAddress() { + return localAddress; + } + + @Override + public String getProfile() { + return profile; + } + + /** + * Schedules a channel to be closed by the selector event loop with which it is registered. + * <p> + * If the current state is UNREGISTERED, the call will attempt to transition the state from UNREGISTERED + * to CLOSING. If this transition is successful, the channel can no longer be registered with an event + * loop and the channel will be synchronously closed in this method call. + * <p> + * If the channel is REGISTERED and the state can be transitioned to CLOSING, the close operation will + * be scheduled with the event loop. + * <p> + * If the channel is CLOSING or CLOSED, nothing will be done. + * + * @return future that will be complete when the channel is closed + */ + @Override + public CloseFuture closeAsync() { + if (selector != null && selector.isOnCurrentThread()) { + closeFromSelector(); + return closeFuture; + } + + for (; ; ) { + int state = this.state.get(); + if (state == UNREGISTERED && this.state.compareAndSet(UNREGISTERED, CLOSING)) { + close0(); + break; + } else if (state == REGISTERED && this.state.compareAndSet(REGISTERED, CLOSING)) { + selector.queueChannelClose(this); + break; + } else if (state == CLOSING || state == CLOSED) { + break; + } + } + return closeFuture; + } + + /** + * Closes the channel synchronously. This method should only be called from the selector thread. + * <p> + * Once this method returns, the channel will be closed. + */ + @Override + public void closeFromSelector() { + // This will not exit the loop until this thread or someone else has set the state to CLOSED. + // Whichever thread succeeds in setting the state to CLOSED will close the raw channel. + for (; ; ) { + int state = this.state.get(); + if (state < CLOSING && this.state.compareAndSet(state, CLOSING)) { + close0(); + } else if (state == CLOSING) { + close0(); + } else if (state == CLOSED) { + break; + } + } + } + + /** + * This method attempts to registered a channel with a selector. If method returns true the channel was + * successfully registered. If it returns false, the registration failed. The reason a registered might + * fail is if something else closed this channel. + * + * @param selector to register the channel + * @return if the channel was successfully registered + * @throws ClosedChannelException if the raw channel was closed + */ + @Override + public boolean register(ESSelector selector) throws ClosedChannelException { + if (markRegistered(selector)) { + setSelectionKey(socketChannel.register(selector.rawSelector(), 0)); + return true; + } else { + return false; + } + } + + @Override + public ESSelector getSelector() { + return selector; + } + + @Override + public SelectionKey getSelectionKey() { + return selectionKey; + } + + @Override + public CloseFuture getCloseFuture() { + return closeFuture; + } + + @Override + public S getRawChannel() { + return socketChannel; + } + + // Package visibility for testing + void setSelectionKey(SelectionKey selectionKey) { + this.selectionKey = selectionKey; + } + + boolean markRegistered(ESSelector selector) { + this.selector = selector; + return state.compareAndSet(UNREGISTERED, REGISTERED); + } + + private void close0() { + if (this.state.compareAndSet(CLOSING, CLOSED)) { + try { + socketChannel.close(); + closeFuture.channelClosed(this); + } catch (IOException e) { + closeFuture.channelCloseThrewException(this, e); + } + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java new file mode 100644 index 0000000000..84c36d4110 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java @@ -0,0 +1,105 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.mocksocket.PrivilegedSocketAccess; +import org.elasticsearch.transport.TcpTransport; +import org.elasticsearch.transport.nio.TcpReadHandler; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; + +public class ChannelFactory { + + private final boolean tcpNoDelay; + private final boolean tcpKeepAlive; + private final boolean tcpReusedAddress; + private final int tcpSendBufferSize; + private final int tcpReceiveBufferSize; + private final TcpReadHandler handler; + + public ChannelFactory(Settings settings, TcpReadHandler handler) { + tcpNoDelay = TcpTransport.TCP_NO_DELAY.get(settings); + tcpKeepAlive = TcpTransport.TCP_KEEP_ALIVE.get(settings); + tcpReusedAddress = TcpTransport.TCP_REUSE_ADDRESS.get(settings); + tcpSendBufferSize = Math.toIntExact(TcpTransport.TCP_SEND_BUFFER_SIZE.get(settings).getBytes()); + tcpReceiveBufferSize = Math.toIntExact(TcpTransport.TCP_RECEIVE_BUFFER_SIZE.get(settings).getBytes()); + this.handler = handler; + } + + public NioSocketChannel openNioChannel(InetSocketAddress remoteAddress) throws IOException { + SocketChannel rawChannel = SocketChannel.open(); + configureSocketChannel(rawChannel); + PrivilegedSocketAccess.connect(rawChannel, remoteAddress); + NioSocketChannel channel = new NioSocketChannel(NioChannel.CLIENT, rawChannel); + channel.setContexts(new TcpReadContext(channel, handler), new TcpWriteContext(channel)); + return channel; + } + + public NioSocketChannel acceptNioChannel(NioServerSocketChannel serverChannel) throws IOException { + ServerSocketChannel serverSocketChannel = serverChannel.getRawChannel(); + SocketChannel rawChannel = PrivilegedSocketAccess.accept(serverSocketChannel); + configureSocketChannel(rawChannel); + NioSocketChannel channel = new NioSocketChannel(serverChannel.getProfile(), rawChannel); + channel.setContexts(new TcpReadContext(channel, handler), new TcpWriteContext(channel)); + return channel; + } + + public NioServerSocketChannel openNioServerSocketChannel(String profileName, InetSocketAddress address) + throws IOException { + ServerSocketChannel socketChannel = ServerSocketChannel.open(); + socketChannel.configureBlocking(false); + ServerSocket socket = socketChannel.socket(); + socket.setReuseAddress(tcpReusedAddress); + socketChannel.bind(address); + return new NioServerSocketChannel(profileName, socketChannel, this); + } + + private void configureSocketChannel(SocketChannel channel) throws IOException { + channel.configureBlocking(false); + Socket socket = channel.socket(); + socket.setTcpNoDelay(tcpNoDelay); + socket.setKeepAlive(tcpKeepAlive); + socket.setReuseAddress(tcpReusedAddress); + if (tcpSendBufferSize > 0) { + socket.setSendBufferSize(tcpSendBufferSize); + } + if (tcpReceiveBufferSize > 0) { + socket.setSendBufferSize(tcpReceiveBufferSize); + } + } + + private static <T> T getSocketChannel(CheckedSupplier<T, IOException> supplier) throws IOException { + try { + return AccessController.doPrivileged((PrivilegedExceptionAction<T>) supplier::get); + } catch (PrivilegedActionException e) { + throw (IOException) e.getCause(); + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/CloseFuture.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/CloseFuture.java new file mode 100644 index 0000000000..e41632174a --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/CloseFuture.java @@ -0,0 +1,104 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.common.util.concurrent.BaseFuture; + +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.function.Consumer; + +public class CloseFuture extends BaseFuture<NioChannel> { + + private final SetOnce<Consumer<NioChannel>> listener = new SetOnce<>(); + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + throw new UnsupportedOperationException("Cannot cancel close future"); + } + + public void awaitClose() throws InterruptedException, IOException { + try { + super.get(); + } catch (ExecutionException e) { + throw (IOException) e.getCause(); + } + } + + public void awaitClose(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException, IOException { + try { + super.get(timeout, unit); + } catch (ExecutionException e) { + throw (IOException) e.getCause(); + } + } + + public IOException getCloseException() { + if (isDone()) { + try { + super.get(0, TimeUnit.NANOSECONDS); + return null; + } catch (ExecutionException e) { + // We only make a setter for IOException + return (IOException) e.getCause(); + } catch (TimeoutException e) { + return null; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return null; + } + } else { + return null; + } + } + + public boolean isClosed() { + return super.isDone(); + } + + public void setListener(Consumer<NioChannel> listener) { + this.listener.set(listener); + } + + void channelClosed(NioChannel channel) { + boolean set = set(channel); + if (set) { + Consumer<NioChannel> listener = this.listener.get(); + if (listener != null) { + listener.accept(channel); + } + } + } + + + void channelCloseThrewException(NioChannel channel, IOException ex) { + boolean set = setException(ex); + if (set) { + Consumer<NioChannel> listener = this.listener.get(); + if (listener != null) { + listener.accept(channel); + } + } + } + +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ConnectFuture.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ConnectFuture.java new file mode 100644 index 0000000000..4bc1ca6043 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ConnectFuture.java @@ -0,0 +1,94 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.common.util.concurrent.BaseFuture; + +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ConnectFuture extends BaseFuture<NioSocketChannel> { + + public boolean awaitConnectionComplete(long timeout, TimeUnit unit) throws InterruptedException { + try { + super.get(timeout, unit); + return true; + } catch (ExecutionException | TimeoutException e) { + return false; + } + } + + public Exception getException() { + if (isDone()) { + try { + // Get should always return without blocking as we already checked 'isDone' + // We are calling 'get' here in order to throw the ExecutionException + super.get(); + return null; + } catch (ExecutionException e) { + // We only make a public setters for IOException or RuntimeException + return (Exception) e.getCause(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return null; + } + } else { + return null; + } + } + + public boolean isConnectComplete() { + return getChannel() != null; + } + + public boolean connectFailed() { + return getException() != null; + } + + void setConnectionComplete(NioSocketChannel channel) { + set(channel); + } + + void setConnectionFailed(IOException e) { + setException(e); + } + + void setConnectionFailed(RuntimeException e) { + setException(e); + } + + private NioSocketChannel getChannel() { + if (isDone()) { + try { + // Get should always return without blocking as we already checked 'isDone' + return super.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return null; + } catch (ExecutionException e) { + return null; + } + } else { + return null; + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java new file mode 100644 index 0000000000..281e296391 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java @@ -0,0 +1,52 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.transport.nio.ESSelector; + +import java.net.InetSocketAddress; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.NetworkChannel; +import java.nio.channels.SelectionKey; + +public interface NioChannel { + + String CLIENT = "client-socket"; + + boolean isOpen(); + + InetSocketAddress getLocalAddress(); + + String getProfile(); + + CloseFuture closeAsync(); + + void closeFromSelector(); + + boolean register(ESSelector selector) throws ClosedChannelException; + + ESSelector getSelector(); + + SelectionKey getSelectionKey(); + + CloseFuture getCloseFuture(); + + NetworkChannel getRawChannel(); +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java new file mode 100644 index 0000000000..bc8d423a45 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java @@ -0,0 +1,37 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import java.io.IOException; +import java.nio.channels.ServerSocketChannel; + +public class NioServerSocketChannel extends AbstractNioChannel<ServerSocketChannel> { + + private final ChannelFactory channelFactory; + + public NioServerSocketChannel(String profile, ServerSocketChannel socketChannel, ChannelFactory channelFactory) throws IOException { + super(profile, socketChannel); + this.channelFactory = channelFactory; + } + + public ChannelFactory getChannelFactory() { + return channelFactory; + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioSocketChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioSocketChannel.java new file mode 100644 index 0000000000..62404403de --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioSocketChannel.java @@ -0,0 +1,189 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.transport.nio.NetworkBytesReference; +import org.elasticsearch.transport.nio.ESSelector; +import org.elasticsearch.transport.nio.SocketSelector; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.SocketChannel; +import java.util.Arrays; + +public class NioSocketChannel extends AbstractNioChannel<SocketChannel> { + + private final InetSocketAddress remoteAddress; + private final ConnectFuture connectFuture = new ConnectFuture(); + private volatile SocketSelector socketSelector; + private WriteContext writeContext; + private ReadContext readContext; + + public NioSocketChannel(String profile, SocketChannel socketChannel) throws IOException { + super(profile, socketChannel); + this.remoteAddress = (InetSocketAddress) socketChannel.getRemoteAddress(); + } + + @Override + public CloseFuture closeAsync() { + clearQueuedWrites(); + + return super.closeAsync(); + } + + @Override + public void closeFromSelector() { + // Even if the channel has already been closed we will clear any pending write operations just in case + clearQueuedWrites(); + + super.closeFromSelector(); + } + + @Override + public SocketSelector getSelector() { + return socketSelector; + } + + @Override + boolean markRegistered(ESSelector selector) { + this.socketSelector = (SocketSelector) selector; + return super.markRegistered(selector); + } + + public int write(NetworkBytesReference[] references) throws IOException { + int written; + if (references.length == 1) { + written = socketChannel.write(references[0].getReadByteBuffer()); + } else { + ByteBuffer[] buffers = new ByteBuffer[references.length]; + for (int i = 0; i < references.length; ++i) { + buffers[i] = references[i].getReadByteBuffer(); + } + written = (int) socketChannel.write(buffers); + } + if (written <= 0) { + return written; + } + + NetworkBytesReference.vectorizedIncrementReadIndexes(Arrays.asList(references), written); + + return written; + } + + public int read(NetworkBytesReference reference) throws IOException { + int bytesRead = socketChannel.read(reference.getWriteByteBuffer()); + + if (bytesRead == -1) { + return bytesRead; + } + + reference.incrementWrite(bytesRead); + return bytesRead; + } + + public void setContexts(ReadContext readContext, WriteContext writeContext) { + this.readContext = readContext; + this.writeContext = writeContext; + } + + public WriteContext getWriteContext() { + return writeContext; + } + + public ReadContext getReadContext() { + return readContext; + } + + public InetSocketAddress getRemoteAddress() { + return remoteAddress; + } + + public boolean isConnectComplete() { + return connectFuture.isConnectComplete(); + } + + public boolean isWritable() { + return state.get() == REGISTERED; + } + + public boolean isReadable() { + return state.get() == REGISTERED; + } + + /** + * This method will attempt to complete the connection process for this channel. It should be called for + * new channels or for a channel that has produced a OP_CONNECT event. If this method returns true then + * the connection is complete and the channel is ready for reads and writes. If it returns false, the + * channel is not yet connected and this method should be called again when a OP_CONNECT event is + * received. + * + * @return true if the connection process is complete + * @throws IOException if an I/O error occurs + */ + public boolean finishConnect() throws IOException { + if (connectFuture.isConnectComplete()) { + return true; + } else if (connectFuture.connectFailed()) { + Exception exception = connectFuture.getException(); + if (exception instanceof IOException) { + throw (IOException) exception; + } else { + throw (RuntimeException) exception; + } + } + + boolean isConnected = socketChannel.isConnected(); + if (isConnected == false) { + isConnected = internalFinish(); + } + if (isConnected) { + connectFuture.setConnectionComplete(this); + } + return isConnected; + } + + public ConnectFuture getConnectFuture() { + return connectFuture; + } + + private boolean internalFinish() throws IOException { + try { + return socketChannel.finishConnect(); + } catch (IOException e) { + connectFuture.setConnectionFailed(e); + throw e; + } catch (RuntimeException e) { + connectFuture.setConnectionFailed(e); + throw e; + } + } + + private void clearQueuedWrites() { + // Even if the channel has already been closed we will clear any pending write operations just in case + if (state.get() > UNREGISTERED) { + SocketSelector selector = getSelector(); + if (selector != null && selector.isOnCurrentThread() && writeContext.hasQueuedWriteOps()) { + writeContext.clearQueuedWriteOps(new ClosedChannelException()); + } + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ReadContext.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ReadContext.java new file mode 100644 index 0000000000..9d2919b192 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ReadContext.java @@ -0,0 +1,28 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import java.io.IOException; + +public interface ReadContext { + + int read() throws IOException; + +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/SelectionKeyUtils.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/SelectionKeyUtils.java new file mode 100644 index 0000000000..b0cf555206 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/SelectionKeyUtils.java @@ -0,0 +1,53 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import java.nio.channels.CancelledKeyException; +import java.nio.channels.SelectionKey; + +public final class SelectionKeyUtils { + + private SelectionKeyUtils() {} + + public static void setWriteInterested(NioChannel channel) throws CancelledKeyException { + SelectionKey selectionKey = channel.getSelectionKey(); + selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_WRITE); + } + + public static void removeWriteInterested(NioChannel channel) throws CancelledKeyException { + SelectionKey selectionKey = channel.getSelectionKey(); + selectionKey.interestOps(selectionKey.interestOps() & ~SelectionKey.OP_WRITE); + } + + public static void setConnectAndReadInterested(NioChannel channel) throws CancelledKeyException { + SelectionKey selectionKey = channel.getSelectionKey(); + selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_CONNECT | SelectionKey.OP_READ); + } + + public static void removeConnectInterested(NioChannel channel) throws CancelledKeyException { + SelectionKey selectionKey = channel.getSelectionKey(); + selectionKey.interestOps(selectionKey.interestOps() & ~SelectionKey.OP_CONNECT); + } + + public static void setAcceptInterested(NioServerSocketChannel channel) { + SelectionKey selectionKey = channel.getSelectionKey(); + selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_ACCEPT); + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpFrameDecoder.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpFrameDecoder.java new file mode 100644 index 0000000000..356af44c5b --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpFrameDecoder.java @@ -0,0 +1,118 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.monitor.jvm.JvmInfo; +import org.elasticsearch.transport.TcpHeader; +import org.elasticsearch.transport.TcpTransport; + +import java.io.IOException; +import java.io.StreamCorruptedException; + +public class TcpFrameDecoder { + + private static final long NINETY_PER_HEAP_SIZE = (long) (JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() * 0.9); + private static final int HEADER_SIZE = TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE; + + private int expectedMessageLength = -1; + + public BytesReference decode(BytesReference bytesReference, int currentBufferSize) throws IOException { + if (currentBufferSize >= 6) { + int messageLength = readHeaderBuffer(bytesReference); + int totalLength = messageLength + HEADER_SIZE; + if (totalLength > currentBufferSize) { + expectedMessageLength = totalLength; + return null; + } else if (totalLength == bytesReference.length()) { + expectedMessageLength = -1; + return bytesReference; + } else { + expectedMessageLength = -1; + return bytesReference.slice(0, totalLength); + } + } else { + return null; + } + } + + public int expectedMessageLength() { + return expectedMessageLength; + } + + private int readHeaderBuffer(BytesReference headerBuffer) throws IOException { + if (headerBuffer.get(0) != 'E' || headerBuffer.get(1) != 'S') { + if (appearsToBeHTTP(headerBuffer)) { + throw new TcpTransport.HttpOnTransportException("This is not a HTTP port"); + } + + throw new StreamCorruptedException("invalid internal transport message format, got (" + + Integer.toHexString(headerBuffer.get(0) & 0xFF) + "," + + Integer.toHexString(headerBuffer.get(1) & 0xFF) + "," + + Integer.toHexString(headerBuffer.get(2) & 0xFF) + "," + + Integer.toHexString(headerBuffer.get(3) & 0xFF) + ")"); + } + final int messageLength; + try (StreamInput input = headerBuffer.streamInput()) { + input.skip(TcpHeader.MARKER_BYTES_SIZE); + messageLength = input.readInt(); + } + + if (messageLength == -1) { + // This is a ping + return 0; + } + + if (messageLength <= 0) { + throw new StreamCorruptedException("invalid data length: " + messageLength); + } + + if (messageLength > NINETY_PER_HEAP_SIZE) { + throw new IllegalArgumentException("transport content length received [" + new ByteSizeValue(messageLength) + "] exceeded [" + + new ByteSizeValue(NINETY_PER_HEAP_SIZE) + "]"); + } + + return messageLength; + } + + private static boolean appearsToBeHTTP(BytesReference headerBuffer) { + return bufferStartsWith(headerBuffer, "GET") || + bufferStartsWith(headerBuffer, "POST") || + bufferStartsWith(headerBuffer, "PUT") || + bufferStartsWith(headerBuffer, "HEAD") || + bufferStartsWith(headerBuffer, "DELETE") || + // TODO: Actually 'OPTIONS'. But that does not currently fit in 6 bytes + bufferStartsWith(headerBuffer, "OPTION") || + bufferStartsWith(headerBuffer, "PATCH") || + bufferStartsWith(headerBuffer, "TRACE"); + } + + private static boolean bufferStartsWith(BytesReference buffer, String method) { + char[] chars = method.toCharArray(); + for (int i = 0; i < chars.length; i++) { + if (buffer.get(i) != chars[i]) { + return false; + } + } + return true; + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpReadContext.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpReadContext.java new file mode 100644 index 0000000000..c332adbd31 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpReadContext.java @@ -0,0 +1,109 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.CompositeBytesReference; +import org.elasticsearch.transport.nio.NetworkBytesReference; +import org.elasticsearch.transport.nio.TcpReadHandler; + +import java.io.IOException; +import java.util.Iterator; +import java.util.LinkedList; + +public class TcpReadContext implements ReadContext { + + private static final int DEFAULT_READ_LENGTH = 1 << 14; + + private final TcpReadHandler handler; + private final NioSocketChannel channel; + private final TcpFrameDecoder frameDecoder; + private final LinkedList<NetworkBytesReference> references = new LinkedList<>(); + private int rawBytesCount = 0; + + public TcpReadContext(NioSocketChannel channel, TcpReadHandler handler) { + this(channel, handler, new TcpFrameDecoder()); + } + + public TcpReadContext(NioSocketChannel channel, TcpReadHandler handler, TcpFrameDecoder frameDecoder) { + this.handler = handler; + this.channel = channel; + this.frameDecoder = frameDecoder; + this.references.add(NetworkBytesReference.wrap(new BytesArray(new byte[DEFAULT_READ_LENGTH]))); + } + + @Override + public int read() throws IOException { + NetworkBytesReference last = references.peekLast(); + if (last == null || last.hasWriteRemaining() == false) { + this.references.add(NetworkBytesReference.wrap(new BytesArray(new byte[DEFAULT_READ_LENGTH]))); + } + + int bytesRead = channel.read(references.getLast()); + + if (bytesRead == -1) { + return bytesRead; + } + + rawBytesCount += bytesRead; + + BytesReference message; + + while ((message = frameDecoder.decode(createCompositeBuffer(), rawBytesCount)) != null) { + int messageLengthWithHeader = message.length(); + NetworkBytesReference.vectorizedIncrementReadIndexes(references, messageLengthWithHeader); + trimDecodedMessages(messageLengthWithHeader); + rawBytesCount -= messageLengthWithHeader; + + try { + BytesReference messageWithoutHeader = message.slice(6, message.length() - 6); + handler.handleMessage(messageWithoutHeader, channel, channel.getProfile(), messageWithoutHeader.length()); + } catch (Exception e) { + handler.handleException(channel, e); + } + } + + return bytesRead; + } + + private CompositeBytesReference createCompositeBuffer() { + return new CompositeBytesReference(references.toArray(new BytesReference[references.size()])); + } + + private void trimDecodedMessages(int bytesToTrim) { + while (bytesToTrim != 0) { + NetworkBytesReference ref = references.getFirst(); + int readIndex = ref.getReadIndex(); + bytesToTrim -= readIndex; + if (readIndex == ref.length()) { + references.removeFirst(); + } else { + assert bytesToTrim == 0; + if (readIndex != 0) { + references.removeFirst(); + NetworkBytesReference slicedRef = ref.slice(readIndex, ref.length() - readIndex); + references.addFirst(slicedRef); + } + } + + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpWriteContext.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpWriteContext.java new file mode 100644 index 0000000000..a332ea89a3 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpWriteContext.java @@ -0,0 +1,108 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.transport.nio.SocketSelector; +import org.elasticsearch.transport.nio.WriteOperation; + +import java.io.IOException; +import java.nio.channels.ClosedChannelException; +import java.util.LinkedList; + +public class TcpWriteContext implements WriteContext { + + private final NioSocketChannel channel; + private final LinkedList<WriteOperation> queued = new LinkedList<>(); + + public TcpWriteContext(NioSocketChannel channel) { + this.channel = channel; + } + + @Override + public void sendMessage(BytesReference reference, ActionListener<NioChannel> listener) { + if (channel.isWritable() == false) { + listener.onFailure(new ClosedChannelException()); + return; + } + + WriteOperation writeOperation = new WriteOperation(channel, reference, listener); + SocketSelector selector = channel.getSelector(); + if (selector.isOnCurrentThread() == false) { + selector.queueWrite(writeOperation); + return; + } + + // TODO: Eval if we will allow writes from sendMessage + selector.queueWriteInChannelBuffer(writeOperation); + } + + @Override + public void queueWriteOperations(WriteOperation writeOperation) { + assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to queue writes"; + queued.add(writeOperation); + } + + @Override + public void flushChannel() throws IOException { + assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to flush writes"; + int ops = queued.size(); + if (ops == 1) { + singleFlush(queued.pop()); + } else if (ops > 1) { + multiFlush(); + } + } + + @Override + public boolean hasQueuedWriteOps() { + assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to access queued writes"; + return queued.isEmpty() == false; + } + + @Override + public void clearQueuedWriteOps(Exception e) { + assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to clear queued writes"; + for (WriteOperation op : queued) { + op.getListener().onFailure(e); + } + queued.clear(); + } + + private void singleFlush(WriteOperation headOp) throws IOException { + headOp.flush(); + + if (headOp.isFullyFlushed()) { + headOp.getListener().onResponse(channel); + } else { + queued.push(headOp); + } + } + + private void multiFlush() throws IOException { + boolean lastOpCompleted = true; + while (lastOpCompleted && queued.isEmpty() == false) { + WriteOperation op = queued.pop(); + singleFlush(op); + lastOpCompleted = op.isFullyFlushed(); + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/WriteContext.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/WriteContext.java new file mode 100644 index 0000000000..1a14d279dd --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/WriteContext.java @@ -0,0 +1,40 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.transport.nio.WriteOperation; + +import java.io.IOException; + +public interface WriteContext { + + void sendMessage(BytesReference reference, ActionListener<NioChannel> listener); + + void queueWriteOperations(WriteOperation writeOperation); + + void flushChannel() throws IOException; + + boolean hasQueuedWriteOps(); + + void clearQueuedWriteOps(Exception e); + +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptingSelectorTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptingSelectorTests.java new file mode 100644 index 0000000000..e3cf9b0a7e --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptingSelectorTests.java @@ -0,0 +1,113 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.nio.channel.NioChannel; +import org.elasticsearch.transport.nio.channel.NioServerSocketChannel; +import org.elasticsearch.transport.nio.utils.TestSelectionKey; +import org.junit.Before; + +import java.io.IOException; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.security.PrivilegedActionException; +import java.util.HashSet; +import java.util.Set; + +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class AcceptingSelectorTests extends ESTestCase { + + private AcceptingSelector selector; + private NioServerSocketChannel serverChannel; + private AcceptorEventHandler eventHandler; + private TestSelectionKey selectionKey; + private HashSet<SelectionKey> keySet = new HashSet<>(); + + @Before + public void setUp() throws Exception { + super.setUp(); + + eventHandler = mock(AcceptorEventHandler.class); + serverChannel = mock(NioServerSocketChannel.class); + + Selector rawSelector = mock(Selector.class); + selector = new AcceptingSelector(eventHandler, rawSelector); + this.selector.setThread(); + + selectionKey = new TestSelectionKey(0); + selectionKey.attach(serverChannel); + when(serverChannel.getSelectionKey()).thenReturn(selectionKey); + when(rawSelector.selectedKeys()).thenReturn(keySet); + when(rawSelector.select(0)).thenReturn(1); + } + + public void testRegisteredChannel() throws IOException, PrivilegedActionException { + selector.registerServerChannel(serverChannel); + + when(serverChannel.register(selector)).thenReturn(true); + + selector.doSelect(0); + + verify(eventHandler).serverChannelRegistered(serverChannel); + Set<NioChannel> registeredChannels = selector.getRegisteredChannels(); + assertEquals(1, registeredChannels.size()); + assertTrue(registeredChannels.contains(serverChannel)); + } + + public void testAcceptEvent() throws IOException { + selectionKey.setReadyOps(SelectionKey.OP_ACCEPT); + keySet.add(selectionKey); + + selector.doSelect(0); + + verify(eventHandler).acceptChannel(serverChannel); + } + + public void testAcceptException() throws IOException { + selectionKey.setReadyOps(SelectionKey.OP_ACCEPT); + keySet.add(selectionKey); + IOException ioException = new IOException(); + + doThrow(ioException).when(eventHandler).acceptChannel(serverChannel); + + selector.doSelect(0); + + verify(eventHandler).acceptException(serverChannel, ioException); + } + + public void testCleanup() throws IOException { + selector.registerServerChannel(serverChannel); + + when(serverChannel.register(selector)).thenReturn(true); + + selector.doSelect(0); + + assertEquals(1, selector.getRegisteredChannels().size()); + + selector.cleanup(); + + verify(eventHandler).handleClose(serverChannel); + } +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java new file mode 100644 index 0000000000..fc6829d594 --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java @@ -0,0 +1,99 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.nio.channel.ChannelFactory; +import org.elasticsearch.transport.nio.channel.DoNotRegisterServerChannel; +import org.elasticsearch.transport.nio.channel.NioServerSocketChannel; +import org.elasticsearch.transport.nio.channel.NioSocketChannel; +import org.junit.Before; + +import java.io.IOException; +import java.nio.channels.SelectionKey; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class AcceptorEventHandlerTests extends ESTestCase { + + private AcceptorEventHandler handler; + private SocketSelector socketSelector; + private ChannelFactory channelFactory; + private OpenChannels openChannels; + private NioServerSocketChannel channel; + + @Before + public void setUpHandler() throws IOException { + channelFactory = mock(ChannelFactory.class); + socketSelector = mock(SocketSelector.class); + openChannels = new OpenChannels(logger); + ArrayList<SocketSelector> selectors = new ArrayList<>(); + selectors.add(socketSelector); + handler = new AcceptorEventHandler(logger, openChannels, new RoundRobinSelectorSupplier(selectors)); + + channel = new DoNotRegisterServerChannel("", mock(ServerSocketChannel.class), channelFactory); + channel.register(mock(ESSelector.class)); + } + + public void testHandleRegisterAdjustsOpenChannels() { + assertEquals(0, openChannels.serverChannelsCount()); + + handler.serverChannelRegistered(channel); + + assertEquals(1, openChannels.serverChannelsCount()); + } + + public void testHandleRegisterSetsOP_ACCEPTInterest() { + assertEquals(0, channel.getSelectionKey().interestOps()); + + handler.serverChannelRegistered(channel); + + assertEquals(SelectionKey.OP_ACCEPT, channel.getSelectionKey().interestOps()); + } + + public void testHandleAcceptRegistersWithSelector() throws IOException { + NioSocketChannel childChannel = new NioSocketChannel("", mock(SocketChannel.class)); + when(channelFactory.acceptNioChannel(channel)).thenReturn(childChannel); + + handler.acceptChannel(channel); + + verify(socketSelector).registerSocketChannel(childChannel); + } + + public void testHandleAcceptAddsToOpenChannelsAndAddsCloseListenerToRemove() throws IOException { + NioSocketChannel childChannel = new NioSocketChannel("", SocketChannel.open()); + when(channelFactory.acceptNioChannel(channel)).thenReturn(childChannel); + + handler.acceptChannel(channel); + + assertEquals(new HashSet<>(Arrays.asList(childChannel)), openChannels.getAcceptedChannels()); + + childChannel.closeAsync(); + + assertEquals(new HashSet<>(), openChannels.getAcceptedChannels()); + } +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/ByteBufferReferenceTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/ByteBufferReferenceTests.java new file mode 100644 index 0000000000..335e3d2f77 --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/ByteBufferReferenceTests.java @@ -0,0 +1,155 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.test.ESTestCase; + +import java.nio.ByteBuffer; + +public class ByteBufferReferenceTests extends ESTestCase { + + private NetworkBytesReference buffer; + + public void testBasicGetByte() { + byte[] bytes = new byte[10]; + initializeBytes(bytes); + buffer = NetworkBytesReference.wrap(new BytesArray(bytes)); + + assertEquals(10, buffer.length()); + for (int i = 0 ; i < bytes.length; ++i) { + assertEquals(i, buffer.get(i)); + } + } + + public void testBasicGetByteWithOffset() { + byte[] bytes = new byte[10]; + initializeBytes(bytes); + buffer = NetworkBytesReference.wrap(new BytesArray(bytes, 2, 8)); + + assertEquals(8, buffer.length()); + for (int i = 2 ; i < bytes.length; ++i) { + assertEquals(i, buffer.get(i - 2)); + } + } + + public void testBasicGetByteWithOffsetAndLimit() { + byte[] bytes = new byte[10]; + initializeBytes(bytes); + buffer = NetworkBytesReference.wrap(new BytesArray(bytes, 2, 6)); + + assertEquals(6, buffer.length()); + for (int i = 2 ; i < bytes.length - 2; ++i) { + assertEquals(i, buffer.get(i - 2)); + } + } + + public void testGetWriteBufferRespectsWriteIndex() { + byte[] bytes = new byte[10]; + + buffer = NetworkBytesReference.wrap(new BytesArray(bytes, 2, 8)); + + ByteBuffer writeByteBuffer = buffer.getWriteByteBuffer(); + + assertEquals(2, writeByteBuffer.position()); + assertEquals(10, writeByteBuffer.limit()); + + buffer.incrementWrite(2); + + writeByteBuffer = buffer.getWriteByteBuffer(); + assertEquals(4, writeByteBuffer.position()); + assertEquals(10, writeByteBuffer.limit()); + } + + public void testGetReadBufferRespectsReadIndex() { + byte[] bytes = new byte[10]; + + buffer = NetworkBytesReference.wrap(new BytesArray(bytes, 3, 6), 6, 0); + + ByteBuffer readByteBuffer = buffer.getReadByteBuffer(); + + assertEquals(3, readByteBuffer.position()); + assertEquals(9, readByteBuffer.limit()); + + buffer.incrementRead(2); + + readByteBuffer = buffer.getReadByteBuffer(); + assertEquals(5, readByteBuffer.position()); + assertEquals(9, readByteBuffer.limit()); + } + + public void testWriteAndReadRemaining() { + byte[] bytes = new byte[10]; + + buffer = NetworkBytesReference.wrap(new BytesArray(bytes, 2, 8)); + + assertEquals(0, buffer.getReadRemaining()); + assertEquals(8, buffer.getWriteRemaining()); + + buffer.incrementWrite(3); + buffer.incrementRead(2); + + assertEquals(1, buffer.getReadRemaining()); + assertEquals(5, buffer.getWriteRemaining()); + } + + public void testBasicSlice() { + byte[] bytes = new byte[20]; + initializeBytes(bytes); + + buffer = NetworkBytesReference.wrap(new BytesArray(bytes, 2, 18)); + + NetworkBytesReference slice = buffer.slice(4, 14); + + assertEquals(14, slice.length()); + assertEquals(0, slice.getReadIndex()); + assertEquals(0, slice.getWriteIndex()); + + for (int i = 6; i < 20; ++i) { + assertEquals(i, slice.get(i - 6)); + } + } + + public void testSliceWithReadAndWriteIndexes() { + byte[] bytes = new byte[20]; + initializeBytes(bytes); + + buffer = NetworkBytesReference.wrap(new BytesArray(bytes, 2, 18)); + + buffer.incrementWrite(9); + buffer.incrementRead(5); + + NetworkBytesReference slice = buffer.slice(6, 12); + + assertEquals(12, slice.length()); + assertEquals(0, slice.getReadIndex()); + assertEquals(3, slice.getWriteIndex()); + + for (int i = 8; i < 20; ++i) { + assertEquals(i, slice.get(i - 8)); + } + } + + private void initializeBytes(byte[] bytes) { + for (int i = 0 ; i < bytes.length; ++i) { + bytes[i] = (byte) i; + } + } +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/ESSelectorTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/ESSelectorTests.java new file mode 100644 index 0000000000..e57b1bc4ef --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/ESSelectorTests.java @@ -0,0 +1,114 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.nio.channel.NioChannel; +import org.junit.Before; + +import java.io.IOException; +import java.nio.channels.ClosedSelectorException; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +public class ESSelectorTests extends ESTestCase { + + private ESSelector selector; + private EventHandler handler; + + @Before + public void setUp() throws Exception { + super.setUp(); + handler = mock(EventHandler.class); + selector = new TestSelector(handler); + } + + public void testQueueChannelForClosed() throws IOException { + NioChannel channel = mock(NioChannel.class); + selector.registeredChannels.add(channel); + + selector.queueChannelClose(channel); + + assertEquals(1, selector.getRegisteredChannels().size()); + + selector.singleLoop(); + + verify(handler).handleClose(channel); + + assertEquals(0, selector.getRegisteredChannels().size()); + } + + public void testSelectorClosedExceptionIsNotCaughtWhileRunning() throws IOException { + ((TestSelector) this.selector).setClosedSelectorException(new ClosedSelectorException()); + + boolean closedSelectorExceptionCaught = false; + try { + this.selector.singleLoop(); + } catch (ClosedSelectorException e) { + closedSelectorExceptionCaught = true; + } + + assertTrue(closedSelectorExceptionCaught); + } + + public void testIOExceptionWhileSelect() throws IOException { + IOException ioException = new IOException(); + ((TestSelector) this.selector).setIOException(ioException); + + this.selector.singleLoop(); + + verify(handler).selectException(ioException); + } + + private static class TestSelector extends ESSelector { + + private ClosedSelectorException closedSelectorException; + private IOException ioException; + + protected TestSelector(EventHandler eventHandler) throws IOException { + super(eventHandler); + } + + @Override + void doSelect(int timeout) throws IOException, ClosedSelectorException { + if (closedSelectorException != null) { + throw closedSelectorException; + } + if (ioException != null) { + throw ioException; + } + } + + @Override + void cleanup() { + + } + + public void setClosedSelectorException(ClosedSelectorException exception) { + this.closedSelectorException = exception; + } + + public void setIOException(IOException ioException) { + this.ioException = ioException; + } + } + +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java new file mode 100644 index 0000000000..e9f6dfe7f7 --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java @@ -0,0 +1,193 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.elasticsearch.Version; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.ConnectTransportException; +import org.elasticsearch.transport.nio.channel.ChannelFactory; +import org.elasticsearch.transport.nio.channel.CloseFuture; +import org.elasticsearch.transport.nio.channel.ConnectFuture; +import org.elasticsearch.transport.nio.channel.NioChannel; +import org.elasticsearch.transport.nio.channel.NioSocketChannel; +import org.junit.Before; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class NioClientTests extends ESTestCase { + + private NioClient client; + private SocketSelector selector; + private ChannelFactory channelFactory; + private OpenChannels openChannels = new OpenChannels(logger); + private NioSocketChannel[] channels; + private DiscoveryNode node; + private Consumer<NioChannel> listener; + private TransportAddress address; + + @Before + @SuppressWarnings("unchecked") + public void setUpClient() { + channelFactory = mock(ChannelFactory.class); + selector = mock(SocketSelector.class); + listener = mock(Consumer.class); + + ArrayList<SocketSelector> selectors = new ArrayList<>(); + selectors.add(selector); + Supplier<SocketSelector> selectorSupplier = new RoundRobinSelectorSupplier(selectors); + client = new NioClient(logger, openChannels, selectorSupplier, TimeValue.timeValueMillis(5), channelFactory); + + channels = new NioSocketChannel[2]; + address = new TransportAddress(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)); + node = new DiscoveryNode("node-id", address, Version.CURRENT); + } + + public void testCreateConnections() throws IOException, InterruptedException { + NioSocketChannel channel1 = mock(NioSocketChannel.class); + ConnectFuture connectFuture1 = mock(ConnectFuture.class); + CloseFuture closeFuture1 = mock(CloseFuture.class); + NioSocketChannel channel2 = mock(NioSocketChannel.class); + ConnectFuture connectFuture2 = mock(ConnectFuture.class); + CloseFuture closeFuture2 = mock(CloseFuture.class); + + when(channelFactory.openNioChannel(address.address())).thenReturn(channel1, channel2); + when(channel1.getCloseFuture()).thenReturn(closeFuture1); + when(channel1.getConnectFuture()).thenReturn(connectFuture1); + when(channel2.getCloseFuture()).thenReturn(closeFuture2); + when(channel2.getConnectFuture()).thenReturn(connectFuture2); + when(connectFuture1.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true); + when(connectFuture2.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true); + + client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener); + + verify(closeFuture1).setListener(listener); + verify(closeFuture2).setListener(listener); + verify(selector).registerSocketChannel(channel1); + verify(selector).registerSocketChannel(channel2); + + assertEquals(channel1, channels[0]); + assertEquals(channel2, channels[1]); + } + + public void testWithADifferentConnectTimeout() throws IOException, InterruptedException { + NioSocketChannel channel1 = mock(NioSocketChannel.class); + ConnectFuture connectFuture1 = mock(ConnectFuture.class); + CloseFuture closeFuture1 = mock(CloseFuture.class); + + when(channelFactory.openNioChannel(address.address())).thenReturn(channel1); + when(channel1.getCloseFuture()).thenReturn(closeFuture1); + when(channel1.getConnectFuture()).thenReturn(connectFuture1); + when(connectFuture1.awaitConnectionComplete(3, TimeUnit.MILLISECONDS)).thenReturn(true); + + channels = new NioSocketChannel[1]; + client.connectToChannels(node, channels, TimeValue.timeValueMillis(3), listener); + + verify(closeFuture1).setListener(listener); + verify(selector).registerSocketChannel(channel1); + + assertEquals(channel1, channels[0]); + } + + public void testConnectionTimeout() throws IOException, InterruptedException { + NioSocketChannel channel1 = mock(NioSocketChannel.class); + ConnectFuture connectFuture1 = mock(ConnectFuture.class); + CloseFuture closeFuture1 = mock(CloseFuture.class); + NioSocketChannel channel2 = mock(NioSocketChannel.class); + ConnectFuture connectFuture2 = mock(ConnectFuture.class); + CloseFuture closeFuture2 = mock(CloseFuture.class); + + when(channelFactory.openNioChannel(address.address())).thenReturn(channel1, channel2); + when(channel1.getCloseFuture()).thenReturn(closeFuture1); + when(channel1.getConnectFuture()).thenReturn(connectFuture1); + when(channel2.getCloseFuture()).thenReturn(closeFuture2); + when(channel2.getConnectFuture()).thenReturn(connectFuture2); + when(connectFuture1.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true); + when(connectFuture2.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(false); + + try { + client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener); + fail("Should have thrown ConnectTransportException"); + } catch (ConnectTransportException e) { + assertTrue(e.getMessage().contains("connect_timeout[5ms]")); + } + + verify(channel1).closeAsync(); + verify(channel2).closeAsync(); + + assertNull(channels[0]); + assertNull(channels[1]); + } + + public void testConnectionException() throws IOException, InterruptedException { + NioSocketChannel channel1 = mock(NioSocketChannel.class); + ConnectFuture connectFuture1 = mock(ConnectFuture.class); + CloseFuture closeFuture1 = mock(CloseFuture.class); + NioSocketChannel channel2 = mock(NioSocketChannel.class); + ConnectFuture connectFuture2 = mock(ConnectFuture.class); + CloseFuture closeFuture2 = mock(CloseFuture.class); + IOException ioException = new IOException(); + + when(channelFactory.openNioChannel(address.address())).thenReturn(channel1, channel2); + when(channel1.getCloseFuture()).thenReturn(closeFuture1); + when(channel1.getConnectFuture()).thenReturn(connectFuture1); + when(channel2.getCloseFuture()).thenReturn(closeFuture2); + when(channel2.getConnectFuture()).thenReturn(connectFuture2); + when(connectFuture1.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true); + when(connectFuture2.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(false); + when(connectFuture2.getException()).thenReturn(ioException); + + try { + client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener); + fail("Should have thrown ConnectTransportException"); + } catch (ConnectTransportException e) { + assertTrue(e.getMessage().contains("connect_exception")); + assertSame(ioException, e.getCause()); + } + + verify(channel1).closeAsync(); + verify(channel2).closeAsync(); + + assertNull(channels[0]); + assertNull(channels[1]); + } + + public void testCloseDoesNotAllowConnections() throws IOException { + client.close(); + + assertFalse(client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener)); + + for (NioSocketChannel channel : channels) { + assertNull(channel); + } + } +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java new file mode 100644 index 0000000000..a35355a393 --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java @@ -0,0 +1,132 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.elasticsearch.Version; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.network.NetworkService; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.node.Node; +import org.elasticsearch.test.transport.MockTransportService; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.AbstractSimpleTransportTestCase; +import org.elasticsearch.transport.BindTransportException; +import org.elasticsearch.transport.ConnectTransportException; +import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.transport.TransportSettings; +import org.elasticsearch.transport.nio.channel.NioChannel; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.Collections; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; + +public class SimpleNioTransportTests extends AbstractSimpleTransportTestCase { + + public static MockTransportService nioFromThreadPool(Settings settings, ThreadPool threadPool, final Version version, + ClusterSettings clusterSettings, boolean doHandshake) { + NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); + NetworkService networkService = new NetworkService(settings, Collections.emptyList()); + Transport transport = new NioTransport(settings, threadPool, + networkService, + BigArrays.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService()) { + + @Override + protected Version executeHandshake(DiscoveryNode node, NioChannel channel, TimeValue timeout) throws IOException, + InterruptedException { + if (doHandshake) { + return super.executeHandshake(node, channel, timeout); + } else { + return version.minimumCompatibilityVersion(); + } + } + + @Override + protected Version getCurrentVersion() { + return version; + } + + @Override + protected SocketEventHandler getSocketEventHandler() { + return new TestingSocketEventHandler(logger, this::exceptionCaught); + } + }; + MockTransportService mockTransportService = + MockTransportService.createNewService(Settings.EMPTY, transport, version, threadPool, clusterSettings); + mockTransportService.start(); + return mockTransportService; + } + + @Override + protected MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake) { + settings = Settings.builder().put(settings).put(TransportSettings.PORT.getKey(), "0").build(); + MockTransportService transportService = nioFromThreadPool(settings, threadPool, version, clusterSettings, doHandshake); + transportService.start(); + return transportService; + } + + public void testConnectException() throws UnknownHostException { + try { + serviceA.connectToNode(new DiscoveryNode("C", new TransportAddress(InetAddress.getByName("localhost"), 9876), + emptyMap(), emptySet(),Version.CURRENT)); + fail("Expected ConnectTransportException"); + } catch (ConnectTransportException e) { + assertThat(e.getMessage(), containsString("connect_exception")); + assertThat(e.getMessage(), containsString("[127.0.0.1:9876]")); + Throwable cause = e.getCause(); + assertThat(cause, instanceOf(IOException.class)); + assertEquals("Connection refused", cause.getMessage()); + } + } + + public void testBindUnavailableAddress() { + // this is on a lower level since it needs access to the TransportService before it's started + int port = serviceA.boundAddress().publishAddress().getPort(); + Settings settings = Settings.builder() + .put(Node.NODE_NAME_SETTING.getKey(), "foobar") + .put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "") + .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING") + .put("transport.tcp.port", port) + .build(); + ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + BindTransportException bindTransportException = expectThrows(BindTransportException.class, () -> { + MockTransportService transportService = nioFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true); + try { + transportService.start(); + } finally { + transportService.stop(); + transportService.close(); + } + }); + assertEquals("Failed to bind to ["+ port + "]", bindTransportException.getMessage()); + } +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java new file mode 100644 index 0000000000..393b9dc7cc --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java @@ -0,0 +1,175 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.nio.channel.CloseFuture; +import org.elasticsearch.transport.nio.channel.DoNotRegisterChannel; +import org.elasticsearch.transport.nio.channel.NioChannel; +import org.elasticsearch.transport.nio.channel.NioSocketChannel; +import org.elasticsearch.transport.nio.channel.ReadContext; +import org.elasticsearch.transport.nio.channel.SelectionKeyUtils; +import org.elasticsearch.transport.nio.channel.TcpWriteContext; +import org.junit.Before; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.CancelledKeyException; +import java.nio.channels.SelectionKey; +import java.nio.channels.SocketChannel; +import java.util.function.BiConsumer; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class SocketEventHandlerTests extends ESTestCase { + + private BiConsumer<NioSocketChannel, Throwable> exceptionHandler; + + private SocketEventHandler handler; + private NioSocketChannel channel; + private ReadContext readContext; + private SocketChannel rawChannel; + + @Before + @SuppressWarnings("unchecked") + public void setUpHandler() throws IOException { + exceptionHandler = mock(BiConsumer.class); + SocketSelector socketSelector = mock(SocketSelector.class); + handler = new SocketEventHandler(logger, exceptionHandler); + rawChannel = mock(SocketChannel.class); + channel = new DoNotRegisterChannel("", rawChannel); + readContext = mock(ReadContext.class); + when(rawChannel.finishConnect()).thenReturn(true); + + channel.setContexts(readContext, new TcpWriteContext(channel)); + channel.register(socketSelector); + channel.finishConnect(); + + when(socketSelector.isOnCurrentThread()).thenReturn(true); + } + + public void testRegisterAddsOP_CONNECTAndOP_READInterest() throws IOException { + handler.handleRegistration(channel); + assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT, channel.getSelectionKey().interestOps()); + } + + public void testRegistrationExceptionCallsExceptionHandler() throws IOException { + CancelledKeyException exception = new CancelledKeyException(); + handler.registrationException(channel, exception); + verify(exceptionHandler).accept(channel, exception); + } + + public void testConnectRemovesOP_CONNECTInterest() throws IOException { + SelectionKeyUtils.setConnectAndReadInterested(channel); + handler.handleConnect(channel); + assertEquals(SelectionKey.OP_READ, channel.getSelectionKey().interestOps()); + } + + public void testConnectExceptionCallsExceptionHandler() throws IOException { + IOException exception = new IOException(); + handler.connectException(channel, exception); + verify(exceptionHandler).accept(channel, exception); + } + + public void testHandleReadDelegatesToReadContext() throws IOException { + when(readContext.read()).thenReturn(1); + + handler.handleRead(channel); + + verify(readContext).read(); + } + + public void testHandleReadMarksChannelForCloseIfPeerClosed() throws IOException { + NioSocketChannel nioSocketChannel = mock(NioSocketChannel.class); + CloseFuture closeFuture = mock(CloseFuture.class); + when(nioSocketChannel.getReadContext()).thenReturn(readContext); + when(readContext.read()).thenReturn(-1); + when(nioSocketChannel.getCloseFuture()).thenReturn(closeFuture); + when(closeFuture.isDone()).thenReturn(true); + + handler.handleRead(nioSocketChannel); + + verify(nioSocketChannel).closeFromSelector(); + } + + public void testReadExceptionCallsExceptionHandler() throws IOException { + IOException exception = new IOException(); + handler.readException(channel, exception); + verify(exceptionHandler).accept(channel, exception); + } + + @SuppressWarnings("unchecked") + public void testHandleWriteWithCompleteFlushRemovesOP_WRITEInterest() throws IOException { + SelectionKey selectionKey = channel.getSelectionKey(); + setWriteAndRead(channel); + assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, selectionKey.interestOps()); + + BytesArray bytesArray = new BytesArray(new byte[1]); + NetworkBytesReference networkBuffer = NetworkBytesReference.wrap(bytesArray); + channel.getWriteContext().queueWriteOperations(new WriteOperation(channel, networkBuffer, mock(ActionListener.class))); + + when(rawChannel.write(ByteBuffer.wrap(bytesArray.array()))).thenReturn(1); + handler.handleWrite(channel); + + assertEquals(SelectionKey.OP_READ, selectionKey.interestOps()); + } + + @SuppressWarnings("unchecked") + public void testHandleWriteWithInCompleteFlushLeavesOP_WRITEInterest() throws IOException { + SelectionKey selectionKey = channel.getSelectionKey(); + setWriteAndRead(channel); + assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, selectionKey.interestOps()); + + BytesArray bytesArray = new BytesArray(new byte[1]); + NetworkBytesReference networkBuffer = NetworkBytesReference.wrap(bytesArray, 1, 0); + channel.getWriteContext().queueWriteOperations(new WriteOperation(channel, networkBuffer, mock(ActionListener.class))); + + when(rawChannel.write(ByteBuffer.wrap(bytesArray.array()))).thenReturn(0); + handler.handleWrite(channel); + + assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, selectionKey.interestOps()); + } + + public void testHandleWriteWithNoOpsRemovesOP_WRITEInterest() throws IOException { + SelectionKey selectionKey = channel.getSelectionKey(); + setWriteAndRead(channel); + assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, channel.getSelectionKey().interestOps()); + + handler.handleWrite(channel); + + assertEquals(SelectionKey.OP_READ, selectionKey.interestOps()); + } + + private void setWriteAndRead(NioChannel channel) { + SelectionKeyUtils.setConnectAndReadInterested(channel); + SelectionKeyUtils.removeConnectInterested(channel); + SelectionKeyUtils.setWriteInterested(channel); + } + + public void testWriteExceptionCallsExceptionHandler() throws IOException { + IOException exception = new IOException(); + handler.writeException(channel, exception); + verify(exceptionHandler).accept(channel, exception); + } +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java new file mode 100644 index 0000000000..050cf85644 --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java @@ -0,0 +1,336 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.nio.channel.NioChannel; +import org.elasticsearch.transport.nio.channel.NioSocketChannel; +import org.elasticsearch.transport.nio.channel.WriteContext; +import org.elasticsearch.transport.nio.utils.TestSelectionKey; +import org.junit.Before; + +import java.io.IOException; +import java.nio.channels.CancelledKeyException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.ClosedSelectorException; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.util.HashSet; +import java.util.Set; + +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class SocketSelectorTests extends ESTestCase { + + private SocketSelector socketSelector; + private SocketEventHandler eventHandler; + private NioSocketChannel channel; + private TestSelectionKey selectionKey; + private WriteContext writeContext; + private HashSet<SelectionKey> keySet = new HashSet<>(); + private ActionListener<NioChannel> listener; + private NetworkBytesReference bufferReference = NetworkBytesReference.wrap(new BytesArray(new byte[1])); + + @Before + @SuppressWarnings("unchecked") + public void setUp() throws Exception { + super.setUp(); + eventHandler = mock(SocketEventHandler.class); + channel = mock(NioSocketChannel.class); + writeContext = mock(WriteContext.class); + listener = mock(ActionListener.class); + selectionKey = new TestSelectionKey(0); + selectionKey.attach(channel); + Selector rawSelector = mock(Selector.class); + + this.socketSelector = new SocketSelector(eventHandler, rawSelector); + this.socketSelector.setThread(); + + when(rawSelector.selectedKeys()).thenReturn(keySet); + when(rawSelector.select(0)).thenReturn(1); + when(channel.getSelectionKey()).thenReturn(selectionKey); + when(channel.getWriteContext()).thenReturn(writeContext); + when(channel.isConnectComplete()).thenReturn(true); + } + + public void testRegisterChannel() throws Exception { + socketSelector.registerSocketChannel(channel); + + when(channel.register(socketSelector)).thenReturn(true); + + socketSelector.doSelect(0); + + verify(eventHandler).handleRegistration(channel); + + Set<NioChannel> registeredChannels = socketSelector.getRegisteredChannels(); + assertEquals(1, registeredChannels.size()); + assertTrue(registeredChannels.contains(channel)); + } + + public void testRegisterChannelFails() throws Exception { + socketSelector.registerSocketChannel(channel); + + when(channel.register(socketSelector)).thenReturn(false); + + socketSelector.doSelect(0); + + verify(channel, times(0)).finishConnect(); + + Set<NioChannel> registeredChannels = socketSelector.getRegisteredChannels(); + assertEquals(0, registeredChannels.size()); + assertFalse(registeredChannels.contains(channel)); + } + + public void testRegisterChannelFailsDueToException() throws Exception { + socketSelector.registerSocketChannel(channel); + + ClosedChannelException closedChannelException = new ClosedChannelException(); + when(channel.register(socketSelector)).thenThrow(closedChannelException); + + socketSelector.doSelect(0); + + verify(eventHandler).registrationException(channel, closedChannelException); + verify(channel, times(0)).finishConnect(); + + Set<NioChannel> registeredChannels = socketSelector.getRegisteredChannels(); + assertEquals(0, registeredChannels.size()); + assertFalse(registeredChannels.contains(channel)); + } + + public void testSuccessfullyRegisterChannelWillConnect() throws Exception { + socketSelector.registerSocketChannel(channel); + + when(channel.register(socketSelector)).thenReturn(true); + when(channel.finishConnect()).thenReturn(true); + + socketSelector.doSelect(0); + + verify(eventHandler).handleConnect(channel); + } + + public void testConnectIncompleteWillNotNotify() throws Exception { + socketSelector.registerSocketChannel(channel); + + when(channel.register(socketSelector)).thenReturn(true); + when(channel.finishConnect()).thenReturn(false); + + socketSelector.doSelect(0); + + verify(eventHandler, times(0)).handleConnect(channel); + } + + public void testQueueWriteWhenNotRunning() throws Exception { + socketSelector.close(false); + + socketSelector.queueWrite(new WriteOperation(channel, bufferReference, listener)); + + verify(listener).onFailure(any(ClosedSelectorException.class)); + } + + public void testQueueWriteChannelIsNoLongerWritable() throws Exception { + WriteOperation writeOperation = new WriteOperation(channel, bufferReference, listener); + socketSelector.queueWrite(writeOperation); + + when(channel.isWritable()).thenReturn(false); + socketSelector.doSelect(0); + + verify(writeContext, times(0)).queueWriteOperations(writeOperation); + verify(listener).onFailure(any(ClosedChannelException.class)); + } + + public void testQueueWriteSelectionKeyThrowsException() throws Exception { + SelectionKey selectionKey = mock(SelectionKey.class); + + WriteOperation writeOperation = new WriteOperation(channel, bufferReference, listener); + CancelledKeyException cancelledKeyException = new CancelledKeyException(); + socketSelector.queueWrite(writeOperation); + + when(channel.isWritable()).thenReturn(true); + when(channel.getSelectionKey()).thenReturn(selectionKey); + when(selectionKey.interestOps(anyInt())).thenThrow(cancelledKeyException); + socketSelector.doSelect(0); + + verify(writeContext, times(0)).queueWriteOperations(writeOperation); + verify(listener).onFailure(cancelledKeyException); + } + + public void testQueueWriteSuccessful() throws Exception { + WriteOperation writeOperation = new WriteOperation(channel, bufferReference, listener); + socketSelector.queueWrite(writeOperation); + + assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) == 0); + + when(channel.isWritable()).thenReturn(true); + socketSelector.doSelect(0); + + verify(writeContext).queueWriteOperations(writeOperation); + assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) != 0); + } + + public void testQueueDirectlyInChannelBufferSuccessful() throws Exception { + WriteOperation writeOperation = new WriteOperation(channel, bufferReference, listener); + + assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) == 0); + + when(channel.isWritable()).thenReturn(true); + socketSelector.queueWriteInChannelBuffer(writeOperation); + + verify(writeContext).queueWriteOperations(writeOperation); + assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) != 0); + } + + public void testQueueDirectlyInChannelBufferSelectionKeyThrowsException() throws Exception { + SelectionKey selectionKey = mock(SelectionKey.class); + + WriteOperation writeOperation = new WriteOperation(channel, bufferReference, listener); + CancelledKeyException cancelledKeyException = new CancelledKeyException(); + + when(channel.isWritable()).thenReturn(true); + when(channel.getSelectionKey()).thenReturn(selectionKey); + when(selectionKey.interestOps(anyInt())).thenThrow(cancelledKeyException); + socketSelector.queueWriteInChannelBuffer(writeOperation); + + verify(writeContext, times(0)).queueWriteOperations(writeOperation); + verify(listener).onFailure(cancelledKeyException); + } + + public void testConnectEvent() throws Exception { + keySet.add(selectionKey); + + selectionKey.setReadyOps(SelectionKey.OP_CONNECT); + + when(channel.finishConnect()).thenReturn(true); + socketSelector.doSelect(0); + + verify(eventHandler).handleConnect(channel); + } + + public void testConnectEventFinishUnsuccessful() throws Exception { + keySet.add(selectionKey); + + selectionKey.setReadyOps(SelectionKey.OP_CONNECT); + + when(channel.finishConnect()).thenReturn(false); + socketSelector.doSelect(0); + + verify(eventHandler, times(0)).handleConnect(channel); + } + + public void testConnectEventFinishThrowException() throws Exception { + keySet.add(selectionKey); + IOException ioException = new IOException(); + + selectionKey.setReadyOps(SelectionKey.OP_CONNECT); + + when(channel.finishConnect()).thenThrow(ioException); + socketSelector.doSelect(0); + + verify(eventHandler, times(0)).handleConnect(channel); + verify(eventHandler).connectException(channel, ioException); + } + + public void testWillNotConsiderWriteOrReadUntilConnectionComplete() throws Exception { + keySet.add(selectionKey); + IOException ioException = new IOException(); + + selectionKey.setReadyOps(SelectionKey.OP_WRITE | SelectionKey.OP_READ); + + doThrow(ioException).when(eventHandler).handleWrite(channel); + + when(channel.isConnectComplete()).thenReturn(false); + socketSelector.doSelect(0); + + verify(eventHandler, times(0)).handleWrite(channel); + verify(eventHandler, times(0)).handleRead(channel); + } + + public void testSuccessfulWriteEvent() throws Exception { + keySet.add(selectionKey); + + selectionKey.setReadyOps(SelectionKey.OP_WRITE); + + socketSelector.doSelect(0); + + verify(eventHandler).handleWrite(channel); + } + + public void testWriteEventWithException() throws Exception { + keySet.add(selectionKey); + IOException ioException = new IOException(); + + selectionKey.setReadyOps(SelectionKey.OP_WRITE); + + doThrow(ioException).when(eventHandler).handleWrite(channel); + + socketSelector.doSelect(0); + + verify(eventHandler).writeException(channel, ioException); + } + + public void testSuccessfulReadEvent() throws Exception { + keySet.add(selectionKey); + + selectionKey.setReadyOps(SelectionKey.OP_READ); + + socketSelector.doSelect(0); + + verify(eventHandler).handleRead(channel); + } + + public void testReadEventWithException() throws Exception { + keySet.add(selectionKey); + IOException ioException = new IOException(); + + selectionKey.setReadyOps(SelectionKey.OP_READ); + + doThrow(ioException).when(eventHandler).handleRead(channel); + + socketSelector.doSelect(0); + + verify(eventHandler).readException(channel, ioException); + } + + public void testCleanup() throws Exception { + NioSocketChannel unRegisteredChannel = mock(NioSocketChannel.class); + + when(channel.register(socketSelector)).thenReturn(true); + socketSelector.registerSocketChannel(channel); + + socketSelector.doSelect(0); + + NetworkBytesReference networkBuffer = NetworkBytesReference.wrap(new BytesArray(new byte[1])); + socketSelector.queueWrite(new WriteOperation(mock(NioSocketChannel.class), networkBuffer, listener)); + socketSelector.registerSocketChannel(unRegisteredChannel); + + socketSelector.cleanup(); + + verify(listener).onFailure(any(ClosedSelectorException.class)); + verify(eventHandler).handleClose(channel); + verify(eventHandler).handleClose(unRegisteredChannel); + } +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java new file mode 100644 index 0000000000..29f595c87a --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java @@ -0,0 +1,72 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.transport.nio.channel.NioSocketChannel; + +import java.io.IOException; +import java.util.Collections; +import java.util.Set; +import java.util.WeakHashMap; +import java.util.function.BiConsumer; + +public class TestingSocketEventHandler extends SocketEventHandler { + + private final Logger logger; + + public TestingSocketEventHandler(Logger logger, BiConsumer<NioSocketChannel, Throwable> exceptionHandler) { + super(logger, exceptionHandler); + this.logger = logger; + } + + private Set<NioSocketChannel> hasConnectedMap = Collections.newSetFromMap(new WeakHashMap<>()); + + public void handleConnect(NioSocketChannel channel) { + assert hasConnectedMap.contains(channel) == false : "handleConnect should only be called once per channel"; + hasConnectedMap.add(channel); + super.handleConnect(channel); + } + + private Set<NioSocketChannel> hasConnectExceptionMap = Collections.newSetFromMap(new WeakHashMap<>()); + + public void connectException(NioSocketChannel channel, Exception e) { + assert hasConnectExceptionMap.contains(channel) == false : "connectException should only called at maximum once per channel"; + hasConnectExceptionMap.add(channel); + super.connectException(channel, e); + } + + public void handleRead(NioSocketChannel channel) throws IOException { + super.handleRead(channel); + } + + public void readException(NioSocketChannel channel, Exception e) { + super.readException(channel, e); + } + + public void handleWrite(NioSocketChannel channel) throws IOException { + super.handleWrite(channel); + } + + public void writeException(NioSocketChannel channel, Exception e) { + super.writeException(channel, e); + } + +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/WriteOperationTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/WriteOperationTests.java new file mode 100644 index 0000000000..d7284491d6 --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/WriteOperationTests.java @@ -0,0 +1,78 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.nio.channel.NioChannel; +import org.elasticsearch.transport.nio.channel.NioSocketChannel; +import org.junit.Before; + +import java.io.IOException; + +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class WriteOperationTests extends ESTestCase { + + private NioSocketChannel channel; + private ActionListener<NioChannel> listener; + + @Before + @SuppressWarnings("unchecked") + public void setFields() { + channel = mock(NioSocketChannel.class); + listener = mock(ActionListener.class); + + } + + public void testFlush() throws IOException { + WriteOperation writeOp = new WriteOperation(channel, new BytesArray(new byte[10]), listener); + + + when(channel.write(any())).thenAnswer(invocationOnMock -> { + NetworkBytesReference[] refs = (NetworkBytesReference[]) invocationOnMock.getArguments()[0]; + refs[0].incrementRead(10); + return 10; + }); + + writeOp.flush(); + + assertTrue(writeOp.isFullyFlushed()); + } + + public void testPartialFlush() throws IOException { + WriteOperation writeOp = new WriteOperation(channel, new BytesArray(new byte[10]), listener); + + when(channel.write(any())).thenAnswer(invocationOnMock -> { + NetworkBytesReference[] refs = (NetworkBytesReference[]) invocationOnMock.getArguments()[0]; + refs[0].incrementRead(5); + return 5; + }); + + writeOp.flush(); + + assertFalse(writeOp.isFullyFlushed()); + assertEquals(5, writeOp.getByteReferences()[0].getReadRemaining()); + } +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/AbstractNioChannelTestCase.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/AbstractNioChannelTestCase.java new file mode 100644 index 0000000000..c3909a0644 --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/AbstractNioChannelTestCase.java @@ -0,0 +1,99 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.common.CheckedRunnable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.mocksocket.MockServerSocket; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.nio.TcpReadHandler; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.io.InputStream; +import java.net.Socket; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; + +import static org.mockito.Mockito.mock; + +public abstract class AbstractNioChannelTestCase extends ESTestCase { + + ChannelFactory channelFactory = new ChannelFactory(Settings.EMPTY, mock(TcpReadHandler.class)); + MockServerSocket mockServerSocket; + private Thread serverThread; + + @Before + public void serverSocketSetup() throws IOException { + mockServerSocket = new MockServerSocket(0); + serverThread = new Thread(() -> { + while (!mockServerSocket.isClosed()) { + try { + Socket socket = mockServerSocket.accept(); + InputStream inputStream = socket.getInputStream(); + socket.close(); + } catch (IOException e) { + } + } + }); + serverThread.start(); + } + + @After + public void serverSocketTearDown() throws IOException { + serverThread.interrupt(); + mockServerSocket.close(); + } + + public abstract NioChannel channelToClose() throws IOException; + + public void testClose() throws IOException, TimeoutException, InterruptedException { + AtomicReference<NioChannel> ref = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + + NioChannel socketChannel = channelToClose(); + CloseFuture closeFuture = socketChannel.getCloseFuture(); + closeFuture.setListener((c) -> {ref.set(c); latch.countDown();}); + + assertFalse(closeFuture.isClosed()); + assertTrue(socketChannel.getRawChannel().isOpen()); + + socketChannel.closeAsync(); + + closeFuture.awaitClose(100, TimeUnit.SECONDS); + + assertFalse(socketChannel.getRawChannel().isOpen()); + assertTrue(closeFuture.isClosed()); + latch.await(); + assertSame(socketChannel, ref.get()); + } + + protected Runnable wrappedRunnable(CheckedRunnable<Exception> runnable) { + return () -> { + try { + runnable.run(); + } catch (Exception e) { + } + }; + } +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/DoNotRegisterChannel.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/DoNotRegisterChannel.java new file mode 100644 index 0000000000..38f381bfcc --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/DoNotRegisterChannel.java @@ -0,0 +1,44 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.transport.nio.ESSelector; +import org.elasticsearch.transport.nio.utils.TestSelectionKey; + +import java.io.IOException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.SocketChannel; + +public class DoNotRegisterChannel extends NioSocketChannel { + + public DoNotRegisterChannel(String profile, SocketChannel socketChannel) throws IOException { + super(profile, socketChannel); + } + + @Override + public boolean register(ESSelector selector) throws ClosedChannelException { + if (markRegistered(selector)) { + setSelectionKey(new TestSelectionKey(0)); + return true; + } else { + return false; + } + } +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/DoNotRegisterServerChannel.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/DoNotRegisterServerChannel.java new file mode 100644 index 0000000000..e9e1fc207a --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/DoNotRegisterServerChannel.java @@ -0,0 +1,44 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.transport.nio.ESSelector; +import org.elasticsearch.transport.nio.utils.TestSelectionKey; + +import java.io.IOException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.ServerSocketChannel; + +public class DoNotRegisterServerChannel extends NioServerSocketChannel { + + public DoNotRegisterServerChannel(String profile, ServerSocketChannel channel, ChannelFactory channelFactory) throws IOException { + super(profile, channel, channelFactory); + } + + @Override + public boolean register(ESSelector selector) throws ClosedChannelException { + if (markRegistered(selector)) { + setSelectionKey(new TestSelectionKey(0)); + return true; + } else { + return false; + } + } +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java new file mode 100644 index 0000000000..c991263562 --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java @@ -0,0 +1,33 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; + +public class NioServerSocketChannelTests extends AbstractNioChannelTestCase { + + @Override + public NioChannel channelToClose() throws IOException { + return channelFactory.openNioServerSocketChannel("nio", new InetSocketAddress(InetAddress.getLoopbackAddress(),0)); + } + +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java new file mode 100644 index 0000000000..d195e83569 --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java @@ -0,0 +1,85 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import java.io.IOException; +import java.net.ConnectException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.LockSupport; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; + +public class NioSocketChannelTests extends AbstractNioChannelTestCase { + + private InetAddress loopbackAddress = InetAddress.getLoopbackAddress(); + + @Override + public NioChannel channelToClose() throws IOException { + return channelFactory.openNioChannel(new InetSocketAddress(loopbackAddress, mockServerSocket.getLocalPort())); + } + + public void testConnectSucceeds() throws IOException, InterruptedException { + InetSocketAddress remoteAddress = new InetSocketAddress(loopbackAddress, mockServerSocket.getLocalPort()); + NioSocketChannel socketChannel = channelFactory.openNioChannel(remoteAddress); + Thread thread = new Thread(wrappedRunnable(() -> ensureConnect(socketChannel))); + thread.start(); + ConnectFuture connectFuture = socketChannel.getConnectFuture(); + connectFuture.awaitConnectionComplete(100, TimeUnit.SECONDS); + + assertTrue(socketChannel.isConnectComplete()); + assertTrue(socketChannel.isOpen()); + assertFalse(connectFuture.connectFailed()); + assertNull(connectFuture.getException()); + + thread.join(); + } + + public void testConnectFails() throws IOException, InterruptedException { + mockServerSocket.close(); + InetSocketAddress remoteAddress = new InetSocketAddress(loopbackAddress, mockServerSocket.getLocalPort()); + NioSocketChannel socketChannel = channelFactory.openNioChannel(remoteAddress); + Thread thread = new Thread(wrappedRunnable(() -> ensureConnect(socketChannel))); + thread.start(); + ConnectFuture connectFuture = socketChannel.getConnectFuture(); + connectFuture.awaitConnectionComplete(100, TimeUnit.SECONDS); + + assertFalse(socketChannel.isConnectComplete()); + // Even if connection fails the channel is 'open' until close() is called + assertTrue(socketChannel.isOpen()); + assertTrue(connectFuture.connectFailed()); + assertThat(connectFuture.getException(), instanceOf(ConnectException.class)); + assertThat(connectFuture.getException().getMessage(), containsString("Connection refused")); + + thread.join(); + } + + private void ensureConnect(NioSocketChannel nioSocketChannel) throws IOException { + for (;;) { + boolean isConnected = nioSocketChannel.finishConnect(); + if (isConnected) { + return; + } + LockSupport.parkNanos(TimeUnit.MILLISECONDS.toNanos(1)); + } + } +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpFrameDecoderTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpFrameDecoderTests.java new file mode 100644 index 0000000000..519828592b --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpFrameDecoderTests.java @@ -0,0 +1,169 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.TcpTransport; + +import java.io.IOException; +import java.io.StreamCorruptedException; + +import static org.hamcrest.Matchers.instanceOf; + +public class TcpFrameDecoderTests extends ESTestCase { + + private TcpFrameDecoder frameDecoder = new TcpFrameDecoder(); + + public void testDefaultExceptedMessageLengthIsNegative1() { + assertEquals(-1, frameDecoder.expectedMessageLength()); + } + + public void testDecodeWithIncompleteHeader() throws IOException { + BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14); + streamOutput.write('E'); + streamOutput.write('S'); + streamOutput.write(1); + streamOutput.write(1); + streamOutput.write(0); + streamOutput.write(0); + + assertNull(frameDecoder.decode(streamOutput.bytes(), 4)); + assertEquals(-1, frameDecoder.expectedMessageLength()); + } + + public void testDecodePing() throws IOException { + BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14); + streamOutput.write('E'); + streamOutput.write('S'); + streamOutput.writeInt(-1); + + BytesReference message = frameDecoder.decode(streamOutput.bytes(), 6); + + assertEquals(-1, frameDecoder.expectedMessageLength()); + assertEquals(streamOutput.bytes(), message); + } + + public void testDecodePingWithStartOfSecondMessage() throws IOException { + BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14); + streamOutput.write('E'); + streamOutput.write('S'); + streamOutput.writeInt(-1); + streamOutput.write('E'); + streamOutput.write('S'); + + BytesReference message = frameDecoder.decode(streamOutput.bytes(), 8); + + assertEquals(6, message.length()); + assertEquals(streamOutput.bytes().slice(0, 6), message); + } + + public void testDecodeMessage() throws IOException { + BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14); + streamOutput.write('E'); + streamOutput.write('S'); + streamOutput.writeInt(2); + streamOutput.write('M'); + streamOutput.write('A'); + + BytesReference message = frameDecoder.decode(streamOutput.bytes(), 8); + + assertEquals(-1, frameDecoder.expectedMessageLength()); + assertEquals(streamOutput.bytes(), message); + } + + public void testDecodeIncompleteMessage() throws IOException { + BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14); + streamOutput.write('E'); + streamOutput.write('S'); + streamOutput.writeInt(3); + streamOutput.write('M'); + streamOutput.write('A'); + + BytesReference message = frameDecoder.decode(streamOutput.bytes(), 8); + + assertEquals(9, frameDecoder.expectedMessageLength()); + assertNull(message); + } + + public void testInvalidLength() throws IOException { + BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14); + streamOutput.write('E'); + streamOutput.write('S'); + streamOutput.writeInt(-2); + streamOutput.write('M'); + streamOutput.write('A'); + + try { + frameDecoder.decode(streamOutput.bytes(), 8); + fail("Expected exception"); + } catch (Exception ex) { + assertThat(ex, instanceOf(StreamCorruptedException.class)); + assertEquals("invalid data length: -2", ex.getMessage()); + } + } + + public void testInvalidHeader() throws IOException { + BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14); + streamOutput.write('E'); + streamOutput.write('C'); + byte byte1 = randomByte(); + byte byte2 = randomByte(); + streamOutput.write(byte1); + streamOutput.write(byte2); + streamOutput.write(randomByte()); + streamOutput.write(randomByte()); + streamOutput.write(randomByte()); + + try { + frameDecoder.decode(streamOutput.bytes(), 7); + fail("Expected exception"); + } catch (Exception ex) { + assertThat(ex, instanceOf(StreamCorruptedException.class)); + String expected = "invalid internal transport message format, got (45,43," + + Integer.toHexString(byte1 & 0xFF) + "," + + Integer.toHexString(byte2 & 0xFF) + ")"; + assertEquals(expected, ex.getMessage()); + } + } + + public void testHTTPHeader() throws IOException { + String[] httpHeaders = {"GET", "POST", "PUT", "HEAD", "DELETE", "OPTIONS", "PATCH", "TRACE"}; + + for (String httpHeader : httpHeaders) { + BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14); + + for (char c : httpHeader.toCharArray()) { + streamOutput.write((byte) c); + } + streamOutput.write(new byte[6]); + + try { + BytesReference bytes = streamOutput.bytes(); + frameDecoder.decode(bytes, bytes.length()); + fail("Expected exception"); + } catch (Exception ex) { + assertThat(ex, instanceOf(TcpTransport.HttpOnTransportException.class)); + assertEquals("This is not a HTTP port", ex.getMessage()); + } + } + } +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpReadContextTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpReadContextTests.java new file mode 100644 index 0000000000..fc8d7e48ab --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpReadContextTests.java @@ -0,0 +1,150 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.CompositeBytesReference; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.nio.NetworkBytesReference; +import org.elasticsearch.transport.nio.TcpReadHandler; +import org.junit.Before; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; + +public class TcpReadContextTests extends ESTestCase { + + private static String PROFILE = "profile"; + private TcpReadHandler handler; + private int messageLength; + private NioSocketChannel channel; + private TcpReadContext readContext; + + @Before + public void init() throws IOException { + handler = mock(TcpReadHandler.class); + + messageLength = randomInt(96) + 4; + channel = mock(NioSocketChannel.class); + readContext = new TcpReadContext(channel, handler); + + when(channel.getProfile()).thenReturn(PROFILE); + } + + public void testSuccessfulRead() throws IOException { + byte[] bytes = createMessage(messageLength); + byte[] fullMessage = combineMessageAndHeader(bytes); + + final AtomicInteger bufferCapacity = new AtomicInteger(); + when(channel.read(any(NetworkBytesReference.class))).thenAnswer(invocationOnMock -> { + NetworkBytesReference reference = (NetworkBytesReference) invocationOnMock.getArguments()[0]; + ByteBuffer buffer = reference.getWriteByteBuffer(); + bufferCapacity.set(reference.getWriteRemaining()); + buffer.put(fullMessage); + reference.incrementWrite(fullMessage.length); + return fullMessage.length; + }); + + readContext.read(); + + verify(handler).handleMessage(new BytesArray(bytes), channel, PROFILE, messageLength); + assertEquals(1024 * 16, bufferCapacity.get()); + + BytesArray bytesArray = new BytesArray(new byte[10]); + bytesArray.slice(5, 5); + bytesArray.slice(5, 0); + } + + public void testPartialRead() throws IOException { + byte[] part1 = createMessage(messageLength); + byte[] fullPart1 = combineMessageAndHeader(part1, messageLength + messageLength); + byte[] part2 = createMessage(messageLength); + + final AtomicInteger bufferCapacity = new AtomicInteger(); + final AtomicReference<byte[]> bytes = new AtomicReference<>(); + + when(channel.read(any(NetworkBytesReference.class))).thenAnswer(invocationOnMock -> { + NetworkBytesReference reference = (NetworkBytesReference) invocationOnMock.getArguments()[0]; + ByteBuffer buffer = reference.getWriteByteBuffer(); + bufferCapacity.set(reference.getWriteRemaining()); + buffer.put(bytes.get()); + reference.incrementWrite(bytes.get().length); + return bytes.get().length; + }); + + + bytes.set(fullPart1); + readContext.read(); + + assertEquals(1024 * 16, bufferCapacity.get()); + verifyZeroInteractions(handler); + + bytes.set(part2); + readContext.read(); + + assertEquals(1024 * 16 - fullPart1.length, bufferCapacity.get()); + + CompositeBytesReference reference = new CompositeBytesReference(new BytesArray(part1), new BytesArray(part2)); + verify(handler).handleMessage(reference, channel, PROFILE, messageLength + messageLength); + } + + public void testReadThrowsIOException() throws IOException { + IOException ioException = new IOException(); + when(channel.read(any())).thenThrow(ioException); + + try { + readContext.read(); + fail("Expected exception"); + } catch (Exception ex) { + assertSame(ioException, ex); + } + } + + private static byte[] combineMessageAndHeader(byte[] bytes) { + return combineMessageAndHeader(bytes, bytes.length); + } + + private static byte[] combineMessageAndHeader(byte[] bytes, int messageLength) { + byte[] fullMessage = new byte[bytes.length + 6]; + ByteBuffer wrapped = ByteBuffer.wrap(fullMessage); + wrapped.put((byte) 'E'); + wrapped.put((byte) 'S'); + wrapped.putInt(messageLength); + wrapped.put(bytes); + return fullMessage; + } + + private static byte[] createMessage(int length) { + byte[] bytes = new byte[length]; + for (int i = 0; i < length; ++i) { + bytes[i] = randomByte(); + } + return bytes; + } + +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpWriteContextTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpWriteContextTests.java new file mode 100644 index 0000000000..d2a2f446e7 --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpWriteContextTests.java @@ -0,0 +1,296 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.nio.SocketSelector; +import org.elasticsearch.transport.nio.WriteOperation; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.SocketChannel; + +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class TcpWriteContextTests extends ESTestCase { + + private SocketSelector selector; + private ActionListener<NioChannel> listener; + private TcpWriteContext writeContext; + private NioSocketChannel channel; + + @Before + @SuppressWarnings("unchecked") + public void setUp() throws Exception { + super.setUp(); + selector = mock(SocketSelector.class); + listener = mock(ActionListener.class); + channel = mock(NioSocketChannel.class); + writeContext = new TcpWriteContext(channel); + + when(channel.getSelector()).thenReturn(selector); + when(selector.isOnCurrentThread()).thenReturn(true); + } + + public void testWriteFailsIfChannelNotWritable() throws Exception { + when(channel.isWritable()).thenReturn(false); + + writeContext.sendMessage(new BytesArray(generateBytes(10)), listener); + + verify(listener).onFailure(any(ClosedChannelException.class)); + } + + public void testSendMessageFromDifferentThreadIsQueuedWithSelector() throws Exception { + byte[] bytes = generateBytes(10); + BytesArray bytesArray = new BytesArray(bytes); + ArgumentCaptor<WriteOperation> writeOpCaptor = ArgumentCaptor.forClass(WriteOperation.class); + + when(selector.isOnCurrentThread()).thenReturn(false); + when(channel.isWritable()).thenReturn(true); + + writeContext.sendMessage(bytesArray, listener); + + verify(selector).queueWrite(writeOpCaptor.capture()); + WriteOperation writeOp = writeOpCaptor.getValue(); + + assertSame(listener, writeOp.getListener()); + assertSame(channel, writeOp.getChannel()); + assertEquals(ByteBuffer.wrap(bytes), writeOp.getByteReferences()[0].getReadByteBuffer()); + } + + public void testSendMessageFromSameThreadIsQueuedInChannel() throws Exception { + byte[] bytes = generateBytes(10); + BytesArray bytesArray = new BytesArray(bytes); + ArgumentCaptor<WriteOperation> writeOpCaptor = ArgumentCaptor.forClass(WriteOperation.class); + + when(channel.isWritable()).thenReturn(true); + + writeContext.sendMessage(bytesArray, listener); + + verify(selector).queueWriteInChannelBuffer(writeOpCaptor.capture()); + WriteOperation writeOp = writeOpCaptor.getValue(); + + assertSame(listener, writeOp.getListener()); + assertSame(channel, writeOp.getChannel()); + assertEquals(ByteBuffer.wrap(bytes), writeOp.getByteReferences()[0].getReadByteBuffer()); + } + + public void testWriteIsQueuedInChannel() throws Exception { + assertFalse(writeContext.hasQueuedWriteOps()); + + writeContext.queueWriteOperations(new WriteOperation(channel, new BytesArray(generateBytes(10)), listener)); + + assertTrue(writeContext.hasQueuedWriteOps()); + } + + public void testWriteOpsCanBeCleared() throws Exception { + assertFalse(writeContext.hasQueuedWriteOps()); + + writeContext.queueWriteOperations(new WriteOperation(channel, new BytesArray(generateBytes(10)), listener)); + + assertTrue(writeContext.hasQueuedWriteOps()); + + ClosedChannelException e = new ClosedChannelException(); + writeContext.clearQueuedWriteOps(e); + + verify(listener).onFailure(e); + + assertFalse(writeContext.hasQueuedWriteOps()); + } + + public void testQueuedWriteIsFlushedInFlushCall() throws Exception { + assertFalse(writeContext.hasQueuedWriteOps()); + + WriteOperation writeOperation = mock(WriteOperation.class); + writeContext.queueWriteOperations(writeOperation); + + assertTrue(writeContext.hasQueuedWriteOps()); + + when(writeOperation.isFullyFlushed()).thenReturn(true); + when(writeOperation.getListener()).thenReturn(listener); + writeContext.flushChannel(); + + verify(writeOperation).flush(); + verify(listener).onResponse(channel); + assertFalse(writeContext.hasQueuedWriteOps()); + } + + public void testPartialFlush() throws IOException { + assertFalse(writeContext.hasQueuedWriteOps()); + + WriteOperation writeOperation = mock(WriteOperation.class); + writeContext.queueWriteOperations(writeOperation); + + assertTrue(writeContext.hasQueuedWriteOps()); + + when(writeOperation.isFullyFlushed()).thenReturn(false); + writeContext.flushChannel(); + + verify(listener, times(0)).onResponse(channel); + assertTrue(writeContext.hasQueuedWriteOps()); + } + + @SuppressWarnings("unchecked") + public void testMultipleWritesPartialFlushes() throws IOException { + assertFalse(writeContext.hasQueuedWriteOps()); + + ActionListener listener2 = mock(ActionListener.class); + WriteOperation writeOperation1 = mock(WriteOperation.class); + WriteOperation writeOperation2 = mock(WriteOperation.class); + when(writeOperation1.getListener()).thenReturn(listener); + when(writeOperation2.getListener()).thenReturn(listener2); + writeContext.queueWriteOperations(writeOperation1); + writeContext.queueWriteOperations(writeOperation2); + + assertTrue(writeContext.hasQueuedWriteOps()); + + when(writeOperation1.isFullyFlushed()).thenReturn(true); + when(writeOperation2.isFullyFlushed()).thenReturn(false); + writeContext.flushChannel(); + + verify(listener).onResponse(channel); + verify(listener2, times(0)).onResponse(channel); + assertTrue(writeContext.hasQueuedWriteOps()); + + when(writeOperation2.isFullyFlushed()).thenReturn(true); + + writeContext.flushChannel(); + + verify(listener2).onResponse(channel); + assertFalse(writeContext.hasQueuedWriteOps()); + } + + private class ConsumeAllChannel extends NioSocketChannel { + + private byte[] bytes; + private byte[] bytes2; + + ConsumeAllChannel() throws IOException { + super("", mock(SocketChannel.class)); + } + + public int write(ByteBuffer buffer) throws IOException { + bytes = new byte[buffer.remaining()]; + buffer.get(bytes); + return bytes.length; + } + + public long vectorizedWrite(ByteBuffer[] buffer) throws IOException { + if (buffer.length != 2) { + throw new IOException("Only allows 2 buffers"); + } + bytes = new byte[buffer[0].remaining()]; + buffer[0].get(bytes); + + bytes2 = new byte[buffer[1].remaining()]; + buffer[1].get(bytes2); + return bytes.length + bytes2.length; + } + } + + private class HalfConsumeChannel extends NioSocketChannel { + + private byte[] bytes; + private byte[] bytes2; + + HalfConsumeChannel() throws IOException { + super("", mock(SocketChannel.class)); + } + + public int write(ByteBuffer buffer) throws IOException { + bytes = new byte[buffer.limit() / 2]; + buffer.get(bytes); + return bytes.length; + } + + public long vectorizedWrite(ByteBuffer[] buffers) throws IOException { + if (buffers.length != 2) { + throw new IOException("Only allows 2 buffers"); + } + if (bytes == null) { + bytes = new byte[buffers[0].remaining()]; + bytes2 = new byte[buffers[1].remaining()]; + } + + if (buffers[0].remaining() != 0) { + buffers[0].get(bytes); + return bytes.length; + } else { + buffers[1].get(bytes2); + return bytes2.length; + } + } + } + + private class MultiWriteChannel extends NioSocketChannel { + + private byte[] write1Bytes; + private byte[] write1Bytes2; + private byte[] write2Bytes1; + private byte[] write2Bytes2; + + MultiWriteChannel() throws IOException { + super("", mock(SocketChannel.class)); + } + + public long vectorizedWrite(ByteBuffer[] buffers) throws IOException { + if (buffers.length != 4 && write1Bytes == null) { + throw new IOException("Only allows 4 buffers"); + } else if (buffers.length != 2 && write1Bytes != null) { + throw new IOException("Only allows 2 buffers on second write"); + } + if (write1Bytes == null) { + write1Bytes = new byte[buffers[0].remaining()]; + write1Bytes2 = new byte[buffers[1].remaining()]; + write2Bytes1 = new byte[buffers[2].remaining()]; + write2Bytes2 = new byte[buffers[3].remaining()]; + } + + if (buffers[0].remaining() != 0) { + buffers[0].get(write1Bytes); + buffers[1].get(write1Bytes2); + buffers[2].get(write2Bytes1); + return write1Bytes.length + write1Bytes2.length + write2Bytes1.length; + } else { + buffers[1].get(write2Bytes2); + return write2Bytes2.length; + } + } + } + + private byte[] generateBytes(int n) { + n += 10; + byte[] bytes = new byte[n]; + for (int i = 0; i < n; ++i) { + bytes[i] = randomByte(); + } + return bytes; + } + +} diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/utils/TestSelectionKey.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/utils/TestSelectionKey.java new file mode 100644 index 0000000000..0f0011f155 --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/utils/TestSelectionKey.java @@ -0,0 +1,65 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.utils; + +import java.nio.channels.SelectableChannel; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.spi.AbstractSelectionKey; + +public class TestSelectionKey extends AbstractSelectionKey { + + private int ops = 0; + private int readyOps; + + public TestSelectionKey(int ops) { + this.ops = ops; + } + + @Override + public SelectableChannel channel() { + return null; + } + + @Override + public Selector selector() { + return null; + } + + @Override + public int interestOps() { + return ops; + } + + @Override + public SelectionKey interestOps(int ops) { + this.ops = ops; + return this; + } + + @Override + public int readyOps() { + return readyOps; + } + + public void setReadyOps(int readyOps) { + this.readyOps = readyOps; + } +} |