Using Long Short-Term Memory Layers (LSTM)
Add long short-term memory (LSTM) layers to recurrent neural networks to avoid long-term dependency problems.
Overview
An LSTM layer consists of four gates that manipulate cell-state data:
- Forget gate
Controls whether to erase data from the cell-state.
- Input gate
Controls whether to write data to the cell-state.
- Candidate gate
Controls what data to write to the cell-state.
- Output gate
Controls what data to pass as the output hidden state.
The following figure illustrates the components of an LSTM layer. The inputs are the cell-state (c), the hidden state (h), and the input data (x). The outputs are the updated cell-state (c) and hidden state (h):
[Image]
Note that the default activation function for the forget, input, and output gates is sigmoid; the default activation function for the candidate gate is tanh.
BNNS provides direct apply functions for forward and backward LSTM passes, that is, you don’t need to create an explicit LSTM layer. Rather, you create descriptors of the input, output, and gates to create a parameters structure, and pass the parameters structure to the apply function.
The input and out descriptors require n-dimensional array descriptors for the data, hidden state, and cell-state:
let inputDescriptor = BNNSLSTMDataDescriptor(data_desc: inputDataDescriptor,
hidden_desc: inputHiddenDescriptor,
cell_state_desc: inputCellStateDescriptor)
let outputDescriptor = BNNSLSTMDataDescriptor(data_desc: outputDataDescriptor,
hidden_desc: outputHiddenDescriptor,
cell_state_desc: outputCellStateDescriptor)Define each gate by specifying input, hidden, and cell-state weights, and bias:
let forgetGate = BNNSLSTMGateDescriptor(iw_desc: (forgetGateInputWeightsDescriptor,
forgetGateInputWeightsDescriptor),
hw_desc: forgetGateHiddenWeightsDescriptor,
cw_desc: forgetGateCellStateWeightsDescriptor,
b_desc: forgetGateBiasDescriptor,
activation: BNNSActivation(function: .sigmoid))
let inputGate = BNNSLSTMGateDescriptor(iw_desc: (inputGateInputWeightsDescriptor,
inputGateInputWeightsDescriptor),
hw_desc: inputGateHiddenWeightsDescriptor,
cw_desc: inputGateCellStateWeightsDescriptor,
b_desc: inputGateBiasDescriptor,
activation: BNNSActivation(function: .sigmoid))
let candidateGate = BNNSLSTMGateDescriptor(iw_desc: (candidateGateInputWeightsDescriptor,
candidateGateInputWeightsDescriptor),
hw_desc: candidateGateHiddenWeightsDescriptor,
cw_desc: candidateGateCellStateWeightsDescriptor,
b_desc: candidateGateBiasDescriptor,
activation: BNNSActivation(function: .tanh))
let outputGate = BNNSLSTMGateDescriptor(iw_desc: (outputGateInputWeightsDescriptor,
outputGateInputWeightsDescriptor),
hw_desc: outputGateHiddenWeightsDescriptor,
cw_desc: outputGateCellStateWeightsDescriptor,
b_desc: outputGateBiasDescriptor,
activation: BNNSActivation(function: .sigmoid))The following code shows how each gate computes its output:
for (size_t o = 0; o < iw_desc.size[1]; o++)
{
float res = bias[o]; // init with bias value
for (size_t i = 0; i < iw_desc.size[0]; i++) // matrix vector multiply
res += input[i] * input_weights[o][i];
for (size_t i = 0; i < hw_desc.size[0]; i++) // matrix vector multiply
res += hidden[i] * hidden_weights[o][i];
for (size_t i = 0; i < cw_desc.size[0]; i++) // matrix vector multiply
res += cell[i] * cell_weights[o][i];
out[i] = activation.func(res); // apply activation function
}Give the descriptors for the input, output, and gates, you’re ready to create the parameters structure:
var layerParams = BNNSLayerParametersLSTM(input_size: ... ,
hidden_size: ... ,
batch_size: ... ,
num_layers: ... ,
seq_len: ... ,
dropout: ... ,
lstm_flags: BNNSLayerFlagsLSTMDefaultActivations.rawValue,
sequence_descriptor: BNNSNDArrayDescriptor(),
input_descriptor: inputDescriptor,
output_descriptor: outputDescriptor,
input_gate: inputGate,
forget_gate: forgetGate,
candidate_gate: candidateGate,
output_gate: outputGate,
hidden_activation: BNNSActivation(function: .identity))LSTM provides the option to define a training cache that stores intermediate results to accelerate backward computation. Use BNNSComputeLSTMTrainingCacheCapacity(_:) to compute the size, in bytes, of the training cache:
let trainingCacheBufferBytes = BNNSComputeLSTMTrainingCacheCapacity(&layerParams)
let trainingCache = UnsafeMutableRawPointer.allocate(byteCount: trainingCacheBufferBytes,
alignment: 1)
defer {
trainingCache.deallocate()
}Finally, call BNNSDirectApplyLSTMBatchTrainingCaching(_:_:_:_:) to apply the LSTM layer:
BNNSDirectApplyLSTMBatchTrainingCaching(&layerParams,
nil,
trainingCache,
trainingCacheBufferBytes)