当前位置:首页 > 文章列表 > 文章 > python教程 > VisionTransformer多标签分类详解

VisionTransformer多标签分类详解

2025-11-09 08:00:37 0浏览 收藏

本文深入解析了如何将Vision Transformer(ViT)应用于多标签分类任务,有别于传统的单标签多分类,多标签分类允许图像同时属于多个类别。文章强调了核心概念的转变,特别是损失函数的选择,指出`CrossEntropyLoss`不再适用,并详细介绍了`BCEWithLogitsLoss`的优势与使用方法,包括模型输出和标签格式的要求。此外,文章还全面阐述了多标签分类中常用的评估指标,如精确率、召回率、F1分数和mAP,并提供了使用`torchmetrics`库的代码示例,助力读者掌握ViT在多标签环境下的训练与评估,从而提升模型性能。

Vision Transformer多标签分类:损失函数与评估策略深度解析

本文旨在详细阐述如何将Vision Transformer(ViT)从单标签多分类任务转换为多标签分类任务,并重点介绍损失函数的选择与评估策略的调整。我们将探讨为何`CrossEntropyLoss`不适用于多标签场景,并深入讲解`BCEWithLogitsLoss`的使用方法,包括标签格式要求。此外,文章还将介绍多标签分类任务中常用的评估指标,如精确率、召回率、F1分数和mAP,并提供代码示例,确保读者能够顺利实现ViT在多标签环境下的训练与评估。

从单标签到多标签:核心概念转变

在深度学习的图像分类任务中,单标签多分类(Single-label Multi-class Classification)是指每张图片只属于一个类别,模型需要从多个互斥的类别中预测出唯一正确的那个。而多标签分类(Multi-label Classification)则允许每张图片同时属于一个或多个类别,模型需要为每个类别独立地判断其是否存在于图片中。

这种任务性质的转变,要求我们对模型的输出层、损失函数以及评估策略进行相应的调整。对于Vision Transformer(ViT)而言,其特征提取部分通常保持不变,但最终的分类头和训练流程需要进行适配。

损失函数的选择与实现

在单标签多分类任务中,我们通常使用torch.nn.CrossEntropyLoss作为损失函数。它内部包含了Softmax激活函数和负对数似然损失,期望模型的输出是每个类别的Logits,并且这些Logits经过Softmax后会转化为概率分布,所有类别的概率和为1。

然而,在多标签分类任务中,由于图片可能同时属于多个类别,各个类别之间不再是互斥关系。因此,CrossEntropyLoss不再适用,因为它强制了类别之间的互斥性。

推荐的损失函数:torch.nn.BCEWithLogitsLoss

对于多标签分类任务,最常用且推荐的损失函数是torch.nn.BCEWithLogitsLoss。这个损失函数结合了Sigmoid激活函数和二元交叉熵损失(Binary Cross Entropy Loss)。

其主要优点包括:

  1. 独立处理每个类别: BCEWithLogitsLoss会对模型输出的每个Logit独立地计算二元交叉熵,这与多标签任务中各类别独立存在的特性相符。
  2. 数值稳定性: 它直接作用于模型的原始Logits输出,内部处理了Sigmoid激活,避免了先手动计算Sigmoid再计算交叉熵可能导致的数值溢出或下溢问题。

使用BCEWithLogitsLoss的注意事项:

  1. 模型输出: 模型的最终输出层应该是一个全连接层,输出维度等于类别的总数,且不应在其后接Softmax激活函数。例如,如果你的模型有7个类别,最终输出应为形状(batch_size, 7)的Logits张量。
  2. 标签格式: 标签(target)必须是与模型输出Logits形状相同的浮点型(torch.float)张量。它通常是一个“多热编码”(multi-hot encoding)向量,其中1表示该类别存在,0表示该类别不存在。例如,[0, 1, 1, 0, 0, 1, 0]表示第二个、第三个和第六个类别存在。

代码示例:替换损失函数

假设我们有一个ViT模型,其输出为pred(Logits),标签为labels(多热编码)。

import torch
import torch.nn as nn

# 假设模型输出的Logits,形状为 (batch_size, num_classes)
# 这里以 batch_size = 2, num_classes = 7 为例
logits = torch.randn(2, 7) # 模拟模型输出的原始Logits

# 假设对应的多标签,形状也为 (batch_size, num_classes)
# 注意:标签必须是浮点型 (torch.float)
labels = torch.tensor([
    [0, 1, 1, 0, 0, 1, 0], # 第一个样本的标签
    [1, 0, 1, 1, 0, 0, 0]  # 第二个样本的标签
]).float()

# 实例化 BCEWithLogitsLoss
loss_function = nn.BCEWithLogitsLoss()

# 计算损失
loss = loss_function(logits, labels)

print(f"Logits:\n{logits}")
print(f"Labels:\n{labels}")
print(f"Calculated Loss: {loss.item()}")

# 原始训练循环中的应用
# pred = model(images.to(device))
# loss = loss_function(pred, labels.to(device))
# loss.backward()
# optimizer.step()

多标签分类的评估策略

