在开源中国里面看到一个入门级RPC框架实现的项目,使用Spring + Netty + Protostuff + ZooKeeper实现的。第一眼看到介绍就感觉这是个简化版本的dubbo,虽然只实现了里面rpc的这一块功能感觉也很值得学习一下。
目录:
1.介绍
2.使用方式与实现思路
3.部分源码分析
<!----分割线---->
1.介绍
rpc过程图:
spring:主要是用来依赖注入,通过对象序列化之后使用其代理调用目标接口
netty:简化nio开发,以数据流转过程中添加编码解码器来实现通信协议的开发,这里用到了nio,同时将序列化框架以编码解码器的方式整合进去
protostuff:序列化框架,类型产品很多,主要是因为jdk自带的序列化备受诟病
zk:用来维护一张服务列表,主要是利用其强数据一致性实现服务的动态上下线,用来实现服务的暴露、注册、发现等
整合起来就是目前rpc常见的一种实现方式了,rpc可以基于应用层协议实现也可以基于传输层协议实现,各有好坏。基于tcp实现rpc效率更高,没有http请求那么多冗余信息,但是对很多问题例如握手连接、断线重连、心跳检测等问题需要自己来开发增加开发难度,而http请求结果这儿长时间的发展已经把各种问题都考虑进去了。
2.使用方式与实现思路
使用方式和实现思路和dubbo类似的。
使用方式:
服务提供者:
(a)写服务提供者接口,实现服务提供者接口,将服务提供者配置到sprimg中,同时指定一个发布服务的端口。作者添加了注解的支持。
(b)封装了zk客户端实例到bean中,配置到spring中,同时指定zk服务ip地址用作服务注册。
(c)加载spring配置文件,服务提供者跑起来。
服务消费者:
(a)将服务提供者接口配置到spring中。用户调用他,可以远程调用服务提供者
(b)和上面一样,将zk客户端实例封装后配置到spring中,同时指定zk服务的ip地址,消费者从zk服务上面发现服务提供者(就是获取到服务提供者是否上线,跑在那个端口下)
(c)创建代理接口进行远程调用
实现思路:
核心就在自定义一个rpc协议利用netty在消费者端发送消息(消息中包含想要调用的方法、类、参数等),在提供者端接收消息本地调用后生成调用结果返回消息(消息中包含了调用状态、error、结果等)。
在消费者端利用jdk的动态代理生成一个提供者接口的代理对象,调用这个代理对象的目标方法(触发了上面的过程)并返回远程调用的结果。
3.部分源码分析
(a)封装zk客户端用来服务注册:
首先手动建立永久节点/registry,注册的服务都在这个节点下面建立临时节点。核心就是ServiceRegistry中的createNode方法。
//常量,zk客户端超时时间,临时节点的前后缀等 public interface Constant { int ZK_SESSION_TIMEOUT = 5000; String ZK_REGISTRY_PATH = "/registry"; String ZK_DATA_PATH = ZK_REGISTRY_PATH + "/data"; } //注册中心实现(注入到spring中) public class ServiceRegistry { private static final Logger LOGGER = LoggerFactory.getLogger(ServiceRegistry.class); private CountDownLatch latch = new CountDownLatch(1); private String registryAddress; public ServiceRegistry(String registryAddress) { this.registryAddress = registryAddress; } public void register(String data) { if (data != null) { ZooKeeper zk = connectServer(); if (zk != null) { createNode(zk, data); } } } private ZooKeeper connectServer() { ZooKeeper zk = null; try { zk = new ZooKeeper(registryAddress, Constant.ZK_SESSION_TIMEOUT, new Watcher() { @Override public void process(WatchedEvent event) { if (event.getState() == Event.KeeperState.SyncConnected) { latch.countDown(); } } }); latch.await(); } catch (IOException | InterruptedException e) { LOGGER.error("", e); } return zk; } private void createNode(ZooKeeper zk, String data) { try { byte[] bytes = data.getBytes(); String path = zk.create(Constant.ZK_DATA_PATH, bytes, ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.EPHEMERAL_SEQUENTIAL); LOGGER.debug("create zookeeper node ({} => {})", path, data); } catch (KeeperException | InterruptedException e) { LOGGER.error("", e); } } }
(b)封装zk客户端用来服务发现: 核心就是watchNode方法用来读取zk下面的临时节点,读取到就证明服务在线且可用。
public class ServiceDiscovery { private static final Logger LOGGER = LoggerFactory.getLogger(ServiceDiscovery.class); private CountDownLatch latch = new CountDownLatch(1); private volatile List<String> dataList = new ArrayList<>(); private String registryAddress; public ServiceDiscovery(String registryAddress) { this.registryAddress = registryAddress; ZooKeeper zk = connectServer(); if (zk != null) { watchNode(zk); } } public String discover() { String data = null; int size = dataList.size(); if (size > 0) { if (size == 1) { data = dataList.get(0); LOGGER.debug("using only data: {}", data); } else { data = dataList.get(ThreadLocalRandom.current().nextInt(size)); LOGGER.debug("using random data: {}", data); } } return data; } private ZooKeeper connectServer() { ZooKeeper zk = null; try { zk = new ZooKeeper(registryAddress, Constant.ZK_SESSION_TIMEOUT, new Watcher() { @Override public void process(WatchedEvent event) { if (event.getState() == Event.KeeperState.SyncConnected) { latch.countDown(); } } }); latch.await(); } catch (IOException | InterruptedException e) { LOGGER.error("", e); } return zk; } private void watchNode(final ZooKeeper zk) { try { List<String> nodeList = zk.getChildren(Constant.ZK_REGISTRY_PATH, new Watcher() { @Override public void process(WatchedEvent event) { if (event.getType() == Event.EventType.NodeChildrenChanged) { watchNode(zk); } } }); List<String> dataList = new ArrayList<>(); for (String node : nodeList) { byte[] bytes = zk.getData(Constant.ZK_REGISTRY_PATH + "/" + node, false, null); dataList.add(new String(bytes)); } LOGGER.debug("node data: {}", dataList); this.dataList = dataList; } catch (KeeperException | InterruptedException e) { LOGGER.error("", e); } } }
(c)开发netty:
服务的注册于发现主要是用来动态得管理服务,进行服务治理。
正正的rpc功能还是要靠netty来实现。作者的思路其实是利用netty自己基于tcp的基础上封装了一个rpc协议(rpcRequest用来传输序列化对象包括想要远程调用的接口类名、方法名、参数等信息,rpcRespnse用来返回调用状态、调用结果等信息)。
然后自定义rpc协议的编码器、解码器进行消息通信,具体如下:首先从zk上面获取注册的服务地址,在消费者端使用netty发送rpcRequest到服务提供者,服务提供者端利用自定义的rpc解码器解析rpcRequest获取到消费者想要调用的某个方法,然后利用反射进行调用,接着将结果构造成rpcResponse发回给消费者。消费者再次利用netty解析rpcResponse获取到调用结果。
整个过程利用的相关技术包括:自定义rpcRequest类、rpcResponse类,两个解码器,zk客户端读取临时节点,解码编码之间使用SimpleChannelInboundHandler来处理rpc请求(真正利用反射invoke()进行调用就是在这个类里面),序列化过程(序列化主要在消费者端进行,反射调用在提供者端进行,这个时候高效率的序列化框架就用上了)
rpcRequest/rpcResponse:
public class RpcRequest { private String requestId; private String className; private String methodName; private Class<?>[] parameterTypes; private Object[] parameters; // getter/setter... } public class RpcResponse { private String requestId; private Throwable error; private Object result; // getter/setter... }
rpcRequest/rpcResponse的编解码器:
public class RpcDecoder extends ByteToMessageDecoder { private Class<?> genericClass; public RpcDecoder(Class<?> genericClass) { this.genericClass = genericClass; } @Override public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { if (in.readableBytes() < 4) { return; } in.markReaderIndex(); int dataLength = in.readInt(); if (dataLength < 0) { ctx.close(); } if (in.readableBytes() < dataLength) { in.resetReaderIndex(); return; } byte[] data = new byte[dataLength]; in.readBytes(data); Object obj = SerializationUtil.deserialize(data, genericClass); out.add(obj); } } public class RpcEncoder extends MessageToByteEncoder { private Class<?> genericClass; public RpcEncoder(Class<?> genericClass) { this.genericClass = genericClass; } @Override public void encode(ChannelHandlerContext ctx, Object in, ByteBuf out) throws Exception { if (genericClass.isInstance(in)) { byte[] data = SerializationUtil.serialize(in); out.writeInt(data.length); out.writeBytes(data); } } }
序列化工具类:
整合了序列化框架,当然先用jdk原生的也可以,想用其他的如marshling、protobuf也可以,直接在工具类中修改实现就可以了。
public class SerializationUtil { private static Map<Class<?>, Schema<?>> cachedSchema = new ConcurrentHashMap<>(); private static Objenesis objenesis = new ObjenesisStd(true); private SerializationUtil() { } @SuppressWarnings("unchecked") private static <T> Schema<T> getSchema(Class<T> cls) { Schema<T> schema = (Schema<T>) cachedSchema.get(cls); if (schema == null) { schema = RuntimeSchema.createFrom(cls); if (schema != null) { cachedSchema.put(cls, schema); } } return schema; } @SuppressWarnings("unchecked") public static <T> byte[] serialize(T obj) { Class<T> cls = (Class<T>) obj.getClass(); LinkedBuffer buffer = LinkedBuffer.allocate(LinkedBuffer.DEFAULT_BUFFER_SIZE); try { Schema<T> schema = getSchema(cls); return ProtostuffIOUtil.toByteArray(obj, schema, buffer); } catch (Exception e) { throw new IllegalStateException(e.getMessage(), e); } finally { buffer.clear(); } } public static <T> T deserialize(byte[] data, Class<T> cls) { try { T message = (T) objenesis.newInstance(cls); Schema<T> schema = getSchema(cls); ProtostuffIOUtil.mergeFrom(data, message, schema); return message; } catch (Exception e) { throw new IllegalStateException(e.getMessage(), e); } } }
提供者端处理rpc请求的handler:
直接继承netty的SimpleChannelInboundHandler即可。
public class RpcHandler extends SimpleChannelInboundHandler<RpcRequest> { private static final Logger LOGGER = LoggerFactory.getLogger(RpcHandler.class); private final Map<String, Object> handlerMap; public RpcHandler(Map<String, Object> handlerMap) { this.handlerMap = handlerMap; } @Override public void channelRead0(final ChannelHandlerContext ctx, RpcRequest request) throws Exception { RpcResponse response = new RpcResponse(); response.setRequestId(request.getRequestId()); try { Object result = handle(request); response.setResult(result); } catch (Throwable t) { response.setError(t); } ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); } private Object handle(RpcRequest request) throws Throwable { String className = request.getClassName(); Object serviceBean = handlerMap.get(className); Class<?> serviceClass = serviceBean.getClass(); String methodName = request.getMethodName(); Class<?>[] parameterTypes = request.getParameterTypes(); Object[] parameters = request.getParameters(); /*Method method = serviceClass.getMethod(methodName, parameterTypes); method.setAccessible(true); return method.invoke(serviceBean, parameters);*/ FastClass serviceFastClass = FastClass.create(serviceClass); FastMethod serviceFastMethod = serviceFastClass.getMethod(methodName, parameterTypes); return serviceFastMethod.invoke(serviceBean, parameters); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { LOGGER.error("server caught exception", cause); ctx.close(); } }
public class RpcClient extends SimpleChannelInboundHandler<RpcResponse> { private static final Logger LOGGER = LoggerFactory.getLogger(RpcClient.class); private String host; private int port; private RpcResponse response; private final Object obj = new Object(); public RpcClient(String host, int port) { this.host = host; this.port = port; } @Override public void channelRead0(ChannelHandlerContext ctx, RpcResponse response) throws Exception { this.response = response; synchronized (obj) { obj.notifyAll(); // 收到响应,唤醒线程 } } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { LOGGER.error("client caught exception", cause); ctx.close(); } public RpcResponse send(RpcRequest request) throws Exception { EventLoopGroup group = new NioEventLoopGroup(); try { Bootstrap bootstrap = new Bootstrap(); bootstrap.group(group).channel(NioSocketChannel.class) .handler(new ChannelInitializer<SocketChannel>() { @Override public void initChannel(SocketChannel channel) throws Exception { channel.pipeline() .addLast(new RpcEncoder(RpcRequest.class)) // 将 RPC 请求进行编码(为了发送请求) .addLast(new RpcDecoder(RpcResponse.class)) // 将 RPC 响应进行解码(为了处理响应) .addLast(RpcClient.this); // 使用 RpcClient 发送 RPC 请求 } }) .option(ChannelOption.SO_KEEPALIVE, true); ChannelFuture future = bootstrap.connect(host, port).sync(); future.channel().writeAndFlush(request).sync(); synchronized (obj) { obj.wait(); // 未收到响应,使线程等待 } if (response != null) { future.channel().closeFuture().sync(); } return response; } finally { group.shutdownGracefully(); } } }
(d)代理开发:
消费者想要调用提供者需要对象来调用其方法,消费者由于是远程调用所以使用一个代理对象,主要是根据想要调用的接口生成代理对象,核心在于代理对象的InvocationHandler方法中利用上面的netty发送rpc请求,获取rpc想要然后生成调用结果返回。所以使用代理对象调用目标方法时就得到了远程调用后的结果。这样就营造了一种远程方法像在本地调用一样的效果
public class RpcProxy { private String serverAddress; private ServiceDiscovery serviceDiscovery; public RpcProxy(String serverAddress) { this.serverAddress = serverAddress; } public RpcProxy(ServiceDiscovery serviceDiscovery) { this.serviceDiscovery = serviceDiscovery; } @SuppressWarnings("unchecked") public <T> T create(Class<?> interfaceClass) { return (T) Proxy.newProxyInstance( interfaceClass.getClassLoader(), new Class<?>[]{interfaceClass}, new InvocationHandler() { @Override public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { RpcRequest request = new RpcRequest(); // 创建并初始化 RPC 请求 request.setRequestId(UUID.randomUUID().toString()); request.setClassName(method.getDeclaringClass().getName()); request.setMethodName(method.getName()); request.setParameterTypes(method.getParameterTypes()); request.setParameters(args); if (serviceDiscovery != null) { serverAddress = serviceDiscovery.discover(); // 发现服务 } String[] array = serverAddress.split(":"); String host = array[0]; int port = Integer.parseInt(array[1]); RpcClient client = new RpcClient(host, port); // 初始化 RPC 客户端 RpcResponse response = client.send(request); // 通过 RPC 客户端发送 RPC 请求并获取 RPC 响应 if (response.isError()) { throw response.getError(); } else { return response.getResult(); } } } ); } }