PyTorchCrossEntropyLoss错误解决方法
来到golang学习网的大家,相信都是编程学习爱好者,希望在这里学习文章相关编程知识。下面本篇文章就来带大家聊聊《PyTorch CrossEntropyLoss类型错误解决方法》,介绍一下,希望对大家的知识积累有所帮助,助力实战开发!

本文深入探讨了PyTorch中`CrossEntropyLoss`常见的`RuntimeError: expected scalar type Long but found Float`错误。该错误通常源于目标标签(target)的数据类型不符合`CrossEntropyLoss`的预期。我们将详细解析错误原因,并提供如何在训练循环中正确使用`CrossEntropyLoss`,包括标签类型转换、输入顺序以及避免重复应用Softmax等关键最佳实践,以确保模型训练的稳定性和准确性。
在深度学习的分类任务中,torch.nn.CrossEntropyLoss是一个非常常用的损失函数。它结合了LogSoftmax和负对数似然损失(NLLLoss),能够高效地处理多分类问题。然而,初学者在使用时常会遇到一个特定的运行时错误:RuntimeError: expected scalar type Long but found Float。这个错误明确指出,CrossEntropyLoss在处理其目标标签(target)时,期望的数据类型是torch.Long(即64位整数),但实际接收到的是torch.Float。
理解CrossEntropyLoss的工作原理
CrossEntropyLoss函数在PyTorch中通常接收两个主要参数:
- input (或 logits):这是模型的原始输出,通常是未经Softmax激活函数处理的“对数几率”(logits)。它的形状通常是 (N, C),其中 N 是批量大小,C 是类别数量。对于图像任务,如果模型输出是像素级别的分类(如U-Net),则形状可能是 (N, C, H, W)。
- target (或 labels):这是真实的类别标签。它应该包含每个样本的类别索引,其数据类型必须是torch.long(或torch.int64)。它的形状通常是 (N),对于像素级别的分类,形状可能是 (N, H, W)。target中的值应介于 0 到 C-1 之间,代表对应的类别索引。
关键点: CrossEntropyLoss内部会自行执行Softmax操作,因此,向其传入经过Softmax处理的概率值是不正确的,这可能导致数值不稳定或不准确的损失计算。
RuntimeError: expected scalar type Long but found Float 错误解析与修正
这个错误的核心在于target张量的数据类型不匹配。在提供的代码片段中,错误发生在以下这行:
loss = criterion(output, labels.float())
尽管labels张量在创建时已经被明确指定为long类型:
labels = Variable(torch.FloatTensor(10).uniform_(0, 120).long())
但在计算损失时,又通过.float()方法将其强制转换回了float类型。这就是导致CrossEntropyLoss抛出错误的原因。
修正方法: 只需移除对labels的.float()调用,确保target张量保持其long类型即可。
# 错误代码 # loss = criterion(output, labels.float()) # 正确代码 loss = criterion(output, labels)
训练循环中的常见误用及修正
除了上述直接的类型转换错误,在提供的train_one_epoch函数中,也存在一些与CrossEntropyLoss使用相关的常见误区。
1. 标签数据类型转换错误
在train_one_epoch函数内部,标签被错误地转换成了float类型:
labels = labels.to(device).float() # 错误:将标签转换为float类型
这会直接导致CrossEntropyLoss接收到float类型的标签,再次触发同样的RuntimeError。
修正方法: 确保标签在送入损失函数前是long类型。
labels = labels.to(device).long() # 正确:将标签转换为long类型
2. CrossEntropyLoss输入参数顺序和类型错误
在train_one_epoch函数中,计算损失的行是:
loss = criterion(labels, torch.argmax(outputs, dim=1)) # 错误:参数顺序和类型不符
这里存在两个问题:
- 参数顺序错误: criterion(即CrossEntropyLoss)期望的第一个参数是模型的输出(logits),第二个参数是真实标签(target)。这里却反了过来。
- target参数类型错误: torch.argmax(outputs, dim=1) 已经是一个预测结果的类别索引,它不应该作为CrossEntropyLoss的target参数传入。target参数应是真实的、未经模型处理的类别标签。
修正方法: 将模型的原始输出(logits)作为第一个参数,真实的long类型标签作为第二个参数。
3. 预先应用Softmax的错误
在计算outputs时,代码中显式地应用了F.softmax:
outputs = F.softmax(model(inputs.float()), dim=1) # 错误:CrossEntropyLoss内部已包含Softmax
由于CrossEntropyLoss内部已经包含了Softmax操作,再次应用F.softmax会导致:
- 冗余计算: 增加了不必要的计算开销。
- 数值稳定性问题: 两次Softmax操作可能导致数值精度下降,尤其是在处理非常大或非常小的对数几率时。
修正方法: 直接将模型的原始输出(logits)传递给CrossEntropyLoss。
优化后的训练函数示例
综合以上修正,以下是train_one_epoch函数的一个优化版本,遵循了CrossEntropyLoss的最佳实践:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
# 假设 model, optimizer, dataloaders, device 已经定义
def train_one_epoch(model, optimizer, data_loader, device):
model.train()
running_loss = 0.0
start_time = time.time()
total = 0
correct = 0
# 确保 data_loader 是实际的 DataLoader 对象
# 这里假设 dataloaders['train'] 是一个可迭代的 DataLoader
current_data_loader = data_loader # 如果传入的是字符串'train',需要根据实际情况获取
if isinstance(data_loader, str):
current_data_loader = dataloaders[data_loader] # 假设 dataloaders 是一个全局字典
for i, (inputs, labels) in enumerate(current_data_loader):
inputs = inputs.to(device)
# 核心修正:确保标签是long类型
labels = labels.to(device).long()
optimizer.zero_grad()
# 修正:直接使用模型的原始输出(logits),不应用Softmax
# 假设 model(inputs.float()) 返回的是 logits
logits = model(inputs.float())
# 打印形状以调试
# print("Inputs shape:", inputs.shape)
# print("Logits shape:", logits.shape)
# print("Labels shape:", labels.shape)
# 修正:CrossEntropyLoss的正确使用方式是 (logits, target_indices)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
# 计算准确率时,需要对logits应用argmax
_, predicted = torch.max(logits.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
running_loss += loss.item()
if i % 10 == 0: # print every 10 batches
batch_time = time.time()
speed = (i+1)/(batch_time-start_time)
print('[%5d] loss: %.3f, speed: %.2f, accuracy: %.2f %%' %
(i, running_loss, speed, accuracy))
running_loss = 0.0
total = 0
correct = 0验证模型函数 (val_model) 的注意事项
val_model函数在处理标签时使用了labels = labels.to(device).long(),这是正确的。同时,outputs = model(inputs.float()) 假设模型输出的是logits,然后用 torch.max(outputs.data, 1) 来获取预测类别,这也是标准做法。
唯一需要注意的是,model.val() 应该更正为 model.eval(),这会将模型设置为评估模式,禁用Dropout和BatchNorm等层,以确保评估结果的稳定性。
def val_model(model, data_loader, device): # 添加 device 参数
model.eval() # 修正:使用 model.eval()
start_time = time.time()
total = 0
correct = 0
current_data_loader = data_loader
if isinstance(data_loader, str):
current_data_loader = dataloaders[data_loader]
with torch.no_grad():
for i, (inputs, labels) in enumerate(current_data_loader):
inputs = inputs.to(device)
labels = labels.to(device).long() # 正确
outputs = model(inputs.float()) # 假设 model 输出 logits
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
# 修正:(predicted == labels).sum() 返回一个标量,直接 .item() 即可
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print('Finished Testing')
print('Testing accuracy: %.1f %%' %(accuracy))总结与最佳实践
处理PyTorch中的CrossEntropyLoss时,请牢记以下关键点:
- 目标标签的数据类型: CrossEntropyLoss的target参数必须是torch.long类型(即64位整数),且包含类别索引(从0到C-1)。
- 模型输出: CrossEntropyLoss的input参数应是模型的原始输出(logits),即未经Softmax激活函数处理的对数几率。
- 避免重复Softmax: 不要在将模型输出传递给CrossEntropyLoss之前手动应用F.softmax,因为CrossEntropyLoss内部已经包含了此操作。
- 参数顺序: CrossEntropyLoss的调用格式是 loss = criterion(logits, target_labels)。
- 评估模式: 在验证或测试模型时,务必使用model.eval()来设置模型为评估模式,并在torch.no_grad()上下文管理器中执行前向传播,以节省内存和计算。
遵循这些原则,可以有效避免RuntimeError: expected scalar type Long but found Float以及其他与CrossEntropyLoss使用相关的常见问题,确保模型训练的顺利进行。
今天关于《PyTorchCrossEntropyLoss错误解决方法》的内容就介绍到这里了,是不是学起来一目了然!想要了解更多关于的内容请关注golang学习网公众号!
Golang策略模式实现可插拔算法详解
- 上一篇
- Golang策略模式实现可插拔算法详解
- 下一篇
- Win11连接AirPods音质差怎么解决
-
- 文章 · python教程 | 7分钟前 |
- Pandas修改首行数据技巧分享
- 485浏览 收藏
-
- 文章 · python教程 | 1小时前 |
- Python列表创建技巧全解析
- 283浏览 收藏
-
- 文章 · python教程 | 2小时前 |
- Python计算文件实际占用空间技巧
- 349浏览 收藏
-
- 文章 · python教程 | 3小时前 |
- OpenCV中OCR技术应用详解
- 204浏览 收藏
-
- 文章 · python教程 | 4小时前 |
- Pandas读取Django表格:协议关键作用
- 401浏览 收藏
-
- 文章 · python教程 | 4小时前 | 身份验证 断点续传 requests库 PythonAPI下载 urllib库
- Python调用API下载文件方法
- 227浏览 收藏
-
- 文章 · python教程 | 4小时前 |
- Windows7安装RtMidi失败解决办法
- 400浏览 收藏
-
- 文章 · python教程 | 4小时前 |
- Python异步任务优化技巧分享
- 327浏览 收藏
-
- 文章 · python教程 | 4小时前 |
- PyCharm图形界面显示问题解决方法
- 124浏览 收藏
-
- 文章 · python教程 | 5小时前 |
- Python自定义异常类怎么创建
- 450浏览 收藏
-
- 文章 · python教程 | 6小时前 |
- Python抓取赛狗数据:指定日期赛道API教程
- 347浏览 收藏
-
- 前端进阶之JavaScript设计模式
- 设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
- 543次学习
-
- GO语言核心编程课程
- 本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
- 516次学习
-
- 简单聊聊mysql8与网络通信
- 如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
- 500次学习
-
- JavaScript正则表达式基础与实战
- 在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
- 487次学习
-
- 从零制作响应式网站—Grid布局
- 本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
- 485次学习
-
- ChatExcel酷表
- ChatExcel酷表是由北京大学团队打造的Excel聊天机器人,用自然语言操控表格,简化数据处理,告别繁琐操作,提升工作效率!适用于学生、上班族及政府人员。
- 3179次使用
-
- Any绘本
- 探索Any绘本(anypicturebook.com/zh),一款开源免费的AI绘本创作工具,基于Google Gemini与Flux AI模型,让您轻松创作个性化绘本。适用于家庭、教育、创作等多种场景,零门槛,高自由度,技术透明,本地可控。
- 3390次使用
-
- 可赞AI
- 可赞AI,AI驱动的办公可视化智能工具,助您轻松实现文本与可视化元素高效转化。无论是智能文档生成、多格式文本解析,还是一键生成专业图表、脑图、知识卡片,可赞AI都能让信息处理更清晰高效。覆盖数据汇报、会议纪要、内容营销等全场景,大幅提升办公效率,降低专业门槛,是您提升工作效率的得力助手。
- 3419次使用
-
- 星月写作
- 星月写作是国内首款聚焦中文网络小说创作的AI辅助工具,解决网文作者从构思到变现的全流程痛点。AI扫榜、专属模板、全链路适配,助力新人快速上手,资深作者效率倍增。
- 4525次使用
-
- MagicLight
- MagicLight.ai是全球首款叙事驱动型AI动画视频创作平台,专注于解决从故事想法到完整动画的全流程痛点。它通过自研AI模型,保障角色、风格、场景高度一致性,让零动画经验者也能高效产出专业级叙事内容。广泛适用于独立创作者、动画工作室、教育机构及企业营销,助您轻松实现创意落地与商业化。
- 3798次使用
-
- Flask框架安装技巧:让你的开发更高效
- 2024-01-03 501浏览
-
- Django框架中的并发处理技巧
- 2024-01-22 501浏览
-
- 提升Python包下载速度的方法——正确配置pip的国内源
- 2024-01-17 501浏览
-
- Python与C++:哪个编程语言更适合初学者?
- 2024-03-25 501浏览
-
- 品牌建设技巧
- 2024-04-06 501浏览

