Gemma 4 - Drafter 解析
為了提升 Gemma 4 模型的推論速度,官方在發布主系列模型的同時,也推出了一系列自動回歸的「drafter」模型。這些 draft 模型不再單純依賴 Gemma 4 主模型(稱為「目標 (target)」模型),而是能在目標模型處理一個 token 的時間內,預測出多個 token。這項技術也被稱為推測解碼 (speculative decoding)。
在 drafter 預測出多個 draft token 後,目標模型現在只需要驗證這些建議的 token 即可。驗證過程是並行執行的,因此能大幅提升推論速度,並減少目標模型針對每個 token 所需執行的 forward pass 次數。由於我們的 drafter 會生成一串 token 供驗證,我們稱其為 Multi-Token Prediction (MTP) head。
針對 Gemma 4 系列發布的 draft 模型體積輕量,並引入了多項增強功能以提升 draft token 的品質並進一步加快推論速度,例如利用目標模型的啟用 (activations) 和 KV-cache 來獲得更精準的預測。
這些增強功能在保證品質相當的前提下,帶來了顯著的解碼加速,使得這些檢查點 (checkpoints) 非常適合低延遲和行動裝置端的應用。
這裡有很多細節值得探討,讓我們深入了解推測解碼、MTP 以及這些 drafter!
什麼是推測解碼 (Speculative Decoding)?
Gemma 4 模型以自動回歸方式生成文字,一次產生一個 token。無論預測特定 token 的難度如何,每個 token 所需的運算量大致相同。因此,當 token 非常容易預測時,這可能是一個不必要的緩慢過程。
想像一下,大型模型正在生成文字,並且已經產生了「Actions speak」。對於那些認出這句話開頭的人來說,這是一句常見的英文諺語,完整句子是「Actions speak louder than words.」。由於這句話很常見,較小的模型很有可能生成與大型模型完全相同的補全內容(即「louder than words」)。因此,讓大型模型一次一個 token 地預測「louder than words」簡直是浪費時間和運算資源。
透過推測解碼,我們可以使用較小的 draft 模型提前預測多個 token。draft 模型會接收相同的輸入「Actions speak」,並同樣以自動回歸方式預測多個 token,假設是四個 token。由於 draft 模型的大小僅為大型模型的一小部分,這些 draft token 的生成速度會比大型模型快得多。
什麼是 Multi-Token Prediction (MTP)?
然而,draft token 不一定正確,否則我們直接使用較小的模型即可。相反地,這些 token 會被傳遞給目標模型進行並行驗證。由於目標模型可以在一次 forward pass 中完成此操作,因此它不必為每個 token 都執行一次 forward pass。我們所說的 drafter 就是 Multi-Token Prediction (MTP) head。目標模型的每次 forward pass 都會執行常規的 next-token prediction (NTP) 並產生中間隱藏狀態 (hidden states)。drafter (MTP Head) 會使用這些隱藏狀態並執行多次自動回歸的 forward pass 來生成多個 token。因此,目標模型的一次 forward pass 會產生多個 token,而不是一個。其中一個來自目標模型的 next-token prediction,其餘多個則來自 drafter (MTP head)。
如果目標模型同意 draft 模型的建議,那麼所有 token 都會被接受。較小的模型在極短的時間內就完成了原本需要生成四個 token 的工作。目標模型只需要花費生成一個 token 的時間來驗證它們。此外,如果所有 draft token 都被接受,目標模型本身仍會額外生成一個 token。
如果目標模型僅不同意部分 draft token,它會接受直到出現分歧為止,隨後目標模型會用自己的 token 取代被拒絕的 token。
考慮到模型可以一次性驗證所有 draft token 的品質,而不必逐一驗證,這個過程實際上非常快。由於 draft 模型非常小,與目標模型相比,預測單個 token 所需的時間要少得多。這意味著目標模型可以在幾乎與生成單個 token 相同的時間內驗證多個 token!請注意,draft 模型像大多數語言模型一樣,是順序生成這些 token 的,但由於其體積小,速度快得多。
目標模型認為足夠好的所有 token 都會被選中。第一個被拒絕的 token 以及隨後的所有 token 都不會被包含在內,並被丟棄。然而,由於目標模型已經執行了一次 forward pass,它仍然可以執行 next token prediction。因此,即使像「pens」這樣的 token 被拒絕,目標模型仍然會提供該被拒絕 token 的替代方案。
結果就是,目標模型可能會選中任意數量的 draft token。考慮到 draft 模型以自動回歸方式執行處理並逐序列生成 token,而目標模型可以並行驗證所有 draft token,整個過程的可視化非常有趣。目標模型仍然是自動回歸的,但現在它不必逐一生成那些 draft token,而是可以一次性驗證它們。
Gemma 4 的 MTP
為 Gemma 4 系列發布的 draft 模型與稠密 (dense) Gemma 4 模型最為相似,但體積小得多。事實上,Gemma 4 E2B 的 draft 模型僅擁有約 76M 個參數、四個層,以及較小的輸入 embedding 大小(256,相較於主模型的 1536)。
請注意 decoder 本身與稠密 Gemma 4 模型相似。然而,在 decoder 之前和之後發生了很多事情!
這些 draft 模型具備多項增強功能,專門用於提高效率並進一步加快推論速度。同樣地,也有一些有趣的技術被用來提升 draft token 的品質並降低 drafter 的延遲。畢竟,我們希望 draft token 盡可能準確,且生成速度盡可能快。
這些變更可總結如下:
目標啟用 (Target Activations):draft 模型使用目標模型最後一層的啟用,將其與 token embedding 連接起來,並向下投影 (down-project) 到 drafter 模型的維度。
KV Cache 共享:draft 模型會 cross-attend 到目標模型的 KV cache,而不是建立自己的 cache。
高效 Embedder:LM Head 執行一種稀疏解碼技術,用以識別最有可能預測的 token 叢集(僅限 E2B 和 E4B)。
讓我們更詳細地探討其中每一項!
目標啟用 (Target Activations)
為了提升 draft 模型生成 token 的品質,目標模型(例如 E2B)的最終啟用會被輸入到 draft 模型中。這些啟用會與 draft 模型的 token embedding 連接,假設是 E2B 模型,兩者皆有 1,536 個值。連接後的 embedding 非常大,為了效率考量,它們被投影縮減至僅 256 個值。這本質上是對大型 draft 模型處理後的狀態與 draft 模型新初始化的 embedding 進行壓縮。為什麼要浪費之前的啟用呢?
這些啟用僅在處理的第一輪提供給 draft 模型。請記住,在 draft 開始初始輪次後,它可能會透過將 token 回傳給自身進行自動回歸來產生多個 token。因此,在第 2 輪中,由於產生了新的 token,會使用 draft 模型在第 1 輪中生成的啟用。由於小型 draft 模型的中間啟用僅有 256 個值,它們會被向上投影以匹配其輸入 embedding 表的維度(即 1,536 個值)。請注意,為了進一步提高效率,輸入 embedding 表是在目標模型和 draft 模型之間共享的。
然後,在第 3 輪中,draft 模型使用第 2 輪生成的啟用,以此類推!
KV Cache 共享
KV cache 可能會佔用相當大的空間,因為它包含了序列中每個 token 在每一層的 key 和 value 表示。儘管 Gemma 4 已經採取了許多措施來減少這種佔用(例如在全域注意力層中設定 K=V),但 draft 模型更進一步。
draft 模型不需要處理完整的 prompt 並建立自己的 KV cache,而是 cross-attend 到目標模型已經計算好的 KV cache。對於其局部注意力層,draft 模型直接重複使用目標模型最後計算出的局部 KV cache。由於任何 Gemma 4 模型的最後一層總是全域的,該全域 KV cache 會被重複用於 draft 模型的全域注意力層。
如前所述,既然目標模型已經完成了大部分繁重的工作,浪費掉這…