图神经网络 | (6) 图分类(SAGPool)实战

落爺英雄遲暮 提交于 2020-02-04 02:28:00

近期买了一本图神经网络的入门书,最近几篇博客对书中的一些实战案例进行整理,具体的理论和原理部分可以自行查阅该书,该书购买链接:《深入浅出的图神经网络》

该书配套代码

本节我们通过代码来实现基于自注意力的池化机制(Self-Attention Pooling)。来对图整体进行分类,之前我们是对节点分类,每个节点表示一条数据,学习节点的表示,进而完成分类,本节我们通过自注意力池化机制,得到整个图的表示,进而对全图进行分类(每个图表示一条数据)。

  • 导入必要的包
import os
import urllib
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import numpy as np
import scipy.sparse as sp
from zipfile import ZipFile
from sklearn.model_selection import train_test_split
import pickle
import pandas as pd
import torch_scatter #注意:torch_scatter 安装时编译需要用到cuda
from collections import Counter

目录

1. D&D数据

2. 图卷积

3. Self-Attention Pooling

4. ReadOut

5. 图分类模型

6. 程序流程


 

1. D&D数据

D&D是一个包含1178个蛋白质结构的数据集。每个蛋白质被表示为一个图,其中节点是氨基酸,如果两个节点之间的距离小于6埃,那么他们之间会有一条边相连。预测任务是将蛋白质分类为酶和非酶。

  • 数据集下载和预处理
class DDDataset(object):
    #数据集下载链接
    url = "https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/DD.zip"
    
    def __init__(self, data_root="data", train_size=0.8):
        self.data_root = data_root
        self.maybe_download() #下载 并解压
        sparse_adjacency, node_labels, graph_indicator, graph_labels = self.read_data()
        #把coo格式转换为csr 进行稀疏矩阵运算
        self.sparse_adjacency = sparse_adjacency.tocsr()
        self.node_labels = node_labels
        self.graph_indicator = graph_indicator
        self.graph_labels = graph_labels
        
        self.train_index, self.test_index = self.split_data(train_size)
        self.train_label = graph_labels[self.train_index] #得到训练集中所有图对应的类别标签
        self.test_label = graph_labels[self.test_index] #得到测试集中所有图对应的类别标签

    def split_data(self, train_size):
        unique_indicator = np.asarray(list(set(self.graph_indicator)))
        #随机划分训练集和测试集 得到各自对应的图索引   (一个图代表一条数据)
        train_index, test_index = train_test_split(unique_indicator,
                                                   train_size=train_size,
                                                   random_state=1234)
        return train_index, test_index
    
    def __getitem__(self, index):
       
        mask = self.graph_indicator == index  
        #得到图索引为index的图对应的所有节点(索引)
        graph_indicator = self.graph_indicator[mask]
        #每个节点对应的特征标签
        node_labels = self.node_labels[mask]
        #该图对应的类别标签
        graph_labels = self.graph_labels[index]
        #该图对应的邻接矩阵
        adjacency = self.sparse_adjacency[mask, :][:, mask]
        return adjacency, node_labels, graph_indicator, graph_labels
    
    def __len__(self):
        return len(self.graph_labels)
    
    def read_data(self):
        #解压后的路径
        data_dir = os.path.join(self.data_root, "DD")
        print("Loading DD_A.txt")
        #从txt文件中读取邻接表(每一行可以看作一个坐标,即邻接矩阵中非0值的位置)  包含所有图的节点
        adjacency_list = np.genfromtxt(os.path.join(data_dir, "DD_A.txt"),
                                       dtype=np.int64, delimiter=',') - 1
        print("Loading DD_node_labels.txt")
        #读取节点的特征标签(包含所有图) 每个节点代表一种氨基酸 氨基酸有20多种,所以每个节点会有一个类型标签 表示是哪一种氨基酸
        node_labels = np.genfromtxt(os.path.join(data_dir, "DD_node_labels.txt"), 
                                    dtype=np.int64) - 1
        print("Loading DD_graph_indicator.txt")
        #每个节点属于哪个图
        graph_indicator = np.genfromtxt(os.path.join(data_dir, "DD_graph_indicator.txt"), 
                                        dtype=np.int64) - 1
        print("Loading DD_graph_labels.txt")
        #每个图的标签 (2分类 0,1)
        graph_labels = np.genfromtxt(os.path.join(data_dir, "DD_graph_labels.txt"), 
                                     dtype=np.int64) - 1
        num_nodes = len(node_labels) #节点数 (包含所有图的节点)
        #通过邻接表生成邻接矩阵  (包含所有的图)稀疏存储节省内存(coo格式 只存储非0值的行索引、列索引和非0值)
        #coo格式无法进行稀疏矩阵运算
        sparse_adjacency = sp.coo_matrix((np.ones(len(adjacency_list)), 
                                          (adjacency_list[:, 0], adjacency_list[:, 1])),
                                         shape=(num_nodes, num_nodes), dtype=np.float32)
        print("Number of nodes: ", num_nodes)
        return sparse_adjacency, node_labels, graph_indicator, graph_labels
    
    def maybe_download(self):
        save_path = os.path.join(self.data_root)
        #本地不存在 则下载
        if not os.path.exists(save_path):
            self.download_data(self.url, save_path)
        #对数据集压缩包进行解压
        if not os.path.exists(os.path.join(self.data_root, "DD")):
            zipfilename = os.path.join(self.data_root, "DD.zip")
            with ZipFile(zipfilename, "r") as zipobj:
                zipobj.extractall(os.path.join(self.data_root))
                print("Extracting data from {}".format(zipfilename))
    
    @staticmethod
    def download_data(url, save_path):
        """数据下载工具,当原始数据不存在时将会进行下载"""
        print("Downloading data from {}".format(url))
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        #下载数据集压缩包 保存在本地
        data = urllib.request.urlopen(url)
        filename = "DD.zip"
        with open(os.path.join(save_path, filename), 'wb') as f:
            f.write(data.read())
        return True

 

