手写rpc框架

天大地大妈咪最大 提交于 2020-02-07 03:06:05

手写rpc框架

rpc概念

rpc是什么

远程过程调用(Remote Procedure Call,缩写为 RPC)是一个计算机通信协议。该协议允许运行于一台计算机的程序调用另一台计算机的子程序,而程序员无需额外地为这个交互作用编程。如果涉及的软件采用面向对象编程,那么远程过程调用亦可称作远程调用远程方法调用

为什么要用rpc

RPC框架介于传输层和应用中间,它会帮助处理:

  • 服务化
  • 可重用
  • 系统间交互调用

rpc核心概念术语

  • Client、Server、calls、replies、service、programs、procedures、version、marshalling(编组)、unmarshalling(解组)
  • 一个网络服务由一个或多个远程程序集构成
  • 一个远程程序实现一个或多个远程过程
  • 过程、过程的参数、结果在程序协议说明书中定义说明
  • 为兼容程序协议变更,一个服务端可能支持多个版本的远程程序

rpc的流程

在这里插入图片描述

  1. 客户端处理过程中调用Client Stub(就像调用本地方法一样),传递参数;

  2. Client Stub 将参数编组为消息,然后通过系统调用向服务端发送消息;

  3. 客户端本地操作系统将消息从客户端机器发送到服务端机器;

  4. 服务端操作系统将接收到的数据包传递给Server Stub;

  5. Server Stub解组消息为参数;

  6. Server Stub再调用服务端的过程,过程执行结果以反方向的相同步骤响应给客户端。

开发rpc框架

用户使用rpc框架的步骤如下:

  1. 定义过程定义接口

  2. 服务端实现过程

  3. 客户端使用生成的stub代理对象

所以在开发rpc框架中,需实现客户端和服务端。

设计客户端

代理对象生成

首先考虑客户端如何生成过程接口的代理对象。在设计中,设计客户端代理工厂,用JDK动态代理即可生成接口的代理对象。类图如下图所示。

在这里插入图片描述

发现者

设计客户端的时候,在ClientStubInvocationHandler中需要完成的两件事为编组消息和发送网络请求,而将请求的内容编组为消息这件事就交由客户端的stub代理,它除了消息协议和网络层的事务以外,可能还存在一个服务信息发现。此外消息协议可能也是会存在变化的,我们也需要去支持多种协议。此时我们需要得知某服务用的是什么协议,所以我们需要引入一个服务发现者。

在这里插入图片描述

协议层

想要做到支持多种协议,类该如何设计(面向接口,策略模式,组合)。

在这里插入图片描述

此时又存在一些问题,单纯依靠编组和解组的方法是不够的,编组和解组的操作对象是请求、响应,但是它们的内容是不同的,此时我们又需要定义框架标准的请求、响应类。

在这里插入图片描述

此时协议层扩展为4个方法。将消息协议独立为一层,因为客户端和服务端都需要使用。

在这里插入图片描述

网络层

网络层的工作主要是发送请求和获得响应,此时我们如果需要发起网络请求必定先要知道服务地址,此时我们利用下图中serviceInfo对象作为必须依赖,setRequest()方法里面会存在发送数据,还有发送给谁。

在这里插入图片描述

实现客户端

按照之前的类图设计,进行填码。

代理对象生成

public class ClientStubProxyFactory {

	private ServiceInfoDiscoverer sid;

	private Map<String, MessageProtocol> supportMessageProtocols;

	private NetClient netClient;

	private Map<Class<?>, Object> objectCache = new HashMap<>();

	public <T> T getProxy(Class<T> interf) {
		T obj = (T) this.objectCache.get(interf);
		if (obj == null) {
			obj = (T) Proxy.newProxyInstance(interf.getClassLoader(), new Class<?>[] { interf },
					new ClientStubInvocationHandler(interf));
			this.objectCache.put(interf, obj);
		}

		return obj;
	}

	public ServiceInfoDiscoverer getSid() {
		return sid;
	}

	public void setSid(ServiceInfoDiscoverer sid) {
		this.sid = sid;
	}

	public Map<String, MessageProtocol> getSupportMessageProtocols() {
		return supportMessageProtocols;
	}

	public void setSupportMessageProtocols(Map<String, MessageProtocol> supportMessageProtocols) {
		this.supportMessageProtocols = supportMessageProtocols;
	}

	public NetClient getNetClient() {
		return netClient;
	}

	public void setNetClient(NetClient netClient) {
		this.netClient = netClient;
	}

	private class ClientStubInvocationHandler implements InvocationHandler {
		private Class<?> interf;

