当前位置:首页 > 文章列表 > 科技周边 > 人工智能 > 持续学习常用六种方法总结:使ML模型适应新数据的同时保持旧数据的性能

持续学习常用六种方法总结:使ML模型适应新数据的同时保持旧数据的性能

来源:51CTO.COM 2023-04-29 18:27:25 0浏览 收藏

IT行业相对于一般传统行业,发展更新速度更快,一旦停止了学习,很快就会被行业所淘汰。所以我们需要踏踏实实的不断学习,精进自己的技术,尤其是初学者。今天golang学习网给大家整理了《持续学习常用六种方法总结:使ML模型适应新数据的同时保持旧数据的性能》,聊聊,我们一起来看看吧!

持续学习是指在不忘记从前面的任务中获得的知识的情况下,按顺序学习大量任务的模型。这是一个重要的概念,因为在监督学习的前提下,机器学习模型被训练为针对给定数据集或数据分布的最佳函数。而在现实环境中,数据很少是静态的,可能会发生变化。当面对不可见的数据时,典型的ML模型可能会性能下降。这种现象被称为灾难性遗忘。

持续学习常用六种方法总结:使ML模型适应新数据的同时保持旧数据的性能

解决这类问题的常用方法是在包含新旧数据的新的更大数据集上对整个模型进行再训练。但是这种做法往往代价高昂。所以有一个ML研究领域正在研究这个问题,基于该领域的研究,本文将讨论6种方法,使模型可以在保持旧的性能的同时适应新数据,并避免需要在整个数据集(旧+新)上进行重新训练。

Prompt

Prompt 想法源于对GPT 3的提示(短序列的单词)可以帮助驱动模型更好地推理和回答。所以在本文中将Prompt 翻译为提示。提示调优是指使用小型可学习的提示,并将其与实际输入一起作为模型的输入。这允许我们只在新数据上训练提供提示的小模型,而无需再训练模型权重。

具体来说,我选择了使用提示进行基于文本的密集检索的例子,这个例子改编自Wang的文章《Learning to Prompt for continuous Learning》。

该论文的作者使用下图描述了他们的想法:

持续学习常用六种方法总结:使ML模型适应新数据的同时保持旧数据的性能

实际编码的文本输入用作从提示池中识别最小匹配对的key。在将这些标识的提示输入到模型之前,首先将它们添加到未编码的文本嵌入中。这样做的目的是训练这些提示来表示新的任务,同时保持旧的模型不变,这里提示的很小,大概每个提示只有20个令牌。

class PromptPool(nn.Module):
def __init__(self, M = 100, hidden_size = 768, length = 20, N=5):
super().__init__()
self.pool = nn.Parameter(torch.rand(M, length, hidden_size), requires_grad=True).float()
self.keys = nn.Parameter(torch.rand(M, hidden_size), requires_grad=True).float()
 
self.length = length
self.hidden = hidden_size
self.n = N
 
nn.init.xavier_normal_(self.pool)
nn.init.xavier_normal_(self.keys)
 
def init_weights(self, embedding):
pass
 
# function to select from pool based on index
def concat(self, indices, input_embeds):
subset = self.pool[indices, :] # 2, 2, 20, 768
 
subset = subset.to("cuda:0").reshape(indices.size(0),
self.n*self.length,
self.hidden) # 2, 40, 768
 
return torch.cat((subset, input_embeds), 1)
 
# x is cls output
def query_fn(self, x):
 
# encode input x to same dim as key using cosine
x = x / x.norm(dim=1)[:, None]
k = self.keys / self.keys.norm(dim=1)[:, None]
 
scores = torch.mm(x, k.transpose(0,1).to("cuda:0"))
 
# get argmin
subsets = torch.topk(scores, self.n, 1, False).indices # k smallest
 
return subsets
 
 pool = PromptPool()

然后我们使用的经过训练的旧数据模型,训练新的数据,这里只训练提示部分的权重。

def train():
count = 0
print("*********** Started Training *************")
 
start = time.time()
for epoch in range(40):
model.eval()
pool.train()
 
optimizer.zero_grad(set_to_none=True)
lap = time.time()
 
