目录
LSTM参数InputsOutputs案例LSTM参数
官方文档给出的解释为:
总共有七个参数,其中只有前三个是必须的。由于大家普遍使用PyTorch的DataLoader来形成批量数据,因此batch_first也比较重要。LSTM的两个常见的应用场景为文本处理和时序预测,因此下面对每个参数我都会从这两个方面来进行具体解释。
input_size:在文本处理中,由于一个单词没法参与运算,因此我们得通过Word2Vec来对单词进行嵌入表示,将每一个单词表示成一个向量,此时input_size=embedding_size。比如每个句子中有五个单词,每个单词用一个100维向量来表示,那么这里input_size=100;在时间序列预测中,比如需要预测负荷,每一个负荷都是一个单独的值,都可以直接参与运算,因此并不需要将每一个负荷表示成一个向量,此时input_size=1。但如果我们使用多变量进行预测,比如我们利用前24小时每一时刻的[负荷、风速、温度、压强、湿度、天气、节假日信息]来预测下一时刻的负荷,那么此时input_size=7。hidden_size:隐藏层节点个数。可以随意设置。num_layers:层数。nn.LSTMCell与nn.LSTM相比,num_layers默认为1。batch_first:默认为False,意义见后文。Inputs
关于LSTM的输入,官方文档给出的定义为:
(相关资料图)
可以看到,输入由两部分组成:input、(初始的隐状态h_0,初始的单元状态c_0)
其中input:
input(seq_len, batch_size, input_size)seq_len:在文本处理中,如果一个句子有7个单词,则seq_len=7;在时间序列预测中,假设我们用前24个小时的负荷来预测下一时刻负荷,则seq_len=24。batch_size:一次性输入LSTM中的样本个数。在文本处理中,可以一次性输入很多个句子;在时间序列预测中,也可以一次性输入很多条数据。input_size
(h_0, c_0):
h_0(num_directions * num_layers, batch_size, hidden_size) c_0(num_directions * num_layers, batch_size, hidden_size)
h_0和c_0的shape一致。
num_directions:如果是双向LSTM,则num_directions=2;否则num_directions=1。num_layers:batch_size:hidden_size:Outputs
关于LSTM的输出,官方文档给出的定义为:
可以看到,输出也由两部分组成:otput、(隐状态h_n,单元状态c_n)
其中output的shape为:
output(seq_len, batch_size, num_directions * hidden_size)
h_n和c_n的shape保持不变,参数解释见前文。
batch_first
如果在初始化LSTM时令batch_first=True,那么input和output的shape将由:
input(seq_len, batch_size, input_size) output(seq_len, batch_size, num_directions * hidden_size)
变为:
input(batch_size, seq_len, input_size) output(batch_size, seq_len, num_directions * hidden_size)
即batch_size提前。
案例
简单搭建一个LSTM如下所示:
class LSTM(nn.Module): def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size): super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.output_size = output_size self.num_directions = 1 # 单向LSTM self.batch_size = batch_size self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True) self.linear = nn.Linear(self.hidden_size, self.output_size) def forward(self, input_seq): batch_size, seq_len = input_seq[0], input_seq[1] h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device) c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device) # output(batch_size, seq_len, num_directions * hidden_size) output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64) pred = self.linear(output) # (5, 30, 1) pred = pred[:, -1, :] # (5, 1) return pred
其中定义模型的代码为:
self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True) self.linear = nn.Linear(self.hidden_size, self.output_size)
我们加上具体的数字:
self.lstm = nn.LSTM(self.input_size=1, self.hidden_size=64, self.num_layers=5, batch_first=True) self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)
再看前向传播:
def forward(self, input_seq): batch_size, seq_len = input_seq[0], input_seq[1] h_0 = torch.randn(self.num_directions * self.num_layers, batch_size, self.hidden_size).to(device) c_0 = torch.randn(self.num_directions * self.num_layers, batch_size, self.hidden_size).to(device) # input(batch_size, seq_len, input_size) # output(batch_size, seq_len, num_directions * hidden_size) output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64) pred = self.linear(output) # (5, 30, 1) pred = pred[:, -1, :] # (5, 1) return pred
假设用前30个预测下一个,则seq_len=30,batch_size=5,由于设置了batch_first=True,因此,输入到LSTM中的input的shape应该为:
input(batch_size, seq_len, input_size) = input(5, 30, 1)
经过DataLoader处理后的input_seq为:
input_seq(batch_size, seq_len, input_size) = input_seq(5, 30, 1)
然后将input_seq送入LSTM:
output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)
根据前文,output的shape为:
output(batch_size, seq_len, num_directions * hidden_size) = output(5, 30, 64)
全连接层的定义为:
self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)
然后将output送入全连接层:
pred = self.linear(output) # pred(5, 30, 1)
得到的预测值shape为(5, 30, 1),由于输出是输入右移,我们只需要取pred第二维度(time)中的最后一个数据:
pred = pred[:, -1, :] # (5, 1)
这样,我们就得到了预测值,然后与label求loss,然后再反向更新参数即可。
到此这篇关于深入学习PyTorch中LSTM的输入和输出的文章就介绍到这了,更多相关PyTorch LSTM内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
X 关闭
X 关闭
- 1转转集团发布2022年二季度手机行情报告:二手市场“飘香”
- 2充电宝100Wh等于多少毫安?铁路旅客禁止、限制携带和托运物品目录
- 3好消息!京东与腾讯续签三年战略合作协议 加强技术创新与供应链服务
- 4名创优品拟通过香港IPO全球发售4100万股 全球发售所得款项有什么用处?
- 5亚马逊云科技成立量子网络中心致力解决量子计算领域的挑战
- 6京东绿色建材线上平台上线 新增用户70%来自下沉市场
- 7网红淘品牌“七格格”chuu在北京又开一家店 潮人新宠chuu能红多久
- 8市场竞争加剧,有车企因经营不善出现破产、退网、退市
- 9北京市市场监管局为企业纾困减负保护经济韧性
- 10市场监管总局发布限制商品过度包装标准和第1号修改单