【记录】Java NIO实现网络模块遇到的BUG

1.背景

通过JavaNio实现一个简单的网络模块,有点像Netty的线程模型,一个线程(AcceptThread)建立新连接,把新连接绑定到某个SelectorThread,SelectorThread处理读/写。

  • AcceptThread:拥有一个Selector,上面只注册了一个ServerSocketChannel,监听客户端新连接,当接收到新连接时,把新连接注册到SelectorThread上。
  • SelectorThread:拥有一个Selector,用于注册客户端连接,并且处理读/写事件。

2.代码

public class NonblockingServer {
    /** 处理客户端读写 */
    private SelectorThread[] selectorThreads;
    /** 建立新连接 */
    private AcceptThread acceptThread;
    /** 选择器 */
    private SelectNextSelectorThread nextSelectorThread;
    /** 默认selectorThread数量 */
    private static int defaultSelectorNum = 3;

    public NonblockingServer() {
        this(defaultSelectorNum);
    }

    public NonblockingServer(int selectorThreadNum) {
        if (selectorThreadNum < 1) {
            throw new IllegalArgumentException("SelectorThread线程数量不能低于1");
        }
        selectorThreads = new SelectorThread[selectorThreadNum];
        try {
            for (int i = 0; i < selectorThreads.length; i++) {
                selectorThreads[i] = new SelectorThread("selector-thread-" + i);
            }
            acceptThread = new AcceptThread("accept-thread");
            nextSelectorThread = new SelectNextSelectorThread(Arrays.asList(selectorThreads));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void listen(String host, int port) {
        try {
            acceptThread.listen(host, port);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 处理连接事件,把新连接注册到一个SelectorThread上
     */
    class AcceptThread extends Thread {
        private Selector selector;
        private ServerSocketChannel serverChannel;

        AcceptThread(String name) throws IOException {
            super(name);
            selector = Selector.open();
        }

        public void listen(String host, int port) throws IOException {
            serverChannel = ServerSocketChannel.open();
            serverChannel.configureBlocking(false);
            serverChannel.register(selector, SelectionKey.OP_ACCEPT);
            serverChannel.bind(new InetSocketAddress(host, port));
            this.start();
        }

        @Override
        public void run() {
            while (!Thread.currentThread().isInterrupted()) {
                try {
                    int select = selector.select();
                    if (select == 0) {
                        continue;
                    }
                    Set<SelectionKey> keys = selector.selectedKeys();
                    Iterator<SelectionKey> iterator = keys.iterator();
                    while (iterator.hasNext()) {
                        SelectionKey key = iterator.next();
                        iterator.remove();
                        if (key.isAcceptable()) {
                            handAccept();
                        }
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }

        private void handAccept() throws IOException {
            SocketChannel clientChannel = serverChannel.accept();
            SelectorThread selectorThread = nextSelectorThread.nextThread();
            // 懒加载
            if (!selectorThread.isStart()) {
                selectorThread.start();
            }
            selectorThread.register(clientChannel);
        }
    }

    /**
     * 处理客户端连接的读/写事件
     */
    class SelectorThread extends Thread {
        private Selector selector;
        private volatile boolean start = false;

        SelectorThread(String name) throws IOException {
            super(name);
            selector = Selector.open();
        }

        @Override
        public void run() {
            while (start) {
                try {
                    selector.select();
                    Set<SelectionKey> keys = selector.selectedKeys();
                    Iterator<SelectionKey> iterator = keys.iterator();
                    while (start && iterator.hasNext()) {
                        SelectionKey key = iterator.next();
                        iterator.remove();
                        SocketChannel channel = (SocketChannel) key.channel();
                        if (key.isReadable()) {
                            handRead(channel);
                        } else if (key.isWritable()) {
                            handWrite(channel);
                        }
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }

        /**
         * 这里简单的把输出打印出来
         */
        private void handRead(SocketChannel channel) throws IOException {
            ByteBuffer buffer = ByteBuffer.allocate(1024);
            int read = channel.read(buffer);
            if (read == -1) {
                System.out.println(channel.getRemoteAddress() + " 断开连接");
                channel.close();
            } else {
                byte[] bytes = new byte[read];
                buffer.flip();
                buffer.get(bytes);
                System.out.println(new String(bytes));
                buffer.clear();
            }
        }

        private void handWrite(SocketChannel channel) {
            // nothing
        }

        /**
         * 把新连接注册到本线程
         */
        public void register(SocketChannel clientChannel) throws IOException {
            clientChannel.configureBlocking(false);
            clientChannel.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE);
        }

        @Override
        public synchronized void start() {
            start = true;
            super.start();
        }

        public void shutdown() {
            this.start = false;
            this.interrupt();
        }

        public boolean isStart() {
            return start;
        }
    }

    static class SelectNextSelectorThread {
        private final Collection<? extends SelectorThread> threads;
        private Iterator<? extends SelectorThread> iterator;

        public <T extends SelectorThread> SelectNextSelectorThread(Collection<T> threads) {
            this.threads = threads;
            iterator = this.threads.iterator();
        }

        /**
         * 选择下一个SelectorThread,这里为轮询
         *
         * @return SelectorThread
         */
        public SelectorThread nextThread() {
            if (!iterator.hasNext()) {
                iterator = threads.iterator();
            }
            return iterator.next();
        }
    }
}
NonblockingServer

3.问题

当用客户端连接几次后,发现AcceptThread被卡死了!新连接无法建立。
用jvisualVM可以看到AcceptThread状态为监视,说明此时的AcceptThread在等待某个锁。


通过IDEA的dump threads也可以看到:

观察SelectThread线程:

SelectorImpl.lockAndDoSelect代码:

    private int lockAndDoSelect(long var1) throws IOException {
        synchronized(this) {
            if (!this.isOpen()) {
                throw new ClosedSelectorException();
            } else {
                int var10000;
                synchronized(this.publicKeys) {
                    synchronized(this.publicSelectedKeys) {
                        var10000 = this.doSelect(var1);
                    }
                }

                return var10000;
            }
        }
    }

SelectorImpl.register代码:

    protected final SelectionKey register(AbstractSelectableChannel var1, int var2, Object var3) {
        if (!(var1 instanceof SelChImpl)) {
            throw new IllegalSelectorException();
        } else {
            SelectionKeyImpl var4 = new SelectionKeyImpl((SelChImpl)var1, this);
            var4.attach(var3);
            synchronized(this.publicKeys) {
                this.implRegister(var4);
            }

            var4.interestOps(var2);
            return var4;
        }
    }

这两个方法都需要获取this.publicKeys的锁,经过调试发现,SeletorThread在获取Selector的产生的事件时selector.select(),一直在不断的循环执行SelectorImpl.lockAndDoSelect方法,一直在不停的获取this.publicKeys的锁,所以这里的register方法很难有机会拿到this.publicKeys的锁,AcceptThread就卡死在获取锁的过程中了。

4.解决方法

Netty在注册新连接时,是把这个注册过程封装成一个任务,交给SelectorThread执行的,就不会发生线程冲突。

Thrift是把新连接放到SelecttorThread中的BlockingQueue中,也是由SelectorThread执行的。

5.修正后的代码

public class NonblockingServer {
    /** 处理客户端读写 */
    private SelectorThread[] selectorThreads;
    /** 建立新连接 */
    private AcceptThread acceptThread;
    /** 选择器 */
    private SelectNextSelectorThread nextSelectorThread;
    /** 默认selectorThread数量 */
    private static int defaultSelectorNum = 3;

    public NonblockingServer() {
        this(defaultSelectorNum);
    }

    public NonblockingServer(int selectorThreadNum) {
        if (selectorThreadNum < 1) {
            throw new IllegalArgumentException("SelectorThread线程数量不能低于1");
        }
        selectorThreads = new SelectorThread[selectorThreadNum];
        try {
            for (int i = 0; i < selectorThreads.length; i++) {
                selectorThreads[i] = new SelectorThread("selector-thread-" + i);
            }
            acceptThread = new AcceptThread("accept-thread");
            nextSelectorThread = new SelectNextSelectorThread(Arrays.asList(selectorThreads));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void listen(String host, int port) {
        try {
            acceptThread.listen(host, port);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 处理连接事件,把新连接注册到一个SelectorThread上
     */
    class AcceptThread extends Thread {
        private Selector selector;
        private ServerSocketChannel serverChannel;

        AcceptThread(String name) throws IOException {
            super(name);
            selector = Selector.open();
        }

        public void listen(String host, int port) throws IOException {
            serverChannel = ServerSocketChannel.open();
            serverChannel.configureBlocking(false);
            serverChannel.register(selector, SelectionKey.OP_ACCEPT);
            serverChannel.bind(new InetSocketAddress(host, port));
            this.start();
        }

        @Override
        public void run() {
            while (!Thread.currentThread().isInterrupted()) {
                try {
                    int select = selector.select();
                    if (select == 0) {
                        continue;
                    }
                    Set<SelectionKey> keys = selector.selectedKeys();
                    Iterator<SelectionKey> iterator = keys.iterator();
                    while (iterator.hasNext()) {
                        SelectionKey key = iterator.next();
                        iterator.remove();
                        if (key.isAcceptable()) {
                            handAccept();
                        }
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }

        private void handAccept() throws IOException {
            SocketChannel clientChannel = serverChannel.accept();
            SelectorThread selectorThread = nextSelectorThread.nextThread();
            // 懒加载
            if (!selectorThread.isStart()) {
                selectorThread.start();
            }
            selectorThread.register(clientChannel);
        }
    }

    /**
     * 处理客户端连接的读/写事件
     */
    class SelectorThread extends Thread {
        private Selector selector;
        private BlockingQueue<Runnable> tasks;
        private volatile boolean start = false;

        SelectorThread(String name) throws IOException {
            super(name);
            selector = Selector.open();
            tasks = new LinkedBlockingQueue<>();
        }

        @Override
        public void run() {
            while (start) {
                try {
                    selector.select();
                    handTask();
                    Set<SelectionKey> keys = selector.selectedKeys();
                    Iterator<SelectionKey> iterator = keys.iterator();
                    while (start && iterator.hasNext()) {
                        SelectionKey key = iterator.next();
                        iterator.remove();
                        SocketChannel channel = (SocketChannel) key.channel();
                        if (key.isReadable()) {
                            handRead(channel);
                        } else if (key.isWritable()) {
                            handWrite(channel);
                        }
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }

        private void handTask() {
            Runnable task;
            while (start && (task = tasks.poll()) != null) {
                task.run();
            }
        }

        /**
         * 这里简单的把输出打印出来
         */
        private void handRead(SocketChannel channel) throws IOException {
            ByteBuffer buffer = ByteBuffer.allocate(1024);
            int read = channel.read(buffer);
            if (read == -1) {
                System.out.println(channel.getRemoteAddress() + " 断开连接");
                channel.close();
            } else {
                byte[] bytes = new byte[read];
                buffer.flip();
                buffer.get(bytes);
                System.out.println(new String(bytes));
                buffer.clear();
            }
        }

        private void handWrite(SocketChannel channel) {
            // nothing
        }

        /**
         * 把新连接注册到本线程
         */
        public void register(SocketChannel clientChannel) {
            submit(() -> {
                try {
                    clientChannel.configureBlocking(false);
                    clientChannel.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE);
                } catch (Exception e) {
                    try {
                        clientChannel.close();
                    } catch (IOException ex) {
                        ex.printStackTrace();
                    }
                }
            });
            selector.wakeup();
        }

        public void submit(Runnable task) {
            tasks.offer(task);
        }

        @Override
        public synchronized void start() {
            start = true;
            super.start();
        }

        public void shutdown() {
            this.start = false;
            this.interrupt();
        }

        public boolean isStart() {
            return start;
        }
    }

    static class SelectNextSelectorThread {
        private final Collection<? extends SelectorThread> threads;
        private Iterator<? extends SelectorThread> iterator;

        public <T extends SelectorThread> SelectNextSelectorThread(Collection<T> threads) {
            this.threads = threads;
            iterator = this.threads.iterator();
        }

        /**
         * 选择下一个SelectorThread,这里为轮询
         *
         * @return SelectorThread
         */
        public SelectorThread nextThread() {
            if (!iterator.hasNext()) {
                iterator = threads.iterator();
            }
            return iterator.next();
        }
    }
}
NonblockingServer

猜你喜欢

转载自www.cnblogs.com/Alpharun/p/12092623.html
今日推荐