谷歌JAX深度学习从零开始学
上QQ阅读APP看书,第一时间看更新

1.2.5 JAX的Python代码小练习:计算SeLU函数

对于科学计算来说,最简单的想法就是可以将数学公式直接表达成程序语言,可以说,Python满足了这个想法。本小节将使用Python实现和计算一个深度学习中最为常见的函数—SeLU激活函数。至于这个函数的作用,现在不加以说明,这里只是带领读者尝试实现其程序的编写。

首先SeLU激活函数计算公式如下所示:

α=α×(ex-1)×θ α=1.67

θ=1.05

e是自然常数

其中α和θ是预定义的参数,e是自然常数,以上3个数在这里直接使用即可。SeLU激活函数的图形如图1.20所示。

SeLU激活函数的代码如下所示。

图1.20 SeLU激活函数图形

【程序1-1】

import jax.numpy as jnp                            #导入NumPy计算包
from jax import random                             #导入random随机数包
#完成的seLU函数
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
key = random.PRNGKey(17)                           #产生了一个固定数17作为key
x = random.normal(key, (5,))                       #随机生成一个大小为[1,5]的矩阵
print(selu(x))                                     #打印结果

可以看到,当传入一个随机数列后,分别计算每个数值所对应的函数值,结果如下: