Socket封装之聊天程序(二)

▼魔方 西西 提交于 2019-11-26 03:57:55

  今天,学习一下socket的封装。

类图

  首先,我们把需要封装的各个类初步的设计如下:

Socket封装之聊天程序(二)
  接下来,我们建立类与类之间的关系:
Socket封装之聊天程序(二)
  其中,CStream类可有可无,这个类是用来封装各种读写流的。

socket封装

stream类

stream.h:

class CStream
{
public:
    CStream(int fd = -1);
    ~CStream();

    void SetFd(int fd);
    int GetFd();
    int Read(char *buf, int count);     //阻塞读   
    int Read(char *buf, int count, int sec);    //超时读
//  int Read(char *buf, int count, CAddress &addr);     //UDP读
    bool Write(char *buf,int count);    //普通写

private:
    int m_fd;
};

stream.cpp:

include "stream.h"

CStream::CStream( int fd /*= -1*/ )
{
    this->m_fd = fd;
}

CStream::~CStream()
{

}

void CStream::SetFd( int fd )
{
    this->m_fd = fd;
}

int CStream::GetFd()
{
    return this->m_fd;
}

int CStream::Read( char *buf, int count )
{
    int nByte;
    nByte = read(m_fd,buf,count);
    if (nByte == -1)
    {
        perror("read");
    }
    return nByte;
}

int CStream::Read( char *buf, int count, int sec )
{
    //用select
    fd_set set;
    int nEvent;
    struct timeval time = {0};

    FD_ZERO(&set);
    FD_SET(m_fd,&set);

    time.tv_sec = sec;
    nEvent = select(m_fd + 1, &set, NULL, NULL, &time);

    if (nEvent == -1)
    {
        perror("select");
    }
    else if (nEvent == 0) 
    {
        printf("time out.\n");
    }
    else
    {
        return Read(buf,count);
    }
    return nEvent;
}

bool CStream::Write( char *buf,int count )
{
    int nByte;
    nByte = write(m_fd,buf,count);
    if (nByte == -1)
    {
        perror("write");
        return false;
    }
    return true;
}

socket基类

SocketBase.h:

typedef enum socket_type
{
    tcp_sock,
    udp_sock
}SOCKET_TYPE;

class CSocketBase
{
public:
    CSocketBase(int fd = -1);
    CSocketBase(char *ip, unsigned short port, int fd);

    void SetFd(int fd);
    int GetFd();
    void SetAddr(char *ip, unsigned short port);
    void SetAddr(CAddress &addr);
    CAddress GetAddr();
    bool Socket(int type = tcp_sock);
    bool Bind();
    virtual int Read(char *buf,int count);
    virtual bool Write(char *buf,int count);
    bool Close();

protected:
    int m_fd;
    CAddress m_addr;
    CStream m_stream;
private:
};

SocketBase.cpp:

CSocketBase::CSocketBase( int fd /*= -1*/ )
{
    m_fd = fd;
}

CSocketBase::CSocketBase( char *ip, unsigned short port, int fd )
{
    m_addr.SetIP(ip);
    m_addr.SetPort(port);
    m_fd = fd;
}

void CSocketBase::SetFd( int fd )
{
    m_fd = fd;
}

int CSocketBase::GetFd()
{
    return m_fd;
}

void CSocketBase::SetAddr( char *ip, unsigned short port )
{
    m_addr.SetIP(ip);
    m_addr.SetPort(port);
}

void CSocketBase::SetAddr( CAddress &addr )
{
    m_addr.SetIP(addr.GetIP());
    m_addr.SetPort(addr.GetPort());
}

CAddress CSocketBase::GetAddr()
{
    return m_addr;
}

bool CSocketBase::Socket( int type /*= TCP_SOCK*/ )
{
    if (type == tcp_sock)
    {
        m_fd = socket(PF_INET, SOCK_STREAM ,0);
    }
    else
    {
        m_fd = socket(PF_INET, SOCK_DGRAM ,0);
    }

    if (m_fd == -1)
    {
        perror("socket");
        return false;
    }
    m_stream.SetFd(m_fd);
    return true;    
}

bool CSocketBase::Bind()
{
    int ret = bind(m_fd,m_addr.GetAddr(),m_addr.GetAddrSize());
    if (ret == -1)
    {
        perror("bind");
        return false;
    }
    return true;
}

int CSocketBase::Read( char *buf,int count )
{

}

