Hello MNIST
0. 前言
MNIST是一个被很多人玩过的数据集,就好比是深度学习领域的Hello World。Kaggle上有一堆Notebook,抄过来跑一跑,就能快速入门。
对于调包侠来说,很多时候不用去知道NN/DNN/CNN里面的具体实现。再加上这几年LLM大火,影响之一是如今的热点潮流变成了LLM而不是逐渐趋于传统的NN;影响之二是我们更加不需要再去花大量时间从零实现一个神经网络,甚至AI都能随手写出一个基于pytorch的MNIST识别器。
因此,本文只是留作记录,我不去分析NN的基础原理,而只是贴几个我曾经用来入门的Notebook,在这个AI时代留下一点印记。
1. 手撸的MNIST,从下载到从零推矩阵 [1]
先把MNIST数据集下载下来,pytorch, tensorflow都有自带的函数可以加载MNIST,第一次会先下载到本地,之后会直接从保存的地方读取。
train_dataset = torchvision.datasets.MNIST(
root='../data',
train=True,
download=True,
transform=None
)
test_dataset = torchvision.datasets.MNIST(
root='../data',
train=False,
download=True,
transform=None
)
之后这位大神[1]不靠任何框架,从零开始,加载数据集,写softmax函数,定义连接层,算loss,训练模型。跟着做一遍基本原理也就懂了。
2. 从DNN到CNN [2] [3]
精髓在这里,在pytorch框架下定义网络里各层的连接。这是DNN里最基础、最经典的一种,每一层的每一个神经元都与下一层的每一个神经元相连接,称为多层感知机(Multilayer Perceptron, MLP),这个例子里除了输入层和输出层以外还有两个隐藏层。每一层都是nn.Linear()
,是线性的。
import torch.nn as nn
import torch.nn.functional as F
# define the NN architecture
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# number of hidden nodes in each layer (512)
hidden_1 = 512
hidden_2 = 512
# linear layer (784 -> hidden_1)
self.fc1 = nn.Linear(28 * 28, hidden_1)
# linear layer (n_hidden -> hidden_2)
self.fc2 = nn.Linear(hidden_1, hidden_2)
# linear layer (n_hidden -> 10)
self.fc3 = nn.Linear(hidden_2, 10)
# dropout layer (p=0.2)
# dropout prevents overfitting of data
self.dropout = nn.Dropout(0.2)
def forward(self, x):
# flatten image input
x = x.view(-1, 28 * 28)
# add hidden layer, with relu activation function
x = F.relu(self.fc1(x))
# add dropout layer
x = self.dropout(x)
# add hidden layer, with relu activation function
x = F.relu(self.fc2(x))
# add dropout layer
x = self.dropout(x)
# add output layer
x = self.fc3(x)
return x
# initialize the NN
model = Net()
print(model)
而这是CNN,区别只在于定义层连接的时候,nn.Linear()
变成了nn.Conv2d()
,即线性层变成了卷积层。
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
self.conv3 = nn.Conv2d(32,64, kernel_size=5)
self.fc1 = nn.Linear(3*3*64, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
#x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(F.max_pool2d(self.conv3(x),2))
x = F.dropout(x, p=0.5, training=self.training)
x = x.view(-1,3*3*64 )
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
cnn = CNN()
print(cnn)
然而,最简单的方法其实是对Amazon Q说:“帮我写一个基于pytorch的可以识别MNIST数据集的CNN”,或者“我现在有一个DNN,帮我把它改成CNN”。