Netty实现Rpc调用

在这里插入图片描述

定义接口

public interface IRpcHelloService {
    
    
    String hello(String name);  
} 
public interface IRpcService {
    
    

	/** 加 */
	public int add(int a,int b);

	/** 减 */
	public int sub(int a,int b);

	/** 乘 */
	public int mult(int a,int b);

	/** 除 */
	public int div(int a,int b);

}

接口实现类

public class RpcHelloServiceImpl implements IRpcHelloService {
    
    

    public String hello(String name) {
    
      
        return "Hello " + name + "!";  
    }  
  
} 
public class RpcServiceImpl implements IRpcService {
    
    

	public int add(int a, int b) {
    
    
		return a + b;
	}

	public int sub(int a, int b) {
    
    
		return a - b;
	}

	public int mult(int a, int b) {
    
    
		return a * b;
	}

	public int div(int a, int b) {
    
    
		return a / b;
	}

}

启动注册中心,provider进行注册

  • 注册:为给个服务指定自己的服务名称
  • 对其所在的位置做一个标记
public class RpcRegistry {
    
      
    private int port;  
    public RpcRegistry(int port){
    
      
        this.port = port;  
    }  
    public void start(){
    
      
        EventLoopGroup bossGroup = new NioEventLoopGroup();  
        EventLoopGroup workerGroup = new NioEventLoopGroup();  
          
        try {
    
      
            ServerBootstrap b = new ServerBootstrap();
            b.group(bossGroup, workerGroup)
            		.channel(NioServerSocketChannel.class)  
                    .childHandler(new ChannelInitializer<SocketChannel>() {
    
    
  
                        @Override  
                        protected void initChannel(SocketChannel ch) throws Exception {
    
      
                            //在Netty中,客户端的请求都会放到一个队列中
                            
                            ChannelPipeline pipeline = ch.pipeline();
                            //自定义协议解码器
                            /** 入参有5个,分别解释如下
                             maxFrameLength:框架的最大长度。如果帧的长度大于此值,则将抛出TooLongFrameException。
                             lengthFieldOffset:长度字段的偏移量:即对应的长度字段在整个消息数据中得位置
                             lengthFieldLength:长度字段的长度。如:长度字段是int型表示,那么这个值就是4(long型就是8)
                             lengthAdjustment:要添加到长度字段值的补偿值
                             initialBytesToStrip:从解码帧中去除的第一个字节数
                             */
                            pipeline.addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4));
                            //自定义协议编码器
                            pipeline.addLast(new LengthFieldPrepender(4));
                            //对象参数类型编码器
                            pipeline.addLast("encoder",new ObjectEncoder());
                            //对象参数类型解码器
                            pipeline.addLast("decoder",new ObjectDecoder(Integer.MAX_VALUE,ClassResolvers.cacheDisabled(null)));
                            //以上完成对象的解析
                            //以下处理自己的业务逻辑
                            pipeline.addLast(new RegistryHandler());
                        }  
                    })
                    .option(ChannelOption.SO_BACKLOG, 128)       
                    .childOption(ChannelOption.SO_KEEPALIVE, true);  
            //服务启动,相当于死循环
            ChannelFuture future = b.bind(port).sync();      
            System.out.println("RPC Registry start listen at " + port );
            future.channel().closeFuture().sync();    
        } catch (Exception e) {
    
      
             bossGroup.shutdownGracefully();    
             workerGroup.shutdownGracefully();  
        }  
    }
    
    
    public static void main(String[] args) throws Exception {
    
        
        new RpcRegistry(8080).start();    
    }    
}  
自处理逻辑
  • 进行包扫描,保存可使用的provider服务map
  • 针对每次的网络调用,获取指定的provider服务
  • 由provider服务完成对指定方法的调用
import java.io.File;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;

public class RegistryHandler  extends ChannelInboundHandlerAdapter {
    
    

	//用保存所有可用的服务
    public static ConcurrentHashMap<String, Object> registryMap = new ConcurrentHashMap<String,Object>();

    //保存所有相关的服务类
    private List<String> classNames = new ArrayList<String>();
    
    public RegistryHandler(){
    
    
    	//完成递归扫描,此处扫描的是本地
    	scannerClass("com.gupaoedu.vip.netty.rpc.provider");
    	doRegister();
    }
    
    //有客户端连接,发生回调
    @Override    
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
    
    
    	//每进行一次网络调用,此方法执行一次
    	Object result = new Object();
        InvokerProtocol request = (InvokerProtocol)msg;

        //当客户端建立连接时,需要从自定义协议中获取信息,拿到具体的服务和实参
		//使用反射调用
        if(registryMap.containsKey(request.getClassName())){
    
     
        	Object clazz = registryMap.get(request.getClassName());
        	Method method = clazz.getClass().getMethod(request.getMethodName(), request.getParames());    
        	result = method.invoke(clazz, request.getValues());   
        }
        ctx.write(result);  
        ctx.flush();    
        ctx.close();  
    }
    
    //客户端连接出错时,发生回调
    @Override    
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
    
        
         cause.printStackTrace();    
         ctx.close();    
    }
    

    /*
     * 递归扫描
     */
	private void scannerClass(String packageName){
    
    
	//替换包路径(含.)为文件夹路径(含/)
		URL url = this.getClass().getClassLoader().getResource(packageName.replaceAll("\\.", "/"));
		File dir = new File(url.getFile());
		for (File file : dir.listFiles()) {
    
    
			//如果是一个文件夹,继续递归
			if(file.isDirectory()){
    
    
				scannerClass(packageName + "." + file.getName());
			}else{
    
    
				classNames.add(packageName + "." + file.getName().replace(".class", "").trim());
			}
		}
	}

	/**
	 * 完成注册
	 */
	private void doRegister(){
    
    
		if(classNames.size() == 0){
    
     return; }
		for (String className : classNames) {
    
    
			try {
    
    
				Class<?> clazz = Class.forName(className);
				Class<?> i = clazz.getInterfaces()[0];
				//实际这里应该保存的是提供端所在地址IP
				registryMap.put(i.getName(), clazz.newInstance());
			} catch (Exception e) {
    
    
				e.printStackTrace();
			}
		}
	}
  
}

