旋转位置编码与SwiGLU解析详解
对于一个科技周边开发者来说,牢固扎实的基础是十分重要的,golang学习网就来带大家一点点的掌握基础知识点。今天本篇文章带大家了解《大模型架构:旋转位置编码与SwiGLU解析》,主要介绍了,希望对大家的知识积累有所帮助,快点收藏起来吧,否则需要时就找不到了!
RoPE / SwiGLU
前言
✍ 上一篇我们把现代大模型的两件“基础设施”——GQA 注意力 和 RMSNorm + Pre-Norm 细讲了一遍,从多头注意力的演化一路讲到归一化的升级。这一篇,我们就顺势把剩下的两件标配武器补上:
RoPE(Rotary Positional Embedding):解决“长上下文 + 相对位置建模”的问题;SwiGLU 前馈网络:解决“FFN 表达力与训练稳定性”的问题。一、位置编码
1.1 绝对位置编码——三角函数编码
在最早的 Transformer 里,模型本身对“顺序”是没有感觉的,它只看到一串向量 : x_1, x_2, \dots, x_L \in \mathbb{R}^{d_{\text{model}}} 。他并不像RNN、LSTM一样具备循环机制,因此对于位置信息是不敏感。为了让模型知道“谁在前谁在后”,Transformer 直接给每个位置加了一个位置向量 PE_{pos}
\tilde{x}_{pos} = x_{pos} + PE_{pos}
Transformer 原始论文里的做法采用了三角函数位置编码:
PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right),\quad PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right))
三角函数编码是绝对位置编码其中一种经典实现方式,用固定的 sin/cos 函数给每个绝对位置生成向量,方便模型外推到更长序列。
绝对 PE 的好处是实现很简单,但也有两点局限:
它更偏向“绝对位置”:第 10 个 token 和第 20 个 token 的位置向量完全不同;对于超长上下文,learned pos embedding 很难直接外推,sin-cos 虽然能算,但模型未必学会用。? 读到这里,读者可能会有点疑惑,为什么在前面说了 ① 三角函数编码方便模型外推到更长序列, 但是后面又说了② 对于超长上下文,learned pos embedding 很难直接外推,sin-cos 虽然能算,但模型未必学会用。
① 对 learned pos embedding 来说,当我们只训练了max_len = 2048,那 embedding table 里就只有 0~2047 这些 index;想用到 4096-long 序列时,根本没有 PE[3000] 这一行可用,得重新插值/扩表。但是,当我们采用了三角函数的位置编码时, 想算 pos=4096、pos=10000 都随时能算,从“函数定义”角度确实更“可外推”。② 明确表示了sin-cos 虽然能算,但模型未必学会用,对远超训练长度的位置(比如 8192)对应的正弦相位组合,模型可能根本没“学会如何解读”;因此这里根本不会自相矛盾,用一句土话讲就是“可以但没用的外推”。
代码实现import mathimport torchimport torch.nn as nnclass RotaryEmbedding(nn.Module): """ RoPE 位置编码模块: - 只负责根据 head_dim + seq_len 生成 cos/sin - 不直接改 Q/K,在外面用 apply_rotary_pos_emb 处理 """ def __init__(self, head_dim: int, max_position_embeddings: int = 4096, base: float = 10000.0): super().__init__() assert head_dim % 2 == 0, "head_dim 必须是偶数,才能两两配对旋转" self.head_dim = head_dim self.max_position_embeddings = max_position_embeddings # inv_freq: [head_dim/2] # 对应论文里的 1 / base^{2i/d} inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) self.register_buffer("inv_freq", inv_freq) # 不参与训练 # 预先算好最大长度的 cos/sin,后面按 seq_len 切片 self._build_cache(max_position_embeddings) def _build_cache(self, max_seq_len: int): # t: [max_seq_len] t = torch.arange(max_seq_len, dtype=torch.float32, device=self.inv_freq.device) # freqs: [max_seq_len, head_dim/2] freqs = torch.einsum("i,j->ij", t, self.inv_freq) # 扩成 [max_seq_len, head_dim] emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos()[None, None, :, :]) # [1,1,L,D] self.register_buffer("sin_cached", emb.sin()[None, None, :, :]) # [1,1,L,D] def forward(self, seq_len: int, device=None): """ 返回: cos, sin: [1, 1, seq_len, head_dim] """ if seq_len > self.max_position_embeddings: # 超过预设长度就重建缓存(简单写法,够用) self.max_position_embeddings = seq_len self._build_cache(seq_len) cos = self.cos_cached[:, :, :seq_len, :] # [1,1,L,D] sin = self.sin_cached[:, :, :seq_len, :] # [1,1,L,D] if device is not None: cos = cos.to(device) sin = sin.to(device) return cos, sindef rotate_half(x: torch.Tensor) -> torch.Tensor: """ 将最后一维两两配对做 (x1, x2) -> (-x2, x1) x: [..., D] 且 D 为偶数 """ x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1)def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: """ RoPE 旋转操作: x: [B, H, L, D] cos: [1, 1, L, D] sin: [1, 1, L, D] """ # 广播到 [B,H,L,D] return x * cos + rotate_half(x) * sinclass RoPEMultiHeadAttention(nn.Module): """ 带 RoPE 的多头注意力: - 输入 / 输出: [B, L, d_model] - 内部: 拆成 [B, H, L, Dh],对 Q/K 应用 RoPE """ def __init__(self, d_model, num_heads, dropout=0.0, max_position_embeddings: int = 4096): super().__init__() assert d_model % num_heads == 0 self.d_model = d_model self.num_heads = num_heads self.head_dim = d_model // num_heads self.w_q = nn.Linear(d_model, d_model) self.w_k = nn.Linear(d_model, d_model) self.w_v = nn.Linear(d_model, d_model) self.w_o = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) # RoPE 模块,专门生成 cos/sin self.rotary_emb = RotaryEmbedding( head_dim=self.head_dim, max_position_embeddings=max_position_embeddings ) def forward(self, x, attn_mask=None): """ x: [B, L, d_model] attn_mask: [B, 1, L, L] 或 [B, L, L],为 0 的位置会被 mask 掉 """ B, L, _ = x.size() device = x.device # 1. 线性投影 Q = self.w_q(x) # [B, L, d_model] K = self.w_k(x) V = self.w_v(x) # 2. 拆成多头 [B, H, L, Dh] def split_heads(t): return t.view(B, L, self.num_heads, self.head_dim).transpose(1, 2) Q = split_heads(Q) # [B, H, L, Dh] K = split_heads(K) V = split_heads(V) # 3. 生成 RoPE 的 cos/sin,并作用在 Q/K 上 cos, sin = self.rotary_emb(seq_len=L, device=device) # [1,1,L,Dh] Q = apply_rotary_pos_emb(Q, cos, sin) # [B,H,L,Dh] K = apply_rotary_pos_emb(K, cos, sin) # 4. 缩放点积注意力 scores = Q @ K.transpose(-2, -1) / (self.head_dim ** 0.5) # [B,H,L,L] if attn_mask is not None: # 根据你项目里 attn_mask 的形状调整,这里假设 0 的地方是 mask 掉 if attn_mask.dim() == 3: attn_mask = attn_mask.unsqueeze(1) # [B,1,L,L] scores = scores.masked_fill(attn_mask == 0, float('-inf')) attn = torch.softmax(scores, dim=-1) attn = self.dropout(attn) out = attn @ V # [B,H,L,Dh] # 5. 合并多头 out = out.transpose(1, 2).contiguous().view(B, L, self.d_model) out = self.w_o(out) # [B,L,d_model] return outclass SwiGLUFFN(nn.Module): def __init__(self, d_model, d_ff=4096, dropout=0.1): super().__init__() self.w1 = nn.Linear(d_model, 2 * d_ff) # gate + value self.w2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): x_proj = self.w1(x) # [B, L, 2*d_ff] gate, value = x_proj.chunk(2, dim=-1) # [B,L,d_ff] x2 x = torch.nn.functional.silu(gate) * value # SwiGLU x = self.w2(self.dropout(x)) return xclass RMSNorm(nn.Module): def __init__(self, d_model, eps=1e-8): super().__init__() self.weight = nn.Parameter(torch.ones(d_model)) self.eps = eps def forward(self, x): # x: [B,L,d_model] rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt() x_norm = x / rms return self.weight * x_normclass DecoderBlockWithRoPE(nn.Module): """ 现代 LLM 风格的 Decoder Block: - RoPE + MHA - RMSNorm + Pre-Norm - SwiGLU FFN """ def __init__(self, d_model, num_heads, d_ff=4096, dropout=0.1, max_position_embeddings: int = 4096): super().__init__() self.self_attn = RoPEMultiHeadAttention( d_model=d_model, num_heads=num_heads, dropout=dropout, max_position_embeddings=max_position_embeddings, ) self.ffn = SwiGLUFFN(d_model, d_ff, dropout) self.norm1 = RMSNorm(d_model) self.norm2 = RMSNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, attn_mask=None): """ x: [B, L, d_model] """ # 1) Pre-Norm + RoPE Self-Attention h = self.norm1(x) attn_out = self.self_attn(h, attn_mask=attn_mask) x = x + self.dropout(attn_out) # 2) Pre-Norm + SwiGLU FFN h = self.norm2(x) ffn_out = self.ffn(h) x = x + self.dropout(ffn_out) return x四、总结
这一篇我们把另外两件标配武器补齐了:
RoPE:不再给输入加位置向量,而是在 Q/K 空间对每对维度做“旋转”,让注意力点积天然依赖相对位置差,更适合长上下文与外推;SwiGLU:给 FFN 加上一扇“门”,用 gate × value 的方式在通道维度上做细粒度控制,在相似参数量下比普通 GELU/ReLU FFN 更有表达力、训练更稳定。到这里,已经把“现代 LLM 架构四件套:GQA / RoPE / SwiGLU / RMSNorm + Pre-Norm”串成一个整体故事了。
本篇关于《旋转位置编码与SwiGLU解析详解》的介绍就到此结束啦,但是学无止境,想要了解学习更多关于科技周边的相关知识,请关注golang学习网公众号!
豆包连接失败怎么解决
- 上一篇
- 豆包连接失败怎么解决
- 下一篇
- 悟空浏览器最新版官网下载入口
-
- 科技周边 · 人工智能 | 26分钟前 | MySQL索引
- 单字段还是组合索引?MySQL优化实战解析
- 175浏览 收藏
-
- 科技周边 · 人工智能 | 48分钟前 |
- DeepSeek整合印象笔记打造个人AI知识库
- 416浏览 收藏
-
- 科技周边 · 人工智能 | 49分钟前 |
- MagicBooksAI评测:赋能互动电子书创作
- 417浏览 收藏
-
- 科技周边 · 人工智能 | 50分钟前 | Notion
- Notion任务提醒设置教程详解
- 386浏览 收藏
-
- 科技周边 · 人工智能 | 56分钟前 |
- AzureOpenAI打造私有数据语音助手
- 363浏览 收藏
-
- 科技周边 · 人工智能 | 59分钟前 |
- 高效工具推荐:提升效率的五大神器
- 251浏览 收藏
-
- 科技周边 · 人工智能 | 1小时前 |
- EnsoAI平台测评与功能详解
- 401浏览 收藏
-
- 科技周边 · 人工智能 | 1小时前 | ChatGPT
- 让ChatGPT写出符合PEP8的Python代码方法
- 477浏览 收藏
-
- 科技周边 · 人工智能 | 1小时前 |
- 魔法家庭作业助手,趣味数学启蒙方式
- 206浏览 收藏
-
- 科技周边 · 人工智能 | 1小时前 |
- Gemini与PixverseAI,轻松生成创意影像
- 142浏览 收藏
-
- 科技周边 · 人工智能 | 1小时前 |
- 即梦生成界面卡顿怎么解决
- 377浏览 收藏
-
- 科技周边 · 人工智能 | 1小时前 |
- Yahoo问答攻略:轻松上手技巧分享
- 216浏览 收藏
-
- 前端进阶之JavaScript设计模式
- 设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
- 543次学习
-
- GO语言核心编程课程
- 本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
- 516次学习
-
- 简单聊聊mysql8与网络通信
- 如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
- 500次学习
-
- JavaScript正则表达式基础与实战
- 在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
- 487次学习
-
- 从零制作响应式网站—Grid布局
- 本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
- 485次学习
-
- ChatExcel酷表
- ChatExcel酷表是由北京大学团队打造的Excel聊天机器人,用自然语言操控表格,简化数据处理,告别繁琐操作,提升工作效率!适用于学生、上班族及政府人员。
- 3353次使用
-
- Any绘本
- 探索Any绘本(anypicturebook.com/zh),一款开源免费的AI绘本创作工具,基于Google Gemini与Flux AI模型,让您轻松创作个性化绘本。适用于家庭、教育、创作等多种场景,零门槛,高自由度,技术透明,本地可控。
- 3565次使用
-
- 可赞AI
- 可赞AI,AI驱动的办公可视化智能工具,助您轻松实现文本与可视化元素高效转化。无论是智能文档生成、多格式文本解析,还是一键生成专业图表、脑图、知识卡片,可赞AI都能让信息处理更清晰高效。覆盖数据汇报、会议纪要、内容营销等全场景,大幅提升办公效率,降低专业门槛,是您提升工作效率的得力助手。
- 3595次使用
-
- 星月写作
- 星月写作是国内首款聚焦中文网络小说创作的AI辅助工具,解决网文作者从构思到变现的全流程痛点。AI扫榜、专属模板、全链路适配,助力新人快速上手,资深作者效率倍增。
- 4721次使用
-
- MagicLight
- MagicLight.ai是全球首款叙事驱动型AI动画视频创作平台,专注于解决从故事想法到完整动画的全流程痛点。它通过自研AI模型,保障角色、风格、场景高度一致性,让零动画经验者也能高效产出专业级叙事内容。广泛适用于独立创作者、动画工作室、教育机构及企业营销,助您轻松实现创意落地与商业化。
- 3970次使用
-
- GPT-4王者加冕!读图做题性能炸天,凭自己就能考上斯坦福
- 2023-04-25 501浏览
-
- 单块V100训练模型提速72倍!尤洋团队新成果获AAAI 2023杰出论文奖
- 2023-04-24 501浏览
-
- ChatGPT 真的会接管世界吗?
- 2023-04-13 501浏览
-
- VR的终极形态是「假眼」?Neuralink前联合创始人掏出新产品:科学之眼!
- 2023-04-30 501浏览
-
- 实现实时制造可视性优势有哪些?
- 2023-04-15 501浏览

