pytorch基础用法-预测
数据准备
1 2 3 4 5 6 7 8
| transform = transforms.Compose( [transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
im = Image.open('1.jpg') im = transform(im) im = torch.unsqueeze(im, dim=0)
|
模型定义
参数加载
1 2
| net.load_state_dict(torch.load('Lenet.pth'))
|
前向计算
1 2 3
| with torch.no_grad(): outputs = net(im) predict = torch.max(outputs, dim=1)[1].data.numpy()
|