pytorch基础用法-预测

数据准备

1
2
3
4
5
6
7
8
transform = transforms.Compose(
[transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

im = Image.open('1.jpg')
im = transform(im) # [C, H, W]
im = torch.unsqueeze(im, dim=0) # [N, C, H, W]

模型定义

1
2
# 不需要to(device)吗?mark
net = LeNet()

参数加载

1
2
# 优化器也有同样的操作,可用于恢复训练
net.load_state_dict(torch.load('Lenet.pth'))

前向计算

1
2
3
with torch.no_grad():
outputs = net(im)
predict = torch.max(outputs, dim=1)[1].data.numpy()

本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!