scaledDotProductAttention(query:key:value:mask:scale:name:)
Creates a scaled dot product attention (SDPA) operation and returns the result tensor.
Declaration
func scaledDotProductAttention(query queryTensor: MPSGraphTensor, key keyTensor: MPSGraphTensor, value valueTensor: MPSGraphTensor, mask maskTensor: MPSGraphTensor?, scale: Float, name: String?) -> MPSGraphTensorParameters
- queryTensor:
A tensor that represents the query projection.
- keyTensor:
A tensor that represents the key projection.
- valueTensor:
A tensor that represents the value projection.
- maskTensor:
An optional tensor that contains a mask that is applied to the scaled, matrix multiplied query and value matrices. If mask tensor is nil, the QK^T is not element-wise masked.
- scale:
A scale that is applied to the result of query and value matrix multiply.
- name:
The name for the operation.
Return Value
A valid MPSGraphTensor object.
Discussion
SDPA Op computes attention by computing softmax(scale * QK^T + M)V. queryTensor Q with shape [B, Hq, Nq, F] and keyTensor K with shape [B, Hq, Nkv, F], with Q’s H dimension expandable to satisfy matmul QK^T. maskTensor M’s shape should be broadcast compatible to satisfy (QK^T + M). valueTensor V with shape [B, Hv, Nkv, F] should satisfy the matmul (QK^T + M)V.