观点:返回最大值的index pytorch方式
【资料图】
目录
返回最大值的indexpytorch 找最大值返回最大值的index
import torch a=torch.tensor([[.1,.2,.3], [1.1,1.2,1.3], [2.1,2.2,2.3], [3.1,3.2,3.3]]) print(a.argmax(dim=1)) print(a.argmax())
输出:
tensor([ 2, 2, 2, 2])
tensor(11)
pytorch 找最大值
题意:使用神经网络实现,从数组中找出最大值。
提供数据:两个 csv 文件,一个存训练集:n 个 m 维特征自然数数据,另一个存每条数据对应的 label ,就是每条数据中的最大值。
这里将随机构建训练集:
#%%
import numpy as np
import pandas as pd
import torch
import random
import torch.utils.data as Data
import torch.nn as nn
import torch.optim as optim
def GetData(m, n):
dataset = []
for j in range(m):
max_v = random.randint(0, 9)
data = [random.randint(0, 9) for i in range(n)]
dataset.append(data)
label = [max(dataset[i]) for i in range(len(dataset))]
data_list = np.column_stack((dataset, label))
data_list = data_list.astype(np.float32)
return data_list
#%%
# 数据集封装 重载函数len, getitem
class GetMaxEle(Data.Dataset):
def __init__(self, trainset):
self.data = trainset
def __getitem__(self, index):
item = self.data[index]
x = item[:-1]
y = item[-1]
return x, y
def __len__(self):
return len(self.data)
# %% 定义网络模型
class SingleNN(nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(SingleNN, self).__init__()
self.hidden = nn.Linear(n_feature, n_hidden)
self.relu = nn.ReLU()
self.predict = nn.Linear(n_hidden, n_output)
def forward(self, x):
x = self.hidden(x)
x = self.relu(x)
x = self.predict(x)
return x
def train(m, n, batch_size, PATH):
# 随机生成 m 个 n 个维度的训练样本
data_list =GetData(m, n)
dataset = GetMaxEle(data_list)
trainset = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True)
net = SingleNN(n_feature=10, n_hidden=100,
n_output=10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
#
total_epoch = 100
for epoch in range(total_epoch):
for index, data in enumerate(trainset):
input_x, labels = data
labels = labels.long()
optimizer.zero_grad()
output = net(input_x)
# print(output)
# print(labels)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
# scheduled_optimizer.step()
print(f"Epoch {epoch}, loss:{loss.item()}")
# %% 保存参数
torch.save(net.state_dict(), PATH)
#测试
def test(m, n, batch_size, PATH):
data_list = GetData(m, n)
dataset = GetMaxEle(data_list)
testloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
dataiter = iter(testloader)
input_x, labels = dataiter.next()
net = SingleNN(n_feature=10, n_hidden=100,
n_output=10)
net.load_state_dict(torch.load(PATH))
outputs = net(input_x)
_, predicted = torch.max(outputs, 1)
print("Ground_truth:",labels.numpy())
print("predicted:",predicted.numpy())
if __name__ == "__main__":
m = 1000
n = 10
batch_size = 64
PATH = "./max_list.pth"
train(m, n, batch_size, PATH)
test(m, n, batch_size, PATH)初始的想法是使用全连接网络+分类来实现, 但是结果不尽人意,主要原因:不同类别之间的样本量差太大,几乎90%都是最大值。
比如代码中随机构建 10 个 0~9 的数字构成一个样本[2, 3, 5, 8, 9, 5, 3, 9, 3, 6], 该样本标签是9。
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
X 关闭
X 关闭
- 1转转集团发布2022年二季度手机行情报告:二手市场“飘香”
- 2充电宝100Wh等于多少毫安?铁路旅客禁止、限制携带和托运物品目录
- 3好消息!京东与腾讯续签三年战略合作协议 加强技术创新与供应链服务
- 4名创优品拟通过香港IPO全球发售4100万股 全球发售所得款项有什么用处?
- 5亚马逊云科技成立量子网络中心致力解决量子计算领域的挑战
- 6京东绿色建材线上平台上线 新增用户70%来自下沉市场
- 7网红淘品牌“七格格”chuu在北京又开一家店 潮人新宠chuu能红多久
- 8市场竞争加剧,有车企因经营不善出现破产、退网、退市
- 9北京市市场监管局为企业纾困减负保护经济韧性
- 10市场监管总局发布限制商品过度包装标准和第1号修改单

