近期买了一本图神经网络的入门书,最近几篇博客对书中的一些实战案例进行整理,具体的理论和原理部分可以自行查阅该书,该书购买链接:《深入浅出的图神经网络》。
本节我们通过代码来实现基于自注意力的池化机制(Self-Attention Pooling)。来对图整体进行分类,之前我们是对节点分类,每个节点表示一条数据,学习节点的表示,进而完成分类,本节我们通过自注意力池化机制,得到整个图的表示,进而对全图进行分类(每个图表示一条数据)。
- 导入必要的包
import os
import urllib
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import numpy as np
import scipy.sparse as sp
from zipfile import ZipFile
from sklearn.model_selection import train_test_split
import pickle
import pandas as pd
import torch_scatter #注意:torch_scatter 安装时编译需要用到cuda
from collections import Counter
目录
1. D&D数据
D&D是一个包含1178个蛋白质结构的数据集。每个蛋白质被表示为一个图,其中节点是氨基酸,如果两个节点之间的距离小于6埃,那么他们之间会有一条边相连。预测任务是将蛋白质分类为酶和非酶。
- 数据集下载和预处理
class DDDataset(object):
#数据集下载链接
url = "https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/DD.zip"
def __init__(self, data_root="data", train_size=0.8):
self.data_root = data_root
self.maybe_download() #下载 并解压
sparse_adjacency, node_labels, graph_indicator, graph_labels = self.read_data()
#把coo格式转换为csr 进行稀疏矩阵运算
self.sparse_adjacency = sparse_adjacency.tocsr()
self.node_labels = node_labels
self.graph_indicator = graph_indicator
self.graph_labels = graph_labels
self.train_index, self.test_index = self.split_data(train_size)
self.train_label = graph_labels[self.train_index] #得到训练集中所有图对应的类别标签
self.test_label = graph_labels[self.test_index] #得到测试集中所有图对应的类别标签
def split_data(self, train_size):
unique_indicator = np.asarray(list(set(self.graph_indicator)))
#随机划分训练集和测试集 得到各自对应的图索引 (一个图代表一条数据)
train_index, test_index = train_test_split(unique_indicator,
train_size=train_size,
random_state=1234)
return train_index, test_index
def __getitem__(self, index):
mask = self.graph_indicator == index
#得到图索引为index的图对应的所有节点(索引)
graph_indicator = self.graph_indicator[mask]
#每个节点对应的特征标签
node_labels = self.node_labels[mask]
#该图对应的类别标签
graph_labels = self.graph_labels[index]
#该图对应的邻接矩阵
adjacency = self.sparse_adjacency[mask, :][:, mask]
return adjacency, node_labels, graph_indicator, graph_labels
def __len__(self):
return len(self.graph_labels)
def read_data(self):
#解压后的路径
data_dir = os.path.join(self.data_root, "DD")
print("Loading DD_A.txt")
#从txt文件中读取邻接表(每一行可以看作一个坐标,即邻接矩阵中非0值的位置) 包含所有图的节点
adjacency_list = np.genfromtxt(os.path.join(data_dir, "DD_A.txt"),
dtype=np.int64, delimiter=',') - 1
print("Loading DD_node_labels.txt")
#读取节点的特征标签(包含所有图) 每个节点代表一种氨基酸 氨基酸有20多种,所以每个节点会有一个类型标签 表示是哪一种氨基酸
node_labels = np.genfromtxt(os.path.join(data_dir, "DD_node_labels.txt"),
dtype=np.int64) - 1
print("Loading DD_graph_indicator.txt")
#每个节点属于哪个图
graph_indicator = np.genfromtxt(os.path.join(data_dir, "DD_graph_indicator.txt"),
dtype=np.int64) - 1
print("Loading DD_graph_labels.txt")
#每个图的标签 (2分类 0,1)
graph_labels = np.genfromtxt(os.path.join(data_dir, "DD_graph_labels.txt"),
dtype=np.int64) - 1
num_nodes = len(node_labels) #节点数 (包含所有图的节点)
#通过邻接表生成邻接矩阵 (包含所有的图)稀疏存储节省内存(coo格式 只存储非0值的行索引、列索引和非0值)
#coo格式无法进行稀疏矩阵运算
sparse_adjacency = sp.coo_matrix((np.ones(len(adjacency_list)),
(adjacency_list[:, 0], adjacency_list[:, 1])),
shape=(num_nodes, num_nodes), dtype=np.float32)
print("Number of nodes: ", num_nodes)
return sparse_adjacency, node_labels, graph_indicator, graph_labels
def maybe_download(self):
save_path = os.path.join(self.data_root)
#本地不存在 则下载
if not os.path.exists(save_path):
self.download_data(self.url, save_path)
#对数据集压缩包进行解压
if not os.path.exists(os.path.join(self.data_root, "DD")):
zipfilename = os.path.join(self.data_root, "DD.zip")
with ZipFile(zipfilename, "r") as zipobj:
zipobj.extractall(os.path.join(self.data_root))
print("Extracting data from {}".format(zipfilename))
@staticmethod
def download_data(url, save_path):
"""数据下载工具,当原始数据不存在时将会进行下载"""
print("Downloading data from {}".format(url))
if not os.path.exists(save_path):
os.makedirs(save_path)
#下载数据集压缩包 保存在本地
data = urllib.request.urlopen(url)
filename = "DD.zip"
with open(os.path.join(save_path, filename), 'wb') as f:
f.write(data.read())
return True
2. 图卷积
X (N,input_dim) -> X' (N,output_dim)。所有图的节点放在一块计算(N为所有节点的个数)。
class GraphConvolution(nn.Module):
def __init__(self, input_dim, output_dim, use_bias=True):
"""图卷积:L*X*\theta
Args:
----------
input_dim: int
节点输入特征的维度
output_dim: int
输出特征维度
use_bias : bool, optional
是否使用偏置
"""
super(GraphConvolution, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.use_bias = use_bias
#权重矩阵
self.weight = nn.Parameter(torch.Tensor(input_dim, output_dim))
if self.use_bias:
self.bias = nn.Parameter(torch.Tensor(output_dim))
else:
self.register_parameter('bias', None)
self.reset_parameters() #使用自定义参数初始化方式
def reset_parameters(self): #自定义权重和偏置的初始化方式
init.kaiming_uniform_(self.weight)
if self.use_bias:
init.zeros_(self.bias)
def forward(self, adjacency, input_feature):
"""邻接矩阵是稀疏矩阵,因此在计算时使用稀疏矩阵乘法"""
#adjacency (N,N) 归一化的拉普拉斯矩阵
#input_feature(N,input_dim) N为所有节点个数 (包含所有图)
support = torch.mm(input_feature, self.weight) #XW (N,output_dim=hidden_dim)
output = torch.sparse.mm(adjacency, support) #L(XW) (N,output_dim=hidden_dim)
if self.use_bias:
output += self.bias
return output #(N,output_dim=hidden_dim)
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.input_dim) + ' -> ' \
+ str(self.output_dim) + ')'
3. Self-Attention Pooling
基于自注意力的池化机制的思想,是通过图卷积gcn从图中自适应地学习每个节点的重要性。就是使用下面的图卷积公式,为每个节点赋予一个重要性得分:
其中,
表示激活函数,是邻接矩阵A加入自连接并进行归一化得到的拉普拉斯矩阵,X表示节点的特征(N,input_dim),是权重参数,这也是自注意力池化层唯一引入的参数,gcn层的定义在上一小结。(本实验中,我们把所有图的节点放在一起,统一计算,并标明哪些节点属于哪些图)。根据节点的重要性分数和拓扑结构可以进行池化操作,舍弃一些不太重要的节点,形成一个新的图结构,所以要对邻接矩阵和节点特征进行更新,得到池化结果。首先根据下式进行节点选择:
- top_rank
def top_rank(attention_score, graph_indicator, keep_ratio):
"""基于给定的attention_score, 对每个图进行pooling操作.
为了直观体现pooling过程,我们将每个图单独进行池化,最后再将它们级联起来进行下一步计算
Arguments:
----------
attention_score:torch.Tensor
使用GCN计算出的注意力分数,Z = GCN(A, X)
graph_indicator:torch.Tensor
指示每个节点属于哪个图
keep_ratio: float
要保留的节点比例,保留的节点数量为int(N * keep_ratio)
"""
# TODO: 确认是否是有序的, 必须是有序的 从第1个图到第1178个图
graph_id_list = list(set(graph_indicator.cpu().numpy()))
#创建一个空tensor 类型为bool
mask = attention_score.new_empty((0,), dtype=torch.bool)
for graph_id in graph_id_list:
#取出图索引为graph_id的图的节点 对应的注意力分数
graph_attn_score = attention_score[graph_indicator == graph_id]
#该图的节点数
graph_node_num = len(graph_attn_score)
#创建一个大小为graph_node_num的 值全为false 的tensor
graph_mask = attention_score.new_zeros((graph_node_num,),
dtype=torch.bool)
#该图需要保留的节点数
keep_graph_node_num = int(keep_ratio * graph_node_num)
#对该图节点对应的注意力分数降序排列 得到排序后的索引
_, sorted_index = graph_attn_score.sort(descending=True)
#把保留的节点索引设置为True
graph_mask[sorted_index[:keep_graph_node_num]] = True
mask = torch.cat((mask, graph_mask))
return mask #其中每个图需要保留的节点索引位置为True 从第1个图到第1178个图
函数top_rank接受三个参数,一是使用GCN得到的节点重要度分数attention_score (N,);二是指示每个节点属于哪个图的参数graph_indicator(N,),这里我们将多个需要分类的图放在一起进行批处理,以提高运算速度,graph_indicator里面包含的数据为[0,0,...,0,1,1,...,1,2,2,...,2...](按顺序,第一个图中的所有节点,第二个图...)。graph_indicator的标识值需要进行升序排序,属于同一个图中的节点需要连续排列在一起;三是超参数keep_ratio,表示每次池化需要保留的节点比例,这是 针对单个图而言,即每个图都保留这些节点,不是整个批处理中所有的数据。实现逻辑上根据graph_indicator依次遍历每个图,取出该图节点对应的注意力分数,并进行降序排序得到要保留的节点的索引,将这些位置的索引设置为True,得到每个子图节点的掩码向量。将所有图的掩码拼接在一起得到批处理中所有节点的掩码,作为函数返回值。
接下来,根据得到的节点掩码对图结构和特征进行更新。图结构的更新是根据掩码向量对邻接矩阵(
)进行索引,得到保留节点之间的邻接矩阵,再重新进行归一化,作为后续GCN层的输入。因此定义两个功能函数normalization(adjacency)和filter_adjacency(adjacency,mask)。其中normalization(adjacency)接收一个scipy.sparse.csr_matrix,对她进行规范化并转换为torch.sparse.FloatTensor。另一个函数函数filter_adjacency(adjacency,mask)接收两个参数,一个是池化之前的adjacency,它的类型是torch.sparse.FloatTensor,另一个是函数top_rank返回的节点的掩码向量mask。为了利用scipy.sparse提供的切片索引,这里将池化之前的adjacency转换为scipy.sparse.csr_matrix,然后通过掩码向量mask进行切片,得到池化后的节点之间的邻接关系,随后再使用函数normalization进行规范化,作为下一个GCN的输入:- 功能函数
def tensor_from_numpy(x, device): #numpy数组转换为tensor 并转移到所用设备上
return torch.from_numpy(x).to(device)
def normalization(adjacency):
"""计算 L=D^-0.5 * (A+I) * D^-0.5,
Args:
adjacency: sp.csr_matrix.
Returns:
归一化后的邻接矩阵,类型为 torch.sparse.FloatTensor
"""
adjacency += sp.eye(adjacency.shape[0]) # 增加自连接 A+I
degree = np.array(adjacency.sum(1)) #得到此时的度矩阵对角线 对增加自连接的邻接矩阵按行求和
d_hat = sp.diags(np.power(degree, -0.5).flatten()) #开-0.5次方 转换为度矩阵(对角矩阵)
L = d_hat.dot(adjacency).dot(d_hat).tocoo() #得到归一化、并引入自连接的拉普拉斯矩阵 转换为coo稀疏格式
# 转换为 torch.sparse.FloatTensor
#稀疏矩阵非0值 的坐标(行索引,列索引)
indices = torch.from_numpy(np.asarray([L.row, L.col])).long()
#非0值
values = torch.from_numpy(L.data.astype(np.float32))
#存储为tensor稀疏格式
tensor_adjacency = torch.sparse.FloatTensor(indices, values, L.shape)
return tensor_adjacency
def filter_adjacency(adjacency, mask):
"""根据掩码mask对图结构进行更新
Args:
adjacency: torch.sparse.FloatTensor, 池化之前的邻接矩阵
mask: torch.Tensor(dtype=torch.bool), 节点掩码向量
Returns:
torch.sparse.FloatTensor, 池化之后归一化邻接矩阵
"""
device = adjacency.device
mask = mask.cpu().numpy()
indices = adjacency.coalesce().indices().cpu().numpy()
num_nodes = adjacency.size(0)
row, col = indices
maskout_self_loop = row != col
row = row[maskout_self_loop]
col = col[maskout_self_loop]
sparse_adjacency = sp.csr_matrix((np.ones(len(row)), (row, col)),
shape=(num_nodes, num_nodes), dtype=np.float32)
filtered_adjacency = sparse_adjacency[mask, :][:, mask]
return normalization(filtered_adjacency).to(device)
利用上面介绍的这些功能函数,就可以实现自注意力池化层,该层输出为池化之后的特征、节点属于哪个子图的标识以及规范化的新图结构的邻接矩阵:
class SelfAttentionPooling(nn.Module):
def __init__(self, input_dim, keep_ratio, activation=torch.tanh):
super(SelfAttentionPooling, self).__init__()
self.input_dim = input_dim #(hidden_dim*3)
self.keep_ratio = keep_ratio
self.activation = activation
#attention gcn层 (N,hidden_dim*3) -> (N,1)
self.attn_gcn = GraphConvolution(input_dim, 1)
def forward(self, adjacency, input_feature, graph_indicator):
#通过attention gcn层计算注意力分数
#adjacency拉普拉斯矩阵(N*N)
#input_feature三个gcn层计算结果通过relu后再拼接 (N,hidden_dim*3)
#attn_score (N,)
attn_score = self.attn_gcn(adjacency, input_feature).squeeze()
#通过tanh激活函数 (N,)
attn_score = self.activation(attn_score)
#节点掩码向量 (N,)
mask = top_rank(attn_score, graph_indicator, self.keep_ratio)
#保留节点的状态向量 乘以对应的注意力分数
hidden = input_feature[mask] * attn_score[mask].view(-1, 1)
#保留的节点属于哪个图
mask_graph_indicator = graph_indicator[mask]
#得到新的图结构 的邻接矩阵
mask_adjacency = filter_adjacency(adjacency, mask)
return hidden, mask_graph_indicator, mask_adjacency
4. ReadOut
通过自注意力池化层,对每个图的节点做了筛选。要进行图分类,还需要全局池化操作,将节点数不同的图降到同一个维度。常见的全局池化方式包括取最大值或均值。
def global_max_pool(x, graph_indicator):
#对于每个图保留节点的状态向量 按位置取最大值 最后一个图对应一个状态向量
num = graph_indicator.max().item() + 1
return torch_scatter.scatter_max(x, graph_indicator, dim=0, dim_size=num)[0]
def global_avg_pool(x, graph_indicator):
#每个图保留节点的状态向量 按位置取平均值 最后一个图对应一个状态向量
num = graph_indicator.max().item() + 1
return torch_scatter.scatter_mean(x, graph_indicator, dim=0, dim_size=num)
这里我们使用torch_scatter包来简化上述实现过程,其中用到了两个函数scatter_mean和scatter_max的原理如下图所示:
5. 图分类模型
至此,我们可以定义图分类模型了。接下来定义两套SAGPool模型,如下图所示。其中a图只用了一个池化层,称为
,"g"代表global;b图使用了多个池化层,称为,"h"表示hierarchical。在论文的实验部分,可以发现比较适合小图分类,更适合大图分类。
class ModelA(nn.Module):
def __init__(self, input_dim, hidden_dim, num_classes=2):
"""图分类模型结构A
Args:
----
input_dim: int, 输入特征的维度
hidden_dim: int, 隐藏层单元数
num_classes: 分类类别数 (default: 2)
"""
super(ModelA, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_classes = num_classes
#三个gcn层 (N,input_dim) -> (N,hidden_dim)
self.gcn1 = GraphConvolution(input_dim, hidden_dim)
self.gcn2 = GraphConvolution(hidden_dim, hidden_dim)
self.gcn3 = GraphConvolution(hidden_dim, hidden_dim)
self.pool = SelfAttentionPooling(hidden_dim * 3, 0.5)
self.fc1 = nn.Linear(hidden_dim * 3 * 2, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
self.fc3 = nn.Linear(hidden_dim // 2, num_classes)
def forward(self, adjacency, input_feature, graph_indicator):
#adjacency (N,N) 拉普拉斯矩阵
#input_feature (N,input_dim)
gcn1 = F.relu(self.gcn1(adjacency, input_feature)) #(N,hidden_dim)
gcn2 = F.relu(self.gcn2(adjacency, gcn1))#(N,hidden_dim)
gcn3 = F.relu(self.gcn3(adjacency, gcn2))#(N,hidden_dim)
gcn_feature = torch.cat((gcn1, gcn2, gcn3), dim=1) #(N,hidden_dim*3)
#pool (N',hidden_dim*3)
#pool_graph_indicator (N',)
pool, pool_graph_indicator, pool_adjacency = self.pool(adjacency, gcn_feature,
graph_indicator)
#readout(G,hidden_dim*3*2) G为图数
readout = torch.cat((global_avg_pool(pool, pool_graph_indicator),
global_max_pool(pool, pool_graph_indicator)), dim=1)
#fc1(G,hidden_dim)
fc1 = F.relu(self.fc1(readout))
#fc2(G,hidden_dim//2)
fc2 = F.relu(self.fc2(fc1))
#fc3(G,num_classes=2)
logits = self.fc3(fc2)
return logits
class ModelB(nn.Module):
def __init__(self, input_dim, hidden_dim, num_classes=2):
"""图分类模型结构
Args:
-----
input_dim: int, 输入特征的维度
hidden_dim: int, 隐藏层单元数
num_classes: int, 分类类别数 (default: 2)
"""
super(ModelB, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_classes = num_classes
#第一个gcn层 (N,input_dim) ->(N,hidden_dim) N为所有节点数
self.gcn1 = GraphConvolution(input_dim, hidden_dim)
#第一个池化层
self.pool1 = SelfAttentionPooling(hidden_dim, 0.5)
self.gcn2 = GraphConvolution(hidden_dim, hidden_dim)
self.pool2 = SelfAttentionPooling(hidden_dim, 0.5)
self.gcn3 = GraphConvolution(hidden_dim, hidden_dim)
self.pool3 = SelfAttentionPooling(hidden_dim, 0.5)
#把最后的几个全连接层和激活函数 封装在一起
self.mlp = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, num_classes))
def forward(self, adjacency, input_feature, graph_indicator):
#adjacency 拉普拉斯矩阵 (N,N)N为所有节点数
#input_feature 所有节点的特征矩阵 (N,input_dim)
#(N,input_dim) -> (N,hidden_dim)
gcn1 = F.relu(self.gcn1(adjacency, input_feature))
#gcn1 (N,hidden_dim) adjacency(N,N) graph_indicator(N,)每个节点属于哪个图
#pool1 (N',hidden_dim) N'保留的节点数
#pool1_graph_indicator (N',)保留的节点属于哪个图
#pool1_adjacency 保留节点的邻接矩阵(归一化)
pool1, pool1_graph_indicator, pool1_adjacency = \
self.pool1(adjacency, gcn1, graph_indicator)
#global_pool1 (G,hidden_dim*2) G为图数
global_pool1 = torch.cat(
[global_avg_pool(pool1, pool1_graph_indicator),
global_max_pool(pool1, pool1_graph_indicator)],
dim=1)
#pool1_adjacency (N',N') N'保留的节点数 新的图结构对应的拉普拉斯矩阵
#pool1 (N',hiddem_dim)
#(N',hiddem_dim) -> (N',hiddem_dim)
gcn2 = F.relu(self.gcn2(pool1_adjacency, pool1))
#gcn2 (N',hiddem_dim) pool1_adjacency(N',N') pool1_graph_indicator(N',)保留的每个节点属于哪个图
#pool2 (N'',hidden_dim) N''保留的节点数
#pool2_graph_indicator (N'',)保留的节点属于哪个图
#pool2_adjacency 保留节点的邻接矩阵(归一化)
pool2, pool2_graph_indicator, pool2_adjacency = \
self.pool2(pool1_adjacency, gcn2, pool1_graph_indicator)
#global_pool2 (G,hidden_dim*2) G为图数
global_pool2 = torch.cat(
[global_avg_pool(pool2, pool2_graph_indicator),
global_max_pool(pool2, pool2_graph_indicator)],
dim=1)
#pool2_adjacency (N'',N'') N''保留的节点数 新的图结构对应的拉普拉斯矩阵
#pool2 (N'',hiddem_dim)
#(N'',hiddem_dim) -> (N'',hiddem_dim)
gcn3 = F.relu(self.gcn3(pool2_adjacency, pool2))
#gcn3 (N'',hiddem_dim) pool2_adjacency(N'',N'') pool2_graph_indicator(N'',)保留的每个节点属于哪个图
#pool3 (N''',hidden_dim) N'''保留的节点数
#pool3_graph_indicator (N''',)保留的节点属于哪个图
#pool3_adjacency 保留节点的邻接矩阵(归一化)
pool3, pool3_graph_indicator, pool3_adjacency = \
self.pool3(pool2_adjacency, gcn3, pool2_graph_indicator)
#global_pool3 (G,hidden_dim*2) G为图数
global_pool3 = torch.cat(
[global_avg_pool(pool3, pool3_graph_indicator),
global_max_pool(pool3, pool3_graph_indicator)],
dim=1)
#readout (G,hidden_dim*2)
readout = global_pool1 + global_pool2 + global_pool3
#logits (G,num_classes=2)
logits = self.mlp(readout)
return logits
6. 程序流程
- 加载数据
dataset = DDDataset()
- 数据准备
# 模型输入数据准备
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
#所有图对应的大邻接矩阵
adjacency = dataset.sparse_adjacency
#归一化、引入自连接的拉普拉斯矩阵
normalize_adjacency = normalization(adjacency).to(DEVICE)
#所有节点的特征标签
node_labels = tensor_from_numpy(dataset.node_labels, DEVICE)
#把特征标签转换为one-hot特征向量
node_features = F.one_hot(node_labels, node_labels.max().item() + 1).float()
#每个节点对应哪个图
graph_indicator = tensor_from_numpy(dataset.graph_indicator, DEVICE)
#每个图的类别标签
graph_labels = tensor_from_numpy(dataset.graph_labels, DEVICE)
#训练集对应的图索引
train_index = tensor_from_numpy(dataset.train_index, DEVICE)
#测试集对应的图索引
test_index = tensor_from_numpy(dataset.test_index, DEVICE)
#训练集和测试集中的图对应的类别标签
train_label = tensor_from_numpy(dataset.train_label, DEVICE)
test_label = tensor_from_numpy(dataset.test_label, DEVICE)
- 超参数设置
# 超参数设置
INPUT_DIM = node_features.size(1) #特征向量维度
NUM_CLASSES = 2
EPOCHS = 200 # @param {type: "integer"}
HIDDEN_DIM = 32# @param {type: "integer"}
LEARNING_RATE = 0.01 # @param
WEIGHT_DECAY = 0.0001 # @param
- 模型初始化
# 模型初始化
model_g = ModelA(INPUT_DIM, HIDDEN_DIM, NUM_CLASSES).to(DEVICE)
model_h = ModelB(INPUT_DIM, HIDDEN_DIM, NUM_CLASSES).to(DEVICE)
model = model_h #@param ['model_g', 'model_h'] {type: 'raw'}
- 训练
criterion = nn.CrossEntropyLoss().to(DEVICE) #交叉熵损失函数
#Adam优化器
optimizer = optim.Adam(model.parameters(), LEARNING_RATE, weight_decay=WEIGHT_DECAY)
model.train() #训练模式
for epoch in range(EPOCHS):
logits = model(normalize_adjacency, node_features, graph_indicator) #对所有数据(图)前向传播 得到输出
loss = criterion(logits[train_index], train_label) # 只对训练的数据计算损失值
optimizer.zero_grad()
loss.backward() # 反向传播计算参数的梯度
optimizer.step() # 使用优化方法进行梯度更新
#训练集准确率
train_acc = torch.eq(
logits[train_index].max(1)[1], train_label).float().mean()
print("Epoch {:03d}: Loss {:.4f}, TrainAcc {:.4}".format(
epoch, loss.item(), train_acc.item()))
- 测试
model.eval() #测试模式
with torch.no_grad(): #关闭求导
logits = model(normalize_adjacency, node_features, graph_indicator)#所有数据前向传播
test_logits = logits[test_index] #取出测试数据对应的输出
#计算测试数据准确率
test_acc = torch.eq(
test_logits.max(1)[1], test_label
).float().mean()
print(test_acc.item())
来源:CSDN
作者:CoreJT
链接:https://blog.csdn.net/sdu_hao/article/details/104154115