for batch in iter(train_dataloader):
count += 1
q, p, train_labels = batch
 
queries_emb = model(input_ids=q['input_ids'].to("cuda:0"),
attention_mask=q['attention_mask'].to("cuda:0"))
passage_emb = model(input_ids=p['input_ids'].to("cuda:0"),
attention_mask=p['attention_mask'].to("cuda:0"))
 
# pool
q_idx = pool.query_fn(queries_emb)
raw_qembedding = model.model.embeddings(input_ids=q['input_ids'].to("cuda:0"))
q = pool.concat(indices=q_idx, input_embeds=raw_qembedding)
 
p_idx = pool.query_fn(passage_emb)
raw_pembedding = model.model.embeddings(input_ids=p['input_ids'].to("cuda:0"))
p = pool.concat(indices=p_idx, input_embeds=raw_pembedding)
 
qattention_mask = torch.ones(batch_size, q.size(1))
pattention_mask = torch.ones(batch_size, p.size(1))
 
queries_emb = model.model(inputs_embeds=q,
attention_mask=qattention_mask.to("cuda:0")).last_hidden_state
passage_emb = model.model(inputs_embeds=p,
attention_mask=pattention_mask.to("cuda:0")).last_hidden_state
 
q_cls = queries_emb[:, pool.n*pool.length+1, :]
p_cls = passage_emb[:, pool.n*pool.length+1, :]
 
loss, ql, pl = calc_loss(q_cls, p_cls)
loss.backward()
 
optimizer.step()
optimizer.zero_grad(set_to_none=True)
 
if count % 10 == 0:
print("Model Loss:", round(loss.item(),4), 
"| QL:", round(ql.item(),4), "| PL:", round(pl.item(),4), 
"| Took:", round(time.time() - lap), "secondsn")
 
lap = time.time()
 
if count % 40 == 0 and count > 0:
print("model saved")
torch.save(model.state_dict(), model_PATH)
torch.save(pool.state_dict(), pool_PATH)
 
if count == 4600: return
 
print("Training Took:", round(time.time() - start), "seconds")
print("n*********** Training Complete *************")

训练完成后,后续的推理过程需要将输入与检索到的提示结合起来。例如这个例子得到了性能—93%的新数据提示池,而完全(旧+新)训练为—94%。这与原论文中提到的表现类似。但是需要说明的一点是结果可能会因任务而不同,你应该尝试实验来知道什么是最好的。

要使此方法成为值得考虑的方法,它必须能够在旧数据上保留老模型> 80%的性能,同时提示也应该帮助模型在新数据上获得良好的性能。

这种方法的缺点是需要使用提示池,这会增加额外的时间。这也不是一个永久的解决方案,但是目前来说是可行的,也或许以后还会有新的方法出现。

Data Distillation

你可能听说过知识蒸馏一词,这是一种使用来自教师模型的权重来指导和训练较小规模模型的技术。数据蒸馏(Data Distillation)的工作原理也类似,它是使用来自真实数据的权重来训练更小的数据子集。因为数据集的关键信号被提炼并浓缩为更小的数据集,我们对新数据的训练只需要提供一些提炼的数据以保持旧的性能。

在此示例中,我将数据蒸馏应用于密集检索(文本)任务。目前看没有其他人在这个领域使用这种方法,所以结果可能不是最好的,但如果你在文本分类上使用这种方法应该会得到不错的结果。

本质上,文本数据蒸馏的想法源于 Li 的一篇题为 Data Distillation for Text Classification 的论文,该论文的灵感来自 Wang 的 Dataset Distillation,他对图像数据进行了蒸馏。Li 用下图描述了文本数据蒸馏的任务:

持续学习常用六种方法总结:使ML模型适应新数据的同时保持旧数据的性能

根据论文,首先将一批蒸馏数据输入到模型以更新其权重。然后使用真实数据评估更新后的模型,并将信号反向传播到蒸馏数据集。该论文在 8 个公共基准数据集上报告了良好的分类结果(> 80% 准确率)。

按照提出的想法,我做了一些小的改动,使用了一批蒸馏数据和多个真实数据。以下是为密集检索训练创建蒸馏数据的代码:

class DistilledData(nn.Module):
def __init__(self, num_labels, M, q_len=64, hidden_size=768):
super().__init__()
self.num_samples = M
self.q_len = q_len
self.num_labels = num_labels
self.data = nn.Parameter(torch.rand(num_labels, M, q_len, hidden_size), requires_grad=True) # i.e. shape: 1000, 4, 64, 768
 
# init using model embedding, xavier, or load from state dict
def init_weights(self, model, path=None):
if model:
self.data.requires_grad = False
print("Init weights using model embedding")
raw_embedding = model.model.get_input_embeddings()
soft_embeds = raw_embedding.weight[:, :].clone().detach()
nums = soft_embeds.size(0)
for i1 in range(self.num_labels):
for i2 in range(self.num_samples):
for i3 in range(self.q_len):
random_idx = random.randint(0, nums-1)
self.data[i1, i2, i3, :] = soft_embeds[random_idx, :]
print(self.data.shape)
self.data.requires_grad = True
 
if not path:
nn.init.xavier_normal_(self.data)
else:
distilled_data.load_state_dict(torch.load(path), strict=False)
 
# function to sample a passage and positive sample as in the article, i am doing dense retrieval
def get_sample(self, label):
q_idx = random.randint(0, self.num_samples-1)
sampled_dist_q = self.data[label, q_idx, :, :]
 
p_idx = random.randint(0, self.num_samples-1)
while q_idx == p_idx:
p_idx = random.randint(0, self.num_samples-1)
sampled_dist_p = self.data[label, p_idx, :, :]
 
return sampled_dist_q, sampled_dist_p, q_idx, p_idx

这是将信号提取到蒸馏数据上的代码

def distll_train(chunk_size=32):
count, times = 0, 0
print("*********** Started Training *************")
start = time.time()
lap = time.time()
 
for epoch in range(40):
distilled_data.train()
 
for batch in iter(train_dataloader):
count += 1
# get real query, pos, label, distilled data query, distilled data pos, ... from batch
q, p, train_labels, dq, dp, q_indexes, p_indexes = batch
 
for idx in range(0, dq['input_ids'].size(0), chunk_size):
model.train()
 
with torch.enable_grad():
# train on distiled data first
x1 = dq['input_ids'][idx:idx+chunk_size].clone().detach().requires_grad_(True)
x2 = dp['input_ids'][idx:idx+chunk_size].clone().detach().requires_grad_(True)
q_emb = model(inputs_embeds=x1.to("cuda:0"),
attention_mask=dq['attention_mask'][idx:idx+chunk_size].to("cuda:0")).cpu()
p_emb = model(inputs_embeds=x2.to("cuda:0"),
attention_mask=dp['attention_mask'][idx:idx+chunk_size].to("cuda:0"))
loss = default_loss(q_emb.to("cuda:0"), p_emb)
del q_emb, p_emb
 
loss.backward(retain_graph=True, create_graph=False)
state_dict = model.state_dict()
 
# update model weights
with torch.no_grad():
for idx, param in enumerate(model.parameters()):
if param.requires_grad and not param.grad is None:
param.data -= (param.grad*3e-5)
 
# real data
model.eval()
q_embs = []
p_embs = []
for k in range(0, len(q['input_ids']), chunk_size):
with torch.no_grad():
q_emb = model(input_ids=q['input_ids'][k:k+chunk_size].to("cuda:0"),).cpu()
p_emb = model(input_ids=p['input_ids'][k:k+chunk_size].to("cuda:0"),).cpu()
q_embs.append(q_emb)
p_embs.append(p_emb)
q_embs = torch.cat(q_embs, 0)
p_embs = torch.cat(p_embs, 0)
r_loss = default_loss(q_embs.to("cuda:0"), p_embs.to("cuda:0"))
del q_embs, p_embs
 
# distill backward
if count % 2 == 0:
d_grad = torch.autograd.grad(inputs=[x1.to("cuda:0")],#, x2.to("cuda:0")],
outputs=loss,
grad_outputs=r_loss)
indexes = q_indexes
else:
d_grad = torch.autograd.grad(inputs=[x2.to("cuda:0")],
outputs=loss,
grad_outputs=r_loss)
indexes = p_indexes
loss.detach()
r_loss.detach()
 