定义传输协议类

 * 自定义传输协议
 */
@Data
public class InvokerProtocol implements Serializable {
    
    

    private String className;//类名
    private String methodName;//函数名称 
    private Class<?>[] parames;//形参列表
    private Object[] values;//实参列表

}

开始调用

	
    public static void main(String [] args){
    
      
    		
    	//使用代理模式,判断接口为本地接口还是远程接口
    	//如果为远程接口,则手动包装自定义协议,发送网络请求
        IRpcHelloService rpcHello = RpcProxy.create(IRpcHelloService.class);
        
        System.out.println(rpcHello.hello("Tom老师"));

        IRpcService service = RpcProxy.create(IRpcService.class);
        
        System.out.println("8 + 2 = " + service.add(8, 2));
        System.out.println("8 - 2 = " + service.sub(8, 2));
        System.out.println("8 * 2 = " + service.mult(8, 2));
        System.out.println("8 / 2 = " + service.div(8, 2));
    }
    
}

远程调用发送

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;

import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.codec.serialization.ClassResolvers;
import io.netty.handler.codec.serialization.ObjectDecoder;
import io.netty.handler.codec.serialization.ObjectEncoder;

public class RpcProxy {
    
      
	
	public static <T> T create(Class<?> clazz){
    
    
        //clazz传进来本身就是interface
        MethodProxy proxy = new MethodProxy(clazz);
        Class<?> [] interfaces = clazz.isInterface() ?
                                new Class[]{
    
    clazz} :
                                clazz.getInterfaces();
        T result = (T) Proxy.newProxyInstance(clazz.getClassLoader(),interfaces,proxy);
        return result;
    }

	private static class MethodProxy implements InvocationHandler {
    
    
		private Class<?> clazz;
		public MethodProxy(Class<?> clazz){
    
    
			this.clazz = clazz;
		}

		public Object invoke(Object proxy, Method method, Object[] args)  throws Throwable {
    
    
			//如果传进来是一个已实现的具体类(本次演示略过此逻辑)
			if (Object.class.equals(method.getDeclaringClass())) {
    
    
				try {
    
    
					return method.invoke(this, args);
				} catch (Throwable t) {
    
    
					t.printStackTrace();
				}
				//如果传进来的是一个接口(核心)
			} else {
    
    
				return rpcInvoke(proxy,method, args);
			}
			return null;
		}


		/**
		 * 实现接口的核心方法
		 * @param method
		 * @param args
		 * @return
		 */
		public Object rpcInvoke(Object proxy,Method method,Object[] args){
    
    

			//传输协议封装
			InvokerProtocol msg = new InvokerProtocol();
			msg.setClassName(this.clazz.getName());
			msg.setMethodName(method.getName());
			msg.setValues(args);
			msg.setParames(method.getParameterTypes());

			final RpcProxyHandler consumerHandler = new RpcProxyHandler();
			EventLoopGroup group = new NioEventLoopGroup();
			try {
    
    
				Bootstrap b = new Bootstrap();
				b.group(group)
						.channel(NioSocketChannel.class)
						.option(ChannelOption.TCP_NODELAY, true)
						.handler(new ChannelInitializer<SocketChannel>() {
    
    
							@Override
							public void initChannel(SocketChannel ch) throws Exception {
    
    
								ChannelPipeline pipeline = ch.pipeline();
								//自定义协议解码器
								/** 入参有5个,分别解释如下
								 maxFrameLength:框架的最大长度。如果帧的长度大于此值,则将抛出TooLongFrameException。
								 lengthFieldOffset:长度字段的偏移量:即对应的长度字段在整个消息数据中得位置
								 lengthFieldLength:长度字段的长度:如:长度字段是int型表示,那么这个值就是4(long型就是8)
								 lengthAdjustment:要添加到长度字段值的补偿值
								 initialBytesToStrip:从解码帧中去除的第一个字节数
								 */
								pipeline.addLast("frameDecoder", new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4));
								//自定义协议编码器
								pipeline.addLast("frameEncoder", new LengthFieldPrepender(4));
								//对象参数类型编码器
								pipeline.addLast("encoder", new ObjectEncoder());
								//对象参数类型解码器
								pipeline.addLast("decoder", new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.cacheDisabled(null)));
								pipeline.addLast("handler",consumerHandler);
							}
						});

				ChannelFuture future = b.connect("localhost", 8080).sync();
				future.channel().writeAndFlush(msg).sync();
				future.channel().closeFuture().sync();
			} catch(Exception e){
    
    
				e.printStackTrace();
			}finally {
    
    
				group.shutdownGracefully();
			}
			return consumerHandler.getResponse();
		}

	}
}
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;

public class RpcProxyHandler extends ChannelInboundHandlerAdapter {
    
      
	  
    private Object response;    
      
    public Object getResponse() {
    
        
	    return response;    
	}    
  
    @Override    
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
    
        
        response = msg;
    }    
        
    @Override    
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
    
        
        System.out.println("client exception is general");    
    }    
} 

猜你喜欢

转载自blog.csdn.net/weixin_44971379/article/details/120604735