绑定完请刷新页面
取消
刷新

分享好友

×
取消 复制
短教学:如何在PyTorch中编写逻辑回归模型?
2019-08-14 09:58:04

下图显示了逻辑和线性回归之间的差异。

在这篇文章中,我将展示如何在PyTorch中编写逻辑回归模型。

我们将尝试解决MNIST数据集的分类问题。首先,让我们导入所需的所有库。

#Step 1.加载数据集

#Step 2.使数据集可以迭代

#Step 3.创建模型类

#Step 4.实例化模型类

#Step 5.实例化损失类

#Step 6.实例化优化器类

#Step 7.训练模型

让我们一个接一个地完成这些步骤。

加载数据集

为了加载数据集,我们使用了torchvision.datasets,这个库几乎包含了机器学习中使用的所有流行数据集。您可以查看完整的数据集列表(https://pytorch.org/docs/stable/torchvision/datasets.html)

使数据集可重复使用

我们将使用DataLoader类使用以下代码使数据集可迭代。

创建模型类

现在,我们将创建一个定义逻辑回归体系结构的类。

实例化模型类

在实例化之前,我们将初始化一些参数,如下所示。

现在,我们初始化逻辑回归模型。

model = LogisticRegression(input_dim, output_dim)

实例化损失类

我们使用交叉熵来计算损失。

criterion = torch.nn.CrossEntropyLoss() # computes softmax and then the cross entropy

初始化优化器类

优化器将是我们使用的学习算法。在这种情况下,我们将使用随机梯度下降。

optimizer = torch.optim.SGD(model.parameters(), lr=lr_rate)

训练模型

现在,在后一步中,我们将使用以下代码训练模型。

iter = 0

for epoch in range(int(epochs)):

for i, (images, labels) in enumerate(train_loader):

images = Variable(images.view(-1, 28 * 28))

labels = Variable(labels)

optimizer.zero_grad()

outputs = model(images)

loss = criterion(outputs, labels)

loss.backward()

optimizer.step()

iter+=1

if iter%500==0:

# calculate Accuracy

correct = 0

total = 0

for images, labels in test_loader:

images = Variable(images.view(-1, 28*28))

outputs = model(images)

_, predicted = torch.max(outputs.data, 1)

total+= labels.size(0)

# for gpu, bring the predicted and labels back to cpu fro python operations to work

correct+= (predicted == labels).sum()

accuracy = 100 * correct/total

print("Iteration: {}. Loss: {}. Accuracy: {}.".format(iter, loss.item(), accuracy))

经过训练,这个模型只需3000次迭代,精度达到82%。您可以继续调整一下参数,看看准确度是否增加。

在PyTorch中更好地理解逻辑回归模型的一个很好的练习是将它应用于您能想到的任何分类问题。例如,你可以训练一个逻辑回归模型来对你喜欢的漫威超级英雄的图像进行分类(应该不会很难,因为其中有一半已经消失)。

分享好友

分享这个小栈给你的朋友们,一起进步吧。

通俗易懂--机器学习
创建时间:2019-08-02 11:00:07
这里汇集了机器学习、NLP面试中常考到的知识点和代码实现,也是作为一个算法工程师必会的理论基础知识。 以各个模块为切入点,让大家有一个清晰的知识体系。 亦可拿来常读、常记以及面试时复习之用。 每一章里的问题都是面试时有可能问到的知识点,如有遗漏可联系我进行补充,结尾处都有算法的实战代码案例。
展开
订阅须知

• 所有用户可根据关注领域订阅专区或所有专区

• 付费订阅:虚拟交易,一经交易不退款;若特殊情况,可3日内客服咨询

• 专区发布评论属默认订阅所评论专区(除付费小栈外)

栈主、嘉宾

查看更多
  • mantch
    栈主

小栈成员

查看更多
  • 栈栈
  • Jack2k
  • hwayw
  • 天上飘下来的人
戳我,来吐槽~