grads = torch.zeros(distilled_data.data.shape) # lbl, 10, 100, 768
for i, k in enumerate(indexes):
grads[train_labels[i], k, :, :] = grads[train_labels[i], k, :, :].to("cuda:0") 
+ d_grad[0][i, :, :]
distilled_data.data.grad = grads
data_optimizer.step()
data_optimizer.zero_grad(set_to_none=True)
 
model.load_state_dict(state_dict)
model_optimizer.step()
model_optimizer.zero_grad(set_to_none=True)
 
if count % 10 == 0:
print("Count:", count ,"| Data:", round(loss.item(), 4), "| Model:", 
round(r_loss.item(),4), "| Time:", round(time.time() - lap, 4))
# print()
lap = time.time()
 
if count % 100 == 0:
torch.save(model.state_dict(), model_PATH)
torch.save(distilled_data.state_dict(), distill_PATH)
 
if loss  100:
print("Training Took:", round(time.time() - start), "seconds")
print("n*********** Training Complete *************")
return
del loss, r_loss, grads, q, p, train_labels, dq, dp, x1, x2, state_dict
 
print("Training Took:", round(time.time() - start), "seconds")
print("n*********** Training Complete *************")

这里省略了数据加载等代码,训练完蒸馏的数据后,我们可以通过在其上训练新模型来使用它,例如将其与新数据合并一起训练。

根据我的实验,一个在蒸馏数据上训练的模型(每个标签只包含4个样本)获得了66%的最佳性能,而一个完全在原始数据上训练的模型也是得到了66%的最佳性能。而未经训练的普通模型得到45%的性能。就像上面提到的这些数字对于密集检索任务可能不太好,分类数据上会好很多。

要使此方法成为在调整模型以适应新数据时值是一个有用的方法,需要能够提取出比原始数据小得多的数据集(即~ 1%)。经过提炼的数据也能够给你一个略低于或等于主动学习方法的表现。

这个方法的优点是可以创建用于永久使用的蒸馏数据。缺点是提取的数据没有可解释性,并且需要额外的训练时间。

Curriculum/Active training

Curriculum training是一种方法,训练时向模型提供训练样本的难度逐渐变大。在对新数据进行训练时,此方法需要人工的对任务进行标注,将任务分为简单、中等或困难,然后对数据进行采样。为了理解模型的简单、中等或困难意味着什么,我以这张图片为例:

持续学习常用六种方法总结:使ML模型适应新数据的同时保持旧数据的性能

这是在分类任务中的混淆矩阵,困难样本是假阳性(False Positive),是指模型预测为True的可能性很高,但实际上不是True的样本。中等样本是那些具有中到高的正确性可能性但低于预测阈值的True Negative。而简单样本则是那些可能性较低的True Positive/Negative。

Maximally Interfered Retrieval

这是 Rahaf 在题为“Online Continual Learning with Maximally Interfered Retrieval”的论文(1908.04742)中介绍的一种方法。主要思想是,对于正在训练的每个新数据批次,如果针对较新数据更新模型权重,将需要识别在损失值方面受影响最大的旧样本。保留由旧数据组成的有限大小的内存,并检索最大干扰的样本以及每个新数据批次以一起训练。

这篇论文在持续学习领域是一篇成熟的论文,并且有很多引用,因此可能适用于您的案例。

Retrieval Augmentation

检索增强(Retrieval Augmentation)是指通过从集合中检索项目来扩充输入、样本等的技术。这是一个普遍的概念而不是一个特定的技术。我们到目前为止所讨论的方法,大多数都在一定程度都是检索相关的操作。Izacard 的题为 Few-shot Learning with Retrieval Augmented Language Models 的论文使用更小的模型获得了出色的少样本 学习的性能。检索增强也用于许多其他情况,例如单词生成或回答事实问题。

扩展模型在训练时使用附加层是最常见也最简单的方法,但是不一定有效,所以在这里不进行详细的讨论,这里的一个例子是 Lewis 的 Efficient Few-Shot Learning without Prompts。使用附加层通常是在新旧数据上获得良好性能的最简单但经过尝试和测试的方法。主要思想是保持模型权重固定,并通过分类损失在新数据上训练一层或几层。

