世界即时看!pytorch使用nn.Moudle实现逻辑回归
【资料图】
本文实例为大家分享了pytorch使用nn.Moudle实现逻辑回归的具体代码,供大家参考,具体内容如下
内容
pytorch使用nn.Moudle实现逻辑回归
问题
loss下降不明显
解决方法
#源代码 out的数据接收方式 if torch.cuda.is_available(): x_data=Variable(x).cuda() y_data=Variable(y).cuda() else: x_data=Variable(x) y_data=Variable(y) out=logistic_model(x_data) #根据逻辑回归模型拟合出的y值 loss=criterion(out.squeeze(),y_data) #计算损失函数
#源代码 out的数据有拼装数据直接输入 # if torch.cuda.is_available(): # x_data=Variable(x).cuda() # y_data=Variable(y).cuda() # else: # x_data=Variable(x) # y_data=Variable(y) out=logistic_model(x_data) #根据逻辑回归模型拟合出的y值 loss=criterion(out.squeeze(),y_data) #计算损失函数 print_loss=loss.data.item() #得出损失函数值
源代码
import torch
from torch import nn
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
#生成数据
sample_nums = 100
mean_value = 1.7
bias = 1
n_data = torch.ones(sample_nums, 2)
x0 = torch.normal(mean_value * n_data, 1) + bias # 类别0 数据 shape=(100, 2)
y0 = torch.zeros(sample_nums) # 类别0 标签 shape=(100, 1)
x1 = torch.normal(-mean_value * n_data, 1) + bias # 类别1 数据 shape=(100, 2)
y1 = torch.ones(sample_nums) # 类别1 标签 shape=(100, 1)
x_data = torch.cat((x0, x1), 0) #按维数0行拼接
y_data = torch.cat((y0, y1), 0)
#画图
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap="RdYlGn")
plt.show()
# 利用torch.nn实现逻辑回归
class LogisticRegression(nn.Module):
def __init__(self):
super(LogisticRegression, self).__init__()
self.lr = nn.Linear(2, 1)
self.sm = nn.Sigmoid()
def forward(self, x):
x = self.lr(x)
x = self.sm(x)
return x
logistic_model = LogisticRegression()
# if torch.cuda.is_available():
# logistic_model.cuda()
#loss函数和优化
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(logistic_model.parameters(), lr=0.01, momentum=0.9)
#开始训练
#训练10000次
for epoch in range(10000):
# if torch.cuda.is_available():
# x_data=Variable(x).cuda()
# y_data=Variable(y).cuda()
# else:
# x_data=Variable(x)
# y_data=Variable(y)
out=logistic_model(x_data) #根据逻辑回归模型拟合出的y值
loss=criterion(out.squeeze(),y_data) #计算损失函数
print_loss=loss.data.item() #得出损失函数值
#反向传播
loss.backward()
optimizer.step()
optimizer.zero_grad()
mask=out.ge(0.5).float() #以0.5为阈值进行分类
correct=(mask==y_data).sum().squeeze() #计算正确预测的样本个数
acc=correct.item()/x_data.size(0) #计算精度
#每隔20轮打印一下当前的误差和精度
if (epoch+1)%100==0:
print("*"*10)
print("epoch {}".format(epoch+1)) #误差
print("loss is {:.4f}".format(print_loss))
print("acc is {:.4f}".format(acc)) #精度
w0, w1 = logistic_model.lr.weight[0]
w0 = float(w0.item())
w1 = float(w1.item())
b = float(logistic_model.lr.bias.item())
plot_x = np.arange(-7, 7, 0.1)
plot_y = (-w0 * plot_x - b) / w1
plt.xlim(-5, 7)
plt.ylim(-7, 7)
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=logistic_model(x_data)[:,0].cpu().data.numpy(), s=100, lw=0, cmap="RdYlGn")
plt.plot(plot_x, plot_y)
plt.show()输出结果
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。
X 关闭
X 关闭
- 1转转集团发布2022年二季度手机行情报告:二手市场“飘香”
- 2充电宝100Wh等于多少毫安?铁路旅客禁止、限制携带和托运物品目录
- 3好消息!京东与腾讯续签三年战略合作协议 加强技术创新与供应链服务
- 4名创优品拟通过香港IPO全球发售4100万股 全球发售所得款项有什么用处?
- 5亚马逊云科技成立量子网络中心致力解决量子计算领域的挑战
- 6京东绿色建材线上平台上线 新增用户70%来自下沉市场
- 7网红淘品牌“七格格”chuu在北京又开一家店 潮人新宠chuu能红多久
- 8市场竞争加剧,有车企因经营不善出现破产、退网、退市
- 9北京市市场监管局为企业纾困减负保护经济韧性
- 10市场监管总局发布限制商品过度包装标准和第1号修改单

