Pytorch创建多任务学习模型
今天golang学习网给大家带来了《Pytorch创建多任务学习模型》,其中涉及到的知识点包括等等,无论你是小白还是老手,都适合看一看哦~有好的建议也欢迎大家在评论留言,若是看完有所收获,也希望大家能多多点赞支持呀!一起加油学习~
MTL最著名的例子可能是特斯拉的自动驾驶系统。在自动驾驶中需要同时处理大量任务,如物体检测、深度估计、3D重建、视频分析、跟踪等,你可能认为需要10个以上的深度学习模型,但事实并非如此。
HydraNet介绍
一般来说多任务学的模型架构非常简单:一个骨干网络作为特征的提取,然后针对不同的任务创建多个头。利用单一模型解决多个任务。
上图可以看到,特征提取模型提取图像特征。输出最后被分割成多个头,每个头负责一个特定的情况,由于它们彼此独立可以单独进行微调!
特斯拉的讲演中详细的说明这个模型(youtube:v=3SypMvnQT_s)
多任务学习项目
在本文中,我们将介绍如何在Pytorch中实现一个更简单的HydraNet。这里将使用UTK Face数据集,这是一个带有3个标签(性别、种族、年龄)的分类数据集。
我们的HydraNet将有三个独立的头,它们都是不同的,因为年龄的预测是一个回归任务,种族的预测是一个多类分类问题,性别的预测是一个二元分类任务。
每一个Pytorch 的深度学习的项目都应该从定义Dataset和DataLoader开始。
在这个数据集中,通过图像的名称定义了这些标签,例如UTKFace/30_0_3_20170117145159065.jpg.chip.jpg
- 30岁是年龄
- 0为性别(0:男性,1:女性)
- 3是种族(0:白人,1:黑人,2:亚洲人,3:印度人,4:其他)
所以我们的自定义Dataset可以这样写:
class UTKFace(Dataset):
def __init__(self, image_paths):
self.transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
self.image_paths = image_paths
self.images = []
self.ages = []
self.genders = []
self.races = []
for path in image_paths:
filename = path[8:].split("_")
if len(filename)==4:
self.images.append(path)
self.ages.append(int(filename[0]))
self.genders.append(int(filename[1]))
self.races.append(int(filename[2]))
def __len__(self):
return len(self.images)
def __getitem__(self, index):
img = Image.open(self.images[index]).convert('RGB')
img = self.transform(img)
age = self.ages[index]
gender = self.genders[index]
eth = self.races[index]
sample = {'image':img, 'age': age, 'gender': gender, 'ethnicity':eth}
return sample
简单的做个介绍:
__init__方法初始化我们的自定义数据集,负责初始化各种转换和从图像路径中提取标签。
__get_item__将:它将加载一张图像,应用必要的转换,获取标签,并返回数据集的一个元素,也就是说这个方法会返回数据集中的单条数据(单个样本)
然后我们定义dataloader
train_dataloader = DataLoader(UTKFace(train_dataset), shuffle=True, batch_size=BATCH_SIZE)
val_dataloader = DataLoader(UTKFace(valid_dataset), shuffle=False, batch_size=BATCH_SIZE)
下面我们定义模型,这里使用一个预训练的模型作为骨干,然后创建3个头。分别代表年龄,性别和种族。
class HydraNet(nn.Module):
def __init__(self):
super().__init__()
self.net = models.resnet18(pretrained=True)
self.n_features = self.net.fc.in_features
self.net.fc = nn.Identity()
self.net.fc1 = nn.Sequential(OrderedDict(
[('linear', nn.Linear(self.n_features,self.n_features)),
('relu1', nn.ReLU()),
('final', nn.Linear(self.n_features, 1))]))
self.net.fc2 = nn.Sequential(OrderedDict(
[('linear', nn.Linear(self.n_features,self.n_features)),
('relu1', nn.ReLU()),
('final', nn.Linear(self.n_features, 1))]))
self.net.fc3 = nn.Sequential(OrderedDict(
[('linear', nn.Linear(self.n_features,self.n_features)),
('relu1', nn.ReLU()),
('final', nn.Linear(self.n_features, 5))]))
def forward(self, x):
age_head = self.net.fc1(self.net(x))
gender_head = self.net.fc2(self.net(x))
ethnicity_head = self.net.fc3(self.net(x))
return age_head, gender_head, ethnicity_head
forward方法返回每个头的结果。
损失作为优化的基础时十分重要的,因为它将会影响到模型的性能,我们能想到的最简单的事就是地把损失相加:
L = L1 + L2 + L3
但是我们的模型中
- L1:与年龄相关的损失,如平均绝对误差,因为它是回归损失。
- L2:与种族相关的交叉熵,它是一个多类别的分类损失。
- L3:性别有关的损失,例如二元交叉熵。
这里损失的计算最大问题是损失的量级是不一样的,并且损失的权重也是不相同的,这是一个一直在被深入研究的问题,我们这里暂不做讨论,我们只使用简单的相加,所以我们的一些超参数如下:
model = HydraNet().to(device=device)
ethnicity_loss = nn.CrossEntropyLoss()
gender_loss = nn.BCELoss()
age_loss = nn.L1Loss()
sig = nn.Sigmoid()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.09)
然后我们训练的循环如下:
for epoch in range(n_epochs):
model.train()
total_training_loss = 0
for i, data in enumerate(tqdm(train_dataloader)):
inputs = data["image"].to(device=device)
age_label = data["age"].to(device=device)
gender_label = data["gender"].to(device=device)
eth_label = data["ethnicity"].to(device=device)
optimizer.zero_grad()
age_output, gender_output, eth_output = model(inputs)
loss_1 = ethnicity_loss(eth_output, eth_label)
loss_2 = gender_loss(sig(gender_output), gender_label.unsqueeze(1).float())
loss_3 = age_loss(age_output, age_label.unsqueeze(1).float())
loss = loss_1 + loss_2 + loss_3
loss.backward()
optimizer.step()
total_training_loss += loss
这样我们最简单的多任务学习的流程就完成了
关于损失的优化
多任务学习的损失函数,对每个任务的损失进行权重分配,在这个过程中,必须保证所有任务同等重要,而不能让简单任务主导整个训练过程。手动的设置权重是低效而且不是最优的,因此,自动的学习这些权重是十分必要的,
Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics cvpr_2018
这篇论文提出,将不同的loss拉到统一尺度下,这样就容易统一,具体的办法就是利用同方差的不确定性,将不确定性作为噪声,进行训练。
End-to-End Multi-Task Learning with Attention cvpr_2019
这篇论文提出了一种可以自动调节权重的机制( Dynamic Weight Average),使得权重分配更加合理,大概的意思是每个任务首先计算前个epoch对应损失的比值,然后除以一个固定的值T,进行exp映射后,计算各个损失所占比
最后如果你对多任务学习感兴趣,可以先看看这篇论文:
A Survey on Multi-Task LearningarXiv 1707.08114
从算法建模、应用和理论分析的角度对MTL进行了调查,是入门的最好的资料。
文中关于机器学习,语音识别,PyTorch的知识介绍,希望对你的学习有所帮助!若是受益匪浅,那就动动鼠标收藏这篇《Pytorch创建多任务学习模型》文章吧,也可关注golang学习网公众号了解相关技术文章。

