bandPart(_:numLower:numUpper:name:)
Computes the band part of an input tensor.
Declaration
func bandPart(_ inputTensor: MPSGraphTensor, numLower: Int, numUpper: Int, name: String?) -> MPSGraphTensorParameters
- inputTensor:
Input tensor
- numLower:
The number of diagonals in the lower triangle to keep. If -1, the framework returns all sub diagnols.
- numUpper:
The number of diagonals in the upper triangle to keep. If -1, the framework returns all super diagnols.
- name:
Name for the operation.
Return Value
A valid MPSGraphTensor object.
Discussion
This operation copies a diagonal band of values from input tensor to a result tensor of the same size. A coordinate [..., i, j] is in the band if
(numLower < 0 || (i-j) <= numLower) && (numUpper < 0 || (j-i) <= numUpper) The values outside of the band are set to 0.