2. 图卷积

X (N,input_dim) -> X' (N,output_dim)。所有图的节点放在一块计算(N为所有节点的个数)。

class GraphConvolution(nn.Module):
    def __init__(self, input_dim, output_dim, use_bias=True):
        """图卷积:L*X*\theta

        Args:
        ----------
            input_dim: int
                节点输入特征的维度
            output_dim: int
                输出特征维度
            use_bias : bool, optional
                是否使用偏置
        """
        super(GraphConvolution, self).__init__()
        self.input_dim = input_dim 
        self.output_dim = output_dim 
        self.use_bias = use_bias
        #权重矩阵
        self.weight = nn.Parameter(torch.Tensor(input_dim, output_dim))
        if self.use_bias:
            self.bias = nn.Parameter(torch.Tensor(output_dim))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters() #使用自定义参数初始化方式

    def reset_parameters(self): #自定义权重和偏置的初始化方式
        init.kaiming_uniform_(self.weight)
        if self.use_bias:
            init.zeros_(self.bias)

    def forward(self, adjacency, input_feature):
        """邻接矩阵是稀疏矩阵,因此在计算时使用稀疏矩阵乘法"""
        #adjacency (N,N) 归一化的拉普拉斯矩阵
        #input_feature(N,input_dim) N为所有节点个数 (包含所有图)
        support = torch.mm(input_feature, self.weight) #XW (N,output_dim=hidden_dim)
        output = torch.sparse.mm(adjacency, support) #L(XW)  (N,output_dim=hidden_dim)
        if self.use_bias:
            output += self.bias
        return output #(N,output_dim=hidden_dim)

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
            + str(self.input_dim) + ' -> ' \
            + str(self.output_dim) + ')'

3. Self-Attention Pooling

