文章目录
手写rpc框架
rpc概念
rpc是什么
远程过程调用(Remote Procedure Call,缩写为 RPC)是一个计算机通信协议。该协议允许运行于一台计算机的程序调用另一台计算机的子程序,而程序员无需额外地为这个交互作用编程。如果涉及的软件采用面向对象编程,那么远程过程调用亦可称作远程调用或远程方法调用。
为什么要用rpc
RPC框架介于传输层和应用中间,它会帮助处理:
- 服务化
- 可重用
- 系统间交互调用
rpc核心概念术语
- Client、Server、calls、replies、service、programs、procedures、version、marshalling(编组)、unmarshalling(解组)
- 一个网络服务由一个或多个远程程序集构成
- 一个远程程序实现一个或多个远程过程
- 过程、过程的参数、结果在程序协议说明书中定义说明
- 为兼容程序协议变更,一个服务端可能支持多个版本的远程程序
rpc的流程
-
客户端处理过程中调用Client Stub(就像调用本地方法一样),传递参数;
-
Client Stub 将参数编组为消息,然后通过系统调用向服务端发送消息;
-
客户端本地操作系统将消息从客户端机器发送到服务端机器;
-
服务端操作系统将接收到的数据包传递给Server Stub;
-
Server Stub解组消息为参数;
-
Server Stub再调用服务端的过程,过程执行结果以反方向的相同步骤响应给客户端。
开发rpc框架
用户使用rpc框架的步骤如下:
-
定义过程定义接口
-
服务端实现过程
-
客户端使用生成的stub代理对象
所以在开发rpc框架中,需实现客户端和服务端。
设计客户端
代理对象生成
首先考虑客户端如何生成过程接口的代理对象。在设计中,设计客户端代理工厂,用JDK动态代理即可生成接口的代理对象。类图如下图所示。
发现者
设计客户端的时候,在ClientStubInvocationHandler中需要完成的两件事为编组消息和发送网络请求,而将请求的内容编组为消息这件事就交由客户端的stub代理,它除了消息协议和网络层的事务以外,可能还存在一个服务信息发现。此外消息协议可能也是会存在变化的,我们也需要去支持多种协议。此时我们需要得知某服务用的是什么协议,所以我们需要引入一个服务发现者。
协议层
想要做到支持多种协议,类该如何设计(面向接口,策略模式,组合)。
此时又存在一些问题,单纯依靠编组和解组的方法是不够的,编组和解组的操作对象是请求、响应,但是它们的内容是不同的,此时我们又需要定义框架标准的请求、响应类。
此时协议层扩展为4个方法。将消息协议独立为一层,因为客户端和服务端都需要使用。
网络层
网络层的工作主要是发送请求和获得响应,此时我们如果需要发起网络请求必定先要知道服务地址,此时我们利用下图中serviceInfo对象作为必须依赖,setRequest()方法里面会存在发送数据,还有发送给谁。
实现客户端
按照之前的类图设计,进行填码。
代理对象生成
public class ClientStubProxyFactory {
private ServiceInfoDiscoverer sid;
private Map<String, MessageProtocol> supportMessageProtocols;
private NetClient netClient;
private Map<Class<?>, Object> objectCache = new HashMap<>();
public <T> T getProxy(Class<T> interf) {
T obj = (T) this.objectCache.get(interf);
if (obj == null) {
obj = (T) Proxy.newProxyInstance(interf.getClassLoader(), new Class<?>[] { interf },
new ClientStubInvocationHandler(interf));
this.objectCache.put(interf, obj);
}
return obj;
}
public ServiceInfoDiscoverer getSid() {
return sid;
}
public void setSid(ServiceInfoDiscoverer sid) {
this.sid = sid;
}
public Map<String, MessageProtocol> getSupportMessageProtocols() {
return supportMessageProtocols;
}
public void setSupportMessageProtocols(Map<String, MessageProtocol> supportMessageProtocols) {
this.supportMessageProtocols = supportMessageProtocols;
}
public NetClient getNetClient() {
return netClient;
}
public void setNetClient(NetClient netClient) {
this.netClient = netClient;
}
private class ClientStubInvocationHandler implements InvocationHandler {
private Class<?> interf;
private Random random = new Random();
public ClientStubInvocationHandler(Class<?> interf) {
super();
this.interf = interf;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if (method.getName().equals("toString")) {
return proxy.getClass().toString();
}
if (method.getName().equals("hashCode")) {
return 0;
}
// 1、获得服务信息
String serviceName = this.interf.getName();
List<ServiceInfo> sinfos = sid.getServiceInfo(serviceName);
if (sinfos == null || sinfos.size() == 0) {
throw new Exception("远程服务不存在!");
}
// 随机选择一个服务提供者(软负载均衡)
ServiceInfo sinfo = sinfos.get(random.nextInt(sinfos.size()));
// 2、构造request对象
Request req = new Request();
req.setServiceName(sinfo.getName());
req.setMethod(method.getName());
req.setPrameterTypes(method.getParameterTypes());
req.setParameters(args);
// 3、协议层编组
// 获得该方法对应的协议
MessageProtocol protocol = supportMessageProtocols.get(sinfo.getProtocol());
// 编组请求
byte[] data = protocol.marshallingRequest(req);
// 4、调用网络层发送请求
byte[] repData = netClient.sendRequest(data, sinfo);
// 5解组响应消息
Response rsp = protocol.unmarshallingResponse(repData);
// 6、结果处理
if (rsp.getException() != null) {
throw rsp.getException();
}
return rsp.getReturnValue();
}
}
}
发现者
public class ServiceInfo {
private String name;
private String protocol;
private String address;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getProtocol() {
return protocol;
}
public void setProtocol(String protocol) {
this.protocol = protocol;
}
public String getAddress() {
return address;
}
public void setAddress(String address) {
this.address = address;
}
}
public interface ServiceInfoDiscoverer {
List<ServiceInfo> getServiceInfo(String name);
}
zookeeper的服务发现实现如下:
public class ZookeeperServiceInfoDiscoverer implements ServiceInfoDiscoverer {
ZkClient client;
private String centerRootPath = "/Rpc-framework";
public ZookeeperServiceInfoDiscoverer() {
String addr = PropertiesUtils.getProperties("zk.address");
client = new ZkClient(addr);
client.setZkSerializer(new MyZkSerializer());
}
@Override
public List<ServiceInfo> getServiceInfo(String name) {
String servicePath = centerRootPath + "/" + name + "/service";
List<String> children = client.getChildren(servicePath);
List<ServiceInfo> resources = new ArrayList<ServiceInfo>();
for (String ch : children) {
try {
String deCh = URLDecoder.decode(ch, "UTF-8");
ServiceInfo r = JSON.parseObject(deCh, ServiceInfo.class);
resources.add(r);
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
}
}
return resources;
}
}
协议层
public interface MessageProtocol {
byte[] marshallingRequest(Request req) throws Exception;
Request unmarshallingRequest(byte[] data) throws Exception;
byte[] marshallingResponse(Response rsp) throws Exception;
Response unmarshallingResponse(byte[] data) throws Exception;
}
public class JavaSerializeMessageProtocol implements MessageProtocol {
private byte[] serialize(Object obj) throws Exception {
ByteArrayOutputStream bout = new ByteArrayOutputStream();
ObjectOutputStream out = new ObjectOutputStream(bout);
out.writeObject(obj);
return bout.toByteArray();
}
@Override
public byte[] marshallingRequest(Request req) throws Exception {
return this.serialize(req);
}
@Override
public Request unmarshallingRequest(byte[] data) throws Exception {
ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(data));
return (Request) in.readObject();
}
@Override
public byte[] marshallingResponse(Response rsp) throws Exception {
return this.serialize(rsp);
}
@Override
public Response unmarshallingResponse(byte[] data) throws Exception {
ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(data));
return (Response) in.readObject();
}
}
public class Request implements Serializable {
/**
*
*/
private static final long serialVersionUID = -5200571424236772650L;
private String serviceName;
private String method;
private Map<String, String> headers = new HashMap<String, String>();
private Class<?>[] prameterTypes;
private Object[] parameters;
public String getServiceName() {
return serviceName;
}
public void setServiceName(String serviceName) {
this.serviceName = serviceName;
}
public String getMethod() {
return method;
}
public void setMethod(String method) {
this.method = method;
}
public Map<String, String> getHeaders() {
return headers;
}
public void setHeaders(Map<String, String> headers) {
this.headers = headers;
}
public Class<?>[] getPrameterTypes() {
return prameterTypes;
}
public void setPrameterTypes(Class<?>[] prameterTypes) {
this.prameterTypes = prameterTypes;
}
public void setParameters(Object[] prameters) {
this.parameters = prameters;
}
public String getHeader(String name) {
return this.headers == null ? null : this.headers.get(name);
}
public Object[] getParameters() {
return this.parameters;
}
}
public class Response implements Serializable {
/**
*
*/
private static final long serialVersionUID = -4317845782629589997L;
private Status status;
private Map<String, String> headers = new HashMap<>();
private Object returnValue;
private Exception exception;
public Response() {
};
public Response(Status status) {
this.status = status;
}
public void setStatus(Status status) {
this.status = status;
}
public void setHeaders(Map<String, String> headers) {
this.headers = headers;
}
public void setReturnValue(Object returnValue) {
this.returnValue = returnValue;
}
public void setException(Exception exception) {
this.exception = exception;
}
public Status getStatus() {
return status;
}
public Map<String, String> getHeaders() {
return headers;
}
public Object getReturnValue() {
return returnValue;
}
public Exception getException() {
return exception;
}
public String getHeader(String name) {
return this.headers == null ? null : this.headers.get(name);
}
public void setHaader(String name, String value) {
this.headers.put(name, value);
}
}
public enum Status {
SUCCESS(200, "SUCCESS"), ERROR(500, "ERROR"), NOT_FOUND(404, "NOT FOUND");
private int code;
private String message;
private Status(int code, String message) {
this.code = code;
this.message = message;
}
public int getCode() {
return code;
}
public String getMessage() {
return message;
}
}
网络层
public interface NetClient {
byte[] sendRequest(byte[] data, ServiceInfo sinfo) throws Throwable;
}
public class NettyNetClient implements NetClient {
private static Logger logger = LoggerFactory.getLogger(NettyNetClient.class);
@Override
public byte[] sendRequest(byte[] data, ServiceInfo sinfo) throws Throwable {
String[] addInfoArray = sinfo.getAddress().split(":");
SendHandler sendHandler = new SendHandler(data);
byte[] respData = null;
// 配置客户端
EventLoopGroup group = new NioEventLoopGroup();
try {
Bootstrap b = new Bootstrap();
b.group(group).channel(NioSocketChannel.class).option(ChannelOption.TCP_NODELAY, true)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(sendHandler);
}
});
// 启动客户端连接
b.connect(addInfoArray[0], Integer.valueOf(addInfoArray[1])).sync();
respData = (byte[]) sendHandler.rspData();
logger.info("sendRequest get reply: " + respData);
} finally {
// 释放线程组资源
group.shutdownGracefully();
}
return respData;
}
private class SendHandler extends ChannelInboundHandlerAdapter {
private CountDownLatch cdl = null;
private Object readMsg = null;
private byte[] data;
public SendHandler(byte[] data) {
cdl = new CountDownLatch(1);
this.data = data;
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
logger.info("连接服务端成功:" + ctx);
ByteBuf reqBuf = Unpooled.buffer(data.length);
reqBuf.writeBytes(data);
logger.info("客户端发送消息:" + reqBuf);
ctx.writeAndFlush(reqBuf);
}
public Object rspData() {
try {
cdl.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
return readMsg;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
logger.info("client read msg: " + msg);
ByteBuf msgBuf = (ByteBuf) msg;
byte[] resp = new byte[msgBuf.readableBytes()];
msgBuf.readBytes(resp);
readMsg = resp;
cdl.countDown();
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) {
ctx.flush();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
// Close the connection when an exception is raised.
cause.printStackTrace();
logger.error("发生异常:" + cause.getMessage());
ctx.close();
}
}
}
设计服务端
RPCServer
客户端请求过来了,服务端首先通过RPCServer接收请求。在RPCServer中需开启网络服务。
RequestHandler
RPCServer接收到请求以后,将请求交给RequestHandler处理,RequestHandler调用协议层来解组请求消息为Request对象,然后调用过程。消息协议层是复用客户端设计的。
ServiceRegister
ServiceRegister模块实现服务注册、发布。
实现服务端
RPCServer
public abstract class RpcServer {
protected int port;
protected String protocol;
protected RequestHandler handler;
public RpcServer(int port, String protocol, RequestHandler handler) {
super();
this.port = port;
this.protocol = protocol;
this.handler = handler;
}
/**
* 开启服务
*/
public abstract void start();
/**
* 停止服务
*/
public abstract void stop();
public int getPort() {
return port;
}
public void setPort(int port) {
this.port = port;
}
public String getProtocol() {
return protocol;
}
public void setProtocol(String protocol) {
this.protocol = protocol;
}
public RequestHandler getHandler() {
return handler;
}
public void setHandler(RequestHandler handler) {
this.handler = handler;
}
}
public class NettyRpcServer extends RpcServer {
private static Logger logger = LoggerFactory.getLogger(NettyRpcServer.class);
private Channel channel;
public NettyRpcServer(int port, String protocol, RequestHandler handler) {
super(port, protocol, handler);
}
@Override
public void start() {
// 配置服务器
EventLoopGroup bossGroup = new NioEventLoopGroup(1);
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup).channel(NioServerSocketChannel.class).option(ChannelOption.SO_BACKLOG, 100)
.handler(new LoggingHandler(LogLevel.INFO)).childHandler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(new ChannelRequestHandler());
}
});
// 启动服务
ChannelFuture f = b.bind(port).sync();
logger.info("完成服务端端口绑定与启动");
channel = f.channel();
// 等待服务通道关闭
f.channel().closeFuture().sync();
} catch (Exception e) {
e.printStackTrace();
} finally {
// 释放线程组资源
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}
}
@Override
public void stop() {
this.channel.close();
}
private class ChannelRequestHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
logger.info("激活");
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
logger.info("服务端收到消息:" + msg);
ByteBuf msgBuf = (ByteBuf) msg;
byte[] req = new byte[msgBuf.readableBytes()];
msgBuf.readBytes(req);
byte[] res = handler.handleRequest(req);
logger.info("发送响应:" + msg);
ByteBuf respBuf = Unpooled.buffer(res.length);
respBuf.writeBytes(res);
ctx.write(respBuf);
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) {
ctx.flush();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
// Close the connection when an exception is raised.
cause.printStackTrace();
logger.error("发生异常:" + cause.getMessage());
ctx.close();
}
}
}
RequestHandler
public class RequestHandler {
private MessageProtocol protocol;
private ServiceRegister serviceRegister;
public RequestHandler(MessageProtocol protocol, ServiceRegister serviceRegister) {
super();
this.protocol = protocol;
this.serviceRegister = serviceRegister;
}
public byte[] handleRequest(byte[] data) throws Exception {
// 1、解组消息
Request req = this.protocol.unmarshallingRequest(data);
// 2、查找服务对象
ServiceObject so = this.serviceRegister.getServiceObject(req.getServiceName());
Response rsp = null;
if (so == null) {
rsp = new Response(Status.NOT_FOUND);
} else {
// 3、反射调用对应的过程方法
try {
Method m = so.getInterf().getMethod(req.getMethod(), req.getPrameterTypes());
Object returnValue = m.invoke(so.getObj(), req.getParameters());
rsp = new Response(Status.SUCCESS);
rsp.setReturnValue(returnValue);
} catch (NoSuchMethodException | SecurityException | IllegalAccessException | IllegalArgumentException
| InvocationTargetException e) {
rsp = new Response(Status.ERROR);
rsp.setException(e);
}
}
// 4、编组响应消息
return this.protocol.marshallingResponse(rsp);
}
public MessageProtocol getProtocol() {
return protocol;
}
public void setProtocol(MessageProtocol protocol) {
this.protocol = protocol;
}
public ServiceRegister getServiceRegister() {
return serviceRegister;
}
public void setServiceRegister(ServiceRegister serviceRegister) {
this.serviceRegister = serviceRegister;
}
}
ServiceRegister
public interface ServiceRegister {
void register(ServiceObject so, String protocol, int port) throws Exception;
ServiceObject getServiceObject(String name) throws Exception;
}
public class ServiceObject {
private String name;
private Class<?> interf;
private Object obj;
public ServiceObject(String name, Class<?> interf, Object obj) {
super();
this.name = name;
this.interf = interf;
this.obj = obj;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public Class<?> getInterf() {
return interf;
}
public void setInterf(Class<?> interf) {
this.interf = interf;
}
public Object getObj() {
return obj;
}
public void setObj(Object obj) {
this.obj = obj;
}
}
public class DefaultServiceRegister implements ServiceRegister {
private Map<String, ServiceObject> serviceMap = new HashMap<>();
@Override
public void register(ServiceObject so, String protocolName, int port) throws Exception {
if (so == null) {
throw new IllegalArgumentException("参数不能为空");
}
this.serviceMap.put(so.getName(), so);
}
@Override
public ServiceObject getServiceObject(String name) {
return this.serviceMap.get(name);
}
}
/**
* Zookeeper方式获取远程服务信息类。
*
* ZookeeperServiceInfoDiscoverer
*/
public class ZookeeperExportServiceRegister extends DefaultServiceRegister implements ServiceRegister {
private ZkClient client;
private String centerRootPath = "/Rpc-framework";
public ZookeeperExportServiceRegister() {
String addr = PropertiesUtils.getProperties("zk.address");
client = new ZkClient(addr);
client.setZkSerializer(new MyZkSerializer());
}
@Override
public void register(ServiceObject so, String protocolName, int port) throws Exception {
super.register(so, protocolName, port);
ServiceInfo soInf = new ServiceInfo();
String host = InetAddress.getLocalHost().getHostAddress();
String address = host + ":" + port;
soInf.setAddress(address);
soInf.setName(so.getInterf().getName());
soInf.setProtocol(protocolName);
this.exportService(soInf);
}
private void exportService(ServiceInfo serviceResource) {
String serviceName = serviceResource.getName();
String uri = JSON.toJSONString(serviceResource);
try {
uri = URLEncoder.encode(uri, "UTF-8");
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
}
String servicePath = centerRootPath + "/" + serviceName + "/service";
if (!client.exists(servicePath)) {
client.createPersistent(servicePath, true);
}
String uriPath = servicePath + "/" + uri;
if (client.exists(uriPath)) {
client.delete(uriPath);
}
client.createEphemeral(uriPath);
}
}
实现高并发 RPC 框架的要素
实现高并发 RPC 框架的要素,总结起来有三个要点:
- 选择高性能的 I/O 模型,这里推荐使用同步多路 I/O 复用模型;
- 调试网络参数,这里面有一些经验值的推荐。比如将 tcp_nodelay 设置为 true,也有一些参数需要在运行中来调试,比如接受缓冲区和发送缓冲区的大小,客户端连接请求缓冲队列的大小(back log)等等;
- 序列化协议依据具体业务来选择。如果对性能要求不高可以选择 JSON,否则可以从 Thrift 和 Protobuf 中选择其一。
来源:CSDN
作者:xiao_fo
链接:https://blog.csdn.net/u010397795/article/details/104198363