GRU(_:recurrentWeight:inputWeight:bias:initState:descriptor:name:)
Creates a GRU operation and returns the value and optionally the training state tensor.
Declaration
func GRU(_ source: MPSGraphTensor, recurrentWeight: MPSGraphTensor, inputWeight: MPSGraphTensor?, bias: MPSGraphTensor?, initState: MPSGraphTensor?, descriptor: MPSGraphGRUDescriptor, name: String?) -> [MPSGraphTensor]Parameters
- source:
A tensor containing the source data
x[t]with the data layout [T,N,I]. In caseinputWeight = nilandbidirectional = NOthen the layout is [T,N,3H] and forinputWeight = nilandbidirectional = YESthe layout is [T,N,6H]. - recurrentWeight:
A tensor containing the recurrent weights
R. Forbidirectionalthe layout is [2,3H,H] and otherwise it is [3H,H]. - inputWeight:
A tensor containing the input weights matrix
W- optional, if missing the operation assumes a diagonal unit-matrix. Forbidirectionalthe layout is [6H,I] and otherwise it is [3H,I]. - bias:
A tensor containing the bias
b- optional, if missing the operation assumes zeroes. Forbidirectionalthe layout is [6H] and otherwise it is [3H]. - initState:
The initial internal state of the LSTM
h[-1]- optional, if missing the operation assumes zeroes. Forbidirectionalthe layout is [N,2H] and otherwise it is [N,H]. - descriptor:
A descriptor that defines the parameters for the GRU operation.
- name:
The name for the operation.
Return Value
A valid MPSGraphTensor array of size 1 or 2 depending on value of descriptor.training. The layout of the state output is [T,N,H] or [T,N,2H] for bidirectional, and the layout of the trainingState output is [T,N,3H] or [T,N,6H] for bidirectional.
Discussion
This operation returns tensors h and optionally z that are defined recursively as follows:
for t = 0 to T-1
z[t] = fz( (h[t-1] m) R^T + x[t] W^T + b ),
r[t] = fr( (h[t-1] m) R^T + x[t] W^T + b ),
c[t] = (h[t-1] r[t] m) R^T
o[t] = fo( c[t] + x[t] W^T + b )
h[t] = z[t]h[t-1] + (1-z[t])o[t]If resetAfter = YES then c[t] is replaced by
c[t] = ( (h[t-1] m) R^T + b2 ) r[t]If flipZ = YES then h[t] is replaced by
h[t] = (1-z[t])h[t-1] + z[t]o[t].W is optional inputWeight, R is recurrentWeight, b is optional bias, m is optional mask, x[t] is source h[t] is the first output, z[t] is the second output (optional) and h[-1] is initState. b2 is an optional resetBias vector, only used when resetAfter = YES. See MPSGraphGRUDescriptor for different activation options for f().