今天,学习一下socket的封装。
类图
首先,我们把需要封装的各个类初步的设计如下:
接下来,我们建立类与类之间的关系:
其中,CStream类可有可无,这个类是用来封装各种读写流的。
socket封装
stream类
stream.h:
class CStream
{
public:
CStream(int fd = -1);
~CStream();
void SetFd(int fd);
int GetFd();
int Read(char *buf, int count); //阻塞读
int Read(char *buf, int count, int sec); //超时读
// int Read(char *buf, int count, CAddress &addr); //UDP读
bool Write(char *buf,int count); //普通写
private:
int m_fd;
};
stream.cpp:
include "stream.h"
CStream::CStream( int fd /*= -1*/ )
{
this->m_fd = fd;
}
CStream::~CStream()
{
}
void CStream::SetFd( int fd )
{
this->m_fd = fd;
}
int CStream::GetFd()
{
return this->m_fd;
}
int CStream::Read( char *buf, int count )
{
int nByte;
nByte = read(m_fd,buf,count);
if (nByte == -1)
{
perror("read");
}
return nByte;
}
int CStream::Read( char *buf, int count, int sec )
{
//用select
fd_set set;
int nEvent;
struct timeval time = {0};
FD_ZERO(&set);
FD_SET(m_fd,&set);
time.tv_sec = sec;
nEvent = select(m_fd + 1, &set, NULL, NULL, &time);
if (nEvent == -1)
{
perror("select");
}
else if (nEvent == 0)
{
printf("time out.\n");
}
else
{
return Read(buf,count);
}
return nEvent;
}
bool CStream::Write( char *buf,int count )
{
int nByte;
nByte = write(m_fd,buf,count);
if (nByte == -1)
{
perror("write");
return false;
}
return true;
}
socket基类
SocketBase.h:
typedef enum socket_type
{
tcp_sock,
udp_sock
}SOCKET_TYPE;
class CSocketBase
{
public:
CSocketBase(int fd = -1);
CSocketBase(char *ip, unsigned short port, int fd);
void SetFd(int fd);
int GetFd();
void SetAddr(char *ip, unsigned short port);
void SetAddr(CAddress &addr);
CAddress GetAddr();
bool Socket(int type = tcp_sock);
bool Bind();
virtual int Read(char *buf,int count);
virtual bool Write(char *buf,int count);
bool Close();
protected:
int m_fd;
CAddress m_addr;
CStream m_stream;
private:
};
SocketBase.cpp:
CSocketBase::CSocketBase( int fd /*= -1*/ )
{
m_fd = fd;
}
CSocketBase::CSocketBase( char *ip, unsigned short port, int fd )
{
m_addr.SetIP(ip);
m_addr.SetPort(port);
m_fd = fd;
}
void CSocketBase::SetFd( int fd )
{
m_fd = fd;
}
int CSocketBase::GetFd()
{
return m_fd;
}
void CSocketBase::SetAddr( char *ip, unsigned short port )
{
m_addr.SetIP(ip);
m_addr.SetPort(port);
}
void CSocketBase::SetAddr( CAddress &addr )
{
m_addr.SetIP(addr.GetIP());
m_addr.SetPort(addr.GetPort());
}
CAddress CSocketBase::GetAddr()
{
return m_addr;
}
bool CSocketBase::Socket( int type /*= TCP_SOCK*/ )
{
if (type == tcp_sock)
{
m_fd = socket(PF_INET, SOCK_STREAM ,0);
}
else
{
m_fd = socket(PF_INET, SOCK_DGRAM ,0);
}
if (m_fd == -1)
{
perror("socket");
return false;
}
m_stream.SetFd(m_fd);
return true;
}
bool CSocketBase::Bind()
{
int ret = bind(m_fd,m_addr.GetAddr(),m_addr.GetAddrSize());
if (ret == -1)
{
perror("bind");
return false;
}
return true;
}
int CSocketBase::Read( char *buf,int count )
{
}
bool CSocketBase::Write( char *buf,int count )
{
}
bool CSocketBase::Close()
{
if (close(this->m_fd) == -1)
{
return false;
}
return true;
}
TcpSocket类
TcpSocket.h:
class CTcpServer:public CTcpSocket
{
public:
CTcpServer(char *ip,unsigned short port,int fd);
bool Listen(int backlog);
CTcpClient Accept();
protected:
private:
};
TcpSocket.cpp:
CTcpSocket::CTcpSocket( int fd /*= -1*/ )
:CSocketBase(fd)
{
}
CTcpSocket::CTcpSocket( char *ip,unsigned short port, int fd)
:CSocketBase(ip,port,fd)
{
}
int CTcpSocket::Read( char *buf, int count )
{
return m_stream.Read(buf,count);
}
bool CTcpSocket::Write( char *buf, int count )
{
return m_stream.Write(buf,count);
}
客户端、服务器封装
TcpServer类
TcpServer.h:
class CTcpServer:public CTcpSocket
{
public:
CTcpServer(char *ip,unsigned short port,int fd);
bool Listen(int backlog);
CTcpClient Accept();
protected:
private:
};
TcpServer.cpp:
CTcpServer::CTcpServer( char *ip,unsigned short port, int fd)
:CTcpSocket(ip,port,fd)
{
}
bool CTcpServer::Listen( int backlog )
{
int ret = listen(m_fd,backlog);
if(ret == -1)
{
perror("listen");
return false;
}
return true;
}
CTcpClient CTcpServer::Accept()
{
CTcpClient client;
int conn_fd;
CAddress conn_addr;
conn_fd = accept(m_fd,conn_addr.GetAddr(),conn_addr.GetAddrSizePtr());
if (conn_fd == -1)
{
perror("accept");
//抛出异常
}
client.SetFd(conn_fd);
client.SetAddr(conn_addr);
return client; //记得重载拷贝构造??深拷贝
}
TcpClient类
TcpClient.h:
class CTcpClient:public CTcpSocket
{
public:
CTcpClient();
CTcpClient(char *ip,unsigned short port,int fd);
bool Connect(CAddress &ser_addr);
protected:
private:
};
TcpClient.cpp:
CTcpClient::CTcpClient()
:CTcpSocket()
{
}
CTcpClient::CTcpClient( char *ip,unsigned short port,int fd )
:CTcpSocket(ip,port,fd)
{
}
bool CTcpClient::Connect( CAddress &ser_addr )
{
int ret = connect(m_fd,ser_addr.GetAddr(),ser_addr.GetAddrSize());
if(ret == -1)
{
perror("connect");
return false;
}
return true;
}
主函数
接下来,只要修改我们之前的server.cpp和client.cpp就可以了。
servercpp:
#include "common.h"
#define MAX_LISTEN_SIZE 10
#define MAX_EPOLL_SIZE 1000
#define MAX_EVENTS 20
int main()
{
int sockfd;
int connfd;
int reuse = 0;
int epfd;
int nEvent = 0;
struct epoll_event event = {0};
struct epoll_event rtlEvents[MAX_EVENTS] = {0};
char acbuf[100] = "";
int ret;
PK_HEAD head = {0}; //包头
PK_LOGIN login ={0}; //登录包
PK_CHAT chat = {0}; //聊天包
int reply; //登录应答包。 1-成功 0-失败
//1.socket()
char ip[20] = "192.168.159.6";
unsigned short port = 1234;
CTcpServer server(ip,port,sockfd);
SOCKET_TYPE type = tcp_sock;
server.Socket(type);
//2.bind()
server.Bind();
//3.listen()
server.Listen(MAX_LISTEN_SIZE);
//4.epoll初始化
epfd = epoll_create(MAX_EPOLL_SIZE); //创建
event.data.fd = server.GetFd();
event.events = EPOLLIN ;
epoll_ctl(epfd,EPOLL_CTL_ADD,server.GetFd(),&event); //添加sockfd
CTcpClient conn;
//5.通信
while(1)
{
nEvent = epoll_wait(epfd,rtlEvents,MAX_EVENTS,-1); //阻塞
if(nEvent == -1)
{
perror("epoll_wait");
return -1;
}
else if(nEvent == 0)
{
printf("time out.");
}
else
{
//有事件发生,立即处理
for(int i = 0; i < nEvent; i++)
{
//如果是 sockfd
if( rtlEvents[i].data.fd == server.GetFd() )
{
conn = server.Accept();
//添加到事件集合
event.data.fd = conn.GetFd();
event.events = EPOLLIN;
epoll_ctl(epfd,EPOLL_CTL_ADD,conn.GetFd(),&event);
printf("client ip:%s ,port:%u connect.\n",conn.GetAddr().GetIP(),conn.GetAddr().GetPort());
}
else //否则 connfd
{
ret = read(rtlEvents[i].data.fd,acbuf,100);
if( ret == 0) //客户端退出
{
close(rtlEvents[i].data.fd);
//从集合里删除
epoll_ctl(epfd,EPOLL_CTL_DEL,rtlEvents[i].data.fd,rtlEvents);
//从用户列表删除
string username;
for (it = userMap.begin(); it != userMap.end(); it++)
{
if (it->second == rtlEvents[i].data.fd)
{
username = it->first;
userMap.erase(it);
break;
}
}
printf("client ip:%s ,port:%u disconnect.\n",conn.GetAddr().GetIP(),conn.GetAddr().GetPort());
cout<<"client "<<username<<" exit."<<endl;
}
else
{
//解包
memset(&head,0,sizeof(head));
memcpy(&head,acbuf,HEAD_SIZE);
switch(head.type)
{
case 1:
memset(&login,0,sizeof(login));
memcpy(&login,acbuf + HEAD_SIZE,LOGIN_SIZE);
//通过connfd,区分不同客户端
//如果重复登录,失败,让前一个账号下线 ; 如果登录成功,服务器要发送一个应答包给客户端。
if ( (it = userMap.find(login.name)) != userMap.end())
{
reply = LOGIN_FAIL;
memset(acbuf,0,100);
head.size = 4;
memcpy(acbuf,&head,HEAD_SIZE);
memcpy(acbuf + HEAD_SIZE , &reply , 4);
write(it->second,acbuf,HEAD_SIZE + 4); //登录失败应答包
printf("client %s relogin.\n",login.name);
}
else
{
printf("client %s login.\n",login.name);
}
reply = LOGIN_OK;
memcpy(acbuf + HEAD_SIZE , &reply , 4);
write(rtlEvents[i].data.fd,acbuf,HEAD_SIZE + 4); //登录成功应答包
userMap.insert(pair<string,int>(login.name,rtlEvents[i].data.fd));
break;
case 2:
memset(&chat,0,CHAT_SIZE);
memcpy(&chat,acbuf + HEAD_SIZE,CHAT_SIZE);
if(strcmp(chat.toName,"all") == 0)
{
//群聊
for (it = userMap.begin(); it != userMap.end(); it++)
{
//转发消息
if (it->second != rtlEvents[i].data.fd)
{
write(it->second, acbuf, HEAD_SIZE + CHAT_SIZE);
}
}
}
else
{
//私聊
if ( (it = userMap.find(chat.toName)) != userMap.end()) //找到了
{
//转发消息
write(it->second, acbuf, HEAD_SIZE + CHAT_SIZE);
}
else //用户不存在
{
memset(&chat.msg,0,100);
strcpy(chat.msg,"the acccount is not exist.");
memcpy(acbuf + HEAD_SIZE, &chat, CHAT_SIZE);
write(rtlEvents[i].data.fd, acbuf, HEAD_SIZE + CHAT_SIZE);
}
}
break;
case 3:
memcpy(acbuf + HEAD_SIZE,&userMap,USERS_SIZE);
write(rtlEvents[i].data.fd, acbuf, HEAD_SIZE + USERS_SIZE);
break;
}
}
}
}
}
}
return 0;
}
client.cpp:
#include "common.h"
void handler(int no)
{
exit(0);
}
int main()
{
int ret;
int sockfd;
short int port = 0;
char ip[20] = "";
char acbuf[100] = "";
PK_HEAD head = {0}; //包头
PK_LOGIN login ={0}; //登录包
PK_CHAT chat = {0}; //聊天包
int reply; //登录应答包
pid_t pid;
printf("input ip: ");
fflush(stdout);
scanf("%s",ip);
printf("input port: ");
fflush(stdout);
scanf("%d",&port);
//1.socket();
CTcpClient client;
client.Socket();
sockfd = client.GetFd();
//2.连接connect() 服务器的地址
CAddress serAddr(ip,port);
if (client.Connect(serAddr))
{
printf("connect OK.\n");
}
//登录
printf("input username: ");
fflush(stdout);
scanf("%s",login.name);
head.type = 1;
head.size = LOGIN_SIZE;
memcpy(acbuf,&head,HEAD_SIZE);
memcpy(acbuf + HEAD_SIZE,&login,LOGIN_SIZE);
client.Write(acbuf, HEAD_SIZE + LOGIN_SIZE);
//登录成功之后,主进程负责发包,子进程负责收包
pid = fork();
if (pid == -1)
{
perror("fork");
return -1;
}
else if (pid == 0) //子进程
{
while(1)
{
//一直读
memset(acbuf,0,100);
client.Read(acbuf,100);
//解包
memset(&head,0,HEAD_SIZE);
memcpy(&head,acbuf,HEAD_SIZE);
switch(head.type)
{
case 1:
//登录
memcpy(&reply,acbuf + HEAD_SIZE,4);
if (reply == LOGIN_FAIL)
{
printf("your account has been logged in elsewhere.\n");
kill(getppid(),SIGINT); //结束程序
return 0;
}
else if (reply == LOGIN_OK)
{
printf("login OK.\n");
}
else
{
printf("login failed.\n");
kill(getppid(),SIGINT); //结束程序
return 0;
}
break;
case 2:
//聊天
memset(&chat,0,CHAT_SIZE);
memcpy(&chat,acbuf + HEAD_SIZE,CHAT_SIZE);
printf("from %s: %s\n",chat.fromName,chat.msg);
break;
case 3:
//用户列表
// memset(&userMap,0,USERS_SIZE);
// memcpy(&userMap,acbuf + HEAD_SIZE,USERS_SIZE);
// printf("========= online users =========\n");
// for (it = userMap.begin(); it != userMap.end() ; it++)
// {
// cout<<it->first<<endl;
// }
// printf("================================\n");
}
}
}
else
{
signal(SIGINT,handler);
//3.通信
while(1)
{
printf("#");
fflush(stdout);
scanf("%s",acbuf);
if(strcmp(acbuf,"exit") == 0)
{
kill(pid,SIGINT);
break;
}
else if(strcmp(acbuf,"chat") == 0)
{
printf("to who: ");
fflush(stdout);
scanf("%s",chat.toName);
printf("say: ");
fflush(stdout);
scanf("%s",chat.msg);
strcpy(chat.fromName ,login.name);
memset(acbuf,0,100);
head.type = 2;
memcpy(acbuf,&head,HEAD_SIZE);
memcpy(acbuf + HEAD_SIZE ,&chat, CHAT_SIZE);
client.Write(acbuf, HEAD_SIZE + CHAT_SIZE);
}
else if (strcmp(acbuf,"users") == 0)
{
head.type = 3;
head.size = USERS_SIZE;
memset(acbuf,0,100);
memcpy(acbuf,&head,HEAD_SIZE);
client.Write(acbuf,HEAD_SIZE + USERS_SIZE);
}
}
}
//4.close()
if (!client.Close())
{
perror("close");
}
return 0;
}
搞定,(保证正常运行就OK了,不过,这里功能还有问题)下次,封装epoll~
来源:51CTO
作者:SherryX
链接:https://blog.51cto.com/13097817/2066440