Pytorch如何加载自己的数据集(使用DataLoader读取Dataset)
目录
1.Pytorch加载数据集会用到官方整理好的数据集2.Dataset3.DataLoader4.查看数据5.总结1.Pytorch加载数据集会用到官方整理好的数据集
很多时候我们需要加载自己的数据集,这时候我们需要使用Dataset和DataLoader
Dataset
:是被封装进DataLoader里,实现该方法封装自己的数据和标签。DataLoader
:被封装入DataLoaderIter里,实现该方法达到数据的划分。
2.Dataset
阅读源码后,我们可以指导,继承该方法必须实现两个方法:
(资料图片)
_getitem_()
_len_()
因此,在实现过程中我们测试如下:
import torch import numpy as np # 定义GetLoader类,继承Dataset方法,并重写__getitem__()和__len__()方法 class GetLoader(torch.utils.data.Dataset): # 初始化函数,得到数据 def __init__(self, data_root, data_label): self.data = data_root self.label = data_label # index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回 def __getitem__(self, index): data = self.data[index] labels = self.label[index] return data, labels # 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼 def __len__(self): return len(self.data) # 随机生成数据,大小为10 * 20列 source_data = np.random.rand(10, 20) # 随机生成标签,大小为10 * 1列 source_label = np.random.randint(0,2,(10, 1)) # 通过GetLoader将数据进行加载,返回Dataset对象,包含data和labels torch_data = GetLoader(source_data, source_label)
3.DataLoader
提供对Dataset
的操作,操作如下:
torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)
参数含义如下:
dataset
:加载torch.utils.data.Dataset对象数据batch_size
:每个batch的大小shuffle
:是否对数据进行打乱drop_last
:是否对无法整除的最后一个datasize进行丢弃num_workers
:表示加载的时候子进程数
因此,在实现过程中我们测试如下(紧跟上述用例):
from torch.utils.data import DataLoader # 读取数据 datas = DataLoader(torch_data, batch_size=6, shuffle=True, drop_last=False, num_workers=2)
此时,我们的数据已经加载完毕了,只需要在训练过程中使用即可。
4.查看数据
我们可以通过迭代器(enumerate)
进行输出数据,测试如下:
for i, data in enumerate(datas): # i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels print("第 {} 个Batch \n{}".format(i, data))
输出结果如下图:
结果说明:由于数据的是10个,batchsize大小为6,且drop_last=False,因此第一个大小为6,第二个为4。
每一个batch中包含data和对应的labels。
当我们想取出data和对应的labels时候,只需要用下表就可以啦,测试如下:
# 表示输出数据 print(data[0]) # 表示输出标签 print(data[1])
结果如图:
5.总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
X 关闭
X 关闭
- 15G资费不大降!三大运营商谁提供的5G网速最快?中国信通院给出答案
- 2联想拯救者Y70发布最新预告:售价2970元起 迄今最便宜的骁龙8+旗舰
- 3亚马逊开始大规模推广掌纹支付技术 顾客可使用“挥手付”结账
- 4现代和起亚上半年出口20万辆新能源汽车同比增长30.6%
- 5如何让居民5分钟使用到各种设施?沙特“线性城市”来了
- 6AMD实现连续8个季度的增长 季度营收首次突破60亿美元利润更是翻倍
- 7转转集团发布2022年二季度手机行情报告:二手市场“飘香”
- 8充电宝100Wh等于多少毫安?铁路旅客禁止、限制携带和托运物品目录
- 9好消息!京东与腾讯续签三年战略合作协议 加强技术创新与供应链服务
- 10名创优品拟通过香港IPO全球发售4100万股 全球发售所得款项有什么用处?