split(count:alongAxis:)
Splits a tensor into multiple tensors. The tensor is split along dimension axis into count smaller tensors.
Declaration
func split(count: Int, alongAxis axis: Int = 0) -> [MLTensor]Parameters
- count:
The number of splits to create, must divide the size of dimension
axisevenly. - axis:
The dimension along which to split this tensor. The
axismust be in the range[-rank, rank).
Return Value
An array containing the tensor parts.
Discussion
For example:
// 'value' is a tensor with shape [5, 30]
// Split 'value' into 3 tensors along dimension 1:
let parts = value.split(count: 3, alongAxis: 1)
parts[0] // has shape [5, 10]
parts[1] // has shape [5, 10]
parts[2] // has shape [5, 10]