在少样本学习中,用SetFit进行文本分类
大家好,今天本人给大家带来文章《在少样本学习中,用SetFit进行文本分类》,文中内容主要涉及到,如果你对科技周边方面的知识点感兴趣,那就请各位朋友继续看下去吧~希望能真正帮到你们,谢谢!
译者 | 陈峻
审校 | 重楼
在本文中,我将向您介绍“少样本(Few-shot)学习”的相关概念,并重点讨论被广泛应用于文本分类的SetFit方法。
传统的机器学习(ML)
在监督(Supervised)机器学习中,大量数据集被用于模型训练,以便磨练模型能够做出精确预测的能力。在完成训练过程之后,我们便可以利用测试数据,来获得模型的预测结果。然而,这种传统的监督学习方法存在着一个显著缺点:它需要大量无差错的训练数据集。但是并非所有领域都能够提供此类无差错数据集。因此,“少样本学习”的概念应运而生。
在深入研究Sentence Transformer fine-tuning(SetFit)之前,我们有必要简要地回顾一下自然语言处理(Natural Language Processing,NLP)的一个重要方面,也就是:“少样本学习”。
少样本学习
少样本学习是指:使用有限的训练数据集,来训练模型。模型可以从这些被称为支持集的小集合中获取知识。此类学习旨在教会少样本模型,辨别出训练数据中的相同与相异之处。例如,我们并非要指示模型将所给图像分类为猫或狗,而是指示它掌握各种动物之间的共性和区别。可见,这种方法侧重于理解输入数据中的相似点和不同点。因此,它通常也被称为元学习(meta-learning)、或是从学习到学习(learning-to-learn)。
值得一提的是,少样本学习的支持集,也被称为k向(k-way)n样本(n-shot)学习。其中“k”代表支持集里的类别数。例如,在二分类(binary classification)中,k 等于 2。而“n”表示支持集中每个类别的可用样本数。例如,如果正分类有10个数据点,而负分类也有10个数据点,那么 n就等于10。总之,这个支持集可以被描述为双向10样本学习。
既然我们已经对少样本学习有了基本的了解,下面让我们通过使用SetFit进行快速学习,并在实际应用中对电商数据集进行文本分类。
SetFit架构
由Hugging Face和英特尔实验室的团队联合开发的SetFit,是一款用于少样本照片分类的开源工具。你可以在项目库链接--https://github.com/huggingface/setfit?ref=hackernoon.com中,找到关于SetFit的全面信息。
就输出而言,SetFit仅用到了客户评论(Customer Reviews,CR)情感分析数据集里、每个类别的八个标注示例。其结果就能够与由三千个示例组成的完整训练集上,经调优的RoBERTa Large的结果相同。值得强调的是,就体积而言,经微优的RoBERTa模型比SetFit模型大三倍。下图展示的是SetFit架构:
图片来源:https://www.sbert.net/docs/training/overview.html?ref=hackernoon.com
用SetFit实现快速学习
SetFit的训练速度非常快,效率也极高。与GPT-3和T-FEW等大模型相比,其性能极具竞争力。请参见下图:
SetFit与T-Few 3B模型的比较
如下图所示,SetFit在少样本学习方面的表现优于RoBERTa。
SetFit与RoBERT的比较,图片来源:https://huggingface.co/blog/setfit?ref=hackernoon.com
数据集
下面,我们将用到由四个不同类别组成的独特电商数据集,它们分别是:书籍、服装与配件、电子产品、以及家居用品。该数据集的主要目的是将来自电商网站的产品描述归类到指定的标签下。
为了便于采用少样本的训练方法,我们将从四个类别中各选择八个样本,从而得到总共32个训练样本。而其余样本则将留作测试之用。简言之,我们在此使用的支持集是4向8样本学习。下图展示的是自定义电商数据集的示例:
自定义电商数据集样本
我们采用名为“all-mpnet-base-v2”的Sentence Transformers预训练模型,将文本数据转换为各种向量嵌入。该模型可以为输入文本,生成维度为768的向量嵌入。
如下命令所示,我们将通过在conda环境(是一个开源的软件包管理系统和环境管理系统)中安装所需的软件包,来开始SetFit的实施。
!pip3 install SetFit !pip3 install sklearn !pip3 install transformers !pip3 install sentence-transformers
安装完软件包后,我们便可以通过如下代码加载数据集了。
from datasets import load_datasetdataset = load_dataset('csv', data_files={"train": 'E_Commerce_Dataset_Train.csv',"test": 'E_Commerce_Dataset_Test.csv'})
我们来参照下图,看看训练样本和测试样本数。
训练和测试数据
我们使用sklearn软件包中的LabelEncoder,将文本标签转换为编码标签。
from sklearn.preprocessing import LabelEncoder le = LabelEncoder()
通过LabelEncoder,我们将对训练和测试数据集进行编码,并将编码后的标签添加到数据集的“标签”列中。请参见如下代码:
Encoded_Product = le.fit_transform(dataset["train"]['Label']) dataset["train"] = dataset["train"].remove_columns("Label").add_column("Label", Encoded_Product).cast(dataset["train"].features)Encoded_Product = le.fit_transform(dataset["test"]['Label']) dataset["test"] = dataset["test"].remove_columns("Label").add_column("Label", Encoded_Product).cast(dataset["test"].features)
下面,我们将初始化SetFit模型和句子转换器(sentence-transformers)模型。
from setfit import SetFitModel, SetFitTrainer from sentence_transformers.losses import CosineSimilarityLossmodel_id = "sentence-transformers/all-mpnet-base-v2" model = SetFitModel.from_pretrained(model_id)trainer = SetFitTrainer( model=model, train_dataset=dataset["train"], eval_dataset=dataset["test"], loss_class=CosineSimilarityLoss, metric="accuracy", batch_size=64, num_iteratinotallow=20, num_epochs=2, column_mapping={"Text": "text", "Label": "label"})
初始化完成两个模型后,我们现在便可以调用训练程序了。
trainer.train()
在完成了2个训练轮数(epoch)后,我们将在eval_dataset上,对训练好的模型进行评估。
trainer.evaluate()
经测试,我们的训练模型的最高准确率为87.5%。虽然87.5%的准确率并不算高,但是毕竟我们的模型只用了32个样本进行训练。也就是说,考虑到数据集规模的有限性,在测试数据集上取得87.5%的准确率,实际上是相当可观的。
此外,SetFit还能够将训练好的模型,保存到本地存储器中,以便后续从磁盘加载,用于将来的预测。
trainer.model._save_pretrained(save_directory="SetFit_ECommerce_Output/")model=SetFitModel.from_pretrained("SetFit_ECommerce_Output/", local_files_notallow=True)
如下代码展示了根据新的数据进行的预测结果:
input = ["Campus Sutra Men's Sports Jersey T-Shirt Cool-Gear: Our Proprietary Moisture Management technology. Helps to absorb and evaporate sweat quickly. Keeps you Cool & Dry. Ultra-Fresh: Fabrics treated with Ultra-Fresh Antimicrobial Technology. Ultra-Fresh is a trademark of (TRA) Inc, Ontario, Canada. Keeps you odour free."]output = model(input)
可见,其预测输出为1,而标签的LabelEncoded值为“服装与配件”。由于传统的AI模型需要大量的训练资源(包括时间和数据),才能有稳定水准的输出。而我们的模型与之相比,既准确又高效。
至此,相信您已经基本掌握了“少样本学习”的概念,以及如何使用SetFit来进行文本分类等应用。当然,为了获得更深刻的理解,我强烈建议您选择一个实际场景,创建一个数据集,编写对应的代码,并将该过程延展到零样本学习、以及单样本学习上。
译者介绍
陈峻(Julian Chen)是51CTO社区的编辑,他在IT项目实施方面有十多年的经验,擅长管理内外部资源和风险,并专注于传播网络和信息安全的知识和经验
原文标题:Mastering Few-Shot Learning with SetFit for Text Classification,作者:Shyam Ganesh S)
理论要掌握,实操不能落!以上关于《在少样本学习中,用SetFit进行文本分类》的详细介绍,大家都掌握了吧!如果想要继续提升自己的能力,那么就来关注golang学习网公众号吧!

