PyTorch入门学习:6-Logistic Regression

6 Logistic Regression

6.1 Introduction

6.1.1 Classification - The MNIST Dataset

MNIST是一个手写数字数据集,包含60000个训练样本和10000个测试样本。每个样本是28x28的灰度图像,标签是0-9的数字。 在这个模型中,输出值是一个10维的向量,每个元素表示对应数字的概率。

  • torchvision包含MNIST数据集的下载和处理函数
1
2
3
4
import torchvision

train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True)
test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True)

6.1.2 Classification - The CIFAR-10 Dataset

CIFAR-10是一个包含60000个32x32的彩色图像的数据集,分为10个类别。每个类别有6000个图像。

6.1.3 Regression vs Classification

现在将回归问题转换为分类问题 在之前的例子中,我们使用的是回归模型,输入的是学习时长,输出的是学习时长对应的分数。而现在我们将这个问题转换为分类问题,输入的是学习时长,输出的是是否通过考试。

6.2 How to map: R -> [0, 1]

6.2.1 Logistic Function

在分类问题中,我们需要将输出值(实数)映射到[0, 1]的区间,以便于表示概率。

  • Logistic Function:
  • Why Logistic Function?

一些激活函数的图象

  • Logistic函数是最具代表性的sigmoid函数,因此sigmoid函数通常也可代指Logistic函数。

6.2.2 Logistic Regression Model

添加sigmoid函数将实数映射为概率 如图,我们的线性模型被改为常常指Logistic函数。

6.2.3 Loss function for Binary Classification

  • loss function for Linear Regression:
  • loss function for Binary Classification: 计算分布的差异:cross entropy 交叉熵
  • A simple explanation:

使用趋势和值域简单解释交叉熵

6.2.4 Implementation of Logistic Regression

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch.nn.functional as F
class LogisticRegression(torch.nn.Module):
def __init__(self):
super(LogisticRegression, self).__init__()
self.linear = torch.nn.Linear(1, 1)

def forward(self, x):
y_pred = F.sigmoid(self.linear(x))
return y_pred

criterion = torch.nn.BCELoss(size_average=False)
# 是否平均,会影响学习率的设置
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(100):
y_pred = model(x_train)
loss = criterion(y_pred, y_data)

print(epoch, loss.item())

optimizer.zero_grad()
loss.backward()
optimizer.step()

# test
import numpy as np
import matplotlib.pyplot as plt

x=np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200, 1))
y_t = model(x_t)
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], c='r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()

PyTorch入门学习:6-Logistic Regression
https://eleco.top/2026/02/24/learn-torch-6-Logistic-Regression/
作者
Eleco
发布于
2026年2月24日
许可协议