lstm(initialHiddenStates:initialCellStates:inputHiddenWeight:hiddenHiddenWeight:bias:direction:activation:recurrentActivation:cellActivation:outputSequence:)
Adds an LSTM operation to the current graph.
Declaration
func lstm(initialHiddenStates: BNNSGraph.Builder.Tensor<T>, initialCellStates: BNNSGraph.Builder.Tensor<T>, inputHiddenWeight: BNNSGraph.Builder.Tensor<T>, hiddenHiddenWeight: BNNSGraph.Builder.Tensor<T>, bias: BNNSGraph.Builder.Tensor<T>, direction: BNNSGraph.Builder.Direction, activation: BNNSGraph.Builder.Activation, recurrentActivation: BNNSGraph.Builder.Activation, cellActivation: BNNSGraph.Builder.Activation, outputSequence: Bool) -> (output: BNNSGraph.Builder.Tensor<T>, hiddenStates: BNNSGraph.Builder.Tensor<T>, memoryStates: BNNSGraph.Builder.Tensor<T>)Parameters
- initialHiddenStates:
The initial hidden states with the shape
(N, Hout). - initialCellStates:
The initial hidden states with the shape
(N, Hout). - hiddenHiddenWeight:
The hidden-hidden weight with the shape
(4*Hout, Hout). - bias:
The bias (the sum of input-hidden and hidden-hidden biases) with the shape
(4*Hout,). - direction:
An enumeration that specifies a forward or backward RNN.
- activation:
An enumeration that controls the output activation function.
- recurrentActivation:
An enumeration that controls the recurrent activation function.
- cellActivation:
An enumeration that controls the cell activation function.
- outputSequence:
When
true,outputis of shape(L, N, Hout)and contains hidden states from every step,h[:, ...]. Whenfalse,outputis of shape(1, N, Hout)and contains hidden states from the last step,h[-1, ...].
Discussion
For each time t from 0 to L-1, this operation computes the following:
Input gate:
i[t, ...] = RA(matmul(W_ii, x[t, ...]) + b_ii + matmul(W_hi, h[t-1, ...]) + b_hi)Forget gate:
f[t, ...] = RA(matmul(W_if, x[t, ...]) + b_if + matmul(W_hf, h[t-1, ...]) + b_hf)Cell gate:
g[t, ...] = CA(matmul(W_ig, x[t, ...]) + b_ig + matmul(W_hg, h[t-1, ...]) + b_hg)Output gate:
o[t, ...] = RA(matmul(W_io, x[t, ...]) + b_io + matmul(W_ho, h[t-1, ...]) + b_ho)Cell state:
c[t, ...] = f[t, ...] * c[t-1, ...] + i[t, ...] * g[t, ...]Hidden state:
h[t, ...] = o[t, ...] * A(c[t, ...])where:
Ais theactivationfunctionRAis therecurrentActivationfunctionCAis thecellActivationfunctioninputHiddenWeight = concat(W_ii, W_if, W_io, W_ig, axis=-2)hiddenHiddenWeight = concat(W_hi, W_hf, W_ho, W_hg, axis=-2)bias = concat(b_ii + b_hi, b_if + b_hf, b_ig + b_hg, b_io + b_ho, axis=-1)initialHiddenStatesis used forh[t-1, ...]at the first stepinitialCellStatesis used forc[t-1, ...]at the first step*denotes the Hadamard/elementwise product
The input tensor x is of shape (L, N, Hin)
hiddenStates is of shape (N, Hout) and contains hidden states from the last step, h[-1, ...]
memoryStates is of shape (N, Hout) and contains memory states from the last step, c[-1, ...]
Parameter inputHiddenWeightL The input-hidden weight with the shape
(4*Hout, Hin).