- 上一篇
- AI 换装新突破:经 100 万张照片训练,解构重建服装准确率 95.7%

- 下一篇
- Varjo推出最新XR-4 VR/MR头戴式显示器,两个2000万像素摄像头实现实时逼真MR透视
-
- 科技周边 · 人工智能 | 3小时前 |
- 文心一言职场励志文案怎么写?
- 208浏览 收藏
-
- 科技周边 · 人工智能 | 3小时前 |
- 豆包AI能识图吗?多模态使用教程分享
- 309浏览 收藏
-
- 科技周边 · 人工智能 | 4小时前 |
- GeminiAPI限速设置与调用方法
- 272浏览 收藏
-
- 科技周边 · 人工智能 | 4小时前 |
- 笔尖AIAPI接入与安全使用指南
- 490浏览 收藏
-
- 科技周边 · 人工智能 | 5小时前 |
- Deepseek满血版联手Copy.ai,文案模板秒用
- 104浏览 收藏
-
- 科技周边 · 人工智能 | 5小时前 |
- 苹果用户如何安装DeepSeek详解
- 254浏览 收藏
-
- 科技周边 · 人工智能 | 5小时前 |
- 用ChatGPT写评论区文案的步骤与技巧
- 228浏览 收藏
-
- 科技周边 · 人工智能 | 5小时前 |
- Deepseek满血版+Kapwing,轻松剪辑创意视频
- 395浏览 收藏
-
- 科技周边 · 人工智能 | 5小时前 |
- Diffusers图像生成教程:扩散模型推理详解
- 482浏览 收藏
-
- 科技周边 · 人工智能 | 6小时前 |
- 文心一言生成图片步骤详解
- 150浏览 收藏
-
- 前端进阶之JavaScript设计模式
- 设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
- 542次学习
-
- GO语言核心编程课程
- 本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
- 511次学习
-
- 简单聊聊mysql8与网络通信
- 如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
- 498次学习
-
- JavaScript正则表达式基础与实战
- 在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
- 487次学习
-
- 从零制作响应式网站—Grid布局
- 本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
- 484次学习
-
- 千音漫语
- 千音漫语,北京熠声科技倾力打造的智能声音创作助手,提供AI配音、音视频翻译、语音识别、声音克隆等强大功能,助力有声书制作、视频创作、教育培训等领域,官网:https://qianyin123.com
- 231次使用
-
- MiniWork
- MiniWork是一款智能高效的AI工具平台,专为提升工作与学习效率而设计。整合文本处理、图像生成、营销策划及运营管理等多元AI工具,提供精准智能解决方案,让复杂工作简单高效。
- 227次使用
-
- NoCode
- NoCode (nocode.cn)是领先的无代码开发平台,通过拖放、AI对话等简单操作,助您快速创建各类应用、网站与管理系统。无需编程知识,轻松实现个人生活、商业经营、企业管理多场景需求,大幅降低开发门槛,高效低成本。
- 226次使用
-
- 达医智影
- 达医智影,阿里巴巴达摩院医疗AI创新力作。全球率先利用平扫CT实现“一扫多筛”,仅一次CT扫描即可高效识别多种癌症、急症及慢病,为疾病早期发现提供智能、精准的AI影像早筛解决方案。
- 231次使用
-
- 智慧芽Eureka
- 智慧芽Eureka,专为技术创新打造的AI Agent平台。深度理解专利、研发、生物医药、材料、科创等复杂场景,通过专家级AI Agent精准执行任务,智能化工作流解放70%生产力,让您专注核心创新。
- 254次使用
-
- GPT-4王者加冕!读图做题性能炸天,凭自己就能考上斯坦福
- 2023-04-25 501浏览
-
- 单块V100训练模型提速72倍!尤洋团队新成果获AAAI 2023杰出论文奖
- 2023-04-24 501浏览
-
- ChatGPT 真的会接管世界吗?
- 2023-04-13 501浏览
-
- VR的终极形态是「假眼」?Neuralink前联合创始人掏出新产品:科学之眼!
- 2023-04-30 501浏览
-
- 实现实时制造可视性优势有哪些?
- 2023-04-15 501浏览