bool CSocketBase::Write( char *buf,int count )
{

}

bool CSocketBase::Close()
{
    if (close(this->m_fd) == -1)
    {
        return false;
    }
    return true;
}

TcpSocket类

TcpSocket.h:

class CTcpServer:public CTcpSocket
{
public:
    CTcpServer(char *ip,unsigned short port,int fd);

    bool Listen(int backlog);
    CTcpClient Accept();

protected:
private:
};

TcpSocket.cpp:

CTcpSocket::CTcpSocket( int fd /*= -1*/ )
:CSocketBase(fd)
{

}

CTcpSocket::CTcpSocket( char *ip,unsigned short port, int fd)
:CSocketBase(ip,port,fd)
{

}

int CTcpSocket::Read( char *buf, int count )
{
    return m_stream.Read(buf,count);
}

bool CTcpSocket::Write( char *buf, int count )
{
    return m_stream.Write(buf,count);
}

客户端、服务器封装

TcpServer类

TcpServer.h:

class CTcpServer:public CTcpSocket
{
public:
    CTcpServer(char *ip,unsigned short port,int fd);

    bool Listen(int backlog);
    CTcpClient Accept();

protected:
private:
};

TcpServer.cpp:

CTcpServer::CTcpServer( char *ip,unsigned short port, int fd)
:CTcpSocket(ip,port,fd)
{

}

bool CTcpServer::Listen( int backlog )
{
    int ret = listen(m_fd,backlog);
    if(ret == -1)
    {
        perror("listen");
        return false;
    }
    return true;
}

CTcpClient CTcpServer::Accept()
{
    CTcpClient client;
    int conn_fd;
    CAddress conn_addr;
    conn_fd = accept(m_fd,conn_addr.GetAddr(),conn_addr.GetAddrSizePtr());
    if (conn_fd == -1)
    {
        perror("accept");
        //抛出异常

    }
    client.SetFd(conn_fd);
    client.SetAddr(conn_addr);

    return client;  //记得重载拷贝构造??深拷贝
}

TcpClient类

TcpClient.h:

class CTcpClient:public CTcpSocket
{
public:
    CTcpClient();
    CTcpClient(char *ip,unsigned short port,int fd);
    bool Connect(CAddress &ser_addr);
protected:
private:
};

TcpClient.cpp:

CTcpClient::CTcpClient()
:CTcpSocket()
{

}

CTcpClient::CTcpClient( char *ip,unsigned short port,int fd )
:CTcpSocket(ip,port,fd)
{

}

bool CTcpClient::Connect( CAddress &ser_addr )
{
    int ret = connect(m_fd,ser_addr.GetAddr(),ser_addr.GetAddrSize());
    if(ret == -1)
    {
        perror("connect");
        return false;
    }
    return true;
}

主函数

  接下来,只要修改我们之前的server.cpp和client.cpp就可以了。
servercpp:

#include "common.h"

#define MAX_LISTEN_SIZE 10
#define MAX_EPOLL_SIZE 1000
#define MAX_EVENTS 20