基于自注意力的池化机制的思想,是通过图卷积gcn从图中自适应地学习每个节点的重要性。就是使用下面的图卷积公式,为每个节点赋予一个重要性得分:

                                                      

其中,

表示激活函数,

是邻接矩阵A加入自连接并进行归一化得到的拉普拉斯矩阵,X表示节点的特征(N,input_dim),

是权重参数,这也是自注意力池化层唯一引入的参数,gcn层的定义在上一小结。(本实验中,我们把所有图的节点放在一起,统一计算,并标明哪些节点属于哪些图)。

根据节点的重要性分数和拓扑结构可以进行池化操作,舍弃一些不太重要的节点,形成一个新的图结构,所以要对邻接矩阵和节点特征进行更新,得到池化结果。首先根据下式进行节点选择:

                                              

  • top_rank
def top_rank(attention_score, graph_indicator, keep_ratio):
    """基于给定的attention_score, 对每个图进行pooling操作.
    为了直观体现pooling过程,我们将每个图单独进行池化,最后再将它们级联起来进行下一步计算
    
    Arguments:
    ----------
        attention_score:torch.Tensor
            使用GCN计算出的注意力分数,Z = GCN(A, X)
        graph_indicator:torch.Tensor
            指示每个节点属于哪个图
        keep_ratio: float
            要保留的节点比例,保留的节点数量为int(N * keep_ratio)
    """
    # TODO: 确认是否是有序的, 必须是有序的 从第1个图到第1178个图
    graph_id_list = list(set(graph_indicator.cpu().numpy()))
    
    #创建一个空tensor 类型为bool
    mask = attention_score.new_empty((0,), dtype=torch.bool)
    for graph_id in graph_id_list:
        #取出图索引为graph_id的图的节点 对应的注意力分数
        graph_attn_score = attention_score[graph_indicator == graph_id]
        #该图的节点数
        graph_node_num = len(graph_attn_score)
        #创建一个大小为graph_node_num的 值全为false 的tensor
        graph_mask = attention_score.new_zeros((graph_node_num,),
                                                dtype=torch.bool)
        #该图需要保留的节点数
        keep_graph_node_num = int(keep_ratio * graph_node_num)
        #对该图节点对应的注意力分数降序排列 得到排序后的索引
        _, sorted_index = graph_attn_score.sort(descending=True)
        #把保留的节点索引设置为True
        graph_mask[sorted_index[:keep_graph_node_num]] = True
        mask = torch.cat((mask, graph_mask))
    
    return mask #其中每个图需要保留的节点索引位置为True 从第1个图到第1178个图

函数top_rank接受三个参数,一是使用GCN得到的节点重要度分数attention_score (N,);二是指示每个节点属于哪个图的参数graph_indicator(N,),这里我们将多个需要分类的图放在一起进行批处理,以提高运算速度,graph_indicator里面包含的数据为[0,0,...,0,1,1,...,1,2,2,...,2...](按顺序,第一个图中的所有节点,第二个图...)。graph_indicator的标识值需要进行升序排序,属于同一个图中的节点需要连续排列在一起;三是超参数keep_ratio,表示每次池化需要保留的节点比例,这是 针对单个图而言,即每个图都保留这些节点,不是整个批处理中所有的数据。实现逻辑上根据graph_indicator依次遍历每个图,取出该图节点对应的注意力分数,并进行降序排序得到要保留的节点的索引,将这些位置的索引设置为True,得到每个子图节点的掩码向量。将所有图的掩码拼接在一起得到批处理中所有节点的掩码,作为函数返回值。

