replacing(with:where:)
Returns a new tensor replacing values from other with the corresponding element in self where the associated element in mask is true.
Declaration
func replacing(with replacement: MLTensor, where mask: MLTensor) -> MLTensorParameters
- replacement:
The replacement values where
maskistrue. - mask:
The Boolean mask that determines whether the corresponding element / row should be taken from
self(if the element inmaskisfalse) orother(iftrue).
Return Value
A new tensor of the same shape and type as self.
Discussion
For example:
let x = MLTensor([1, 2, 3], scalarType: Float.self)
let y = MLTensor([4, 5, 6], scalarType: Float.self)
let mask = MLTensor([false, true, false])
let z = x.replacing(with: y, where: mask)
await z.shapedArray(of: Float.self) // is [1, 5, 3]