对比学习如何实现异常检测?
本文深入探讨了如何利用对比学习在Python中实现高效的异常表示学习,这对于提升异常检测的准确性和效率至关重要。核心方法在于构建一个深度学习模型,该模型通过自监督学习,学习数据的低维、高信息量表示。模型通过对比损失函数(如InfoNCE Loss)来区分样本的不同增强版本(正样本对)与其他随机样本(负样本对),从而使正常数据在嵌入空间中紧密聚集,而异常数据则被推离。实现这一目标的关键步骤包括:精心设计的数据增强策略,构建包含编码器和投影头的模型架构,定义合适的对比损失函数,以及使用无标签的正常数据进行训练。训练完成后,通过距离度量、K近邻距离或单类分类器等方法对新样本进行异常评分,距离越大或密度越低则越可能是异常。
在Python中实现基于对比学习的异常表示学习,核心步骤包括数据增强、模型构建、对比损失定义、训练以及异常评分。1. 数据增强:通过生成每个样本的多个增强版本,保留语义信息并引入扰动,例如图像数据使用随机裁剪、颜色抖动等方法。2. 模型构建:模型由编码器和投影头组成,编码器提取高维特征,投影头将特征映射到低维嵌入空间。3. 对比损失定义:使用InfoNCE Loss(NT-Xent Loss),最大化正样本对之间一致性,最小化正样本对与负样本对之间一致性。4. 训练:使用无标签的正常数据进行训练,优化模型参数,使正常数据在嵌入空间中聚集。5. 异常评分:通过计算新样本与正常样本集群的距离、K-近邻距离或使用单类分类器进行异常检测,距离越大或密度越低则越可能是异常。
在Python中实现基于对比学习的异常表示学习,核心在于构建一个深度学习模型,通过对比损失函数(如InfoNCE Loss)来学习数据的低维、高信息量表示。这个过程通常是自监督的,模型会学习如何区分一个样本的不同增强版本(正样本对)与其他随机样本(负样本对),从而使得正常数据点在嵌入空间中紧密聚集,而异常点则因为其独特性而被推离这个“正常”集群,变得容易识别。

解决方案
要实现基于对比学习的异常表示学习,我们通常会遵循以下步骤:数据增强、模型构建、对比损失定义、训练以及异常评分。
1. 数据增强 (Augmentation) 这是对比学习的基石。对于每个原始数据点,我们生成至少两个不同的“视图”或增强版本。这些增强应能保留原始样本的语义信息,同时引入一定的扰动。

- 文本数据: 随机删除、替换、插入词语,同义词替换,句子打乱等。
- 图像数据: 随机裁剪、翻转、颜色抖动、高斯模糊等。
- 时间序列数据: 随机裁剪、加噪声、缩放、时间扭曲等。
import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from PIL import Image import numpy as np # 假设是图像数据,定义一个简单的增强策略 class TwoCropsTransform: def __init__(self, base_transform): self.base_transform = base_transform def __call__(self, x): q = self.base_transform(x) k = self.base_transform(x) return q, k # 基础图像增强,可以根据实际任务调整 base_transform = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.2, 1.)), transforms.RandomApply([ transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened ], p=0.8), transforms.RandomGrayscale(p=0.2), transforms.RandomApply([transforms.GaussianBlur((23, 23), sigma=(0.1, 2.0))], p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 对于非图像数据,需要自定义相应的增强函数 def text_augment(text): # 示例:简单地复制两次,实际应有更复杂的NLP增强 return text + "_aug1", text + "_aug2"
2. 模型构建 (Model Architecture) 通常包含一个编码器(Encoder)和一个投影头(Projection Head)。
- 编码器: 负责从原始数据中提取高维特征。可以是CNN(图像)、Transformer(文本)、MLP(表格数据)等。
- 投影头: 一个或多个全连接层,将编码器输出的特征映射到一个低维的、用于对比学习的嵌入空间。
class Encoder(nn.Module): def __init__(self, in_channels=3, feature_dim=128): super(Encoder, self).__init__() # 示例:一个简单的CNN编码器,可以替换为ResNet, BERT等 self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1) self.relu = nn.ReLU() self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) self.flatten = nn.Flatten() # 假设输入是224x224,经过两层conv+pool后,特征维度需要计算 # (224/2/2) * (224/2/2) * 128 = 56*56*128 = 401408 self.fc = nn.Linear(128 * (224 // 4 // 2) * (224 // 4 // 2), feature_dim) # 调整以匹配实际输出维度 def forward(self, x): x = self.pool(self.relu(self.conv1(x))) x = self.pool(self.relu(self.conv2(x))) x = self.flatten(x) x = self.fc(x) return x class ProjectionHead(nn.Module): def __init__(self, in_dim, out_dim): super(ProjectionHead, self).__init__() self.fc1 = nn.Linear(in_dim, in_dim) self.relu = nn.ReLU() self.fc2 = nn.Linear(in_dim, out_dim) def forward(self, x): x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x class ContrastiveModel(nn.Module): def __init__(self, encoder, projection_head_dim=128): super(ContrastiveModel, self).__init__() self.encoder = encoder # 假设encoder的输出维度是feature_dim self.projection_head = ProjectionHead(encoder.fc.out_features, projection_head_dim) def forward(self, x): features = self.encoder(x) projections = self.projection_head(features) # L2归一化,这对对比学习非常重要 projections = F.normalize(projections, dim=1) return features, projections
3. 对比损失定义 (Contrastive Loss) 最常用的是InfoNCE Loss(也称为NT-Xent Loss)。它旨在最大化正样本对之间的一致性,同时最小化正样本对与负样本对之间的一致性。