在单标签分类中,准确率(Accuracy)是最常用的评估指标。然而,在多标签分类中,仅仅计算准确率是不足够的,甚至可能产生误导。例如,如果一个模型总是预测所有类别都不存在,而实际只有少数类别存在,那么它的准确率可能很高(因为它正确预测了大量不存在的类别),但它对存在类别的识别能力却很差。

因此,我们需要采用更全面的指标来评估多标签分类模型的性能。

1. 从Logits到预测结果

在计算评估指标之前,我们需要将模型的Logits输出转换为具体的类别预测。这通常通过对Logits应用Sigmoid函数,然后设定一个阈值(例如0.5)来完成。

# 假设 logits 是模型输出的Logits
# 例如:logits = torch.randn(batch_size, num_classes)

# 1. 应用Sigmoid函数将Logits转换为概率
probabilities = torch.sigmoid(logits)

# 2. 设定阈值,将概率转换为二元预测 (0或1)
threshold = 0.5
predictions = (probabilities > threshold).float()

print(f"Probabilities:\n{probabilities}")
print(f"Predictions (threshold={threshold}):\n{predictions}")

2. 常用评估指标

以下是多标签分类中常用的评估指标:

  • 精确率(Precision)、召回率(Recall)、F1分数(F1-score):

    • 精确率: 预测为正例的样本中,有多少是真正的正例。
    • 召回率: 实际为正例的样本中,有多少被模型预测为正例。
    • F1分数: 精确率和召回率的调和平均值,综合衡量模型的性能。
    • 这些指标可以针对每个类别独立计算(Per-class),也可以通过微平均(Micro-average)或宏平均(Macro-average)来汇总所有类别的结果。
      • Micro-average: 汇总所有类别的TP、FP、FN后再计算总体的Precision、Recall、F1。它更侧重于样本级别的性能,受样本数量较多的类别影响较大。
      • Macro-average: 先计算每个类别的Precision、Recall、F1,然后取这些值的平均。它给予每个类别相同的权重,不受类别样本数量不平衡的影响。
  • 平均精确率(Average Precision, AP)与平均精确率均值(mean Average Precision, mAP):

    • AP: 衡量单个类别在不同召回率下的精确率表现,通常通过计算PR曲线下面积获得。AP值越高,说明模型在该类别上的性能越好。
    • mAP: 对所有类别的AP值取平均,是衡量多标签分类模型整体性能的一个非常重要的指标,尤其在目标检测等领域广泛使用。
  • Jaccard Index (IoU) / Jaccard Similarity Score:

    • 衡量预测集合与真实标签集合的相似度,计算公式为交集大小除以并集大小。对于多标签分类,可以计算每个样本的预测标签集合与真实标签集合的Jaccard相似度,然后取平均。
  • Hamming Loss:

    • 衡量预测结果与真实标签不一致的标签比例。Hamming Loss越低越好。

3. 使用torchmetrics或scikit-learn进行评估

在PyTorch生态中,torchmetrics库提供了丰富的多标签评估指标。scikit-learn也是一个非常强大的工具,可以在CPU上方便地进行评估。

torchmetrics示例 (推荐用于PyTorch训练循环中):

import torch
from torchmetrics.classification import MultilabelF1Score, MultilabelAveragePrecision

# 假设真实标签和预测概率
# num_classes = 7
num_labels = 7
num_samples = 10
target_labels = torch.randint(0, 2, (num_samples, num_labels)).float() # 真实标签 (0或1)
predicted_probs = torch.rand(num_samples, num_labels) # 模型输出的概率 (经过Sigmoid)

# 或者直接使用Logits,让metrics内部处理Sigmoid
predicted_logits = torch.randn(num_samples, num_labels)


# 实例化F1分数,可以指定 average 方式 (e.g., 'micro', 'macro', 'weighted', 'none')
# MultilabelF1Score 期望输入是 (preds, target)
# preds: 概率 (float) 或 原始logits (float)
# target: 真实标签 (int 或 float, 0/1)
f1_score_micro = MultilabelF1Score(num_labels=num_labels, average='micro', validate_args=False)
f1_score_macro = MultilabelF1Score(num_labels=num_labels, average='macro', validate_args=False)

# 计算F1分数
# 注意:MultilabelF1Score 可以直接接收概率或logits,但通常建议给概率
f1_micro_val = f1_score_micro(predicted_probs, target_labels.long()) # target_labels需要是long类型对于F1Score
f1_macro_val = f1_score_macro(predicted_probs, target_labels.long())


print(f"Micro F1 Score: {f1_micro_val.item()}")
print(f"Macro F1 Score: {f1_macro_val.item()}")

# 实例化mAP
# MultilabelAveragePrecision 期望输入是 (preds, target)
# preds: 概率 (float)
# target: 真实标签 (int 或 float, 0/1)
map_metric = MultilabelAveragePrecision(num_labels=num_labels, validate_args=False)

# 计算mAP
map_val = map_metric(predicted_probs, target_labels.long()) # target_labels需要是long类型对于mAP

print(f"mAP: {map_val.item()}")