接下来,根据得到的节点掩码对图结构和特征进行更新。图结构的更新是根据掩码向量对邻接矩阵(

)进行索引,得到保留节点之间的邻接矩阵,再重新进行归一化,作为后续GCN层的输入。因此定义两个功能函数normalization(adjacency)和filter_adjacency(adjacency,mask)。其中normalization(adjacency)接收一个scipy.sparse.csr_matrix,对她进行规范化并转换为torch.sparse.FloatTensor。另一个函数函数filter_adjacency(adjacency,mask)接收两个参数,一个是池化之前的adjacency,它的类型是torch.sparse.FloatTensor,另一个是函数top_rank返回的节点的掩码向量mask。为了利用scipy.sparse提供的切片索引,这里将池化之前的adjacency转换为scipy.sparse.csr_matrix,然后通过掩码向量mask进行切片,得到池化后的节点之间的邻接关系,随后再使用函数normalization进行规范化,作为下一个GCN的输入:
  • 功能函数

def tensor_from_numpy(x, device): #numpy数组转换为tensor 并转移到所用设备上
    return torch.from_numpy(x).to(device)


def normalization(adjacency):
    """计算 L=D^-0.5 * (A+I) * D^-0.5,

    Args:
        adjacency: sp.csr_matrix.

    Returns:
        归一化后的邻接矩阵,类型为 torch.sparse.FloatTensor
    """
    adjacency += sp.eye(adjacency.shape[0])    # 增加自连接 A+I
    degree = np.array(adjacency.sum(1)) #得到此时的度矩阵对角线 对增加自连接的邻接矩阵按行求和
    d_hat = sp.diags(np.power(degree, -0.5).flatten()) #开-0.5次方 转换为度矩阵(对角矩阵)
    L = d_hat.dot(adjacency).dot(d_hat).tocoo() #得到归一化、并引入自连接的拉普拉斯矩阵 转换为coo稀疏格式
    # 转换为 torch.sparse.FloatTensor
    #稀疏矩阵非0值 的坐标(行索引,列索引)
    indices = torch.from_numpy(np.asarray([L.row, L.col])).long()
    #非0值
    values = torch.from_numpy(L.data.astype(np.float32))
    #存储为tensor稀疏格式
    tensor_adjacency = torch.sparse.FloatTensor(indices, values, L.shape)
    return tensor_adjacency


def filter_adjacency(adjacency, mask):
    """根据掩码mask对图结构进行更新
    
    Args:
        adjacency: torch.sparse.FloatTensor, 池化之前的邻接矩阵
        mask: torch.Tensor(dtype=torch.bool), 节点掩码向量
    
    Returns:
        torch.sparse.FloatTensor, 池化之后归一化邻接矩阵
    """
    device = adjacency.device
    mask = mask.cpu().numpy()
    indices = adjacency.coalesce().indices().cpu().numpy()
    num_nodes = adjacency.size(0)
    row, col = indices
    maskout_self_loop = row != col
    row = row[maskout_self_loop]
    col = col[maskout_self_loop]
    sparse_adjacency = sp.csr_matrix((np.ones(len(row)), (row, col)),
                                     shape=(num_nodes, num_nodes), dtype=np.float32)
    filtered_adjacency = sparse_adjacency[mask, :][:, mask]
    return normalization(filtered_adjacency).to(device)

 

利用上面介绍的这些功能函数,就可以实现自注意力池化层,该层输出为池化之后的特征、节点属于哪个子图的标识以及规范化的新图结构的邻接矩阵:

class SelfAttentionPooling(nn.Module):
    
    def __init__(self, input_dim, keep_ratio, activation=torch.tanh):
        super(SelfAttentionPooling, self).__init__()
        self.input_dim = input_dim #(hidden_dim*3)
        self.keep_ratio = keep_ratio
        self.activation = activation
        #attention gcn层  (N,hidden_dim*3) -> (N,1)
        self.attn_gcn = GraphConvolution(input_dim, 1)
    
    def forward(self, adjacency, input_feature, graph_indicator):
        #通过attention gcn层计算注意力分数 
        #adjacency拉普拉斯矩阵(N*N)
        #input_feature三个gcn层计算结果通过relu后再拼接 (N,hidden_dim*3)
        #attn_score (N,)
        attn_score = self.attn_gcn(adjacency, input_feature).squeeze()
        #通过tanh激活函数 (N,)
        attn_score = self.activation(attn_score)
        
        #节点掩码向量 (N,)
        mask = top_rank(attn_score, graph_indicator, self.keep_ratio)
        #保留节点的状态向量 乘以对应的注意力分数 
        hidden = input_feature[mask] * attn_score[mask].view(-1, 1)
        #保留的节点属于哪个图
        mask_graph_indicator = graph_indicator[mask]
        #得到新的图结构  的邻接矩阵 
        mask_adjacency = filter_adjacency(adjacency, mask)
        return hidden, mask_graph_indicator, mask_adjacency