		private Random random = new Random();

		public ClientStubInvocationHandler(Class<?> interf) {
			super();
			this.interf = interf;
		}

		@Override
		public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {

			if (method.getName().equals("toString")) {
				return proxy.getClass().toString();
			}

			if (method.getName().equals("hashCode")) {
				return 0;
			}

			// 1、获得服务信息
			String serviceName = this.interf.getName();
			List<ServiceInfo> sinfos = sid.getServiceInfo(serviceName);

			if (sinfos == null || sinfos.size() == 0) {
				throw new Exception("远程服务不存在!");
			}

			// 随机选择一个服务提供者(软负载均衡)
			ServiceInfo sinfo = sinfos.get(random.nextInt(sinfos.size()));

			// 2、构造request对象
			Request req = new Request();
			req.setServiceName(sinfo.getName());
			req.setMethod(method.getName());
			req.setPrameterTypes(method.getParameterTypes());
			req.setParameters(args);

			// 3、协议层编组
			// 获得该方法对应的协议
			MessageProtocol protocol = supportMessageProtocols.get(sinfo.getProtocol());
			// 编组请求
			byte[] data = protocol.marshallingRequest(req);

			// 4、调用网络层发送请求
			byte[] repData = netClient.sendRequest(data, sinfo);

			// 5解组响应消息
			Response rsp = protocol.unmarshallingResponse(repData);

			// 6、结果处理
			if (rsp.getException() != null) {
				throw rsp.getException();
			}

			return rsp.getReturnValue();
		}
	}
}

发现者

public class ServiceInfo {

	private String name;

	private String protocol;

	private String address;

	public String getName() {
		return name;
	}

	public void setName(String name) {
		this.name = name;
	}

	public String getProtocol() {
		return protocol;
	}

	public void setProtocol(String protocol) {
		this.protocol = protocol;
	}

	public String getAddress() {
		return address;
	}

	public void setAddress(String address) {
		this.address = address;
	}

}

public interface ServiceInfoDiscoverer {
	List<ServiceInfo> getServiceInfo(String name);
}

zookeeper的服务发现实现如下:

public class ZookeeperServiceInfoDiscoverer implements ServiceInfoDiscoverer {

	ZkClient client;

	private String centerRootPath = "/Rpc-framework";

	public ZookeeperServiceInfoDiscoverer() {
		String addr = PropertiesUtils.getProperties("zk.address");
		client = new ZkClient(addr);
		client.setZkSerializer(new MyZkSerializer());
	}

	@Override
	public List<ServiceInfo> getServiceInfo(String name) {
		String servicePath = centerRootPath + "/" + name + "/service";
		List<String> children = client.getChildren(servicePath);
		List<ServiceInfo> resources = new ArrayList<ServiceInfo>();
		for (String ch : children) {
			try {
				String deCh = URLDecoder.decode(ch, "UTF-8");
				ServiceInfo r = JSON.parseObject(deCh, ServiceInfo.class);
				resources.add(r);
			} catch (UnsupportedEncodingException e) {
				e.printStackTrace();
			}
		}
		return resources;
	}

}

协议层

public interface MessageProtocol {

	byte[] marshallingRequest(Request req) throws Exception;

	Request unmarshallingRequest(byte[] data) throws Exception;

	byte[] marshallingResponse(Response rsp) throws Exception;

	Response unmarshallingResponse(byte[] data) throws Exception;
}
public class JavaSerializeMessageProtocol implements MessageProtocol {

	private byte[] serialize(Object obj) throws Exception {
		ByteArrayOutputStream bout = new ByteArrayOutputStream();
		ObjectOutputStream out = new ObjectOutputStream(bout);
		out.writeObject(obj);

		return bout.toByteArray();
	}

	@Override
	public byte[] marshallingRequest(Request req) throws Exception {

		return this.serialize(req);
	}

	@Override
	public Request unmarshallingRequest(byte[] data) throws Exception {
		ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(data));
		return (Request) in.readObject();
	}

	@Override
	public byte[] marshallingResponse(Response rsp) throws Exception {
		return this.serialize(rsp);
	}

	@Override
	public Response unmarshallingResponse(byte[] data) throws Exception {
		ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(data));
		return (Response) in.readObject();
	}

}
public class Request implements Serializable {

	/**
	 * 
	 */
	private static final long serialVersionUID = -5200571424236772650L;

	private String serviceName;

	private String method;

	private Map<String, String> headers = new HashMap<String, String>();

	private Class<?>[] prameterTypes;

	private Object[] parameters;

	public String getServiceName() {
		return serviceName;
	}

