当前位置:首页 > 文章列表 > 文章 > python教程 > PyTorchCrossEntropyLoss错误解决方法

PyTorchCrossEntropyLoss错误解决方法

2025-11-20 14:07:48 0浏览 收藏

来到golang学习网的大家,相信都是编程学习爱好者,希望在这里学习文章相关编程知识。下面本篇文章就来带大家聊聊《PyTorch CrossEntropyLoss类型错误解决方法》,介绍一下,希望对大家的知识积累有所帮助,助力实战开发!

PyTorch CrossEntropyLoss中的数据类型错误解析与最佳实践

本文深入探讨了PyTorch中`CrossEntropyLoss`常见的`RuntimeError: expected scalar type Long but found Float`错误。该错误通常源于目标标签(target)的数据类型不符合`CrossEntropyLoss`的预期。我们将详细解析错误原因,并提供如何在训练循环中正确使用`CrossEntropyLoss`,包括标签类型转换、输入顺序以及避免重复应用Softmax等关键最佳实践,以确保模型训练的稳定性和准确性。

在深度学习的分类任务中,torch.nn.CrossEntropyLoss是一个非常常用的损失函数。它结合了LogSoftmax和负对数似然损失(NLLLoss),能够高效地处理多分类问题。然而,初学者在使用时常会遇到一个特定的运行时错误:RuntimeError: expected scalar type Long but found Float。这个错误明确指出,CrossEntropyLoss在处理其目标标签(target)时,期望的数据类型是torch.Long(即64位整数),但实际接收到的是torch.Float。

理解CrossEntropyLoss的工作原理

CrossEntropyLoss函数在PyTorch中通常接收两个主要参数:

  1. input (或 logits):这是模型的原始输出,通常是未经Softmax激活函数处理的“对数几率”(logits)。它的形状通常是 (N, C),其中 N 是批量大小,C 是类别数量。对于图像任务,如果模型输出是像素级别的分类(如U-Net),则形状可能是 (N, C, H, W)。
  2. target (或 labels):这是真实的类别标签。它应该包含每个样本的类别索引,其数据类型必须是torch.long(或torch.int64)。它的形状通常是 (N),对于像素级别的分类,形状可能是 (N, H, W)。target中的值应介于 0 到 C-1 之间,代表对应的类别索引。

关键点: CrossEntropyLoss内部会自行执行Softmax操作,因此,向其传入经过Softmax处理的概率值是不正确的,这可能导致数值不稳定或不准确的损失计算。

RuntimeError: expected scalar type Long but found Float 错误解析与修正

这个错误的核心在于target张量的数据类型不匹配。在提供的代码片段中,错误发生在以下这行:

loss = criterion(output, labels.float())

尽管labels张量在创建时已经被明确指定为long类型:

labels = Variable(torch.FloatTensor(10).uniform_(0, 120).long())

但在计算损失时,又通过.float()方法将其强制转换回了float类型。这就是导致CrossEntropyLoss抛出错误的原因。

修正方法: 只需移除对labels的.float()调用,确保target张量保持其long类型即可。

# 错误代码
# loss = criterion(output, labels.float())

# 正确代码
loss = criterion(output, labels)

训练循环中的常见误用及修正

除了上述直接的类型转换错误,在提供的train_one_epoch函数中,也存在一些与CrossEntropyLoss使用相关的常见误区。

1. 标签数据类型转换错误

在train_one_epoch函数内部,标签被错误地转换成了float类型:

labels = labels.to(device).float() # 错误:将标签转换为float类型

这会直接导致CrossEntropyLoss接收到float类型的标签,再次触发同样的RuntimeError。

修正方法: 确保标签在送入损失函数前是long类型。

labels = labels.to(device).long() # 正确:将标签转换为long类型

2. CrossEntropyLoss输入参数顺序和类型错误

在train_one_epoch函数中,计算损失的行是:

loss = criterion(labels, torch.argmax(outputs, dim=1)) # 错误:参数顺序和类型不符

这里存在两个问题:

  • 参数顺序错误: criterion(即CrossEntropyLoss)期望的第一个参数是模型的输出(logits),第二个参数是真实标签(target)。这里却反了过来。
  • target参数类型错误: torch.argmax(outputs, dim=1) 已经是一个预测结果的类别索引,它不应该作为CrossEntropyLoss的target参数传入。target参数应是真实的、未经模型处理的类别标签。

