莫烦pytorch-CNN分类

  1. 导入一些必要的库
import torch
import torch.utils.dada as Data
import torch.nn as nn
import torchvision

2.准备数据,这次需要MNIST手写数字数据集,利用torchvision来获得数据集。,torchvisino除了MNIST,还有ciffar10等等许多数据集。有了torchvison模块,很方便对数据进行下载,处理.

train_data=
data_loader=

3.定义网络

class Net():
	def __init__(self,n_input):
		super(Net,self).__init__()
		self.conv1=nn.Sequecial(
			)
	def forward():
		

4.训练

loss_func=
optimizer=torch.optim.ENTROPYloss()

for epoch in EPOCH:
	for  step,(b_x,b_y) in enumerate (data_loader):
		output =net(b_x)
		loss=loss_func(output,b_y)
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
		
		

猜你喜欢

转载自blog.csdn.net/eefresher/article/details/88561592
今日推荐