PyTorch入门学习:6-Logistic Regression
6 Logistic Regression
6.1 Introduction
6.1.1 Classification - The MNIST Dataset
MNIST是一个手写数字数据集,包含60000个训练样本和10000个测试样本。每个样本是28x28的灰度图像,标签是0-9的数字。 在这个模型中,输出值是一个10维的向量,每个元素表示对应数字的概率。
- torchvision包含MNIST数据集的下载和处理函数
1 | |
6.1.2 Classification - The CIFAR-10 Dataset
CIFAR-10是一个包含60000个32x32的彩色图像的数据集,分为10个类别。每个类别有6000个图像。
6.1.3 Regression vs Classification
在之前的例子中,我们使用的是回归模型,输入的是学习时长,输出的是学习时长对应的分数。而现在我们将这个问题转换为分类问题,输入的是学习时长,输出的是是否通过考试。
6.2 How to map: R -> [0, 1]
6.2.1 Logistic Function
在分类问题中,我们需要将输出值(实数)映射到[0, 1]的区间,以便于表示概率。
- Logistic Function:
- Why Logistic Function?

- Logistic函数是最具代表性的sigmoid函数,因此sigmoid函数通常也可代指Logistic函数。
6.2.2 Logistic Regression Model
如图,我们的线性模型被改为
6.2.3 Loss function for Binary Classification
- loss function for Linear Regression:
- loss function for Binary Classification:
计算分布的差异:cross entropy 交叉熵 - A simple explanation:

6.2.4 Implementation of Logistic Regression
1 | |
PyTorch入门学习:6-Logistic Regression
https://eleco.top/2026/02/24/learn-torch-6-Logistic-Regression/