VisionTransformer多标签分类详解
本文深入解析了如何将Vision Transformer(ViT)应用于多标签分类任务,有别于传统的单标签多分类,多标签分类允许图像同时属于多个类别。文章强调了核心概念的转变,特别是损失函数的选择,指出`CrossEntropyLoss`不再适用,并详细介绍了`BCEWithLogitsLoss`的优势与使用方法,包括模型输出和标签格式的要求。此外,文章还全面阐述了多标签分类中常用的评估指标,如精确率、召回率、F1分数和mAP,并提供了使用`torchmetrics`库的代码示例,助力读者掌握ViT在多标签环境下的训练与评估,从而提升模型性能。

本文旨在详细阐述如何将Vision Transformer(ViT)从单标签多分类任务转换为多标签分类任务,并重点介绍损失函数的选择与评估策略的调整。我们将探讨为何`CrossEntropyLoss`不适用于多标签场景,并深入讲解`BCEWithLogitsLoss`的使用方法,包括标签格式要求。此外,文章还将介绍多标签分类任务中常用的评估指标,如精确率、召回率、F1分数和mAP,并提供代码示例,确保读者能够顺利实现ViT在多标签环境下的训练与评估。
从单标签到多标签:核心概念转变
在深度学习的图像分类任务中,单标签多分类(Single-label Multi-class Classification)是指每张图片只属于一个类别,模型需要从多个互斥的类别中预测出唯一正确的那个。而多标签分类(Multi-label Classification)则允许每张图片同时属于一个或多个类别,模型需要为每个类别独立地判断其是否存在于图片中。
这种任务性质的转变,要求我们对模型的输出层、损失函数以及评估策略进行相应的调整。对于Vision Transformer(ViT)而言,其特征提取部分通常保持不变,但最终的分类头和训练流程需要进行适配。
损失函数的选择与实现
在单标签多分类任务中,我们通常使用torch.nn.CrossEntropyLoss作为损失函数。它内部包含了Softmax激活函数和负对数似然损失,期望模型的输出是每个类别的Logits,并且这些Logits经过Softmax后会转化为概率分布,所有类别的概率和为1。
然而,在多标签分类任务中,由于图片可能同时属于多个类别,各个类别之间不再是互斥关系。因此,CrossEntropyLoss不再适用,因为它强制了类别之间的互斥性。
推荐的损失函数:torch.nn.BCEWithLogitsLoss
对于多标签分类任务,最常用且推荐的损失函数是torch.nn.BCEWithLogitsLoss。这个损失函数结合了Sigmoid激活函数和二元交叉熵损失(Binary Cross Entropy Loss)。
其主要优点包括:
- 独立处理每个类别: BCEWithLogitsLoss会对模型输出的每个Logit独立地计算二元交叉熵,这与多标签任务中各类别独立存在的特性相符。
- 数值稳定性: 它直接作用于模型的原始Logits输出,内部处理了Sigmoid激活,避免了先手动计算Sigmoid再计算交叉熵可能导致的数值溢出或下溢问题。
使用BCEWithLogitsLoss的注意事项:
- 模型输出: 模型的最终输出层应该是一个全连接层,输出维度等于类别的总数,且不应在其后接Softmax激活函数。例如,如果你的模型有7个类别,最终输出应为形状(batch_size, 7)的Logits张量。
- 标签格式: 标签(target)必须是与模型输出Logits形状相同的浮点型(torch.float)张量。它通常是一个“多热编码”(multi-hot encoding)向量,其中1表示该类别存在,0表示该类别不存在。例如,[0, 1, 1, 0, 0, 1, 0]表示第二个、第三个和第六个类别存在。
代码示例:替换损失函数
假设我们有一个ViT模型,其输出为pred(Logits),标签为labels(多热编码)。
import torch
import torch.nn as nn
# 假设模型输出的Logits,形状为 (batch_size, num_classes)
# 这里以 batch_size = 2, num_classes = 7 为例
logits = torch.randn(2, 7) # 模拟模型输出的原始Logits
# 假设对应的多标签,形状也为 (batch_size, num_classes)
# 注意:标签必须是浮点型 (torch.float)
labels = torch.tensor([
[0, 1, 1, 0, 0, 1, 0], # 第一个样本的标签
[1, 0, 1, 1, 0, 0, 0] # 第二个样本的标签
]).float()
# 实例化 BCEWithLogitsLoss
loss_function = nn.BCEWithLogitsLoss()
# 计算损失
loss = loss_function(logits, labels)
print(f"Logits:\n{logits}")
print(f"Labels:\n{labels}")
print(f"Calculated Loss: {loss.item()}")
# 原始训练循环中的应用
# pred = model(images.to(device))
# loss = loss_function(pred, labels.to(device))
# loss.backward()
# optimizer.step()多标签分类的评估策略
在单标签分类中,准确率(Accuracy)是最常用的评估指标。然而,在多标签分类中,仅仅计算准确率是不足够的,甚至可能产生误导。例如,如果一个模型总是预测所有类别都不存在,而实际只有少数类别存在,那么它的准确率可能很高(因为它正确预测了大量不存在的类别),但它对存在类别的识别能力却很差。
因此,我们需要采用更全面的指标来评估多标签分类模型的性能。
1. 从Logits到预测结果
在计算评估指标之前,我们需要将模型的Logits输出转换为具体的类别预测。这通常通过对Logits应用Sigmoid函数,然后设定一个阈值(例如0.5)来完成。
# 假设 logits 是模型输出的Logits
# 例如:logits = torch.randn(batch_size, num_classes)
# 1. 应用Sigmoid函数将Logits转换为概率
probabilities = torch.sigmoid(logits)
# 2. 设定阈值,将概率转换为二元预测 (0或1)
threshold = 0.5
predictions = (probabilities > threshold).float()
print(f"Probabilities:\n{probabilities}")
print(f"Predictions (threshold={threshold}):\n{predictions}")2. 常用评估指标
以下是多标签分类中常用的评估指标:
精确率(Precision)、召回率(Recall)、F1分数(F1-score):
- 精确率: 预测为正例的样本中,有多少是真正的正例。
- 召回率: 实际为正例的样本中,有多少被模型预测为正例。
- F1分数: 精确率和召回率的调和平均值,综合衡量模型的性能。
- 这些指标可以针对每个类别独立计算(Per-class),也可以通过微平均(Micro-average)或宏平均(Macro-average)来汇总所有类别的结果。
- Micro-average: 汇总所有类别的TP、FP、FN后再计算总体的Precision、Recall、F1。它更侧重于样本级别的性能,受样本数量较多的类别影响较大。
- Macro-average: 先计算每个类别的Precision、Recall、F1,然后取这些值的平均。它给予每个类别相同的权重,不受类别样本数量不平衡的影响。
平均精确率(Average Precision, AP)与平均精确率均值(mean Average Precision, mAP):
- AP: 衡量单个类别在不同召回率下的精确率表现,通常通过计算PR曲线下面积获得。AP值越高,说明模型在该类别上的性能越好。
- mAP: 对所有类别的AP值取平均,是衡量多标签分类模型整体性能的一个非常重要的指标,尤其在目标检测等领域广泛使用。
Jaccard Index (IoU) / Jaccard Similarity Score:
- 衡量预测集合与真实标签集合的相似度,计算公式为交集大小除以并集大小。对于多标签分类,可以计算每个样本的预测标签集合与真实标签集合的Jaccard相似度,然后取平均。
Hamming Loss:
- 衡量预测结果与真实标签不一致的标签比例。Hamming Loss越低越好。
3. 使用torchmetrics或scikit-learn进行评估
在PyTorch生态中,torchmetrics库提供了丰富的多标签评估指标。scikit-learn也是一个非常强大的工具,可以在CPU上方便地进行评估。
torchmetrics示例 (推荐用于PyTorch训练循环中):
import torch
from torchmetrics.classification import MultilabelF1Score, MultilabelAveragePrecision
# 假设真实标签和预测概率
# num_classes = 7
num_labels = 7
num_samples = 10
target_labels = torch.randint(0, 2, (num_samples, num_labels)).float() # 真实标签 (0或1)
predicted_probs = torch.rand(num_samples, num_labels) # 模型输出的概率 (经过Sigmoid)
# 或者直接使用Logits,让metrics内部处理Sigmoid
predicted_logits = torch.randn(num_samples, num_labels)
# 实例化F1分数,可以指定 average 方式 (e.g., 'micro', 'macro', 'weighted', 'none')
# MultilabelF1Score 期望输入是 (preds, target)
# preds: 概率 (float) 或 原始logits (float)
# target: 真实标签 (int 或 float, 0/1)
f1_score_micro = MultilabelF1Score(num_labels=num_labels, average='micro', validate_args=False)
f1_score_macro = MultilabelF1Score(num_labels=num_labels, average='macro', validate_args=False)
# 计算F1分数
# 注意:MultilabelF1Score 可以直接接收概率或logits,但通常建议给概率
f1_micro_val = f1_score_micro(predicted_probs, target_labels.long()) # target_labels需要是long类型对于F1Score
f1_macro_val = f1_score_macro(predicted_probs, target_labels.long())
print(f"Micro F1 Score: {f1_micro_val.item()}")
print(f"Macro F1 Score: {f1_macro_val.item()}")
# 实例化mAP
# MultilabelAveragePrecision 期望输入是 (preds, target)
# preds: 概率 (float)
# target: 真实标签 (int 或 float, 0/1)
map_metric = MultilabelAveragePrecision(num_labels=num_labels, validate_args=False)
# 计算mAP
map_val = map_metric(predicted_probs, target_labels.long()) # target_labels需要是long类型对于mAP
print(f"mAP: {map_val.item()}")
# 如果输入是logits,可以这样处理 (MultilabelF1Score 和 MultilabelAveragePrecision 默认不带sigmoid,需要手动处理或确保其内部处理了)
# 对于MultilabelF1Score和MultilabelAveragePrecision,当输入是概率时,通常需要手动将target转换为long
# 如果输入是logits,则需要确保metrics内部会执行sigmoid
# 更好的做法是,统一将模型输出转换为概率再传入metrics
probs_from_logits = torch.sigmoid(predicted_logits)
f1_micro_val_logits = f1_score_micro(probs_from_logits, target_labels.long())
map_val_logits = map_metric(probs_from_logits, target_labels.long())
print(f"Micro F1 Score (from logits): {f1_micro_val_logits.item()}")
print(f"mAP (from logits): {map_val_logits.item()}")总结与注意事项
将ViT从单标签多分类转换为多标签分类,关键在于以下几点:
- 模型输出层: 确保模型的最终全连接层输出与类别数量相等的Logits,并且不带Softmax激活。
- 损失函数: 使用torch.nn.BCEWithLogitsLoss作为损失函数,它能独立处理每个类别的预测。
- 标签格式: 真实标签应为多热编码的浮点型张量,形状与模型输出的Logits相同。
- 评估指标: 采用适合多标签任务的评估指标,如Micro/Macro F1分数、mAP、Jaccard Index等,并结合torchmetrics或scikit-learn等库进行高效计算。
- 阈值选择: 在将概率转换为二元预测时,阈值的选择对最终的精确率和召回率有显著影响,可能需要通过验证集进行调优。
- 类别不平衡: 在多标签任务中,类别不平衡问题可能更复杂(例如,某些标签总是同时出现,某些标签非常稀有)。可以考虑使用加权BCE损失、Focal Loss或采样策略来缓解。
通过以上调整,您的Vision Transformer模型将能够有效地处理多标签图像分类任务。
好了,本文到此结束,带大家了解了《VisionTransformer多标签分类详解》,希望本文对你有所帮助!关注golang学习网公众号,给大家分享更多文章知识!
PHP调试指南:Xdebug安装与断点设置
- 上一篇
- PHP调试指南:Xdebug安装与断点设置
- 下一篇
- 正规机票预订官网App推荐清单
-
- 文章 · python教程 | 38分钟前 |
- Pandas读取Django表格:协议关键作用
- 401浏览 收藏
-
- 文章 · python教程 | 45分钟前 | 身份验证 断点续传 requests库 PythonAPI下载 urllib库
- Python调用API下载文件方法
- 227浏览 收藏
-
- 文章 · python教程 | 47分钟前 |
- Windows7安装RtMidi失败解决办法
- 400浏览 收藏
-
- 文章 · python教程 | 53分钟前 |
- Python异步任务优化技巧分享
- 327浏览 收藏
-
- 文章 · python教程 | 1小时前 |
- PyCharm图形界面显示问题解决方法
- 124浏览 收藏
-
- 文章 · python教程 | 2小时前 |
- Python自定义异常类怎么创建
- 450浏览 收藏
-
- 文章 · python教程 | 2小时前 |
- Python抓取赛狗数据:指定日期赛道API教程
- 347浏览 收藏
-
- 文章 · python教程 | 2小时前 |
- Python3中datetime常用转换方式有哪些?
- 464浏览 收藏
-
- 文章 · python教程 | 2小时前 |
- PyCharm无解释器问题解决方法
- 290浏览 收藏
-
- 文章 · python教程 | 3小时前 | 性能优化 Python正则表达式 re模块 匹配结果 正则模式
- Python正则表达式入门与使用技巧
- 112浏览 收藏
-
- 文章 · python教程 | 3小时前 |
- MacPython兼容LibreSSL的解决方法
- 324浏览 收藏
-
- 文章 · python教程 | 3小时前 |
- OdooQWeb浮点转整数技巧
- 429浏览 收藏
-
- 前端进阶之JavaScript设计模式
- 设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
- 543次学习
-
- GO语言核心编程课程
- 本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
- 516次学习
-
- 简单聊聊mysql8与网络通信
- 如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
- 500次学习
-
- JavaScript正则表达式基础与实战
- 在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
- 487次学习
-
- 从零制作响应式网站—Grid布局
- 本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
- 485次学习
-
- ChatExcel酷表
- ChatExcel酷表是由北京大学团队打造的Excel聊天机器人,用自然语言操控表格,简化数据处理,告别繁琐操作,提升工作效率!适用于学生、上班族及政府人员。
- 3178次使用
-
- Any绘本
- 探索Any绘本(anypicturebook.com/zh),一款开源免费的AI绘本创作工具,基于Google Gemini与Flux AI模型,让您轻松创作个性化绘本。适用于家庭、教育、创作等多种场景,零门槛,高自由度,技术透明,本地可控。
- 3390次使用
-
- 可赞AI
- 可赞AI,AI驱动的办公可视化智能工具,助您轻松实现文本与可视化元素高效转化。无论是智能文档生成、多格式文本解析,还是一键生成专业图表、脑图、知识卡片,可赞AI都能让信息处理更清晰高效。覆盖数据汇报、会议纪要、内容营销等全场景,大幅提升办公效率,降低专业门槛,是您提升工作效率的得力助手。
- 3418次使用
-
- 星月写作
- 星月写作是国内首款聚焦中文网络小说创作的AI辅助工具,解决网文作者从构思到变现的全流程痛点。AI扫榜、专属模板、全链路适配,助力新人快速上手,资深作者效率倍增。
- 4523次使用
-
- MagicLight
- MagicLight.ai是全球首款叙事驱动型AI动画视频创作平台,专注于解决从故事想法到完整动画的全流程痛点。它通过自研AI模型,保障角色、风格、场景高度一致性,让零动画经验者也能高效产出专业级叙事内容。广泛适用于独立创作者、动画工作室、教育机构及企业营销,助您轻松实现创意落地与商业化。
- 3797次使用
-
- Flask框架安装技巧:让你的开发更高效
- 2024-01-03 501浏览
-
- Django框架中的并发处理技巧
- 2024-01-22 501浏览
-
- 提升Python包下载速度的方法——正确配置pip的国内源
- 2024-01-17 501浏览
-
- Python与C++:哪个编程语言更适合初学者?
- 2024-03-25 501浏览
-
- 品牌建设技巧
- 2024-04-06 501浏览