	public void setServiceName(String serviceName) {
		this.serviceName = serviceName;
	}

	public String getMethod() {
		return method;
	}

	public void setMethod(String method) {
		this.method = method;
	}

	public Map<String, String> getHeaders() {
		return headers;
	}

	public void setHeaders(Map<String, String> headers) {
		this.headers = headers;
	}

	public Class<?>[] getPrameterTypes() {
		return prameterTypes;
	}

	public void setPrameterTypes(Class<?>[] prameterTypes) {
		this.prameterTypes = prameterTypes;
	}

	public void setParameters(Object[] prameters) {
		this.parameters = prameters;
	}

	public String getHeader(String name) {
		return this.headers == null ? null : this.headers.get(name);
	}

	public Object[] getParameters() {
		return this.parameters;
	}

}

public class Response implements Serializable {

	/**
	 * 
	 */
	private static final long serialVersionUID = -4317845782629589997L;

	private Status status;

	private Map<String, String> headers = new HashMap<>();

	private Object returnValue;

	private Exception exception;

	public Response() {
	};

	public Response(Status status) {
		this.status = status;
	}

	public void setStatus(Status status) {
		this.status = status;
	}

	public void setHeaders(Map<String, String> headers) {
		this.headers = headers;
	}

	public void setReturnValue(Object returnValue) {
		this.returnValue = returnValue;
	}

	public void setException(Exception exception) {
		this.exception = exception;
	}

	public Status getStatus() {
		return status;
	}

	public Map<String, String> getHeaders() {
		return headers;
	}

	public Object getReturnValue() {
		return returnValue;
	}

	public Exception getException() {
		return exception;
	}

	public String getHeader(String name) {
		return this.headers == null ? null : this.headers.get(name);
	}

	public void setHaader(String name, String value) {
		this.headers.put(name, value);

	}
}

public enum Status {
	SUCCESS(200, "SUCCESS"), ERROR(500, "ERROR"), NOT_FOUND(404, "NOT FOUND");

	private int code;

	private String message;

	private Status(int code, String message) {
		this.code = code;
		this.message = message;
	}

	public int getCode() {
		return code;
	}

	public String getMessage() {
		return message;
	}
}

网络层

public interface NetClient {
	byte[] sendRequest(byte[] data, ServiceInfo sinfo) throws Throwable;
}
public class NettyNetClient implements NetClient {

	private static Logger logger = LoggerFactory.getLogger(NettyNetClient.class);

	@Override
	public byte[] sendRequest(byte[] data, ServiceInfo sinfo) throws Throwable {

		String[] addInfoArray = sinfo.getAddress().split(":");

		SendHandler sendHandler = new SendHandler(data);
		byte[] respData = null;
		// 配置客户端
		EventLoopGroup group = new NioEventLoopGroup();
		try {
			Bootstrap b = new Bootstrap();

			b.group(group).channel(NioSocketChannel.class).option(ChannelOption.TCP_NODELAY, true)
					.handler(new ChannelInitializer<SocketChannel>() {
						@Override
						public void initChannel(SocketChannel ch) throws Exception {
							ChannelPipeline p = ch.pipeline();
							p.addLast(sendHandler);
						}
					});

			// 启动客户端连接
			b.connect(addInfoArray[0], Integer.valueOf(addInfoArray[1])).sync();
			respData = (byte[]) sendHandler.rspData();
			logger.info("sendRequest get reply: " + respData);

		} finally {
			// 释放线程组资源
			group.shutdownGracefully();
		}

		return respData;
	}

	private class SendHandler extends ChannelInboundHandlerAdapter {

		private CountDownLatch cdl = null;
		private Object readMsg = null;
		private byte[] data;

		public SendHandler(byte[] data) {
			cdl = new CountDownLatch(1);
			this.data = data;
		}

		@Override
		public void channelActive(ChannelHandlerContext ctx) throws Exception {
			logger.info("连接服务端成功:" + ctx);
			ByteBuf reqBuf = Unpooled.buffer(data.length);
			reqBuf.writeBytes(data);
			logger.info("客户端发送消息:" + reqBuf);
			ctx.writeAndFlush(reqBuf);
		}

		public Object rspData() {

			try {
				cdl.await();
			} catch (InterruptedException e) {
				e.printStackTrace();
			}

			return readMsg;
		}

		@Override
		public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
			logger.info("client read msg: " + msg);
			ByteBuf msgBuf = (ByteBuf) msg;
			byte[] resp = new byte[msgBuf.readableBytes()];
			msgBuf.readBytes(resp);
			readMsg = resp;
			cdl.countDown();
		}