修正方法: 将模型的原始输出(logits)作为第一个参数,真实的long类型标签作为第二个参数。

3. 预先应用Softmax的错误

在计算outputs时,代码中显式地应用了F.softmax:

outputs = F.softmax(model(inputs.float()), dim=1) # 错误:CrossEntropyLoss内部已包含Softmax

由于CrossEntropyLoss内部已经包含了Softmax操作,再次应用F.softmax会导致:

  • 冗余计算: 增加了不必要的计算开销。
  • 数值稳定性问题: 两次Softmax操作可能导致数值精度下降,尤其是在处理非常大或非常小的对数几率时。

修正方法: 直接将模型的原始输出(logits)传递给CrossEntropyLoss。

优化后的训练函数示例

综合以上修正,以下是train_one_epoch函数的一个优化版本,遵循了CrossEntropyLoss的最佳实践:

import torch
import torch.nn as nn
import torch.nn.functional as F
import time

# 假设 model, optimizer, dataloaders, device 已经定义

def train_one_epoch(model, optimizer, data_loader, device):
    model.train()
    running_loss = 0.0
    start_time = time.time()
    total = 0
    correct = 0

    # 确保 data_loader 是实际的 DataLoader 对象
    # 这里假设 dataloaders['train'] 是一个可迭代的 DataLoader
    current_data_loader = data_loader # 如果传入的是字符串'train',需要根据实际情况获取
    if isinstance(data_loader, str):
        current_data_loader = dataloaders[data_loader] # 假设 dataloaders 是一个全局字典

    for i, (inputs, labels) in enumerate(current_data_loader):
        inputs = inputs.to(device)
        # 核心修正:确保标签是long类型
        labels = labels.to(device).long() 

        optimizer.zero_grad()

        # 修正:直接使用模型的原始输出(logits),不应用Softmax
        # 假设 model(inputs.float()) 返回的是 logits
        logits = model(inputs.float()) 

        # 打印形状以调试
        # print("Inputs shape:", inputs.shape)
        # print("Logits shape:", logits.shape)
        # print("Labels shape:", labels.shape)

        # 修正:CrossEntropyLoss的正确使用方式是 (logits, target_indices)
        loss = criterion(logits, labels) 

        loss.backward()
        optimizer.step()

        # 计算准确率时,需要对logits应用argmax
        _, predicted = torch.max(logits.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total

        running_loss += loss.item()
        if i % 10 == 0:    # print every 10 batches
            batch_time = time.time()
            speed = (i+1)/(batch_time-start_time)
            print('[%5d] loss: %.3f, speed: %.2f, accuracy: %.2f %%' %
                  (i, running_loss, speed, accuracy))
            running_loss = 0.0
            total = 0
            correct = 0

验证模型函数 (val_model) 的注意事项

val_model函数在处理标签时使用了labels = labels.to(device).long(),这是正确的。同时,outputs = model(inputs.float()) 假设模型输出的是logits,然后用 torch.max(outputs.data, 1) 来获取预测类别,这也是标准做法。

唯一需要注意的是,model.val() 应该更正为 model.eval(),这会将模型设置为评估模式,禁用Dropout和BatchNorm等层,以确保评估结果的稳定性。

def val_model(model, data_loader, device): # 添加 device 参数
    model.eval() # 修正:使用 model.eval()
    start_time = time.time()
    total = 0
    correct = 0

    current_data_loader = data_loader
    if isinstance(data_loader, str):
        current_data_loader = dataloaders[data_loader]

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(current_data_loader):
            inputs = inputs.to(device)
            labels = labels.to(device).long() # 正确

            outputs = model(inputs.float()) # 假设 model 输出 logits

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            # 修正:(predicted == labels).sum() 返回一个标量,直接 .item() 即可
            correct += (predicted == labels).sum().item() 
        accuracy = 100 * correct / total

        print('Finished Testing')
        print('Testing accuracy: %.1f %%' %(accuracy))

总结与最佳实践

处理PyTorch中的CrossEntropyLoss时,请牢记以下关键点:

  1. 目标标签的数据类型: CrossEntropyLoss的target参数必须是torch.long类型(即64位整数),且包含类别索引(从0到C-1)。
  2. 模型输出: CrossEntropyLoss的input参数应是模型的原始输出(logits),即未经Softmax激活函数处理的对数几率。
  3. 避免重复Softmax: 不要在将模型输出传递给CrossEntropyLoss之前手动应用F.softmax,因为CrossEntropyLoss内部已经包含了此操作。
  4. 参数顺序: CrossEntropyLoss的调用格式是 loss = criterion(logits, target_labels)。
  5. 评估模式: 在验证或测试模型时,务必使用model.eval()来设置模型为评估模式,并在torch.no_grad()上下文管理器中执行前向传播,以节省内存和计算。

遵循这些原则,可以有效避免RuntimeError: expected scalar type Long but found Float以及其他与CrossEntropyLoss使用相关的常见问题,确保模型训练的顺利进行。

今天关于《PyTorchCrossEntropyLoss错误解决方法》的内容就介绍到这里了,是不是学起来一目了然!想要了解更多关于的内容请关注golang学习网公众号!

Golang策略模式实现可插拔算法详解Golang策略模式实现可插拔算法详解
上一篇
Golang策略模式实现可插拔算法详解
Win11连接AirPods音质差怎么解决
下一篇
Win11连接AirPods音质差怎么解决
查看更多
最新文章
查看更多
课程推荐
  • 前端进阶之JavaScript设计模式
    前端进阶之JavaScript设计模式
    设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
    543次学习
  • GO语言核心编程课程
    GO语言核心编程课程
    本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
    516次学习
  • 简单聊聊mysql8与网络通信
    简单聊聊mysql8与网络通信
    如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
    500次学习
  • JavaScript正则表达式基础与实战
    JavaScript正则表达式基础与实战
    在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
    487次学习
  • 从零制作响应式网站—Grid布局
    从零制作响应式网站—Grid布局
    本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
    485次学习
查看更多
AI推荐
  • ChatExcel酷表:告别Excel难题,北大团队AI助手助您轻松处理数据
    ChatExcel酷表
    ChatExcel酷表是由北京大学团队打造的Excel聊天机器人,用自然语言操控表格,简化数据处理,告别繁琐操作,提升工作效率!适用于学生、上班族及政府人员。
    3179次使用
  • Any绘本:开源免费AI绘本创作工具深度解析
    Any绘本
    探索Any绘本(anypicturebook.com/zh),一款开源免费的AI绘本创作工具,基于Google Gemini与Flux AI模型,让您轻松创作个性化绘本。适用于家庭、教育、创作等多种场景,零门槛,高自由度,技术透明,本地可控。
    3390次使用
  • 可赞AI:AI驱动办公可视化智能工具,一键高效生成文档图表脑图
    可赞AI
    可赞AI,AI驱动的办公可视化智能工具,助您轻松实现文本与可视化元素高效转化。无论是智能文档生成、多格式文本解析,还是一键生成专业图表、脑图、知识卡片,可赞AI都能让信息处理更清晰高效。覆盖数据汇报、会议纪要、内容营销等全场景,大幅提升办公效率,降低专业门槛,是您提升工作效率的得力助手。
    3419次使用
  • 星月写作:AI网文创作神器,助力爆款小说速成
    星月写作
    星月写作是国内首款聚焦中文网络小说创作的AI辅助工具,解决网文作者从构思到变现的全流程痛点。AI扫榜、专属模板、全链路适配,助力新人快速上手,资深作者效率倍增。
    4525次使用
  • MagicLight.ai:叙事驱动AI动画视频创作平台 | 高效生成专业级故事动画
    MagicLight
    MagicLight.ai是全球首款叙事驱动型AI动画视频创作平台,专注于解决从故事想法到完整动画的全流程痛点。它通过自研AI模型,保障角色、风格、场景高度一致性,让零动画经验者也能高效产出专业级叙事内容。广泛适用于独立创作者、动画工作室、教育机构及企业营销,助您轻松实现创意落地与商业化。
    3798次使用
微信登录更方便
  • 密码登录
  • 注册账号
登录即同意 用户协议隐私政策
返回登录
  • 重置密码