class NTXentLoss(nn.Module): def __init__(self, temperature=0.07): super(NTXentLoss, self).__init__() self.temperature = temperature def forward(self, z_i, z_j): """ z_i, z_j: 两个增强视图的投影特征,形状为 (batch_size, projection_dim) """ batch_size = z_i.size(0) # 合并所有样本的投影,用于计算所有对之间的相似度 z = torch.cat([z_i, z_j], dim=0) # shape (2*batch_size, projection_dim) # 计算余弦相似度矩阵 sim_matrix = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2) # shape (2*batch_size, 2*batch_size) # 移除对角线上的自相似度(因为一个样本和它自己不是正样本对) logits = sim_matrix / self.temperature # 构建标签:对角线上的元素是正样本对,其余是负样本对 # 例如,对于 batch_size=2, z_i=[z_i1, z_i2], z_j=[z_j1, z_j2] # z = [z_i1, z_i2, z_j1, z_j2] # 那么正样本对是 (z_i1, z_j1) 和 (z_i2, z_j2) # 它们在 sim_matrix 中的索引是 (0, 2) 和 (1, 3) # 还有 (z_j1, z_i1) 和 (z_j2, z_i2) # 它们在 sim_matrix 中的索引是 (2, 0) 和 (3, 1) labels = torch.arange(batch_size, device=z_i.device) labels = torch.cat([labels, labels + batch_size], dim=0) # [0,1,2,3] for batch_size=2 # 排除自相似项 mask = torch.eye(labels.shape[0], dtype=torch.bool, device=z_i.device) logits = logits[~mask].view(labels.shape[0], -1) # 移除对角线元素 # 调整标签以匹配新的logits形状 # 对于每个 z_i 的正样本是 z_j,对于每个 z_j 的正样本是 z_i # 在新的 logits 矩阵中,每个样本的“正样本”位于特定的位置 # 例如,z_i[0] 的正样本是 z_j[0],它在 sim_matrix 中是 sim_matrix[0, batch_size] # 在 logits_mask_self 中,它会是第 batch_size - 1 个元素 pos_mask = torch.zeros_like(sim_matrix, dtype=torch.bool, device=z_i.device) pos_mask[torch.arange(batch_size), torch.arange(batch_size) + batch_size] = True pos_mask[torch.arange(batch_size) + batch_size, torch.arange(batch_size)] = True pos_mask = pos_mask[~mask].view(labels.shape[0], -1) # 移除自相似后,正样本的位置 # 收集正样本的logits positive_logits = logits[pos_mask].view(labels.shape[0], 1) # 计算LogSumExp,用于分母 log_prob = positive_logits - torch.logsumexp(logits, dim=1, keepdim=True) loss = -log_prob.mean() return loss
4. 训练 (Training Loop) 使用大量无标签的正常数据进行训练。目标是让模型学习到正常数据的紧凑表示。
from torch.utils.data import DataLoader, Dataset import random # 假设你有一个包含正常数据的Dataset class NormalDataset(Dataset): def __init__(self, data_list, transform=None): self.data_list = data_list # list of image paths or actual data self.transform = transform def __len__(self): return len(self.data_list) def __getitem__(self, idx): # 假设data_list是PIL Image对象列表 img = self.data_list[idx] if self.transform: return self.transform(img) return img # 模拟一些正常数据 dummy_normal_data = [Image.new('RGB', (224, 224), color = (i, i, i)) for i in range(100)] normal_dataset = NormalDataset(dummy_normal_data, transform=TwoCropsTransform(base_transform)) normal_dataloader = DataLoader(normal_dataset, batch_size=32, shuffle=True) # 初始化模型和优化器 encoder = Encoder() model = ContrastiveModel(encoder) criterion = NTXentLoss(temperature=0.5) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # 训练过程 num_epochs = 10 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) print("开始训练对比学习模型...") for epoch in range(num_epochs): model.train() total_loss = 0 for (img1, img2) in normal_dataloader: img1 = img1.to(device) img2 = img2.to(device) _, proj1 = model(img1) _, proj2 = model(img2) loss = criterion(proj1, proj2) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(normal_dataloader):.4f}") print("训练完成。")
5. 异常评分 (Anomaly Scoring) 训练完成后,模型学会了将正常样本映射到嵌入空间的一个特定区域。对于新的、未见过的样本,我们可以计算其嵌入与正常样本集群的距离,或者利用其在嵌入空间中的密度。
- 方法一:距离度量
- 计算所有训练集中正常样本的嵌入,求其均值或中心点。
- 对于新样本,计算其嵌入与该中心点的欧氏距离或余弦距离。距离越大,越可能是异常。
- 方法二:K-近邻距离
- 对于新样本的嵌入,计算它到训练集中K个最近邻正常样本嵌入的平均距离。距离越大,异常可能性越高。
- 方法三:密度估计
- 在嵌入空间中拟合一个单类分类器(如One-Class SVM, Isolation Forest)或密度估计器(如高斯混合模型)。
- 新样本的得分越低(在正常区域密度越低),越可能是异常。
from sklearn.metrics import roc_auc_score from sklearn.ensemble import IsolationForest # 假设模型已经训练好 model.eval() # 1. 提取所有正常训练样本的嵌入 normal_embeddings = [] with torch.no_grad(): for (img1, img2) in normal_dataloader: # 只需要一个视图来提取特征 img1 = img1.to(device) features, _ = model(img1) normal_embeddings.append(features.cpu().numpy()) normal_embeddings = np.concatenate(normal_embeddings, axis=0) # 2. 训练一个异常检测模型在这些嵌入上 # 这里我们使用Isolation Forest作为示例,它对高维数据效果不错 iso_forest = IsolationForest(contamination='auto', random_state=42) iso_forest.fit(normal_embeddings) # 3. 评估或预测新样本 # 假设你有新的测试数据,其中包含正常和异常样本 # dummy_test_normal_data = [Image.new('RGB', (224, 224), color = (i, i, i)) for i in range(10)] # dummy_test_anomaly_data = [Image.new('RGB', (224, 224), color = (255-i, 0, 0)) for i in range(5)] # 模拟异常 # test_data = dummy_test_normal_data + dummy_test_anomaly_data # test_labels = [0]*len(dummy_test_normal_data) + [1]*len(dummy_test_anomaly_data) # 0 for normal, 1 for anomaly # test_dataset = NormalDataset(test_data, transform=base_transform) # 只用一个视图 # test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False) # test_embeddings = [] # with torch.no_grad(): # for img in test_dataloader: # img = img.to(device) # features, _ = model(img) # test_embeddings.append(features.cpu().numpy()) # test_embeddings = np.concatenate(test_embeddings, axis=0) # anomaly_scores = -iso_forest.decision_function(test_embeddings) # Isolation Forest的decision_function越小越异常,所以取负号 # # ROC AUC 是一个常用的评估指标 # # roc_auc = roc_auc_score(test_labels, anomaly_scores) # # print(f"测试集 ROC AUC: {roc_auc:.4f}") print("\n异常检测模型已基于学习到的嵌入训练完成。") print("新的样本
本篇关于《对比学习如何实现异常检测?》的介绍就到此结束啦,但是学无止境,想要了解学习更多关于文章的相关知识,请关注golang学习网公众号!