		@Override
		public void channelReadComplete(ChannelHandlerContext ctx) {
			ctx.flush();
		}

		@Override
		public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
			// Close the connection when an exception is raised.
			cause.printStackTrace();
			logger.error("发生异常:" + cause.getMessage());
			ctx.close();
		}
	}
}

设计服务端

RPCServer

客户端请求过来了,服务端首先通过RPCServer接收请求。在RPCServer中需开启网络服务。

在这里插入图片描述

RequestHandler

RPCServer接收到请求以后,将请求交给RequestHandler处理,RequestHandler调用协议层来解组请求消息为Request对象,然后调用过程。消息协议层是复用客户端设计的。

在这里插入图片描述

ServiceRegister

ServiceRegister模块实现服务注册、发布。

在这里插入图片描述

实现服务端

RPCServer

public abstract class RpcServer {

	protected int port;

	protected String protocol;

	protected RequestHandler handler;

	public RpcServer(int port, String protocol, RequestHandler handler) {
		super();
		this.port = port;
		this.protocol = protocol;
		this.handler = handler;
	}

	/**
	 * 开启服务
	 */
	public abstract void start();

	/**
	 * 停止服务
	 */
	public abstract void stop();

	public int getPort() {
		return port;
	}

	public void setPort(int port) {
		this.port = port;
	}

	public String getProtocol() {
		return protocol;
	}

	public void setProtocol(String protocol) {
		this.protocol = protocol;
	}

	public RequestHandler getHandler() {
		return handler;
	}

	public void setHandler(RequestHandler handler) {
		this.handler = handler;
	}

}
public class NettyRpcServer extends RpcServer {
	private static Logger logger = LoggerFactory.getLogger(NettyRpcServer.class);

	private Channel channel;

	public NettyRpcServer(int port, String protocol, RequestHandler handler) {
		super(port, protocol, handler);
	}

	@Override
	public void start() {
		// 配置服务器
		EventLoopGroup bossGroup = new NioEventLoopGroup(1);
		EventLoopGroup workerGroup = new NioEventLoopGroup();
		try {
			ServerBootstrap b = new ServerBootstrap();
			b.group(bossGroup, workerGroup).channel(NioServerSocketChannel.class).option(ChannelOption.SO_BACKLOG, 100)
					.handler(new LoggingHandler(LogLevel.INFO)).childHandler(new ChannelInitializer<SocketChannel>() {
						@Override
						public void initChannel(SocketChannel ch) throws Exception {
							ChannelPipeline p = ch.pipeline();
							p.addLast(new ChannelRequestHandler());
						}
					});

			// 启动服务
			ChannelFuture f = b.bind(port).sync();
			logger.info("完成服务端端口绑定与启动");
			channel = f.channel();
			// 等待服务通道关闭
			f.channel().closeFuture().sync();
		} catch (Exception e) {
			e.printStackTrace();
		} finally {
			// 释放线程组资源
			bossGroup.shutdownGracefully();
			workerGroup.shutdownGracefully();
		}
	}

	@Override
	public void stop() {
		this.channel.close();
	}

	private class ChannelRequestHandler extends ChannelInboundHandlerAdapter {

		@Override
		public void channelActive(ChannelHandlerContext ctx) throws Exception {
			logger.info("激活");
		}

		@Override
		public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
			logger.info("服务端收到消息:" + msg);
			ByteBuf msgBuf = (ByteBuf) msg;
			byte[] req = new byte[msgBuf.readableBytes()];
			msgBuf.readBytes(req);
			byte[] res = handler.handleRequest(req);
			logger.info("发送响应:" + msg);
			ByteBuf respBuf = Unpooled.buffer(res.length);
			respBuf.writeBytes(res);
			ctx.write(respBuf);
		}

		@Override
		public void channelReadComplete(ChannelHandlerContext ctx) {
			ctx.flush();
		}

		@Override
		public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
			// Close the connection when an exception is raised.
			cause.printStackTrace();
			logger.error("发生异常:" + cause.getMessage());
			ctx.close();
		}
	}

}

RequestHandler

public class RequestHandler {
	private MessageProtocol protocol;

	private ServiceRegister serviceRegister;

	public RequestHandler(MessageProtocol protocol, ServiceRegister serviceRegister) {
		super();
		this.protocol = protocol;
		this.serviceRegister = serviceRegister;
	}