int main()
{   
    int sockfd;
    int connfd;

    int reuse = 0;
    int epfd; 
    int nEvent = 0;
    struct epoll_event event = {0};
    struct epoll_event rtlEvents[MAX_EVENTS] = {0};
    char acbuf[100] = "";
    int ret;

    PK_HEAD head = {0};     //包头
    PK_LOGIN login ={0};    //登录包
    PK_CHAT chat = {0};     //聊天包
    int reply;              //登录应答包。 1-成功 0-失败

    //1.socket()
    char ip[20] = "192.168.159.6";
    unsigned short port = 1234;
    CTcpServer server(ip,port,sockfd);
    SOCKET_TYPE type = tcp_sock;
    server.Socket(type);

    //2.bind()
    server.Bind();

    //3.listen()
    server.Listen(MAX_LISTEN_SIZE);

    //4.epoll初始化
    epfd = epoll_create(MAX_EPOLL_SIZE);    //创建
    event.data.fd = server.GetFd();
    event.events = EPOLLIN ;    
    epoll_ctl(epfd,EPOLL_CTL_ADD,server.GetFd(),&event);    //添加sockfd

    CTcpClient conn;
    //5.通信
    while(1)
    {
        nEvent = epoll_wait(epfd,rtlEvents,MAX_EVENTS,-1);   //阻塞
        if(nEvent == -1)
        {
            perror("epoll_wait");
            return -1;
        }
        else if(nEvent == 0)
        {
            printf("time out.");
        }
        else
        {
            //有事件发生,立即处理
            for(int i = 0; i < nEvent; i++)
            {
                //如果是 sockfd
                if( rtlEvents[i].data.fd == server.GetFd() )
                {
                    conn = server.Accept();
                    //添加到事件集合
                    event.data.fd = conn.GetFd();
                    event.events = EPOLLIN;  
                    epoll_ctl(epfd,EPOLL_CTL_ADD,conn.GetFd(),&event);
                    printf("client ip:%s ,port:%u connect.\n",conn.GetAddr().GetIP(),conn.GetAddr().GetPort());
                }
                else    //否则 connfd 
                {
                    ret = read(rtlEvents[i].data.fd,acbuf,100);
                    if( ret == 0) //客户端退出
                    {
                        close(rtlEvents[i].data.fd);
                        //从集合里删除
                        epoll_ctl(epfd,EPOLL_CTL_DEL,rtlEvents[i].data.fd,rtlEvents);
                        //从用户列表删除
                        string username;
                        for (it = userMap.begin(); it != userMap.end(); it++)
                        {
                            if (it->second == rtlEvents[i].data.fd)
                            {
                                username = it->first;
                                userMap.erase(it);
                                break;
                            }
                        }
                        printf("client ip:%s ,port:%u disconnect.\n",conn.GetAddr().GetIP(),conn.GetAddr().GetPort());
                        cout<<"client "<<username<<" exit."<<endl;
                    }
                    else
                    {   
                        //解包
                        memset(&head,0,sizeof(head));
                        memcpy(&head,acbuf,HEAD_SIZE);

                        switch(head.type)
                        {
                        case 1:
                            memset(&login,0,sizeof(login));
                            memcpy(&login,acbuf + HEAD_SIZE,LOGIN_SIZE);
                            //通过connfd,区分不同客户端
                            //如果重复登录,失败,让前一个账号下线 ; 如果登录成功,服务器要发送一个应答包给客户端。
                            if ( (it = userMap.find(login.name)) != userMap.end())
                            {
                                reply = LOGIN_FAIL;
                                memset(acbuf,0,100);
                                head.size = 4;
                                memcpy(acbuf,&head,HEAD_SIZE);
                                memcpy(acbuf + HEAD_SIZE , &reply , 4);
                                write(it->second,acbuf,HEAD_SIZE + 4);  //登录失败应答包
                                printf("client %s relogin.\n",login.name);
                            }
                            else
                            {
                                printf("client %s login.\n",login.name);
                            }
                            reply = LOGIN_OK;
                            memcpy(acbuf + HEAD_SIZE , &reply , 4);
                            write(rtlEvents[i].data.fd,acbuf,HEAD_SIZE + 4);    //登录成功应答包
                            userMap.insert(pair<string,int>(login.name,rtlEvents[i].data.fd));

                            break;

                        case 2: 
                            memset(&chat,0,CHAT_SIZE);
                            memcpy(&chat,acbuf + HEAD_SIZE,CHAT_SIZE);
                            if(strcmp(chat.toName,"all") == 0)
                            {
                                //群聊
                                for (it = userMap.begin(); it != userMap.end(); it++)
                                {
                                    //转发消息
                                    if (it->second != rtlEvents[i].data.fd)
                                    {
                                        write(it->second, acbuf, HEAD_SIZE + CHAT_SIZE);
                                    }

                                }
                            }
                            else
                            {
                                //私聊
                                if ( (it = userMap.find(chat.toName)) != userMap.end()) //找到了
                                {
                                    //转发消息
                                    write(it->second, acbuf, HEAD_SIZE + CHAT_SIZE);
                                }
                                else    //用户不存在
                                {
                                    memset(&chat.msg,0,100);
                                    strcpy(chat.msg,"the acccount is not exist.");
                                    memcpy(acbuf + HEAD_SIZE, &chat, CHAT_SIZE);
                                    write(rtlEvents[i].data.fd, acbuf, HEAD_SIZE + CHAT_SIZE);
                                }
                            }
                            break;

                        case 3:
                            memcpy(acbuf + HEAD_SIZE,&userMap,USERS_SIZE);
                            write(rtlEvents[i].data.fd, acbuf, HEAD_SIZE + USERS_SIZE);
                            break;

                        }

                    }

                }

            }
        }

    }

    return 0;
}

