使用Pytorch Geometric 进行链接预测代码示例
对于一个科技周边开发者来说,牢固扎实的基础是十分重要的,golang学习网就来带大家一点点的掌握基础知识点。今天本篇文章带大家了解《使用Pytorch Geometric 进行链接预测代码示例》,主要介绍了,希望对大家的知识积累有所帮助,快点收藏起来吧,否则需要时就找不到了!
PyTorch Geometric (PyG)是构建图神经网络模型和实验各种图卷积的主要工具。在本文中我们将通过链接预测来对其进行介绍。
链接预测答了一个问题:哪两个节点应该相互链接?我们将通过执行“转换分割”,为建模准备数据。为批处理准备专用的图数据加载器。在Torch Geometric中构建一个模型,使用PyTorch Lightning进行训练,并检查模型的性能。
库准备
- Torch 这个就不用多介绍了
- Torch Geometric图形神经网络的主要库,也是本文介绍的重点
- PyTorch Lightning 用于训练、调优和验证模型。它简化了训练的操作
- Sklearn Metrics和Torchmetrics 用于检查模型的性能。
- PyTorch Geometric有一些特定的依赖关系,如果你安装有问题,请参阅其官方文档。
数据准备
我们将使用Cora ML引文数据集。数据集可以通过Torch Geometric访问。
data = tg.datasets.CitationFull(root="data", name="Cora_ML")
默认情况下,Torch Geometric数据集可以返回多个图形。我们看看单个图是什么样子的
data[0] > Data(x=[2995, 2879], edge_index=[2, 16316], y=[2995])
这里的 X是节点的特征。edge_index是2 x (n条边)矩阵(第一维= 2,被解释为:第0行-源节点/“发送方”,第1行-目标节点/“接收方”)。
链接拆分
我们将从拆分数据集中的链接开始。使用20%的图链接作为验证集,10%作为测试集。这里不会向训练数据集中添加负样本,因为这样的负链接将由批处理数据加载器实时创建。
一般来说,负采样会创建“假”样本(在我们的例子中是节点之间的链接),因此模型学习如何区分真实和虚假的链接。负抽样基于抽样的理论和数学,具有一些很好的统计性质。
首先:让我们创建一个链接拆分对象。
link_splitter = tg.transforms.RandomLinkSplit(num_val=0.2, num_test=0.1, add_negative_train_samples=False,disjoint_train_ratio=0.8)
disjoint_train_ratio调节在“监督”阶段将使用多少条边作为训练信息。剩余的边将用于消息传递(网络中的信息传输阶段)。
图神经网络中至少有两种分割边的方法:归纳分割和传导分割。转换方法假设GNN需要从图结构中学习结构模式。在归纳设置中,可以使用节点/边缘标签进行学习。本文最后有两篇论文详细讨论了这些概念,并进行了额外的形式化:([1],[3])。
train_g, val_g, test_g = link_splitter(data[0]) > Data(x=[2995, 2879], edge_index=[2, 2285], y=[2995], edge_label=[9137], edge_label_index=[2, 9137])
在这个操作之后,我们有了一些新的属性:
edge_label :描述边缘是否为真/假。这是我们想要预测的。
edge_label_index 是一个2 x NUM EDGES矩阵,用于存储节点链接。
让我们看看样本的分布
th.unique(train_g.edge_label, return_counts=True) > (tensor([1.]), tensor([9137])) th.unique(val_g.edge_label, return_counts=True) > (tensor([0., 1.]), tensor([3263, 3263])) th.unique(val_g.edge_label, return_counts=True) > (tensor([0., 1.]), tensor([3263, 3263]))
对于训练数据没有负边(我们将训练时创建它们),对于val/测试集——已经以50:50的比例有了一些“假”链接。
模型
现在我们可以在使用GNN进行模型的构建了一个
class GNN(nn.Module):
def __init__(self, dim_in: int, conv_sizes: Tuple[int, ...], act_f: nn.Module = th.relu, dropout: float = 0.1,*args, **kwargs):super().__init__()self.dim_in = dim_inself.dim_out = conv_sizes[-1]self.dropout = dropoutself.act_f = act_flast_in = dim_inlayers = [] # Here we build subsequent graph convolutions.for conv_sz in conv_sizes:# Single graph convolution layerconv = tgnn.SAGEConv(in_channels=last_in, out_channels=conv_sz, *args, **kwargs)last_in = conv_szlayers.append(conv)self.layers = nn.ModuleList(layers) def forward(self, x: th.Tensor, edge_index: th.Tensor) -> th.Tensor:h = x# For every graph convolution in the network...for conv in self.layers:# ... perform node embedding via message passingh = conv(h, edge_index)h = self.act_f(h)if self.dropout:h = nn.functional.dropout(h, p=self.dropout, training=self.training)return h
这个模型中值得注意的部分是一组图卷积——在我们的例子中是SAGEConv。SAGE卷积的正式定义为:
å¾ç
v是当前节点,节点v的N(v)个邻居。要了解更多关于这种卷积类型的信息,请查看GraphSAGE[1]的原始论文
让我们检查一下模型是否可以使用准备好的数据进行预测。这里PyG模型的输入是节点特征X的矩阵和定义edge_index的链接。
gnn = GNN(train_g.x.size()[1], conv_sizes=[512, 256, 128]) with th.no_grad():out = gnn(train_g.x, train_g.edge_index) out > tensor([[0.0000, 0.0000, 0.0051, ..., 0.0997, 0.0000, 0.0000],[0.0107, 0.0000, 0.0576, ..., 0.0651, 0.0000, 0.0000],[0.0000, 0.0000, 0.0102, ..., 0.0973, 0.0000, 0.0000],...,[0.0000, 0.0000, 0.0549, ..., 0.0671, 0.0000, 0.0000],[0.0000, 0.0000, 0.0166, ..., 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0034, ..., 0.1111, 0.0000, 0.0000]])
我们模型的输出是一个维度为:N个节点x嵌入大小的节点嵌入矩阵。
PyTorch Lightning
PyTorch Lightning主要用作训练,但是这里我们在GNN的输出后面增加了一个Linear层做为预测是否链接的输出头。
class LinkPredModel(pl.LightningModule):
def __init__(self,dim_in: int,conv_sizes: Tuple[int, ...], act_f: nn.Module = th.relu, dropout: float = 0.1,lr: float = 0.01,*args, **kwargs):super().__init__() # Our inner GNN modelself.gnn = GNN(dim_in, conv_sizes=conv_sizes, act_f=act_f, dropout=dropout) # Final prediction model on links.self.lin_pred = nn.Linear(self.gnn.dim_out, 1)self.lr = lr def forward(self, x: th.Tensor, edge_index: th.Tensor) -> th.Tensor:# Step 1: make node embeddings using GNN.h = self.gnn(x, edge_index) # Take source nodes embeddings- sendersh_src = h[edge_index[0, :]]# Take target node embeddings - receiversh_dst = h[edge_index[1, :]] # Calculate the product between themsrc_dst_mult = h_src * h_dst# Apply non-linearityout = self.lin_pred(src_dst_mult)return out def _step(self, batch: th.Tensor, phase: str='train') -> th.Tensor:yhat_edge = self(batch.x, batch.edge_label_index).squeeze()y = batch.edge_labelloss = nn.functional.binary_cross_entropy_with_logits(input=yhat_edge, target=y)f1 = tm.functional.f1_score(preds=yhat_edge, target=y, task='binary')prec = tm.functional.precision(preds=yhat_edge, target=y, task='binary')recall = tm.functional.recall(preds=yhat_edge, target=y, task='binary') # Watch for logging here - we need to provide batch_size, as (at the time of this implementation)# PL cannot understand the batch size.self.log(f"{phase}_f1", f1, batch_size=batch.edge_label_index.shape[1])self.log(f"{phase}_loss", loss, batch_size=batch.edge_label_index.shape[1])self.log(f"{phase}_precision", prec, batch_size=batch.edge_label_index.shape[1])self.log(f"{phase}_recall", recall, batch_size=batch.edge_label_index.shape[1])return loss def training_step(self, batch, batch_idx):return self._step(batch) def validation_step(self, batch, batch_idx):return self._step(batch, "val") def test_step(self, batch, batch_idx):return self._step(batch, "test") def predict_step(self, batch):x, edge_index = batchreturn self(x, edge_index) def configure_optimizers(self):return th.optim.Adam(self.parameters(), lr=self.lr)
PyTorch Lightning的作用是帮我们简化了训练的步骤,我们只需要配置一些函数即可,我们可以使用以下命令测试模型是否可用
model = LinkPredModel(val_g.x.size()[1], conv_sizes=[512, 256, 128]) with th.no_grad():out = model.predict_step((val_g.x, val_g.edge_label_index))
训练
对于训练的步骤,需要特殊处理的是数据加载器。
图数据需要特殊处理——尤其是链接预测。PyG有一些专门的数据加载器类,它们负责正确地生成批处理。我们将使用:tg.loader.LinkNeighborLoader,它接受以下输入:
要批量加载的数据(图)。num_neighbors 每个节点在一次“跳”期间加载的最大邻居数量。指定邻居数目的列表1 - 2 - 3 -…-K。对于非常大的图形特别有用。
edge_label_index 哪个属性已经指示了真/假链接。
neg_sampling_ratio -负样本与真实样本的比例。
train_loader = tg.loader.LinkNeighborLoader(train_g,num_neighbors=[-1, 10, 5],batch_size=128,edge_label_index=train_g.edge_label_index, # "on the fly" negative sampling creation for batchneg_sampling_ratio=0.5 ) val_loader = tg.loader.LinkNeighborLoader(val_g,num_neighbors=[-1, 10, 5],batch_size=128,edge_label_index=val_g.edge_label_index,edge_label=val_g.edge_label, # negative samples for val set are done already as ground-truthneg_sampling_ratio=0.0 ) test_loader = tg.loader.LinkNeighborLoader(test_g,num_neighbors=[-1, 10, 5],batch_size=128,edge_label_index=test_g.edge_label_index,edge_label=test_g.edge_label, # negative samples for test set are done already as ground-truthneg_sampling_ratio=0.0 )
下面就是训练模型
model = LinkPredModel(val_g.x.size()[1], conv_sizes=[512, 256, 128]) trainer = pl.Trainer(max_epochs=20, log_every_n_steps=5) # Validate before training - we will see results of untrained model. trainer.validate(model, val_loader) # Train the model trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)
试验数据核对,查看分类报告和ROC曲线。
with th.no_grad():yhat_test_proba = th.sigmoid(model(test_g.x, test_g.edge_label_index)).squeeze()yhat_test_cls = yhat_test_proba >= 0.5 print(classification_report(y_true=test_g.edge_label, y_pred=yhat_test_cls))
结果看起来还不错:
precision recall f1-score support0.0 0.68 0.70 0.69 16311.0 0.69 0.66 0.68 1631accuracy 0.68 3262macro avg 0.68 0.68 0.68 3262
ROC曲线也不错
我们训练的模型并不特别复杂,也没有经过精心调整,但它完成了工作。当然这只是一个为了演示使用的小型数据集。
总结
图神经网络尽管看起来很复杂,但是PyTorch Geometric为我们提供了一个很好的解决方案。我们可以直接使用其中内置的模型实现,这方便了我们使用和简化了入门的门槛。
本文代码:https://github.com/maddataanalyst/blogposts_code/blob/main/graph_nns_series/pyg_pyl_perfect_match/pytorch-geometric-lightning-perfect-match.ipynb
终于介绍完啦!小伙伴们,这篇关于《使用Pytorch Geometric 进行链接预测代码示例》的介绍应该让你收获多多了吧!欢迎大家收藏或分享给更多需要学习的朋友吧~golang学习网公众号也会发布科技周边相关知识,快来关注吧!

- 上一篇
- Gartner:2024年优秀战略技术趋势

- 下一篇
- 生成式人工智能在软件开发过程现代化中的作用
-
- 科技周边 · 人工智能 | 5小时前 | 智能辅助驾驶 firefly萤火虫 地平线征程 高端智能电动小车 全球市场
- 地平线与蔚来合作车型firefly萤火虫正式上市
- 245浏览 收藏
-
- 科技周边 · 人工智能 | 6小时前 |
- 即梦ai添加时间戳教程即梦ai日期水印设置攻略
- 369浏览 收藏
-
- 科技周边 · 人工智能 | 6小时前 |
- 小米汽车上险量下降:YU7投产惹的祸
- 499浏览 收藏
-
- 科技周边 · 人工智能 | 15小时前 |
- MistralAI发布多模态模型MistralMedium3
- 446浏览 收藏
-
- 前端进阶之JavaScript设计模式
- 设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
- 542次学习
-
- GO语言核心编程课程
- 本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
- 508次学习
-
- 简单聊聊mysql8与网络通信
- 如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
- 497次学习
-
- JavaScript正则表达式基础与实战
- 在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
- 487次学习
-
- 从零制作响应式网站—Grid布局
- 本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
- 484次学习
-
- PPTFake答辩PPT生成器
- PPTFake答辩PPT生成器,专为答辩准备设计,极致高效生成PPT与自述稿。智能解析内容,提供多样模板,数据可视化,贴心配套服务,灵活自主编辑,降低制作门槛,适用于各类答辩场景。
- 14次使用
-
- Lovart
- SEO摘要探索Lovart AI,这款专注于设计领域的AI智能体,通过多模态模型集成和智能任务拆解,实现全链路设计自动化。无论是品牌全案设计、广告与视频制作,还是文创内容创作,Lovart AI都能满足您的需求,提升设计效率,降低成本。
- 14次使用
-
- 美图AI抠图
- 美图AI抠图,依托CVPR 2024竞赛亚军技术,提供顶尖的图像处理解决方案。适用于证件照、商品、毛发等多场景,支持批量处理,3秒出图,零PS基础也能轻松操作,满足个人与商业需求。
- 27次使用
-
- PetGPT
- SEO摘要PetGPT 是一款基于 Python 和 PyQt 开发的智能桌面宠物程序,集成了 OpenAI 的 GPT 模型,提供上下文感知对话和主动聊天功能。用户可高度自定义宠物的外观和行为,支持插件热更新和二次开发。适用于需要陪伴和效率辅助的办公族、学生及 AI 技术爱好者。
- 26次使用
-
- 可图AI图片生成
- 探索快手旗下可灵AI2.0发布的可图AI2.0图像生成大模型,体验从文本生成图像、图像编辑到风格转绘的全链路创作。了解其技术突破、功能创新及在广告、影视、非遗等领域的应用,领先于Midjourney、DALL-E等竞品。
- 53次使用
-
- GPT-4王者加冕!读图做题性能炸天,凭自己就能考上斯坦福
- 2023-04-25 501浏览
-
- 单块V100训练模型提速72倍!尤洋团队新成果获AAAI 2023杰出论文奖
- 2023-04-24 501浏览
-
- ChatGPT 真的会接管世界吗?
- 2023-04-13 501浏览
-
- VR的终极形态是「假眼」?Neuralink前联合创始人掏出新产品:科学之眼!
- 2023-04-30 501浏览
-
- 实现实时制造可视性优势有哪些?
- 2023-04-15 501浏览