DeepMind開源2款基於JAX庫,針對神經網絡和強化學習,易用性更強

十三 發自 凹非寺
量子位 報道 | 公眾號 QbitAI

DeepMind今日發佈了HaikuRLax兩個庫,都是基於JAX。

JAX由谷歌提出,是TensorFlow的簡化庫。結合了針對線性代數的編譯器XLA,和自動區分本地 Python 和 Numpy 代碼的庫Autograd,在高性能的機器學習研究中使用。

而此次發佈的兩個庫,分別針對神經網絡強化學習,大幅簡化了JAX的使用。

Haiku是基於JAX的神經網絡庫,允許用戶使用熟悉的面向對象程序設計模型,可完全訪問 JAX 的純函數變換。

RLax是JAX頂層的庫,它提供了用於實現增強學習代理的有用構件。

有意思的是,Reddit網友驚奇的發現Haiku這個庫的名字,竟然不以“ax”結尾。

DeepMind開源2款基於JAX庫,針對神經網絡和強化學習,易用性更強

當然,也有網友對這兩個庫表示了肯定:

毫無疑問,對JAX起到了推動作用。

DeepMind開源2款基於JAX庫,針對神經網絡和強化學習,易用性更強

那麼,我們就來看下Haiku和RLex的廬山真面目吧。

Haiku

Haiku是JAX的神經網絡庫,它允許用戶使用熟悉的面向對象編程模型,同時允許完全訪問JAX的純函數轉換。

它提供了兩個核心工具:模塊抽象hk.Module,和一個簡單的函數轉換hk.transform。

hk.Module是Python對象,包含對其自身參數、其他模塊和對用戶輸入應用函數方法的引用。

hk.transform允許完全訪問JAX的純函數轉換。

其實,在JAX中有許多神經網絡庫,那麼Haiku有什麼特別之處呢?有5點。

1、Haiku已經由DeepMind的研究人員進行了大規模測試

DeepMind相對容易地在Haiku和JAX中複製了許多實驗。其中包括圖像和語言處理的大規模結果、生成模型和強化學習。

2、Haiku是一個庫,而不是一個框架

它的設計是為了簡化一些具體的事情,包括管理模型參數和其他模型狀態。可以與其他庫一起編寫,並與JAX的其他部分一起工作。

3、Haiku並不是另起爐灶

它建立在Sonnet的編程模型和API之上,Sonnet是DeepMind幾乎普遍採用的神經網絡庫。它保留了Sonnet用於狀態管理的基於模塊的編程模型,同時保留了對JAX函數轉換的訪問。

4、過渡到Haiku是比較容易的

通過精心的設計,從TensorFlow和Sonnet,過渡到JAX和Haiku是比較容易的。除了新的函數(如hk.transform),Haiku的目的是Sonnet 2的API。

5、Haiku簡化了JAX

它提供了一個處理隨機數的簡單模型。在轉換後的函數中,hk.next_rng_key()返回一個唯一的rng鍵。

那麼,該如何安裝Haiku呢?

Haiku是用純Python編寫的,但是通過JAX依賴於c++代碼。

首先,按照下方鏈接中的說明,安裝帶有相關加速器支持的JAX。https://github.com/google/jax#installation

然後,只需要一句簡單的pip命令就可以完成安裝。

<code>$ pip install git+https://github.com/deepmind/haiku/<code>

接下來,是一個神經網絡和損失函數的例子。

<code>import haiku as hkimport jax.numpy as jnpdef softmax_cross_entropy(logits, labels):  one_hot = hk.one_hot(labels, logits.shape[-1])  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)def loss_fn(images, labels):  model = hk.Sequential([      hk.Linear(1000),      jax.nn.relu,      hk.Linear(100),      jax.nn.relu,      hk.Linear(10),  ])  logits = model(images)  return jnp.mean(softmax_cross_entropy(logits, labels))loss_obj = hk.transform(loss_fn)/<code>

RLax

RLax是JAX頂層的庫,它提供了用於實現增強學習代理的有用構件。

它所提供的操作和函數不是完整的算法,而是強化學習特定數學操作的實現。

RLax的安裝也非常簡單,一個pip命令就可以搞定。

<code>pip install git+git://github.com/deepmind/rlax.git/<code>

使用JAX的jax.jit函數,所有的RLax代碼可以不同的硬件上編譯。

RLax需要注意的是它的命名規則。

許多函數在連續的時間步長中考慮策略、操作、獎勵和值,以便計算它們的輸出。在這種情況下,後綴_t和tm1通常是為了說明每個輸入是在哪個步驟上生成的,例如:

q_tm1:轉換的源狀態中的操作值。a_tm1:在源狀態下選擇的操作。r_t:在目標狀態下收集的結果獎勵。q_t:目標狀態下的操作值。

Haiku和RLax都已在GitHub上開源,有興趣的讀者可從“傳送門”的鏈接訪問。

傳送門

Haiku:https://github.com/deepmind/haiku

RLax:https://github.com/deepmind/rlax


— 完 —

量子位 QbitAI · 頭條號簽約

關注我們,第一時間獲知前沿科技動態


分享到:


相關文章: