循环神经网络
Recurrent Neural Network
- NN: 学习的是函数
- CNN: 学习的是特征
- RNN: 学习的是程序, 可以看成是状态机
状态和模型
- IID数据
- 分类问题
- 回归问题
- 特征表达
- 大部分数据都不满足 IID
- 序列分析 (Tagging, Annotation)
- 序列生成, 如语言翻译, 自动文本生成
- 内容提取 (Content Extraction), 如图像描述
Example: 输出不仅依赖输入, 还跟之前的时刻有关
序列样本
- RNN不仅仅能够处理序列输出, 也能得到序列输出, 这里序列指的是向量的序列.
- RNN学习出来的是程序, 不是函数
Example:
- one to one: 普通NN网络
- one to many: 输入一张图, 输出是一句话
- many to one: 输入一句话, 输出是情感判断 or 情感预测
- 第一个many to many: 语言翻译
- 第二个many to many: 玩游戏
序列预测
- 输入的时间变化向量序列:
- 在 时刻通过模型来估计
- Problem:
- 对内部状态难以建模和观察
- 对长时间范围的场景(context)难以建模和观察
- Solution:
- 引入内部隐含状态变量
序列预测模型
- 在时间 的更新计算:
- 预测计算:
其中,
- 是离散输入, { ,...,,..., }
- 是状态层, 在0时刻初始化
- 是激励函数(Sigmoid/tanh), 是损失函数/优化目标(e.g. Softmax)
- 在整个计算过程中, W保持不变
RNN训练
- 前向(forward)计算, 相同的W矩阵需要乘以多次
- 多步之前的输入x, 会影响当前的输出
- 在后向(backward)计算的时候, 同样相同的矩阵也会乘以多次
BPTT算法
Back Propogation Through Time.
BPTT算法实现
- RNN前向计算:
- 计算 的偏导, 需要把所有 Time Step 加起来:
- 链式规则:
- 其中,
Gradient vanishing/exploding
根据 , 可知: 因此: This can become very small or very large quickly, and the locality assumption of gradient descent breaks down. --> Vanishing or exploding gradient
解决方案
- Gradient exploding - Clipping, 设置阈值(threshold):
- Gradient vanishing - Wrec 初始化为1, 使用ReLU替换tanh
LSTM模型
Long Short Term Memory 是应用最为广泛, 成功的RNN.
- 不需要记忆复杂的BPTT公式, 利用时序展开, 构造层次关系, 可以开发复杂的BPTT算法
- 做了逐点的控制, 避免一个W连乘. 具备一定已知gradient vanishing / exploding 特性.
LSTM Cell
一个Cell由三个Gate(input、forget、output)和一个cell单元组成。Gate使用一个sigmoid激活函数,而input和cell state通常会使用tanh来转换
- Gates:
- 输入变换:
- 状态更新:
LSTM应用
- 手写输入:
- char-rnn
- handwriting
- 图文翻译:
- nerualtalk2
- 论文: Show and tell: A neural image caption generator
反馈与建议
- 微博:@Girl_AI
参考文献
- Bengio, Y, et al. On the difficulty of training Recurrent Neural Networks
- Bengio, Y., et al. (1994). Learning long-term dependencies with gradient descent is difficult.
- The Unreasonable Effectiveness of Recurrent Neural Networks
- LSTM implementation explained
- 七月算法机器学习在线班