	public byte[] handleRequest(byte[] data) throws Exception {
		// 1、解组消息
		Request req = this.protocol.unmarshallingRequest(data);

		// 2、查找服务对象
		ServiceObject so = this.serviceRegister.getServiceObject(req.getServiceName());

		Response rsp = null;

		if (so == null) {
			rsp = new Response(Status.NOT_FOUND);
		} else {
			// 3、反射调用对应的过程方法
			try {
				Method m = so.getInterf().getMethod(req.getMethod(), req.getPrameterTypes());
				Object returnValue = m.invoke(so.getObj(), req.getParameters());
				rsp = new Response(Status.SUCCESS);
				rsp.setReturnValue(returnValue);
			} catch (NoSuchMethodException | SecurityException | IllegalAccessException | IllegalArgumentException
					| InvocationTargetException e) {
				rsp = new Response(Status.ERROR);
				rsp.setException(e);
			}
		}

		// 4、编组响应消息
		return this.protocol.marshallingResponse(rsp);
	}

	public MessageProtocol getProtocol() {
		return protocol;
	}

	public void setProtocol(MessageProtocol protocol) {
		this.protocol = protocol;
	}

	public ServiceRegister getServiceRegister() {
		return serviceRegister;
	}

	public void setServiceRegister(ServiceRegister serviceRegister) {
		this.serviceRegister = serviceRegister;
	}

}

ServiceRegister

public interface ServiceRegister {

	void register(ServiceObject so, String protocol, int port) throws Exception;

	ServiceObject getServiceObject(String name) throws Exception;
}
public class ServiceObject {

	private String name;

	private Class<?> interf;

	private Object obj;

	public ServiceObject(String name, Class<?> interf, Object obj) {
		super();
		this.name = name;
		this.interf = interf;
		this.obj = obj;
	}

	public String getName() {
		return name;
	}

	public void setName(String name) {
		this.name = name;
	}

	public Class<?> getInterf() {
		return interf;
	}

	public void setInterf(Class<?> interf) {
		this.interf = interf;
	}

	public Object getObj() {
		return obj;
	}

	public void setObj(Object obj) {
		this.obj = obj;
	}

}
public class DefaultServiceRegister implements ServiceRegister {

	private Map<String, ServiceObject> serviceMap = new HashMap<>();

	@Override
	public void register(ServiceObject so, String protocolName, int port) throws Exception {
		if (so == null) {
			throw new IllegalArgumentException("参数不能为空");
		}

		this.serviceMap.put(so.getName(), so);
	}

	@Override
	public ServiceObject getServiceObject(String name) {
		return this.serviceMap.get(name);
	}

}
/**
 * Zookeeper方式获取远程服务信息类。
 * 
 * ZookeeperServiceInfoDiscoverer
 */
public class ZookeeperExportServiceRegister extends DefaultServiceRegister implements ServiceRegister {

	private ZkClient client;

	private String centerRootPath = "/Rpc-framework";

	public ZookeeperExportServiceRegister() {
		String addr = PropertiesUtils.getProperties("zk.address");
		client = new ZkClient(addr);
		client.setZkSerializer(new MyZkSerializer());
	}

	@Override
	public void register(ServiceObject so, String protocolName, int port) throws Exception {
		super.register(so, protocolName, port);
		ServiceInfo soInf = new ServiceInfo();

		String host = InetAddress.getLocalHost().getHostAddress();
		String address = host + ":" + port;
		soInf.setAddress(address);
		soInf.setName(so.getInterf().getName());
		soInf.setProtocol(protocolName);
		this.exportService(soInf);

	}

	private void exportService(ServiceInfo serviceResource) {
		String serviceName = serviceResource.getName();
		String uri = JSON.toJSONString(serviceResource);
		try {
			uri = URLEncoder.encode(uri, "UTF-8");
		} catch (UnsupportedEncodingException e) {
			e.printStackTrace();
		}
		String servicePath = centerRootPath + "/" + serviceName + "/service";
		if (!client.exists(servicePath)) {
			client.createPersistent(servicePath, true);
		}
		String uriPath = servicePath + "/" + uri;
		if (client.exists(uriPath)) {
			client.delete(uriPath);
		}
		client.createEphemeral(uriPath);
	}
}

实现高并发 RPC 框架的要素

实现高并发 RPC 框架的要素,总结起来有三个要点:

  1. 选择高性能的 I/O 模型,这里推荐使用同步多路 I/O 复用模型;
  2. 调试网络参数,这里面有一些经验值的推荐。比如将 tcp_nodelay 设置为 true,也有一些参数需要在运行中来调试,比如接受缓冲区和发送缓冲区的大小,客户端连接请求缓冲队列的大小(back log)等等;
  3. 序列化协议依据具体业务来选择。如果对性能要求不高可以选择 JSON,否则可以从 Thrift 和 Protobuf 中选择其一。
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!