編輯:peter東 喬楊
【新智元導讀】 大模型如今已具有越來越長的上下文,而與之相伴的是推理成本的上升。輝達 最新提出的Star Attention,能夠在不損失精度的同時,顯著減少推理計算量,從而助力邊緣計算。
當下的手機及 AIPC 中都會安裝本地大模型,然而上下文長度增加,推理時的計算成本也會顯著增長。最明顯的一個後果就是,使用者輸入問題後需要等待很久才能看到結果。
為此,已有多種最佳化方案提出,例如 Flash Attention ,而11月26日輝達提出的Star Attention機制,可用於提升Transformer模型在處理長序列時的效率和準確性。
值得一提的是,這篇文章受到了廣泛的關註,登頂H ug ging F ace每日論文榜首。
論文地址:https://arxiv.org/abs/2411.17116
Star Attention如何降低推理成本
在了解Star Attention如何改進大模型推理前,讓我們先看看當前大模型的推理過程涉及的兩個步驟:
1)prompt編碼,即模型處理輸入並在緩存中儲存KV(鍵值)向量;
2)token生成,即模型關註KV緩存並自回歸生成新令牌,同時用新 的KV向量更新緩存。
在許多長上下文任務中,輸入由一個長上下文後跟一個短查詢和一個短答案組成。當大模型的上下文變得越來越長之後,回答查詢所需的資訊通常局限在上下文的小部份內,意味著上下文只需關註附近的token,而查詢token需要關註所有之前上下文涉及的內容。
Star Attention下的兩階段推理
系統中所有器材被分組為多個主機(host),其中一個主機被標記 為「查詢」主機。輸入序列分為兩個階段處理。
階段一:上下文編碼
輸入的上下文部份被分割成較小的塊,並分配到各個主機。除了第一個塊之外,所有塊的前面都加上一個初始塊,稱為「 錨點 」塊(anchor block)。每個主機處理其分配的塊,並儲存非錨點部份的KV緩存。
階段二:查詢編碼和token生成
輸入查詢被廣播到所有主機,在每個主機中,它首先存取在第一階段計算出的本地KV緩存。然後「 查詢 」主機透過聚合所有主機的softmax歸一化統計數據來計算全域註意力。這個過程對於每個生成的token都會重復。
用一個不那麽嚴謹的例子來概述上面的過程:想象一場烹飪比賽(上下文token),每個廚師(主機)負責準備一道菜的一部份(塊)。
為了確保味道一致,每個廚師除了準備自己的部份,還在前面加了一點「 錨點 」調料(錨點塊)。每個廚師準備好自己的部份後,記住自己部份的口味(KV緩存)。
階段二的查詢編碼和token生成可視為:評委(查詢token)來品嘗菜肴,並決定下一道菜的口味(生成新token)。評委先品嘗每個廚師的部份,看看哪個部份最符合他們的口味。
最後,評委匯總所有廚師的意見,確定下一道菜的口味,並告訴廚師們。
Star Attention的效能提升
Star Attention帶來的效能提升,主要體現在以下兩個方面:
1)高達11倍的加速
在多個長上下文基準測試上,Star Attention所加持的8B Llama3的推理速度顯著提升,隨著序列長度增加, 加速比 從1.1x提升到2.7x。
而在參數量更大的Llama3.1-70B上,推理的加速比提升更為顯著。
與此同時,對比采用全域註意力的基準, S tar Atte ntion相對準確率的降低只在 0~3%範圍內。
隨著上下文長度的增加,star attention推理的準確性相比全域註意力幾乎相同,但推理計算成本顯著下降
在更長的上下文尺度(128K)中,上下文編碼過程中不同塊的大小,也會影響推理的準確性和速度。塊尺寸越大,Star Attention 的準確性越高。
在 RULER 基準測試上,不同塊大小對Star Attention準確性的影響,塊大小範圍從4K到32K,適用於序列長度為128K的Llama-3.1-8B instruct 模型
用於評估的RULER,包含了13個任務,分為4個領域:大海撈針 (檢索)、多跳追蹤、聚合和問答,
不同任務中,全域註意力和 Star Attention 的準確性差異對比
而在上下文長度更大,達到1048K時,Star Attention的推理準確性依舊保持在原基準90%,推理加速比達到了10.8 × ~16.9 × 。
而在更大的Llama3.1-70B中,Star Attention能實作更大的加速比,同時保持相似水平的準確率下降。
由於其執行機制 不涉及具體模型,Star Attention可以無縫整合到大多數透過全域註意力訓練的基於Transformer的LLMs中,無需額外的模型微調。
由於減少了推理的計算成本,Star Attention顯著減少了記憶體需求,使得在本地器材(如手機,筆記本中)用LLM處理更長的序列成為可能。
實驗發現,將塊大小設定為總序列長度的約四分之一,可以在精度和速度之間取得最佳平衡。而使用者也可以根據需求調整塊大小,以在計算效率和精度之間進行權衡。
結論
未來的研究,會嘗試將Star Attention擴充套件到更長的序列(最長可達1M)和更大的模型,並希望能觀察到甚至更的加速,同時保持相似水平的準確率。同時專註於最佳化「 錨塊 」機制,並在更復雜的長上下文任務上提高效能,以增強Star Attention的可延伸性和穩健性。
總的來看,對於想要開發部署本地大模型的廠商,Star Attention是一項不容錯過的技術。使用Star Attention後,本地LLM能夠更快地回復使用者,還可在有限的記憶體中相容更長的上下文序列,從而在 RAG 任務中閱讀更長的文本。
而對於雲端大模型的提供商,Star Attention能夠在幾乎不影響使用者體現的前提下,顯著提升推理成本,實作「 降本增效 」,同時減少能源消費(碳足跡)。
透過在多個主機間分配上下文處理,Star Attention使上下文長度能夠隨主機數量線性擴充套件。
參考資料:
https://arxiv.org/abs/2411.17116