谷歌JAX深度学习从零开始学
上QQ阅读APP看书,第一时间看更新

1.3.1 第一步:准备数据集

程序设计的第一步是准备数据,我们使用tensorf l ow_datasets自带的框架解决MNIST数据集下载的问题。打开WSL终端,输入如下命令:

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple tensor flow_datasets

注意

进度条读取完毕后还不能使用,PyCharm对于WSL的支持需要重新加载在WSL中的Python程序,这里只需要重启计算机即可。

MNIST数据集下载好了之后,只需要直接使用给定的代码完成MNIST数据集的载入即可。代码如下:

import tensor flow as tf
import tensor flow_datasets as tfds
x_train = jnp.load("mnist_train_x.npy")
y_train = jnp.load("mnist_train_y.npy")

注意,由于MNIST给出的label是一个以当前图像值为结果的数据,需要转换成one_hot格式,代码如下:

def one_hot_nojit(x, k=10, dtype=jnp. float32):
    """ Create a one-hot encoding of x of size k. """
    return jnp.array(x[:, None] == jnp.arange(k), dtype)