client.cpp:

#include "common.h"

void handler(int no)
{
    exit(0);
}

int main()
{

    int ret;
    int sockfd;
    short int port = 0;
    char ip[20] = "";
    char acbuf[100] = "";

    PK_HEAD head = {0};     //包头
    PK_LOGIN login ={0};    //登录包
    PK_CHAT chat = {0};     //聊天包
    int reply;              //登录应答包
    pid_t pid;

    printf("input ip: ");
    fflush(stdout);
    scanf("%s",ip);
    printf("input port: ");
    fflush(stdout);
    scanf("%d",&port);

    //1.socket();
    CTcpClient client;
    client.Socket();
    sockfd = client.GetFd();

    //2.连接connect() 服务器的地址
    CAddress serAddr(ip,port);
    if (client.Connect(serAddr))
    {
        printf("connect OK.\n");
    }

    //登录
    printf("input username: ");
    fflush(stdout);
    scanf("%s",login.name);

    head.type = 1;
    head.size = LOGIN_SIZE;
    memcpy(acbuf,&head,HEAD_SIZE);
    memcpy(acbuf + HEAD_SIZE,&login,LOGIN_SIZE);
    client.Write(acbuf, HEAD_SIZE + LOGIN_SIZE);

    //登录成功之后,主进程负责发包,子进程负责收包

    pid  = fork();
    if (pid == -1)
    {
        perror("fork");
        return -1;
    }
    else if (pid == 0) //子进程
    {
        while(1)
        {
            //一直读
            memset(acbuf,0,100);
            client.Read(acbuf,100);
            //解包
            memset(&head,0,HEAD_SIZE);
            memcpy(&head,acbuf,HEAD_SIZE);
            switch(head.type)
            {
            case 1: 
                //登录
                memcpy(&reply,acbuf + HEAD_SIZE,4);
                if (reply == LOGIN_FAIL)
                {
                    printf("your account has been logged in elsewhere.\n");
                    kill(getppid(),SIGINT); //结束程序
                    return 0;
                }
                else if (reply == LOGIN_OK)
                {
                    printf("login OK.\n");
                }
                else
                {
                    printf("login failed.\n");
                    kill(getppid(),SIGINT); //结束程序
                    return 0;
                }
                break;
            case 2:
                //聊天
                memset(&chat,0,CHAT_SIZE);
                memcpy(&chat,acbuf + HEAD_SIZE,CHAT_SIZE);
                printf("from %s: %s\n",chat.fromName,chat.msg); 
                break;

            case 3:
                //用户列表
//              memset(&userMap,0,USERS_SIZE);
//              memcpy(&userMap,acbuf + HEAD_SIZE,USERS_SIZE);
//              printf("========= online users =========\n");
//              for (it = userMap.begin(); it != userMap.end() ; it++)
//              {
//                  cout<<it->first<<endl;
//              }
//              printf("================================\n");

            }

        }

    }
    else
    {
        signal(SIGINT,handler);
        //3.通信
        while(1)
        {
            printf("#");
            fflush(stdout);
            scanf("%s",acbuf);
            if(strcmp(acbuf,"exit") == 0)
            {
                kill(pid,SIGINT);
                break;
            }
            else if(strcmp(acbuf,"chat") == 0)
            {
                printf("to who: ");
                fflush(stdout);
                scanf("%s",chat.toName);
                printf("say: ");
                fflush(stdout);
                scanf("%s",chat.msg);
                strcpy(chat.fromName ,login.name);
                memset(acbuf,0,100);

                head.type = 2;
                memcpy(acbuf,&head,HEAD_SIZE);
                memcpy(acbuf + HEAD_SIZE ,&chat, CHAT_SIZE);
                client.Write(acbuf, HEAD_SIZE + CHAT_SIZE);
            }
            else if (strcmp(acbuf,"users") == 0)
            {
                head.type = 3;
                head.size = USERS_SIZE;
                memset(acbuf,0,100);
                memcpy(acbuf,&head,HEAD_SIZE);
                client.Write(acbuf,HEAD_SIZE + USERS_SIZE);

            }
        }

    }

    //4.close()
    if (!client.Close())
    {
        perror("close");
    }
    return 0;
}

  搞定,(保证正常运行就OK了,不过,这里功能还有问题)下次,封装epoll~

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!