Pytorch使用技巧之Dataloader中的collate_fn参数详析
以MNIST为例
from torchvision import datasets mnist = datasets.MNIST(root="./data/", train=True, download=True) print(mnist[0])
结果
(
, 5)
MINIST数据集的dataset是由一张图片和一个label组成的元组
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:x) for each in dataloader: print(each) break
结果
[(
, 0), ( , 2)]
collate_fn为lamda x:x时表示对传入进来的数据不做处理
下面自定义collate_fn看看什么效果
def collate(data): img = [] label = [] for each in data: img.append(each[0]) label.append(each[1]) return img,label dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:collate(x)) for each in dataloader: print(each) break
结果
([
, ], [9, 3])
说明:若不设置collate_fn参数则会使用默认处理函数
但必须保证传进来的数据都是tensor格式否则会报错
附:DataLoader完整的参数表如下:
class torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
DataLoader在数据集上提供单进程或多进程的迭代器
几个关键的参数意思:
- shuffle:设置为True的时候,每个世代都会打乱数据集
- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能
- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留
总结
到此这篇关于Pytorch使用技巧之Dataloader中的collate_fn参数的文章就介绍到这了,更多相关Dataloader中的collate_fn参数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
下一篇:Python合并重叠矩形框
X 关闭
X 关闭
- 15G资费不大降!三大运营商谁提供的5G网速最快?中国信通院给出答案
- 2联想拯救者Y70发布最新预告:售价2970元起 迄今最便宜的骁龙8+旗舰
- 3亚马逊开始大规模推广掌纹支付技术 顾客可使用“挥手付”结账
- 4现代和起亚上半年出口20万辆新能源汽车同比增长30.6%
- 5如何让居民5分钟使用到各种设施?沙特“线性城市”来了
- 6AMD实现连续8个季度的增长 季度营收首次突破60亿美元利润更是翻倍
- 7转转集团发布2022年二季度手机行情报告:二手市场“飘香”
- 8充电宝100Wh等于多少毫安?铁路旅客禁止、限制携带和托运物品目录
- 9好消息!京东与腾讯续签三年战略合作协议 加强技术创新与供应链服务
- 10名创优品拟通过香港IPO全球发售4100万股 全球发售所得款项有什么用处?