总结在本文中,我介绍了在新数据上训练模型时可以使用的 6 种方法。与往常一样应该进行实验并决定哪种方法最适合,但是需要注意的是,除了我上面的方法外还有很多方法,例如数据蒸馏是计算机视觉中的一个活跃领域,你可以找到很多关于它的论文。最后说明的一点是:要使这些方法有价值,它们应该在旧数据和新数据上同时获得良好的性能 。

本篇关于《持续学习常用六种方法总结:使ML模型适应新数据的同时保持旧数据的性能》的介绍就到此结束啦,但是学无止境,想要了解学习更多关于科技周边的相关知识,请关注golang学习网公众号!

版本声明
本文转载于:51CTO.COM 如有侵犯,请联系study_golang@163.com删除
因果推断主要技术思想与方法总结因果推断主要技术思想与方法总结
上一篇
因果推断主要技术思想与方法总结
我应该在我的 iPad 上安装 iPadOS 测试版吗?
下一篇
我应该在我的 iPad 上安装 iPadOS 测试版吗?
查看更多
最新文章
查看更多
课程推荐
  • 前端进阶之JavaScript设计模式
    前端进阶之JavaScript设计模式
    设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
    542次学习
  • GO语言核心编程课程
    GO语言核心编程课程
    本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
    508次学习
  • 简单聊聊mysql8与网络通信
    简单聊聊mysql8与网络通信
    如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
    497次学习
  • JavaScript正则表达式基础与实战
    JavaScript正则表达式基础与实战
    在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
    487次学习
  • 从零制作响应式网站—Grid布局
    从零制作响应式网站—Grid布局
    本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
    484次学习
查看更多
AI推荐
  • SEO标题魔匠AI:高质量学术写作平台,毕业论文生成与优化专家
    魔匠AI
    SEO摘要魔匠AI专注于高质量AI学术写作,已稳定运行6年。提供无限改稿、选题优化、大纲生成、多语言支持、真实参考文献、数据图表生成、查重降重等全流程服务,确保论文质量与隐私安全。适用于专科、本科、硕士学生及研究者,满足多语言学术需求。
    9次使用
  • PPTFake答辩PPT生成器:一键生成高效专业的答辩PPT
    PPTFake答辩PPT生成器
    PPTFake答辩PPT生成器,专为答辩准备设计,极致高效生成PPT与自述稿。智能解析内容,提供多样模板,数据可视化,贴心配套服务,灵活自主编辑,降低制作门槛,适用于各类答辩场景。
    25次使用
  • SEO标题Lovart AI:全球首个设计领域AI智能体,实现全链路设计自动化
    Lovart
    SEO摘要探索Lovart AI,这款专注于设计领域的AI智能体,通过多模态模型集成和智能任务拆解,实现全链路设计自动化。无论是品牌全案设计、广告与视频制作,还是文创内容创作,Lovart AI都能满足您的需求,提升设计效率,降低成本。
    25次使用
  • 美图AI抠图:行业领先的智能图像处理技术,3秒出图,精准无误
    美图AI抠图
    美图AI抠图,依托CVPR 2024竞赛亚军技术,提供顶尖的图像处理解决方案。适用于证件照、商品、毛发等多场景,支持批量处理,3秒出图,零PS基础也能轻松操作,满足个人与商业需求。
    34次使用
  • SEO标题PetGPT:智能桌面宠物程序,结合AI对话的个性化陪伴工具
    PetGPT
    SEO摘要PetGPT 是一款基于 Python 和 PyQt 开发的智能桌面宠物程序,集成了 OpenAI 的 GPT 模型,提供上下文感知对话和主动聊天功能。用户可高度自定义宠物的外观和行为,支持插件热更新和二次开发。适用于需要陪伴和效率辅助的办公族、学生及 AI 技术爱好者。
    35次使用
微信登录更方便
  • 密码登录
  • 注册账号
登录即同意 用户协议隐私政策
返回登录
  • 重置密码