# 如果输入是logits,可以这样处理 (MultilabelF1Score 和 MultilabelAveragePrecision 默认不带sigmoid,需要手动处理或确保其内部处理了)
# 对于MultilabelF1Score和MultilabelAveragePrecision,当输入是概率时,通常需要手动将target转换为long
# 如果输入是logits,则需要确保metrics内部会执行sigmoid
# 更好的做法是,统一将模型输出转换为概率再传入metrics
probs_from_logits = torch.sigmoid(predicted_logits)
f1_micro_val_logits = f1_score_micro(probs_from_logits, target_labels.long())
map_val_logits = map_metric(probs_from_logits, target_labels.long())
print(f"Micro F1 Score (from logits): {f1_micro_val_logits.item()}")
print(f"mAP (from logits): {map_val_logits.item()}")

总结与注意事项

将ViT从单标签多分类转换为多标签分类,关键在于以下几点:

  1. 模型输出层: 确保模型的最终全连接层输出与类别数量相等的Logits,并且不带Softmax激活。
  2. 损失函数: 使用torch.nn.BCEWithLogitsLoss作为损失函数,它能独立处理每个类别的预测。
  3. 标签格式: 真实标签应为多热编码的浮点型张量,形状与模型输出的Logits相同。
  4. 评估指标: 采用适合多标签任务的评估指标,如Micro/Macro F1分数、mAP、Jaccard Index等,并结合torchmetrics或scikit-learn等库进行高效计算。
  5. 阈值选择: 在将概率转换为二元预测时,阈值的选择对最终的精确率和召回率有显著影响,可能需要通过验证集进行调优。
  6. 类别不平衡: 在多标签任务中,类别不平衡问题可能更复杂(例如,某些标签总是同时出现,某些标签非常稀有)。可以考虑使用加权BCE损失、Focal Loss或采样策略来缓解。

通过以上调整,您的Vision Transformer模型将能够有效地处理多标签图像分类任务。

好了,本文到此结束,带大家了解了《VisionTransformer多标签分类详解》,希望本文对你有所帮助!关注golang学习网公众号,给大家分享更多文章知识!

PHP调试指南:Xdebug安装与断点设置PHP调试指南:Xdebug安装与断点设置
上一篇
PHP调试指南:Xdebug安装与断点设置
正规机票预订官网App推荐清单
下一篇
正规机票预订官网App推荐清单
查看更多
最新文章
查看更多
课程推荐
  • 前端进阶之JavaScript设计模式
    前端进阶之JavaScript设计模式
    设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
    543次学习
  • GO语言核心编程课程
    GO语言核心编程课程
    本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
    516次学习
  • 简单聊聊mysql8与网络通信
    简单聊聊mysql8与网络通信
    如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
    500次学习
  • JavaScript正则表达式基础与实战
    JavaScript正则表达式基础与实战
    在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
    487次学习
  • 从零制作响应式网站—Grid布局
    从零制作响应式网站—Grid布局
    本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
    485次学习
查看更多
AI推荐
  • ChatExcel酷表:告别Excel难题,北大团队AI助手助您轻松处理数据
    ChatExcel酷表
    ChatExcel酷表是由北京大学团队打造的Excel聊天机器人,用自然语言操控表格,简化数据处理,告别繁琐操作,提升工作效率!适用于学生、上班族及政府人员。
    3178次使用
  • Any绘本:开源免费AI绘本创作工具深度解析
    Any绘本
    探索Any绘本(anypicturebook.com/zh),一款开源免费的AI绘本创作工具,基于Google Gemini与Flux AI模型,让您轻松创作个性化绘本。适用于家庭、教育、创作等多种场景,零门槛,高自由度,技术透明,本地可控。
    3390次使用
  • 可赞AI:AI驱动办公可视化智能工具,一键高效生成文档图表脑图
    可赞AI
    可赞AI,AI驱动的办公可视化智能工具,助您轻松实现文本与可视化元素高效转化。无论是智能文档生成、多格式文本解析,还是一键生成专业图表、脑图、知识卡片,可赞AI都能让信息处理更清晰高效。覆盖数据汇报、会议纪要、内容营销等全场景,大幅提升办公效率,降低专业门槛,是您提升工作效率的得力助手。
    3418次使用
  • 星月写作:AI网文创作神器,助力爆款小说速成
    星月写作
    星月写作是国内首款聚焦中文网络小说创作的AI辅助工具,解决网文作者从构思到变现的全流程痛点。AI扫榜、专属模板、全链路适配,助力新人快速上手,资深作者效率倍增。
    4523次使用
  • MagicLight.ai:叙事驱动AI动画视频创作平台 | 高效生成专业级故事动画
    MagicLight
    MagicLight.ai是全球首款叙事驱动型AI动画视频创作平台,专注于解决从故事想法到完整动画的全流程痛点。它通过自研AI模型,保障角色、风格、场景高度一致性,让零动画经验者也能高效产出专业级叙事内容。广泛适用于独立创作者、动画工作室、教育机构及企业营销,助您轻松实现创意落地与商业化。
    3797次使用
微信登录更方便
  • 密码登录
  • 注册账号
登录即同意 用户协议隐私政策
返回登录
  • 重置密码