上QQ阅读APP看书,第一时间看更新
1.3.1 第一步:准备数据集
程序设计的第一步是准备数据,我们使用tensorf l ow_datasets自带的框架解决MNIST数据集下载的问题。打开WSL终端,输入如下命令:
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple tensor flow_datasets
注意
进度条读取完毕后还不能使用,PyCharm对于WSL的支持需要重新加载在WSL中的Python程序,这里只需要重启计算机即可。
MNIST数据集下载好了之后,只需要直接使用给定的代码完成MNIST数据集的载入即可。代码如下:
import tensor flow as tf import tensor flow_datasets as tfds x_train = jnp.load("mnist_train_x.npy") y_train = jnp.load("mnist_train_y.npy")
注意,由于MNIST给出的label是一个以当前图像值为结果的数据,需要转换成one_hot格式,代码如下:
def one_hot_nojit(x, k=10, dtype=jnp. float32): """ Create a one-hot encoding of x of size k. """ return jnp.array(x[:, None] == jnp.arange(k), dtype)