循环神经网络在学习过程中的主要问题是长期依赖问题。
LSTM引入了中间变量C_t,使用三个门来控制信息的保留和获取,f_t表示遗忘门,i_t表示记忆门,o_t表示输出门。LSTM的公式为:
简化后的PyTorch代码如下:
1 2 3 4 5 6 7 8 |
def forward(self, x): i = self.conv_i(x) f = self.conv_f(x) g = self.conv_g(x) o = self.conv_o(x) c = f * c + i * g h = o * F.tanh(c) |