出處丨AI前線
什麼?TensorFlow 有了替代品?什麼?竟然還是谷歌自己做出來的?先別慌,從各種意義上來說,這個所謂的“替代品”其實是 TensorFlow 的一個簡化庫,名為 JAX,結合 Autograd 和 XLA,可以支持部分 TensorFlow 的功能,但是比 TensorFlow 更加簡潔易用。雖然還不至於替代 TensorFlow,但已經有 Reddit 網友對 JAX 寄予厚望,並表示“早就期待能有一個可以直接調用 Numpy API 接口的庫了!”,“希望它可以取代 TensorFlow!”。JAX 結合了 Autograd 和 XLA,是專為高性能機器學習研究打造的產品。
有了新版本的 Autograd,JAX 能夠自動對 Python 和 NumPy 的自帶函數求導,支持循環、分支、遞歸、閉包函數求導,而且可以求三階導數。它支持自動模式反向求導(也就是反向傳播)和正向求導,且二者可以任意組合成任何順序。
JAX 的創新之處在於,它基於 XLA 在 GPU 和 TPU 上編譯和運行 NumPy 程序。默認情況下,編譯是在底層進行的,庫調用能夠及時編譯和執行。但是 JAX 還允許使用單一函數 API jit將自己的 Python 函數及時編譯成經過 XLA 優化的內核。編譯和自動求導可以任意組合,因此可以在不脫離 Python 環境的情況下實現複雜算法並獲得最優性能。
JAX 最初由 Matt Johnson、Roy Frostig、Dougal Maclaurin 和 Chris Leary 發起,他們均任職於谷歌大腦團隊。在 GitHub 的說明文檔中,作者明確表示:JAX 目前還只是一個研究項目,不是谷歌的官方產品,因此可能會有一些 bug。從作者的 GitHub 簡介來看,這應該是谷歌大腦正在嘗試的新項目,在同一個 GitHub 目錄下的開源項目還包括 8 月份在業內引起熱議的強化學習框架 Dopamine。
以下是 JAX 的簡單使用示例。
GitHub 項目傳送門:https://github.com/google/JAX
有關具體的安裝和簡單的入門指導大家可以在 GitHub 中自行查看,在此不做過多贅述。
JAX 庫的實現原理
機器學習中的編程是關於函數的表達和轉換。轉換包括自動微分、加速器編譯和自動批處理。像 Python 這樣的高級語言非常適合表達函數,但是通常使用者只能應用它們。我們無法訪問它們的內部結構,因此無法執行轉換。
JAX 可以用於專門化高級 Python+NumPy 函數,並將其轉換為可轉換的表示形式,然後再提升為 Python 函數。
JAX 通過跟蹤專門處理 Python 函數。跟蹤一個函數意味著:監視應用於其輸入,以產生其輸出的所有基本操作,並在有向無環圖 (DAG) 中記錄這些操作及其之間的數據流。為了執行跟蹤,JAX 包裝了基本的操作,就像基本的數字內核一樣,這樣一來,當調用它們時,它們就會將自己添加到執行的操作列表以及輸入和輸出中。為了跟蹤這些原語之間的數據流,跟蹤的值被包裝在 Tracer 類的實例中。
當 Python 函數被提供給 grad 或 jit 時,它被包裝起來以便跟蹤並返回。當調用包裝的函數時,我們將提供的具體參數抽象到 AbstractValue 類的實例中,將它們框起來用於跟蹤跟蹤器類的實例,並對它們調用函數。
抽象參數表示一組可能的值,而不是特定的值:例如,jit 將 ndarray 參數抽象為抽象值,這些值表示具有相同形狀和數據類型的所有 ndarray。相反,grad 抽象 ndarray 參數來表示底層值的無窮小鄰域。通過在這些抽象值上跟蹤 Python 函數,我們確保它足夠專門化,以便轉換是可處理的,並且它仍然足夠通用,以便轉換後的結果是有用的,並且可能是可重用的。然後將這些轉換後的函數提升回 Python 可調用函數,這樣就可以根據需要跟蹤並再次轉換它們。
JAX 跟蹤的基本函數大多與 XLA HLO 1:1 對應,並在 lax.py 中定義。這種 1:1 的對應關係使得到 XLA 的大多數轉換基本上都很簡單,並且確保我們只有一小組原語來覆蓋其他轉換,比如自動微分。 jax.numpy 層是用純 Python 編寫的,它只是用 LAX 函數 (以及我們已經編寫的其他 numpy 函數) 表示 numpy 函數。這使得 jax.numpy 易於延展。
當你使用 jax.numpy 時,底層 LAX 原語是在後臺進行 jit 編譯的,允許你在加速器上執行每個原語操作的同時編寫不受限制的 Python+ numpy 代碼。
但是 JAX 可以做更多的事情:你可以在越來越大的函數上使用 jit 來進行端到端編譯和優化,而不僅僅是編譯和調度到一組固定的單個原語。例如,可以編譯整個網絡,或者編譯整個梯度計算和優化器更新步驟,而不僅僅是編譯和調度卷積運算。
折衷之處是,jit 函數必須滿足一些額外的專門化需求:因為我們希望編譯專門針對形狀和數據類型的跟蹤,但不是專門針對具體值的跟蹤,所以 jit 裝飾器下的 Python 代碼必須適用於抽象值。如果我們嘗試在一個抽象的 x 上求 x >0 的值,結果是一個抽象的值,表示集合{True, False},所以 Python 分支就像 if x > 0 會引起報錯。
有關使用 jit 的更多要求,請參見:https://github.com/google/jax#whats-supported
好消息是,jit 是可選的:JAX 庫在後臺對單個操作和函數使用 jit,允許編寫不受限制的 Python+Numpy,同時仍然使用硬件加速器。但是,當你希望最大化性能時,通常可以在自己的代碼中使用 jit 編譯和端到端優化更大的函數。
後續計劃
目前項目小組還將對以下幾項做更多嘗試和更新:
- 完善說明文檔
- 支持 Cloud TPU
- 支持多 GPU 和多 TPU
- 支持完整的 NumPy 功能和部分 SciPy 功能
- 全面支持 vmap
- 加速
- 降低 XLA 函數調度開銷
- 線性代數例程(CPU 上的 MKL 和 GPU 上的 MAGMA)
- 高效自動微分原語cond和while
有關 JAX 庫的介紹大致如此,如果你在嘗試了 JAX 之後有一些較好的使用心得,歡迎隨時向我們投稿,AI 前線十分願意將你的經驗傳播給更多開發者。
再次附上 GitHub 鏈接:https://github.com/google/jax
相關資源:
JAX 論文鏈接:https://www.sysml.cc/doc/146.pdf
會議推薦
12 月 20-21,AICon將於北京開幕,在這裡可以學習來自 Google、微軟、BAT、360、京東、美團等 40+AI 落地案例,與國內外一線技術大咖面對面交流。
閱讀更多 InfoQ 的文章