4. ReadOut

通过自注意力池化层,对每个图的节点做了筛选。要进行图分类,还需要全局池化操作,将节点数不同的图降到同一个维度。常见的全局池化方式包括取最大值或均值。

def global_max_pool(x, graph_indicator):
    #对于每个图保留节点的状态向量 按位置取最大值 最后一个图对应一个状态向量
    num = graph_indicator.max().item() + 1 
    return torch_scatter.scatter_max(x, graph_indicator, dim=0, dim_size=num)[0]


def global_avg_pool(x, graph_indicator):
    #每个图保留节点的状态向量 按位置取平均值 最后一个图对应一个状态向量
    num = graph_indicator.max().item() + 1
    return torch_scatter.scatter_mean(x, graph_indicator, dim=0, dim_size=num)

这里我们使用torch_scatter包来简化上述实现过程,其中用到了两个函数scatter_mean和scatter_max的原理如下图所示:

5. 图分类模型

至此,我们可以定义图分类模型了。接下来定义两套SAGPool模型,如下图所示。其中a图只用了一个池化层,称为

,"g"代表global;b图使用了多个池化层,称为

,"h"表示hierarchical。在论文的实验部分,可以发现

比较适合小图分类,

更适合大图分类。

class ModelA(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes=2):
        """图分类模型结构A
        
        Args:
        ----
            input_dim: int, 输入特征的维度
            hidden_dim: int, 隐藏层单元数
            num_classes: 分类类别数 (default: 2)
        """
        super(ModelA, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        
        #三个gcn层 (N,input_dim) -> (N,hidden_dim)
        self.gcn1 = GraphConvolution(input_dim, hidden_dim)
        self.gcn2 = GraphConvolution(hidden_dim, hidden_dim)
        self.gcn3 = GraphConvolution(hidden_dim, hidden_dim)
        
        self.pool = SelfAttentionPooling(hidden_dim * 3, 0.5)
        
        self.fc1 = nn.Linear(hidden_dim * 3 * 2, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, num_classes)

    def forward(self, adjacency, input_feature, graph_indicator):
        #adjacency (N,N) 拉普拉斯矩阵
        #input_feature (N,input_dim)
        gcn1 = F.relu(self.gcn1(adjacency, input_feature)) #(N,hidden_dim)
        gcn2 = F.relu(self.gcn2(adjacency, gcn1))#(N,hidden_dim)
        gcn3 = F.relu(self.gcn3(adjacency, gcn2))#(N,hidden_dim)
        
        gcn_feature = torch.cat((gcn1, gcn2, gcn3), dim=1) #(N,hidden_dim*3)
        
        #pool (N',hidden_dim*3) 
        #pool_graph_indicator (N',)
        pool, pool_graph_indicator, pool_adjacency = self.pool(adjacency, gcn_feature,
                                                               graph_indicator)
        
        #readout(G,hidden_dim*3*2) G为图数
        readout = torch.cat((global_avg_pool(pool, pool_graph_indicator),
                             global_max_pool(pool, pool_graph_indicator)), dim=1)
        
        #fc1(G,hidden_dim) 
        fc1 = F.relu(self.fc1(readout))
        #fc2(G,hidden_dim//2)
        fc2 = F.relu(self.fc2(fc1))
        #fc3(G,num_classes=2)
        logits = self.fc3(fc2)
        
        return logits
class ModelB(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes=2):
        """图分类模型结构
        
        Args:
        -----
            input_dim: int, 输入特征的维度
            hidden_dim: int, 隐藏层单元数
            num_classes: int, 分类类别数 (default: 2)
        """
        super(ModelB, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        
        #第一个gcn层 (N,input_dim) ->(N,hidden_dim) N为所有节点数
        self.gcn1 = GraphConvolution(input_dim, hidden_dim)
        #第一个池化层 
        self.pool1 = SelfAttentionPooling(hidden_dim, 0.5)
        self.gcn2 = GraphConvolution(hidden_dim, hidden_dim)
        self.pool2 = SelfAttentionPooling(hidden_dim, 0.5)
        self.gcn3 = GraphConvolution(hidden_dim, hidden_dim)
        self.pool3 = SelfAttentionPooling(hidden_dim, 0.5)
        
        #把最后的几个全连接层和激活函数 封装在一起
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(), 
            nn.Linear(hidden_dim // 2, num_classes))
    
    def forward(self, adjacency, input_feature, graph_indicator):
        #adjacency 拉普拉斯矩阵 (N,N)N为所有节点数
        #input_feature 所有节点的特征矩阵 (N,input_dim)
        #(N,input_dim) -> (N,hidden_dim)
        gcn1 = F.relu(self.gcn1(adjacency, input_feature))
        
        #gcn1 (N,hidden_dim) adjacency(N,N)  graph_indicator(N,)每个节点属于哪个图
        #pool1 (N',hidden_dim) N'保留的节点数
        #pool1_graph_indicator (N',)保留的节点属于哪个图
        #pool1_adjacency 保留节点的邻接矩阵(归一化)
        pool1, pool1_graph_indicator, pool1_adjacency = \
            self.pool1(adjacency, gcn1, graph_indicator)
        
        #global_pool1 (G,hidden_dim*2)   G为图数
        global_pool1 = torch.cat(
            [global_avg_pool(pool1, pool1_graph_indicator),
             global_max_pool(pool1, pool1_graph_indicator)],
            dim=1)
        
        #pool1_adjacency (N',N') N'保留的节点数 新的图结构对应的拉普拉斯矩阵
        #pool1 (N',hiddem_dim) 
        #(N',hiddem_dim) -> (N',hiddem_dim) 
        gcn2 = F.relu(self.gcn2(pool1_adjacency, pool1))
        
        #gcn2 (N',hiddem_dim)  pool1_adjacency(N',N')  pool1_graph_indicator(N',)保留的每个节点属于哪个图
        #pool2 (N'',hidden_dim) N''保留的节点数
        #pool2_graph_indicator (N'',)保留的节点属于哪个图
        #pool2_adjacency 保留节点的邻接矩阵(归一化)
        pool2, pool2_graph_indicator, pool2_adjacency = \
            self.pool2(pool1_adjacency, gcn2, pool1_graph_indicator)
        
        #global_pool2 (G,hidden_dim*2)   G为图数
        global_pool2 = torch.cat(
            [global_avg_pool(pool2, pool2_graph_indicator),
             global_max_pool(pool2, pool2_graph_indicator)],
            dim=1)

        #pool2_adjacency (N'',N'') N''保留的节点数 新的图结构对应的拉普拉斯矩阵
        #pool2 (N'',hiddem_dim) 
        #(N'',hiddem_dim) -> (N'',hiddem_dim) 
        gcn3 = F.relu(self.gcn3(pool2_adjacency, pool2))
        
        #gcn3 (N'',hiddem_dim)  pool2_adjacency(N'',N'')  pool2_graph_indicator(N'',)保留的每个节点属于哪个图
        #pool3 (N''',hidden_dim) N'''保留的节点数
        #pool3_graph_indicator (N''',)保留的节点属于哪个图
        #pool3_adjacency 保留节点的邻接矩阵(归一化)
        pool3, pool3_graph_indicator, pool3_adjacency = \
            self.pool3(pool2_adjacency, gcn3, pool2_graph_indicator)
        
        #global_pool3 (G,hidden_dim*2)   G为图数
        global_pool3 = torch.cat(
            [global_avg_pool(pool3, pool3_graph_indicator),
             global_max_pool(pool3, pool3_graph_indicator)],
            dim=1)
        
        #readout (G,hidden_dim*2)
        readout = global_pool1 + global_pool2 + global_pool3
        
        #logits (G,num_classes=2)
        logits = self.mlp(readout)
        return logits

6. 程序流程

  • 加载数据
dataset = DDDataset()
  • 数据准备
# 模型输入数据准备
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
#所有图对应的大邻接矩阵
adjacency = dataset.sparse_adjacency
#归一化、引入自连接的拉普拉斯矩阵 
normalize_adjacency = normalization(adjacency).to(DEVICE)
#所有节点的特征标签 
node_labels = tensor_from_numpy(dataset.node_labels, DEVICE)
#把特征标签转换为one-hot特征向量
node_features = F.one_hot(node_labels, node_labels.max().item() + 1).float()
#每个节点对应哪个图
graph_indicator = tensor_from_numpy(dataset.graph_indicator, DEVICE)
#每个图的类别标签
graph_labels = tensor_from_numpy(dataset.graph_labels, DEVICE)
#训练集对应的图索引
train_index = tensor_from_numpy(dataset.train_index, DEVICE)
#测试集对应的图索引
test_index = tensor_from_numpy(dataset.test_index, DEVICE)
#训练集和测试集中的图对应的类别标签
train_label = tensor_from_numpy(dataset.train_label, DEVICE)
test_label = tensor_from_numpy(dataset.test_label, DEVICE)
  • 超参数设置
# 超参数设置
INPUT_DIM = node_features.size(1) #特征向量维度
NUM_CLASSES = 2
EPOCHS = 200    # @param {type: "integer"}
HIDDEN_DIM =    32# @param {type: "integer"}
LEARNING_RATE = 0.01 # @param
WEIGHT_DECAY = 0.0001 # @param
  • 模型初始化
# 模型初始化
model_g = ModelA(INPUT_DIM, HIDDEN_DIM, NUM_CLASSES).to(DEVICE)
model_h = ModelB(INPUT_DIM, HIDDEN_DIM, NUM_CLASSES).to(DEVICE)
model = model_h #@param ['model_g', 'model_h'] {type: 'raw'}
  • 训练
criterion = nn.CrossEntropyLoss().to(DEVICE) #交叉熵损失函数
#Adam优化器
optimizer = optim.Adam(model.parameters(), LEARNING_RATE, weight_decay=WEIGHT_DECAY)

model.train() #训练模式
for epoch in range(EPOCHS):
    logits = model(normalize_adjacency, node_features, graph_indicator) #对所有数据(图)前向传播 得到输出
    loss = criterion(logits[train_index], train_label)  # 只对训练的数据计算损失值
    optimizer.zero_grad()
    loss.backward()  # 反向传播计算参数的梯度
    optimizer.step()  # 使用优化方法进行梯度更新
    #训练集准确率
    train_acc = torch.eq(
        logits[train_index].max(1)[1], train_label).float().mean()
    print("Epoch {:03d}: Loss {:.4f}, TrainAcc {:.4}".format(
        epoch, loss.item(), train_acc.item()))
  • 测试
model.eval() #测试模式
with torch.no_grad(): #关闭求导
    logits = model(normalize_adjacency, node_features, graph_indicator)#所有数据前向传播
    test_logits = logits[test_index] #取出测试数据对应的输出
    #计算测试数据准确率
    test_acc = torch.eq(
        test_logits.max(1)[1], test_label
    ).float().mean()

print(test_acc.item())

 

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!