- 上一篇
- JavaScript数组at方法获取最后元素技巧

- 下一篇
- RESTfulAPI开发教程:PHP接口设计详解
-
- 文章 · python教程 | 37秒前 |
- Python连接Redis指南:redis-py配置全解析
- 387浏览 收藏
-
- 文章 · python教程 | 10分钟前 |
- PythonOCR教程:Tesseract配置全解析
- 499浏览 收藏
-
- 文章 · python教程 | 50分钟前 |
- Python文本分类教程:Scikit-learn实战指南
- 305浏览 收藏
-
- 文章 · python教程 | 1小时前 |
- Intake教程:多CSV数据源构建技巧
- 245浏览 收藏
-
- 文章 · python教程 | 1小时前 |
- Python宽表转长表:melt方法全解析
- 232浏览 收藏
-
- 文章 · python教程 | 1小时前 |
- Python用户行为分析:漏斗模型怎么实现
- 282浏览 收藏
-
- 文章 · python教程 | 1小时前 |
- 递归分层计算如何实现
- 467浏览 收藏
-
- 文章 · python教程 | 1小时前 |
- Python连接PostgreSQL:psycopg2使用教程
- 188浏览 收藏
-
- 前端进阶之JavaScript设计模式
- 设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
- 542次学习
-
- GO语言核心编程课程
- 本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
- 511次学习
-
- 简单聊聊mysql8与网络通信
- 如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
- 498次学习
-
- JavaScript正则表达式基础与实战
- 在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
- 487次学习
-
- 从零制作响应式网站—Grid布局
- 本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
- 484次学习
-
- 千音漫语
- 千音漫语,北京熠声科技倾力打造的智能声音创作助手,提供AI配音、音视频翻译、语音识别、声音克隆等强大功能,助力有声书制作、视频创作、教育培训等领域,官网:https://qianyin123.com
- 98次使用
-
- MiniWork
- MiniWork是一款智能高效的AI工具平台,专为提升工作与学习效率而设计。整合文本处理、图像生成、营销策划及运营管理等多元AI工具,提供精准智能解决方案,让复杂工作简单高效。
- 89次使用
-
- NoCode
- NoCode (nocode.cn)是领先的无代码开发平台,通过拖放、AI对话等简单操作,助您快速创建各类应用、网站与管理系统。无需编程知识,轻松实现个人生活、商业经营、企业管理多场景需求,大幅降低开发门槛,高效低成本。
- 109次使用
-
- 达医智影
- 达医智影,阿里巴巴达摩院医疗AI创新力作。全球率先利用平扫CT实现“一扫多筛”,仅一次CT扫描即可高效识别多种癌症、急症及慢病,为疾病早期发现提供智能、精准的AI影像早筛解决方案。
- 99次使用
-
- 智慧芽Eureka
- 智慧芽Eureka,专为技术创新打造的AI Agent平台。深度理解专利、研发、生物医药、材料、科创等复杂场景,通过专家级AI Agent精准执行任务,智能化工作流解放70%生产力,让您专注核心创新。
- 100次使用
-
- Flask框架安装技巧:让你的开发更高效
- 2024-01-03 501浏览
-
- Django框架中的并发处理技巧
- 2024-01-22 501浏览
-
- 提升Python包下载速度的方法——正确配置pip的国内源
- 2024-01-17 501浏览
-
- Python与C++:哪个编程语言更适合初学者?
- 2024-03-25 501浏览
-
- 品牌建设技巧
- 2024-04-06 501浏览