- 上一篇
- 虚拟现实将如何改变建筑业

- 下一篇
- 机器学习创造新的攻击面,需要专门的防御
-
- 科技周边 · 人工智能 | 9分钟前 |
- 惊爆!尊界S800起售价或降至80万
- 183浏览 收藏
-
- 科技周边 · 人工智能 | 1小时前 |
- X9电池满足新国标,小鹏高管确认超出
- 379浏览 收藏
-
- 前端进阶之JavaScript设计模式
- 设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
- 542次学习
-
- GO语言核心编程课程
- 本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
- 508次学习
-
- 简单聊聊mysql8与网络通信
- 如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
- 497次学习
-
- JavaScript正则表达式基础与实战
- 在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
- 487次学习
-
- 从零制作响应式网站—Grid布局
- 本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
- 484次学习
-
- 笔灵AI生成答辩PPT
- 探索笔灵AI生成答辩PPT的强大功能,快速制作高质量答辩PPT。精准内容提取、多样模板匹配、数据可视化、配套自述稿生成,让您的学术和职场展示更加专业与高效。
- 16次使用
-
- 知网AIGC检测服务系统
- 知网AIGC检测服务系统,专注于检测学术文本中的疑似AI生成内容。依托知网海量高质量文献资源,结合先进的“知识增强AIGC检测技术”,系统能够从语言模式和语义逻辑两方面精准识别AI生成内容,适用于学术研究、教育和企业领域,确保文本的真实性和原创性。
- 24次使用
-
- AIGC检测-Aibiye
- AIbiye官网推出的AIGC检测服务,专注于检测ChatGPT、Gemini、Claude等AIGC工具生成的文本,帮助用户确保论文的原创性和学术规范。支持txt和doc(x)格式,检测范围为论文正文,提供高准确性和便捷的用户体验。
- 30次使用
-
- 易笔AI论文
- 易笔AI论文平台提供自动写作、格式校对、查重检测等功能,支持多种学术领域的论文生成。价格优惠,界面友好,操作简便,适用于学术研究者、学生及论文辅导机构。
- 42次使用
-
- 笔启AI论文写作平台
- 笔启AI论文写作平台提供多类型论文生成服务,支持多语言写作,满足学术研究者、学生和职场人士的需求。平台采用AI 4.0版本,确保论文质量和原创性,并提供查重保障和隐私保护。
- 35次使用
-
- GPT-4王者加冕!读图做题性能炸天,凭自己就能考上斯坦福
- 2023-04-25 501浏览
-
- 单块V100训练模型提速72倍!尤洋团队新成果获AAAI 2023杰出论文奖
- 2023-04-24 501浏览
-
- ChatGPT 真的会接管世界吗?
- 2023-04-13 501浏览
-
- VR的终极形态是「假眼」?Neuralink前联合创始人掏出新产品:科学之眼!
- 2023-04-30 501浏览
-
- 实现实时制造可视性优势有哪些?
- 2023-04-15 501浏览