拯救Transformer推理能力!DeepMind新研究TransNAR:给模型嵌入「算法推理大脑」
发布时间:2024-06-19DeepMind最近发表的一篇论文提出用混合架构的方法解决Transformer模型的推理缺陷。将Transformer的NLU技能与基于GNN的神经算法推理器(NAR)的强大算法推理能力相结合,可以实现更加泛化、稳健、准确的LLM推理。
如今的NLP领域,已然是Transformer架构的天下。
从Bert到GPT,再到Llama、Claude,LLM模型使用Transformer已经是再正常不过的事情。
Transformer的「大一统」局面正是由于其简单、高效的架构,以及在理解自然语言方面无与伦比的泛化能力。
然而,随着研究的逐渐深入,Transformer的一个致命缺陷也逐渐暴露出来——无法胜任算法推理任务,尤其是不能进行精确、稳健的推理。
这严重限制了模型在数学、代码等领域下游任务的应用,近年来对Transformer的各种调优、修改似乎也收效甚微。
于是DeepMind的研究人员想到了混合架构——将Transformers的语言理解能力与基于图神经网络(GNN)的神经算法推理器(NAR)的稳健性结合起来,提升其算法推理能力。
他们最近在arxiv上的一篇论文就提出了这个名为TransNAR的架构,但遗憾的是,目前还没有公布源代码。
论文地址:https://arxiv.org/abs/2406.09308
神经算法推理(NAR)由本文作者之一Petar Veleckovic在2021年与人合著的一篇论文中提出,并被接收为Patterns期刊的opinion paper。
论文地址:https://arxiv.org/abs/2105.02761
NAR被称为「构建能执行算法的神经网络的艺术」。作者提出,算法与深度学习的本质不同,但如果神经网络能够更好地模仿算法,它甚至可能具备算法的强泛化性。
更进一步,神经网络若能表示出算法中连续空间内的元素,就会使已知算法更接近现实世界的问题,提出的解决方案可能超过人类科学家。
如上图所示,NAR的整体想法是训练出一个高维隐空间中的处理器网络P(processor network),旨在不断逼近算法的运行结果A(x)。
但由于算法的输入和输出一般是图、树、矩阵等抽象、结构化的形式,这与深度学习模型高维、嘈杂且多变的输入很不兼容,因此还需要训练编码器f和解码器g,将抽象形式转换为自然形式。
NAR发布后,有多项研究证实了它有同时执行多种算法的能力,也能部署在各种下游任务中。更重要的是,它的泛化能力似乎远远优于Transformer架构。
原则上,NAR可以扩展到比训练数据的分布大几个数量级的系统上,有时这个数量级能达到1.8万倍。
在使用适当的归纳偏差(inductive biases)时,即使输入比训练集大6倍,NAR也能在高度复杂的算法任务中保持完美的泛化能力。
找到了Transformer和NAR这两种十分强大且各有所长的架构,下面最关键的问题就是如何进行相应的调整和修改,使这两个似乎完全不相容的模型真正实现沟通和Embedding交换。
TransNAR:用预训练NAR增强Transformer
如何实现NAR+Transformer的有效沟通?作者从多模态LLM中找到了灵感。
多模态LLM可以同时接收文本和图像两种模态的输入,TransNAR也是如此。一边是算法运行需要的图结构,一边是描述问题的自然语言。
作者的设想是,将预训练的NAR作为Transformer中编码的调制器(modulator),二者通过embedding沟通,同时借鉴VLM和Flamingo模型中所用的交叉注意算子,融合不同模态的信息。
TransNAR接受双重输入,包括文本形式的算法问题规范(T个token)及其对应的图表征(N个节点),并输出问题的文本答案。其中输入的图表征遵循算法推理基准CLRS-30的格式。
我们可以假设,编码完成后,文本输入存储在T ∈ R^(T×k)中,图输入存储在G ∈ R^(N×l)中。
TransNAR的前向传播过程如下:
首先,我们通过设置T^(0) = T和G^(0) = G来正确初始化输入。
接下来,为了计算第(t+1)步的表征,文本(token)表征被输入到Transformer的当前层:
其中,Qt,Kt ∈ Rk×d_k,Vt ∈ Rk×k分别是键、查询和值矩阵的变换,FFN是一个前馈神经网络。
以类似的方式,图表征被输入到NAR层,例如实现一个标准的max-MPNN:
其中,ψ,ϕ : Rk × Rk → Rk分别是可学习的消息函数和更新函数,max是逐元素最大值聚合。
需要注意的是,方程2仅简要提供了节点之间的成对交互——实际上,这里的NAR是一个Triplet-GMPNN,它还包含三元组交互和一个门控机制。
此外,还需注意,NAR的可学习部分没有时间步索引——每一步都应用相同的共享函数。这很好地契合了图算法计算的迭代和重复性质。
一旦两个流都准备好它们的表征Θt+1和Gt+1,图中的节点嵌入将对Transformer的token嵌入进行条件设置,从而产生Transformer流中TransNAR块的最终结果:
其中,Qt×,Kt× ∈ Rk×d_k, Vtx ∈ Rk×k分别是交叉注意力的键、查询和值变换。在结束这一层之前,对Gt+1不进行额外的变换。
这个过程会一直重复,直到最后的第Nl层,在这一层中,从TN_l读取最终的文本输出。
最终输出通过最后一层生成的预测头转换为token logits,并通过标准的下一个token预测来监督训练。
在开始TransNAR微调之前,首先预训练NAR,使其能够稳健地执行CLRS-30覆盖的三十个算法。这种方法已知可以在图空间中实现高达4倍输入规模的分布外泛化。
在微调过程中,NAR的参数通常保持冻结状态,因为额外的梯度会削弱模型的原有稳健性特性。同样的原因,图嵌入不会执行交叉注意力。
LLM本身可以在大规模数据集上进行预训练,以建立其一般语言先验,即使在开始时随机初始化LM,也能获得相同的实验结果。
实验设置
在实验中,作者展示了TransNAR为大语言模型架构中的分布外推理带来的显著优势。
Transformer架构和初始化
论文使用Chinchilla家族的一个decoder-only架构、6层的Transformer模型,首先在MassiveText上进行了预训练,参数量有70M,上下文大小为2048。
为了探究初始化设置的影响,作者设计了两个变体进行消融实验。
第一个变体中,Transformer权重用预训练的结果初始化,模拟微调场景;第二个变体则是完全随机的初始化。这两个模型分别被标记为「预训练」和「未训练」。
随机位置编码
之前DeepMind的一篇论文论证过,随机位置编码可以增强Transformer的长度泛化与推理稳健性。
论文地址:https://arxiv.org/abs/2305.16843
作者也提到,随机位置嵌入确实在基线模型和TransNAR上都带来了显著增益,因此本文中的所有实验也都使用随机位置嵌入。
预训练NAR
论文使用CLRS-30基准中的问题预训练了一个多任务、基于MPNN的NAR,输入问题规模最多达16个。
由于CLRS-30的标准图结构表达,这样训练出来的NAR有很强的分布外(OOD)泛化能力,有时在4倍大小的图上仍保持竞争力,这种丰富的知识表达正是文本模型可资利用的。
结合节点和边缘的跨注意力贡献
在上述的算法描述中,我们将NAR模型的图输入限于N个节点,但作者注意到了之前的研究曾尝试过,同时对图的节点和边生成隐变量表达,也许可以添加有用的互补信息。
于是实验中引入图中边的特征E(t) ∈ RN×N×k,并再次应用公式3让Θ(t)对E(t)进行交叉注意力。
作者也尝试其他方法,希望将E(t)和G(t)结合起来,比如拼接后加线性层组合、向量求和、2层MLP,或者用Gram-Schmidt过程使二者的贡献正交化,但这些都没有给原始方法带来提升。
数据集
训练数据使用CLRS-Text基准,即CLRS-30基准的文本版本,以确定性的方式直接从基于图的CLRS-30中派生,因此这两个数据集传达的是完全相同的信息。
表1展示了该数据集的几个样本,以及它们的输入大小和token数量。
由于语言模型上下文长度的限制,实验选择用规模为4、8、12的问题训练,并在规模为110、12、14的问题上评估。
值得注意的是,与当前的评估环境相比,CLRS-Text是对LM最具挑战性的长程推理任务之一——相比小学数学,复杂度显著提高。
CLRS-Text的挑战性主要源于它允许显式控制分布外泛化。然而,每个问题都有清晰的多项式时间解法,这意味当今典型LLM的参数量应该足以解决这些问题。
该数据集每种算法的每种输入规模包含一万个样本,总共240万个数据点,其中70%用于训练、30%用于验证。
训练细节
实验将batch大小设置为256训练了7个epoch,并使用Adam优化器,学习率为10-4。
如前所述,在所有Chinchilla Transformer的旋转位置编码(RoPE)之上应用随机位置编码,最大长度为8192,且训练期间保持NAR冻结。
评估指标
作者提出,合适的评估指标应该反映模型在特定样本上失败的原因,且需要度量型输出与正确答案的接近程度。因此,使用精确字符串匹配来计算模型准确性是绝对不可行的。
论文选择的性能指标包括以下三个:
1. 形状分数:一个二元指标,用于判断输出是否具有正确的形状。例如,在排序任务中,输出应与输入有完全相同的元素数量。或者,如果输出是一个矩阵,我们需要确保其形状与输入和任务一致。
2. 解析分数:一个二元指标,用于判断输出是否不含任何非法字符。例如,在对数字列表进行排序的任务中,输出不应包含任何字母。
3. CLRS分数:输出中与真实答案匹配的元素百分比,也常用于CLRS-30测试。形状分数为0时,CLRS分数也会自动置零。
这种多方面的指标设计能够捕捉到LLM在文本上进行推理任务的各种失败模式。
比如在某个问题规模上过度专门化训练(导致输出的形状不正确)、无法处理看不见的数字组合(导致解析错误),由于推理错误造成的答案不一致则由CLRS分数反映。
结果
实验结果显示,TransNAR整体上显著优于Transformer模型,在动态规划、几何、图、贪心算法、排序、字符串等任务上的OOD推理能力都有大幅提升。
并且在大多数单个算法上,无论是在分布内还是分布外都表现更佳。
特别值得注意的是,这种方法不仅增强了Transformer原有的OOD泛化能力,还激发了一些模型先前完全不具备的能力。
比如Graham扫描(graham_scan)、最长公子串长度(lcs_length)、强连通分量(scc)这些经典问题中,基线模型得分为零或接近零,但TransNAR却实现了突破。
分析形状分数可以进一步解释,为什么TransNAR表现如此出色。
首先,回顾一下,如果形状不匹配,CLRS得分必然为零。
从形状得分来看,将Transformer的输出建立在NAR嵌入基础上显著提高了答案中形状正确的比例——这表明TransNAR缓解了一种特定的LLM故障模式。
此外,通过对比「预训练」和「未训练」两种初始化方式的分数,可以看到模型较好的稳定性和可用性。在随机初始化时,也能训练到与微调相当的水准。
然而,在一些算法中,TransNAR仍未能超越基线,且在分布内和分布外都是如此。
这些算法包括二分搜索、寻找最大子数组、最小值和快速选择等,都涉及在输入列表中按照索引搜索特定元素。
这暗示了TransNAR的一种故障模式:模型无法泛化到训练数据中未见过的新索引边界。因此,使用索引提示或许是一条有前景的改进途径。
另一种可能的解释是,NAR最终计算出的隐藏状态难以在交叉注意力层以可泛化的方式被解码。如果原因在此,解决途径可以是增加交叉注意力的容量,或者采用渐进式解码。
此外,TransNAR在架构上有一个本质的局限性,就是必需一个能得出ground truth的模拟器或者数据标签,用于将输入的文本转换为图结构,再作为模型输入。
但是作者强调,TransNAR的概念对于未来研究是有借鉴意义的。可以考虑将这种混合架构的想法移植到单模态LLM,或者将TransNAR训练后获得的知识提炼出来注入到普通的Transformer中。
参考资料:https://arxiv.org/abs/2406.09308
来源:新智元