pedronahum/metalhlo
**StableHLO Execution on Apple Metal**
Design Philosophy
MetalHLO draws inspiration from both OpenXLA and MLX, combining the best of both worlds:
From OpenXLA:
- StableHLO as the IR — A stable, portable intermediate representation for ML workloads
- Multi-phase optimization pipeline — Simplification → Canonicalization → Fusion → Layout → Scheduling
- Pattern-based fusion — Recognizing and fusing common patterns like attention, depth attention (Attention Residuals), layer norm, GELU
- Cost-model driven decisions — Using analytical cost models to guide fusion decisions
From MLX:
- Lazy graph compilation — Build the computation graph, compile once, execute many times
- Unified memory architecture — Zero-copy data sharing between CPU and GPU on Apple Silicon
- Simplicity first — Clean APIs that make common cases easy and complex cases possible
- Single-device focus — Optimized for one GPU rather than distributed execution
MetalHLO's unique contribution:
- Three execution backends — MPSGraph for broad compatibility, custom Metal kernels for peak performance, and heterogeneous GPU+ANE+CPU for parallel execution
- Heterogeneous GPU+ANE+CPU execution — Automatically partitions profitable workloads across GPU, MPS/ANE, and CPU using a profitability-gated 4-pass pipeline (see how it works below)
- Progressive optimization levels — O0 to O3 matching compiler conventions
- PJRT plugin for JAX — Standard OpenXLA plugin interface enables
import jax; jax.numpyon Apple GPUs - C API for portability — Integrate with any language that can call C functions
How Heterogeneous Execution Works
Every Apple Silicon chip contains a GPU, a Neural Engine (ANE, accessible via MPS), and high-performance CPU cores — three independent compute units that share unified memory. Most ML frameworks only use the GPU. MetalHLO is the first StableHLO runtime to use all three simultaneously.
StableHLO Program
│
┌────────▼────────┐
│ 4-Pass Graph │
│ Pipeline │
└────────┬────────┘
│
Pass 1: Eligibility Annotation
(ProfitabilityGuard)
│
Pass 2: Op Fusion
(chain + attention absorption)
│
Pass 3: Heterogeneous Partitioning
(optimal split per cluster)
│
Pass 4: Sequential Scheduling
(serial between clusters,
parallel within)
│
┌──────────────┼──────────────┐
▼ ▼ ▼
┌────────────┐ ┌────────────┐ ┌────────────┐
│ GPU │ │ MPS/ANE │ │ CPU │
│ (Metal) │ │ │ │ (Accelerate│
│ │ │ • matmul │ │ /vDSP) │
│ • matmul │ │ • conv │ │ │
│ • elem. │ │ (routes │ │ • matmul │
│ • softmax │ │ to ANE) │ │ • elem. │
│ • layerN. │ │ │ │ • layerN. │
└──────┬─────┘ └──────┬─────┘ └──────┬─────┘
│ │ │
│ Unified Memory │ ← Zero-copy: all three units
│ (shared) │ read/write the same buffers
└──────────────┴──────────────┘
▼
ResultsThe key insight is that Apple Silicon's unified memory eliminates data transfer cost between compute units — GPU, MPS/ANE, and CPU all read and write the same physical memory. The system uses a profitability-gated cost model: rather than partitioning every operation, a compound gate filters out shapes where the overhead of multi-unit dispatch exceeds the benefit, ensuring zero regressions on unprofitable workloads.
Profitability Guard (Compound Gate)
Not every operation benefits from partitioning. Empirical validation on GPT-2 and ViT-B/16 established two hardware-derived thresholds:
- Minimum elements: ≥ 10M output elements (below this, dispatch overhead dominates)
- Minimum output columns: N ≥ 32,768 (narrow shapes can't amortize per-unit slice overhead)
Both conditions must be met. In practice, this means only vocabulary-scale projections (GPT-2 N=50,257, LLaMA N=32,000, CLIP N=49,408) are partitioned, while all other operations pass through to single-unit execution with zero overhead.
Op-Class Taxonomy
| Class | Examples | Strategy | Speedup | |-------|----------|----------|---------| | Compute-bound | matmul, attention | 3-unit GPU+MPS+CPU | 1.14-1.91x | | MPS-dominant | convolution | MPS/ANE alone (2-5x faster than partitioned) | — | | Bandwidth-bound | elementwise, layerNorm | 2-unit GPU+CPU, value is cluster absorption | — | | Gather | embedding | 2-unit GPU+CPU, CPU dominates | — |
GPT-2 End-to-End Validation
The heterogeneous pipeline was validated on GPT-2 (124M parameters):
| Metric | Result | |--------|--------| | Logit projection speedup | 1.91x (fused GPU+MPS+CPU vs single-unit) | | Regressions | Zero — compound gate rejects all unprofitable ops | | Ops partitioned (seq=512) | 1 of 147 (logit projection only) | | Cross-architecture | ViT-B/16: 0 partitioned at all batch sizes (correct) |
The compound gate is batch-invariant: the same ops are selected regardless of batch size (1 through 8). At batch ≥ 16 on 8GB devices, memory pressure causes regressions — the gate should be extended with an upper memory bound for large-batch safety.
Features
- StableHLO Conformance — 191 of 277 conformance tests pass (86 skipped for MPS/Metal limitations)
- 88% StableHLO Coverage — 92 of 105 operations implemented
- ~99% Practical ML Coverage — All operations needed for production ML workloads
- Triple API — Native Swift API + C API + PJRT plugin for JAX/XLA integration
- Configurable Optimization — O0 to O3 levels with algebraic simplification and operator fusion
- Heterogeneous Execution — GPU+MPS/ANE+CPU auto mode partitions profitable operations across all compute units with a compound profitability gate
- Full Training Support — Forward and backward pass operations
- Apple Silicon Optimized — Leverages MPSGraph, custom Metal kernels, and the Apple Neural Engine
Supported Workloads
| Workload | Support | |----------|---------| | CNNs | Full (convolution, pooling, batch norm) | | Transformers/MLPs | Full (attention, linear layers, activations) | | RNNs | Full (while loops, dynamic operations) | | Signal Processing | Full (FFT, IFFT, RFFT, IRFFT) | | Quantized Models | Full (quantize/dequantize) |
Requirements
- macOS: 14.0+ (Sonoma)
- Swift: 6.0+
- Xcode: 15.0+
- Hardware: Apple Silicon (M1/M2/M3/M4)
Installation
Swift Package Manager
Add MetalHLO to your Package.swift:
dependencies: [
.package(url: "https://github.com/pedronahum/MetalHLO.git", from: "0.2.0")
]Then add the dependency to your target:
.target(
name: "YourTarget",
dependencies: ["MetalHLO"]
)Building from Source
git clone https://github.com/pedronahum/MetalHLO.git
cd MetalHLO
swift build
swift test
# Build the PJRT plugin for JAX integration
swift build -c release --product PJRTMetalHLOQuick Start
Swift — Basic Usage
import MetalHLO
// Create client
let client = try Client.create()
print("Using device: \(client.deviceName)")
// Compile StableHLO MLIR
let mlir = """
module @add {
func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>) {
%0 = stablehlo.add %arg0, %arg1 : tensor<4xf32>
return %0 : tensor<4xf32>
}
}
"""
let executable = try client.compile(mlir)
// Create input buffers
let a = client.createBuffer([1.0, 2.0, 3.0, 4.0] as [Float], shape: [4])
let b = client.createBuffer([10.0, 20.0, 30.0, 40.0] as [Float], shape: [4])
// Execute
let outputs = try executable.execute([a, b])
let result = try outputs[0].toFloatArray()
print("Result: \(result)") // [11.0, 22.0, 33.0, 44.0]Swift — With Optimization Configuration
import MetalHLO
let client = try Client.create()
// Configure aggressive optimization (O3)
let config = CompilationConfig(optimizationLevel: .O3)
let executable = try client.compile(mlir, config: config)
// Or use presets
let debugExe = try client.compile(mlir, config: .debug) // O0, no caching, debug info
let releaseExe = try client.compile(mlir, config: .release) // O3, caching enabled
let fastExe = try client.compile(mlir, config: .fast) // O1, quick compilationSwift — Heterogeneous GPU+ANE+CPU Execution
import MetalHLO
let client = try Client.create()
// Enable heterogeneous execution across GPU, MPS/ANE, and CPU
let config = CompilationConfig(
optimizationLevel: .O3,
devicePolicy: .auto // Profitability guard decides which ops to partition
)
let executable = try client.compile(mlir, config: config)
// Execute — profitable operations are partitioned across GPU + MPS/ANE + CPU
let outputs = try executable.execute(inputs)The profitability guard gates partitioning to operations where it provably helps:
- Partitioned (3-unit): Large vocabulary projections (N ≥ 32K, ≥ 10M elements) — e.g., GPT-2 logit projection
- Single-unit fallback: Everything else — the guard ensures zero overhead on unprofitable shapes
C — Basic Usage
#include <metalhlo.h>
#include <stdio.h>
int main() {
MHLOClientRef client = NULL;
MHLOStatusCode status;
// Create client
status = mhlo_client_create(&client);
if (status != MHLO_OK) {
fprintf(stderr, "Error: %s\n", mhlo_get_last_error());
return 1;
}
// Print device info
char* device = mhlo_client_device_name(client);
printf("Device: %s\n", device);
mhlo_free_string(device);
// Compile MLIR
const char* mlir =
"module @add {\n"
" func.func @main(%a: tensor<4xf32>, %b: tensor<4xf32>) -> (tensor<4xf32>) {\n"
" %0 = stablehlo.add %a, %b : tensor<4xf32>\n"
" return %0 : tensor<4xf32>\n"
" }\n"
"}\n";
MHLOExecutableRef exe = NULL;
status = mhlo_compile(client, mlir, &exe);
if (status != MHLO_OK) {
fprintf(stderr, "Compile error: %s\n", mhlo_get_last_error());
return 1;
}
// Create input buffers
float a_data[] = {1.0f, 2.0f, 3.0f, 4.0f};
float b_data[] = {10.0f, 20.0f, 30.0f, 40.0f};
int64_t shape[] = {4};
MHLOBufferRef buf_a = NULL, buf_b = NULL;
mhlo_buffer_create(client, a_data, sizeof(a_data), shape, 1, MHLO_F32, &buf_a);
mhlo_buffer_create(client, b_data, sizeof(b_data), shape, 1, MHLO_F32, &buf_b);
// Execute
MHLOBufferRef inputs[] = {buf_a, buf_b};
MHLOBufferRef outputs[1] = {NULL};
int32_t num_outputs = 0;
status = mhlo_execute(exe, inputs, 2, outputs, &num_outputs);
if (status != MHLO_OK) {
fprintf(stderr, "Execute error: %s\n", mhlo_get_last_error());
return 1;
}
// Read result
float result[4];
mhlo_buffer_to_host(outputs[0], result, sizeof(result));
printf("Result: [%.1f, %.1f, %.1f, %.1f]\n",
result[0], result[1], result[2], result[3]);
// Cleanup
mhlo_buffer_destroy(buf_a);
mhlo_buffer_destroy(buf_b);
mhlo_buffer_destroy(outputs[0]);
mhlo_executable_destroy(exe);
mhlo_client_destroy(client);
return 0;
}C — With Optimization Configuration
#include <metalhlo.h>
// Initialize config with defaults (O2)
MHLOCompileConfig config;
mhlo_compile_config_init(&config);
// Use aggressive optimization
config.optimization_level = MHLO_OPT_O3;
config.device_policy = MHLO_DEVICE_AUTO; // Enable GPU+ANE+CPU heterogeneous execution
config.enable_caching = true;
config.enable_debug_info = false;
// Compile with configuration
MHLOExecutableRef exe = NULL;
mhlo_compile_with_config(client, mlir, &config, &exe);
// Get execution statistics after running
MHLOExecutionStats stats;
mhlo_executable_get_stats(exe, &stats);
printf("Executions: %lld, Avg time: %.3f ms\n",
stats.execution_count, stats.average_execution_time_ms);Optimization Levels
MetalHLO provides four optimization levels that control the aggressiveness of the compilation pipeline:
| Level | Description | Use Case | |-------|-------------|----------| | O0 | No optimization | Debugging, fastest compilation | | O1 | Basic optimization | Quick iteration during development | | O2 | Standard optimization (default) | Production with balanced compile time | | O3 | Aggressive optimization | Maximum runtime performance |
What Each Level Does
O0 — No Optimization
- Direct translation to Metal kernels
- Fastest compile time
- Useful for debugging IR issues
O1 — Basic Optimization
- Algebraic simplification (
x + 0 → x,x * 1 → x) - Dead code elimination
- Constant folding
O2 — Standard Optimization (Default)
- All O1 optimizations
- Shape canonicalization (reshape/transpose/broadcast fusion)
- Common subexpression elimination
- Pattern-based fusion (softmax, GELU, layer norm, attention)
- Producer-consumer fusion
O3 — Aggressive Optimization
- All O2 optimizations
- Multiple fusion iterations
- Sibling fusion (multi-output fusion)
- Horizontal fusion (batching small operations)
- Cross-layer fusion
- Layout optimization
Optimization Pipeline
MetalHLO's optimization pipeline runs in phases, inspired by XLA's approach:
┌─────────────────────────────────────────┐
│ StableHLO MLIR Input │
└─────────────────┬───────────────────────┘
│
┌─────────────────▼───────────────────────┐
│ Phase 1: SIMPLIFICATION │
│ • Constant folding │
│ • Algebraic simplification │
│ • Dead code elimination │
│ • Common subexpression elimination │
└─────────────────┬───────────────────────┘
│
┌─────────────────▼───────────────────────┐
│ Phase 2: CANONICALIZATION │
│ • Reshape canonicalization │
│ • Transpose canonicalization │
│ • Broadcast canonicalization │
└─────────────────┬───────────────────────┘
│
┌─────────────────▼───────────────────────┐
│ Phase 3: PATTERN FUSION │
│ • Softmax detection & fusion │
│ • GELU/SiLU activation fusion │
│ • Layer norm / RMS norm fusion │
│ • Attention pattern fusion │
│ • Depth attention (Attention Residuals)│
│ • MatMul + Bias + Activation fusion │
└─────────────────┬───────────────────────┘
│
┌─────────────────▼───────────────────────┐
│ Phase 4: GENERIC FUSION │
│ • Producer-consumer fusion │
│ • Sibling fusion (multi-output) │
│ • Horizontal fusion (op batching) │
└─────────────────┬───────────────────────┘
│
┌─────────────────▼───────────────────────┐
│ Phase 5: LAYOUT & SCHEDULING │
│ • Memory layout optimization │
│ • Buffer assignment │
│ • Kernel scheduling │
└─────────────────┬───────────────────────┘
│
┌─────────────────▼───────────────────────┐
│ Metal Kernel Generation │
└─────────────────────────────────────────┘Key Optimization Passes
| Pass | Description | |------|-------------| | Algebraic Simplifier | Applies identity rules (x+0→x, x*1→x), inverse rules (exp(log(x))→x), and strength reduction | | Dead Code Elimination | Removes operations whose results are unused | | CSE | Common Subexpression Elimination — reuses identical computations | | Reshape Canonicalizer | Fuses consecutive reshapes, eliminates no-op reshapes | | Transpose Canonicalizer | Composes consecutive transposes, eliminates identity transposes | | Broadcast Canonicalizer | Fuses broadcasts, moves broadcasts through elementwise ops | | Pattern Fusion | Recognizes ML patterns (softmax, GELU, attention, depth attention) and replaces with optimized kernels | | Producer-Consumer Fusion | Fuses elementwise operations with their consumers to reduce memory traffic | | Sibling Fusion | Fuses operations that share inputs (multi-output fusion) | | Horizontal Fusion | Batches multiple small independent operations |
Supported Operations
Fully Implemented (92 ops)
| Category | Operations | |----------|------------| | Binary Arithmetic | add, subtract, multiply, divide, maximum, minimum, power | | Unary Math | negate, abs, exp, log, sqrt, rsqrt, sin, cos, tanh, floor, ceil, sign, tan, logistic, is_finite, expm1, log1p, cbrt, round_nearest_afz, round_nearest_even | | Bitwise | not, and, or, xor, shift_left, shift_right_arithmetic, shift_right_logical, popcnt | | Type Conversion | convert, bitcast_convert | | Matrix | dot, dot_general, transpose, reshape, broadcast_in_dim, reverse | | Dynamic Shape | dynamic_slice, dynamic_update_slice, dynamic_reshape, dynamic_broadcast_in_dim, dynamic_pad, dynamic_iota, dynamic_gather | | Convolution | convolution | | Reduction | reduce (sum, max, min, mean), reduce_window | | Normalization | batch_norm_inference, batch_norm_training, batch_norm_grad | | FFT | fft (FFT, IFFT, RFFT, IRFFT) | | Sorting | sort | | Comparison | compare (EQ, NE, LT, LE, GT, GE), select, clamp | | Indexing | slice, pad, concatenate, gather, scatter | | RNG | rng (uniform, normal), rng_bit_generator | | Constants | constant, iota | | Control Flow | while, if | | Quantization | uniform_quantize, uniform_dequantize | | Complex Numbers | complex, real, imag | | Select/Scatter | select_and_scatter | | Custom Calls | fused_scaled_dot_product_attention, fused_depth_attention, fused_layer_norm, fused_rms_norm, fused_matmul_bias_activation, fused_softmax, fused_gelu, fused_rope |
Operations Requiring MPS Kernel Bridging (3 ops)
| Operation | Notes | |-----------|-------| | triangular_solve | Requires MPSMatrixSolveTriangular | | cholesky | Requires MPSMatrixDecompositionCholesky | | map | Requires JIT compilation |
Excluded by Design (14 ops)
Multi-device operations (all_gather, all_reduce, etc.), communication operations (infeed, outfeed, etc.), and tuple operations are excluded per the single-device focus.
API Reference
Swift API
Client
public final class Client: @unchecked Sendable {
/// Create a client for the default Metal device
static func create() throws -> Client
/// The underlying Metal device name
var deviceName: String { get }
/// Compile StableHLO MLIR (default O2 optimization)
func compile(_ mlir: String) throws -> Executable
/// Compile with explicit configuration
func compile(_ mlir: String, config: CompilationConfig) throws -> Executable
/// Create buffers from host data
func createBuffer(_ data: [Float], shape: [Int]) -> Buffer
func createBuffer<T: Numeric>(_ data: [T], shape: [Int], elementType: ElementType) throws -> Buffer
}CompilationConfig
public struct CompilationConfig: Sendable {
var optimizationLevel: OptimizationLevel // .O0, .O1, .O2, .O3
var devicePolicy: DevicePolicy // .gpuOnly, .auto
var enableCaching: Bool
var generateDebugInfo: Bool
// Presets
static let `default`: CompilationConfig // O2, GPU only
static let debug: CompilationConfig // O0, no cache, debug info
static let release: CompilationConfig // O3, caching
static let fast: CompilationConfig // O1, caching
}Executable
public final class Executable: @unchecked Sendable {
var inputCount: Int { get }
var outputCount: Int { get }
var inputTypes: [TensorType] { get }
var outputTypes: [TensorType] { get }
func execute(_ inputs: [Buffer]) throws -> [Buffer]
func executeWithTiming(_ inputs: [Buffer]) throws -> ([Buffer], ExecutionTiming)
}Buffer
public final class Buffer: @unchecked Sendable {
var shape: [Int] { get }
var count: Int { get }
var elementType: ElementType { get }
func toFloatArray() throws -> [Float]
func toInt32Array() throws -> [Int32]
func toData() throws -> Data
}C API
// Version
const char* mhlo_version(void);
// Client
MHLOStatusCode mhlo_client_create(MHLOClientRef* out_client);
void mhlo_client_destroy(MHLOClientRef client);
char* mhlo_client_device_name(MHLOClientRef client);
// Compilation
MHLOStatusCode mhlo_compile(MHLOClientRef client, const char* mlir, MHLOExecutableRef* out);
MHLOStatusCode mhlo_compile_with_config(MHLOClientRef client, const char* mlir,
const MHLOCompileConfig* config, MHLOExecutableRef* out);
void mhlo_compile_config_init(MHLOCompileConfig* config);
void mhlo_executable_destroy(MHLOExecutableRef exe);
// Execution
MHLOStatusCode mhlo_execute(MHLOExecutableRef exe, const MHLOBufferRef* inputs, int32_t num_inputs,
MHLOBufferRef* outputs, int32_t* num_outputs);
MHLOStatusCode mhlo_executable_get_stats(MHLOExecutableRef exe, MHLOExecutionStats* stats);
void mhlo_executable_reset_stats(MHLOExecutableRef exe);
// Buffers
MHLOStatusCode mhlo_buffer_create(MHLOClientRef client, const void* data, size_t size,
const int64_t* shape, int32_t rank, MHLOElementType type,
MHLOBufferRef* out);
void mhlo_buffer_destroy(MHLOBufferRef buffer);
MHLOStatusCode mhlo_buffer_to_host(MHLOBufferRef buffer, void* out, size_t size);
// Utilities
const char* mhlo_get_last_error(void);
void mhlo_free_string(const char* str);PJRT Plugin (JAX Integration)
MetalHLO implements the PJRT C API (v0.90), the standard device plugin interface for the OpenXLA ecosystem. This enables JAX, TensorFlow, and PyTorch/XLA to execute StableHLO programs on Apple Silicon GPUs via MetalHLO.
How It Works
The plugin is a dynamic library (libPJRTMetalHLO.dylib) that exports a GetPjrtApi() symbol returning a fully populated PJRT_Api vtable. JAX discovers and loads this plugin automatically via Python's entry-point mechanism.
Implemented PJRT Functions
| Category | Functions | |----------|-----------| | Error | Error_Destroy, Error_Message, Error_GetCode | | Client | Client_Create, Client_Destroy, Client_PlatformName, Client_Devices, Client_AddressableDevices, Client_Compile, Client_BufferFromHostBuffer | | Executable | LoadedExecutable_Execute, LoadedExecutable_Destroy, LoadedExecutable_GetExecutable, Executable_Name, Executable_NumOutputs, Executable_Serialize, Executable_Destroy, Executable_Fingerprint, Executable_OutputElementTypes, Executable_OutputDimensions, Executable_OutputMemoryKinds, Executable_DeserializeAndLoad | | Buffer | Buffer_ToHostBuffer, Buffer_Destroy, Buffer_ElementType, Buffer_Dimensions, Buffer_OnDeviceSizeInBytes, Buffer_Device, Buffer_Memory, Buffer_ReadyEvent, Buffer_CopyToDevice | | Device | Device_GetDescription, DeviceDescription_Id, DeviceDescription_ProcessIndex, DeviceDescription_Attributes | | Memory | Device_AddressableMemories, Device_DefaultMemory, Memory_Id, Memory_Kind, Memory_Kind_Id | | Event | Event_Await, Event_OnReady, Event_Destroy, Event_IsReady |
Building the Plugin
# Build the dynamic library
swift build -c release --product PJRTMetalHLO
# The dylib is at .build/release/libPJRTMetalHLO.dylibUsing with JAX
Install the Python package that registers MetalHLO as a JAX backend:
# From the repository root
pip install -e python/Then use JAX as usual — MetalHLO will be available as a backend:
import jax
import jax.numpy as jnp
# Check available backends
print(jax.devices()) # Should include MetalHLO device
# Run computations on MetalHLO
x = jnp.array([1.0, 2.0, 3.0, 4.0])
y = jnp.array([5.0, 6.0, 7.0, 8.0])
result = x + y # Executes on Apple GPU via MetalHLOYou can also set the METALHLO_PLUGIN_PATH environment variable to point to the dylib if it's not in the default build locations:
export METALHLO_PLUGIN_PATH=/path/to/libPJRTMetalHLO.dylibPlugin Capabilities
- Platform name:
metalhlo - Memory model: Unified memory (CPU/GPU shared, zero-copy transfers)
- Compilation: Uses O2 optimization by default (pattern fusion, CSE, algebraic simplification)
- Serialization: Full executable serialize/deserialize support for caching compiled programs
- Buffer management: Host-to-device and device-to-host transfers with proper event synchronization
- Device attributes: Reports Metal GPU family, recommended max working set size, and unified memory architecture
Current Limitations
- Single-device execution only (no multi-GPU or distributed)
- Static shapes required (no dynamic shape inference at the PJRT level)
- Token and tuple operations are not supported
- Some JAX operations may require additional PJRT functions not yet implemented
Architecture
┌──────────────────────────────────────────────────────────────────────────┐
│ JAX / XLA C/C++ Projects Swift Projects │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌──────────────┐ ┌─────────────┐ ┌───────────────────┐ │
│ │ PJRT Plugin │ │ C API │ │ Swift API │ │
│ │ GetPjrtApi() │ │ mhlo_* │ │ MetalHLO.Client │ │
│ └──────┬───────┘ └──────┬──────┘ └─────────┬─────────┘ │
│ │ │ │ │
│ └───────────────────┴────────────────────────────┘ │
│ ▼ │
│ ┌─────────────────────────────────────────┐ │
│ │ MetalHLOCore │ │
│ │ │ │
│ │ ┌───────────┐ ┌───────────────────┐ │ │
│ │ │ Parser │ → │ Optimizer │ │ │
│ │ └───────────┘ │ (PassManager) │ │ │
│ │ └─────────┬─────────┘ │ │
│ │ │ │ │
│ │ ┌────────────┴────────────────┐ │ │
│ │ ▼ ▼ ▼ │ │
│ │ ┌──────────┐ ┌──────────┐ ┌────────────┐ │
│ │ │ MPSGraph │ │ Metal │ │ Heterogen. │ │
│ │ │ Backend │ │ Kernel │ │ GPU+ANE+CPU│ │
│ │ │ (default)│ │ (O0-O3) │ │ (auto) │ │
│ │ └──────────┘ └──────────┘ └────────────┘ │
│ └─────────────────────────────────────────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌──────────────────────────────────────────┐ │
│ │ Apple Metal / MPSGraph / ANE / CPU │ │
│ └──────────────────────────────────────────┘ │
└──────────────────────────────────────────────────────────────────────────┘Examples
Two-Layer MLP
let mlir = """
module @mlp {
func.func @main(%x: tensor<32x784xf32>, %w1: tensor<784x256xf32>, %b1: tensor<256xf32>,
%w2: tensor<256x10xf32>, %b2: tensor<10xf32>) -> (tensor<32x10xf32>) {
// Layer 1: y1 = relu(x @ w1 + b1)
%0 = stablehlo.dot %x, %w1 : (tensor<32x784xf32>, tensor<784x256xf32>) -> tensor<32x256xf32>
%1 = stablehlo.broadcast_in_dim %b1, dims = [1] : (tensor<256xf32>) -> tensor<32x256xf32>
%2 = stablehlo.add %0, %1 : tensor<32x256xf32>
%zero1 = stablehlo.constant dense<0.0> : tensor<32x256xf32>
%3 = stablehlo.maximum %2, %zero1 : tensor<32x256xf32>
// Layer 2: y2 = x1 @ w2 + b2
%4 = stablehlo.dot %3, %w2 : (tensor<32x256xf32>, tensor<256x10xf32>) -> tensor<32x10xf32>
%5 = stablehlo.broadcast_in_dim %b2, dims = [1] : (tensor<10xf32>) -> tensor<32x10xf32>
%6 = stablehlo.add %4, %5 : tensor<32x10xf32>
return %6 : tensor<32x10xf32>
}
}
"""
// Compile with O3 for production
let config = CompilationConfig.release
let executable = try client.compile(mlir, config: config)Transformer Attention (Auto-Fused)
// MetalHLO automatically detects and fuses this attention pattern
let mlir = """
module @attention {
func.func @main(%q: tensor<1x8x512x64xf32>, %k: tensor<1x8x512x64xf32>,
%v: tensor<1x8x512x64xf32>) -> (tensor<1x8x512x64xf32>) {
// Transpose K
%kt = stablehlo.transpose %k, dims = [0, 1, 3, 2] : ... -> tensor<1x8x64x512xf32>
// Q @ K^T
%scores = stablehlo.dot_general %q, %kt, ... -> tensor<1x8x512x512xf32>
// Scale by 1/sqrt(d_k)
%scale = stablehlo.constant dense<0.125> : tensor<f32>
%scaled = stablehlo.multiply %scores, %scale : tensor<1x8x512x512xf32>
// Softmax
%max = stablehlo.reduce %scaled, max over [3] : tensor<1x8x512x1xf32>
%shifted = stablehlo.subtract %scaled, %max : tensor<1x8x512x512xf32>
%exp = stablehlo.exponential %shifted : tensor<1x8x512x512xf32>
%sum = stablehlo.reduce %exp, add over [3] : tensor<1x8x512x1xf32>
%probs = stablehlo.divide %exp, %sum : tensor<1x8x512x512xf32>
// Attention @ V
%out = stablehlo.dot_general %probs, %v, ... -> tensor<1x8x512x64xf32>
return %out : tensor<1x8x512x64xf32>
}
}
"""
// With O2+, this pattern is detected and fused into a single optimized kernel
let exe = try client.compile(mlir, config: .release)Testing
Quick Start
Run the core unit tests (fast, ~600 tests):
swift test --filter 'MetalHLOCoreTests'Run the full test suite:
swift testRecommended: Serial Execution
For the most reliable test runs, use the --no-parallel flag. This prevents Metal/MPSGraph resource conflicts that can occur when multiple GPU operations run concurrently:
swift test --no-parallelRunning Specific Tests
# By test suite name
swift test --filter "Binary"
swift test --filter "Reduction"
swift test --filter "CAPITests"
swift test --filter "AlgebraicSimplifier"
# PJRT plugin tests
swift test --filter 'PJRTPluginTests'
# Conformance tests
swift test --filter 'OfficialInterpretTests'
swift test --filter 'Optimization'
# Specific operation tests
swift test --filter 'testAdd'
swift test --filter 'scatter'Test Organization
| Target | Description | Test Count | |--------|-------------|------------| | MetalHLOCoreTests | Core compiler and optimizer tests | ~586 | | MetalHLOTests | Integration and conformance tests | ~400+ | | PJRTMetalHLOTests | PJRT plugin API and execution tests | 10 |
Test Coverage
| Test Suite | Tests | |------------|-------| | Lexer & Parser | 25 | | Binary Operations | 7 | | Unary Operations | 12 | | Matrix Operations | 10 | | Activation Functions | 35 | | CNN Operations | 10 | | Control Flow | 8 | | Custom Call Handlers | 46 | | Optimizer Passes | 50 | | C API | 15 | | Integration Tests | 30 | | PJRT Plugin | 10 | | StableHLO Conformance | 191 | | Total | 449 |
Limitations
Execution Model
- Static Shapes Only — Dynamic shapes require shape inference
- Single Device — No multi-device or distributed execution
- Apple Silicon Only — Intel Macs with AMD GPUs are not supported
- macOS Only — iOS/iPadOS support is a future consideration
Unsupported Types
The following types are not supported due to Metal/MPS limitations:
| Type | Status | Reason | |------|--------|--------| | i2, i4, ui2, ui4 | Not supported | Incompatible overflow semantics when promoted | | complex64, complex128 | Partial | Basic operations work; full arithmetic not supported by MPS | | f4E2M1FN, f6E2M3FN, etc. | Not supported | Exotic float types not supported by Metal | | 64-bit integer bitwise | Limited | MPS doesn't support 64-bit integer bitwise operations | | Distributed/collective ops | Not supported | Single-device focus (all_gather, all_reduce, etc.) | | Token operations | Not supported | (after_all, infeed, outfeed) - I/O primitives |
Operation-Specific Type Restrictions
Some operations only support floating-point types in MPSGraph:
| Operation | Supported Types | Excluded Types | |-----------|-----------------|----------------| | dot, dot_general | f16, f32, f64, bf16 | All integer types (MPSGraph matmul is float-only) | | convolution | f16, f32, f64, bf16 | All integer types (MPSGraph conv is float-only) | | fft | f16, f32, f64, bf16 | All integer types (MPSGraph FFT is float-only) | | triangular_solve | Not implemented | Requires MPSMatrix bridge | | cholesky | Not implemented | Requires MPSMatrix bridge |
Operation-Specific Limitations
| Operation | Status | Notes | |-----------|--------|-------| | gather | Improved | Works with all optimization levels (O0-O3); embedding lookup and index_vector_dim handling; batching dims supported when at leading positions | | scatter | Improved | Supports add/max/min/mul computation modes via MPS; batching dims require leading positions | | convert | Working | Type conversion between numeric types works with all optimization levels | | slice | Working | Static slice extraction with starts/limits/strides | | dynamic_slice | Working | Works when slice_sizes equal input dims (start indices clamped) | | reduce_window | Partial | Works for common patterns; complex regions limited | | convolution | Partial | Standard patterns work; complex dimension permutations may fail | | scatter (region form) | Not supported | Parser does not handle the generic-form region syntax ({ ^bb0(...): ... }) required for scatter's update computation body | | stablehlo.while | Partial | Small loop counts (≤ 1000 iterations) are unrolled inline; large loops fall back to MPSGraph which may crash on complex loop bodies (see below) |
Scatter computation modes:
// Supported via MPS scatter modes
scatter %operand, %indices, %updates, computation = add // Adds update to existing value
scatter %operand, %indices, %updates, computation = max // Takes maximum
scatter %operand, %indices, %updates, computation = min // Takes minimum
scatter %operand, %indices, %updates, computation = mul // Multiplies existing by update
scatter %operand, %indices, %updates // Default: replaces valueBatching dimension support: Gather and scatter parse batching dimension attributes (operand_batching_dims, start_indices_batching_dims for gather; input_batching_dims, scatter_indices_batching_dims for scatter). The implementation includes a transpose-gather/scatter-transpose pattern for arbitrary batch dimension positions. Best tested with batch dimensions at leading positions; complex non-leading configurations may require additional work.
Known Issues Under Investigation
Scatter region parsing: The StableHLO scatter operation requires an inline region body to specify the update computation (e.g., identity/replace or add/accumulate). The MLIR uses a generic form with a region followed by an attribute dictionary:
%1 = "stablehlo.scatter"(%input, %indices, %updates) ({
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
stablehlo.return %arg1 : tensor<f32>
}) { scatter_dimension_numbers = ... }The parser does not currently handle this ({ ... }) region syntax. We are working on extending the parser to support generic-form operations with region bodies.
Large stablehlo.while loops (jax.lax.scan): Programs that use jax.lax.scan with large iteration counts (e.g., 20,000 time steps in a simulation) compile to stablehlo.while loops. MetalHLO handles these in two ways:
- Small loops (≤ 1000 iterations) are unrolled inline into a single flat kernel. This works but produces very large Metal shaders.
- Large loops (> 1000 iterations) fall back to the MPSGraph backend, which has a native while-loop executor. However, the MPSGraph while-loop handler currently crashes on complex loop bodies that include multiple operations (collision, forcing, streaming, boundary conditions).
We are investigating two approaches to resolve this:
- Fix MPSGraph while-loop compilation for complex multi-operation loop bodies.
- Implement a Metal-native loop executor that dispatches the loop body kernel repeatedly from the CPU/command buffer level, avoiding the need to unroll into a single shader or rely on MPSGraph's while-loop support.
StableHLO Conformance
MetalHLO includes a conformance test suite based on the official StableHLO interpreter tests.
Conformance Results
| Metric | Count | |--------|-------| | Total Tests in Suite | 277 | | Tests Run | 191 | | Passed | 191 | | Failed | 0 | | Skipped | 86 (MPS/Metal limitations) |
Running Conformance Tests
# Run all conformance tests (recommended: use --no-parallel for stability)
swift test --filter "Official" --no-parallel
# Run specific operation tests
swift test --filter "OfficialInterpretTests/testAdd"
swift test --filter "OfficialInterpretTests/testTranspose"Type Promotion
To maximize compatibility, MetalHLO automatically promotes unsupported types to supported ones:
| Source Type | Promoted To | Reason | |-------------|-------------|--------| | f64 | f32 | MPS doesn't support f64 | | bf16 | f32 | Better precision for comparison | | f8E3M4, f8E4M3, f8E5M2 | f16 | Exotic float types | | i8, i16, i32, i64 | f32 | Integer arithmetic via float | | ui8, ui16, ui32, ui64 | f32 | Unsigned integers via float |
This allows tests written for other backends to run correctly on Metal.
Skipped Tests (Fundamental Limitations)
The following test categories are skipped due to fundamental MPS/Metal limitations:
| Category | Reason | |----------|--------| | Complex types (complex64, complex128) | MPS doesn't support complex arithmetic | | Small integers (i2, i4, ui2, ui4) | Incompatible overflow semantics when promoted | | Integer matmul (dot, dot_general with int types) | MPSGraph matmul only supports floating-point | | Integer convolution | MPSGraph convolution only supports floating-point | | Integer FFT | MPSGraph FFT only supports floating-point | | Integer overflow tests | Float arithmetic doesn't wrap on overflow like integers | | Integer division tests | Float division doesn't truncate like integer division | | Large integer constants | Values like INT64_MAX can't be exactly represented in Float32 | | Exotic float types (f4E2M1FN, etc.) | Not supported by Metal | | Unsigned integer bitwise ops | MPS treats all integers as signed | | 64-bit integer bitwise ops | MPS doesn't support 64-bit integer bitwise operations |
Performance Tips
- Use O3 for production — Aggressive fusion significantly reduces memory bandwidth
- Batch your inputs — Larger batch sizes amortize kernel launch overhead
- Reuse executables — Compile once, execute many times
- Enable caching — Repeated compilations of the same MLIR return cached results
- Profile with timing — Use
executeWithTiming()to identify bottlenecks
Benchmarks
MetalHLO includes a comprehensive benchmarking framework with multi-backend comparison across four execution paths.
### Multi-Backend Performance Comparison
All results measured on **Apple M1 (8 GB)**, macOS 15.6, release build, quick mode (3 warmup, 10 measurements).
Times are mean in milliseconds. **Bold** indicates the fastest backend for each benchmark.
#### Matrix Operations
| Benchmark | Description | MPSGraph | Metal O2 | Metal O3 | GPU+ANE | Best vs MPSGraph |
|-----------|-------------|----------|----------|----------|---------|-------------------|
| MAT-DOT-001 | GEMM 128x128 | 0.30 | **0.17** | 0.24 | 0.25 | 1.72x (O2) |
| MAT-DOT-002 | GEMM 512x512 | 1.03 | **0.77** | 1.71 | 0.85 | 1.34x (O2) |
| MAT-DOT-003 | GEMM 1024x1024 | **4.17** | 5.10 | 5.32 | 4.53 | MPSGraph best |
| MAT-DOT-004 | GEMM 2048x2048 | **19.55** | 23.00 | 22.63 | 22.17 | MPSGraph best |
| MAT-DOT-005 | GEMM 4096x4096 | **122.7** | 226.1 | 227.6 | 234.4 | MPSGraph best |
| MAT-DOT-006 | Transformer 32x4096x768 | 2.36 | 1.22 | 1.17 | **1.10** | 2.14x (ANE) |
| MAT-DOT-007 | MLP 128x768x3072 | 3.11 | 1.90 | 2.06 | **1.63** | 1.91x (ANE) |
| MAT-DOT-008 | Matvec 1x4096x4096 | **3.93** | 10.91 | 11.68 | 10.94 | MPSGraph best |
| MAT-BATCH-001 | Batched 8x512x512 | 6.10 | 5.41 | **4.69** | 5.58 | 1.30x (O3) |
| MAT-BATCH-002 | Batched 4x1024x1024 | 4.59 | **2.85** | 3.03 | 2.99 | 1.61x (O2) |
| MAT-BATCH-003 | Attention heads | 1.61 | 1.12 | **0.67** | 1.28 | 2.42x (O3) |
| MAT-BATCH-004 | Multi-head attention | 0.81 | 0.38 | 0.38 | **0.36** | 2.29x (ANE) |
| MAT-TR-001 | Transpose 1024x1024 | 1.45 | 0.84 | **0.50** | 1.27 | 2.87x (O3) |
| MAT-TR-002 | Transpose 3D 32x128x64 | 0.49 | 0.65 | **0.25** | 0.56 | 1.99x (O3) |
| MAT-RSH-001 | Reshape flatten 1024x1024 | 1.34 | 0.53 | **0.48** | 0.57 | 2.80x (O3) |
| MAT-RSH-002 | Reshape batch 32x64x128 | 0.48 | **0.41** | 0.46 | 0.42 | 1.16x (O2) |
**Takeaway:** MPSGraph wins on large GEMMs (≥1024x1024) and matvec where Apple's tuned kernels dominate. Metal O3 wins on batched operations, transpose, and reshape (1.3-2.9x faster). O2 wins on small-to-mid GEMMs. ANE excels on transformer-shaped and MLP matmuls.
#### Element-wise Arithmetic
| Benchmark | Description | MPSGraph | Metal O2 | Metal O3 | GPU+ANE | Best vs MPSGraph |
|-----------|-------------|----------|----------|----------|---------|-------------------|
| ARITH-B-001 | Add 1024x1024 | 2.52 | 0.96 | **0.75** | 0.75 | 3.35x (ANE) |
| ARITH-B-002 | Add 4096x4096 | 12.20 | 10.30 | 16.43 | **8.75** | 1.40x (ANE) |
| ARITH-B-003 | Add 8192x8192 | **57.7** | 169.1 | 138.5 | 148.1 | MPSGraph best |
| ARITH-B-004 | Mul 1024x1024 | 2.66 | **0.79** | 0.83 | 0.88 | 3.38x (O2) |
| ARITH-B-005 | Mul 4096x4096 | 11.91 | **10.12** | 10.43 | 13.97 | 1.18x (O2) |
| ARITH-B-006 | Div 1024x1024 | 2.16 | 0.97 | **0.80** | 1.15 | 2.69x (O3) |
| ARITH-B-007 | Pow 1024x1024 | 2.52 | **0.97** | 0.98 | 1.03 | 2.59x (O2) |
| ARITH-B-008 | Max 4096x4096 | 11.85 | 12.46 | 9.42 | **8.32** | 1.42x (ANE) |
| ARITH-U-001 | Exp 1024x1024 | 1.23 | 0.61 | **0.57** | 0.85 | 2.14x (O3) |
| ARITH-U-002 | Log 4096x4096 | 8.07 | 5.01 | **4.90** | 5.08 | 1.65x (O3) |
| ARITH-U-003 | Tanh 1024x1024 | 1.54 | **0.52** | 0.52 | 0.81 | 2.96x (O3) |
| ARITH-U-004 | Sqrt 4096x4096 | 8.02 | **4.61** | 4.79 | 5.30 | 1.74x (O2) |
| ARITH-U-005 | Rsqrt 4096x4096 | 8.71 | 5.65 | 5.74 | **5.57** | 1.56x (ANE) |
| ARITH-U-006 | Sigmoid 1024x1024 | 1.34 | 0.70 | **0.70** | 0.79 | 1.93x (O3) |
| ARITH-BC-001 | Add row broadcast | 1.35 | 1.34 | **0.93** | 1.10 | 1.45x (O3) |
| ARITH-BC-002 | Add scalar broadcast | 1.29 | 1.00 | 0.99 | **0.81** | 1.58x (ANE) |
| ARITH-BC-003 | Mul last-dim broadcast | 0.48 | **0.39** | 0.44 | 0.54 | 1.23x (O2) |
**Takeaway:** Metal backends are **1.2-3.4x faster** than MPSGraph on element-wise operations. The advantage is largest on 1024x1024 tensors (2.1-3.4x); on 4096x4096 tensors MPSGraph has improved significantly (macOS 15.6), narrowing the gap to 1.2-1.7x. O2, O3, and ANE all trade wins.
#### Reduction Operations
| Benchmark | Description | MPSGraph | Metal O2 | Metal O3 | GPU+ANE | Best vs MPSGraph |
|-----------|-------------|----------|----------|----------|---------|-------------------|
| RED-001 | Global sum 1024x1024 | 0.91 | 0.96 | 0.89 | **0.86** | 1.06x (ANE) |
| RED-002 | Row-wise sum 1024x1024 | 1.03 | **0.87** | 1.09 | 0.95 | 1.18x (O2) |
| RED-003 | Column-wise sum 1024x1024 | 0.90 | **0.82** | 0.83 | 0.84 | 1.10x (O2) |
| RED-004 | Row-wise max 4096x4096 | **4.21** | 6.25 | 5.91 | 6.17 | MPSGraph best |
| RED-005 | LayerNorm reduction 32x128x768 | 2.72 | 1.29 | **0.95** | 1.09 | 2.86x (O3) |
| RED-006 | Attention reduction 32x12x512x512 | **22.49** | 47.37 | 46.81 | 37.49 | MPSGraph best |
**Takeaway:** Metal backends win on 1024x1024 reductions (1.1-1.2x). O3 excels on LayerNorm reduction (2.9x). On larger reductions (4096x4096 and attention-shaped), MPSGraph wins thanks to macOS 15.6 improvements.
#### Convolution
| Benchmark | Description | MPSGraph | Metal O2 | Metal O3 | GPU+ANE | Best vs MPSGraph |
|-----------|-------------|----------|----------|----------|---------|-------------------|
| CONV-001 | ResNet first layer | **1.92** | 1.91 | 2.24 | 2.78 | ~1.00x |
| CONV-002 | ResNet stage2 3x3 | **1.16** | 3.36 | 3.09 | 1.16 | MPSGraph best |
| CONV-003 | ResNet stage3 3x3 | **1.06** | 3.64 | 3.15 | 1.36 | MPSGraph best |
| CONV-004 | ResNet stage4 3x3 | **1.73** | 3.45 | 2.77 | 2.01 | MPSGraph best |
| CONV-005 | Batched conv | **11.79** | 34.94 | 34.42 | 20.22 | MPSGraph best |
| CONV-006 | 1x1 pointwise | 0.83 | 0.71 | **0.68** | 1.75 | 1.22x (O3) |
| CONV-007 | Depthwise-like | 1.61 | 4.88 | 5.24 | **1.54** | 1.05x (ANE) |
**Takeaway:** MPSGraph dominates convolutions thanks to Apple's highly optimized `MPSCNNConvolution` kernels. Metal O3 wins on 1x1 pointwise convolutions. ANE occasionally beats MPSGraph on depthwise patterns.
#### Normalization
| Benchmark | Description | MPSGraph | Metal O2 | Metal O3 | GPU+ANE | Best vs MPSGraph |
|-----------|-------------|----------|----------|----------|---------|-------------------|
| NORM-BN-001 | ResNet BN | 0.42 | **0.24** | 0.28 | 0.39 | 1.76x (O2) |
| NORM-BN-002 | Batched ResNet BN | 8.06 | 3.83 | **2.43** | FAIL | 3.32x (O3) |
| NORM-BN-003 | Mid-layer BN | 0.35 | **0.18** | 0.21 | 0.19 | 1.90x (O2) |
| NORM-BN-004 | Late-layer BN | 0.29 | **0.17** | 0.38 | 0.19 | 1.70x (O2) |
| NORM-LN-001 | BERT-base LayerNorm | 0.57 | **0.52** | 0.56 | 0.59 | 1.09x (O2) |
| NORM-LN-002 | BERT-base batched LN | **4.42** | 6.17 | 6.87 | 6.46 | MPSGraph best |
| NORM-LN-003 | BERT-large single LN | **1.36** | 1.77 | 2.17 | 1.63 | MPSGraph best |
| NORM-LN-004 | Long sequence LN | **15.67** | 20.90 | 21.10 | 23.02 | MPSGraph best |
**Takeaway:** Metal O2 excels at batch normalization (1.7-1.9x faster). O3 wins on large batched BN (3.3x). MPSGraph wins on layer normalization, where its native implementation is well-optimized.
#### Transformer Components
| Benchmark | Description | MPSGraph | Metal O2 | Metal O3 | GPU+ANE | Best vs MPSGraph |
|-----------|-------------|----------|----------|----------|---------|-------------------|
| XFMR-INF-001 | Self-attention seq=128 | 2.76 | 2.51 | **2.19** | 2.41 | 1.26x (O3) |
| XFMR-INF-002 | Self-attention seq=512 | **7.61** | 7.85 | 8.28 | 7.82 | MPSGraph best |
| XFMR-INF-003 | Self-attention BS=8 seq=128 | **8.80** | 9.45 | 9.43 | 9.22 | MPSGraph best |
| XFMR-INF-004 | Transformer FFN BS=8 | **11.27** | 18.24 | 16.71 | 16.02 | MPSGraph best |
| XFMR-INF-005 | Softmax 8x12x128x128 | 2.53 | 2.34 | 2.29 | **2.29** | 1.11x (ANE) |
| XFMR-INF-006 | Encoder block BS=1 seq=128 | 7.71 | **7.46** | 8.48 | 7.84 | 1.03x (O2) |
**Takeaway:** MPSGraph wins on 3 of 6 transformer workloads (larger sequence lengths, FFN, batched attention). O3 excels on self-attention (1.3x). ANE is competitive on softmax.
#### MLP Inference
| Benchmark | Description | MPSGraph | Metal O2 | Metal O3 | GPU+ANE | Best vs MPSGraph |
|-----------|-------------|----------|----------|----------|---------|-------------------|
| MLP-INF-001 | 784->256->10 BS=1 | 0.64 | **0.45** | 0.45 | 0.80 | 1.43x (O2) |
| MLP-INF-002 | 784->256->10 BS=32 | **0.41** | 0.43 | 0.44 | 0.60 | MPSGraph best |
| MLP-INF-003 | 784->256->10 BS=128 | 0.60 | 0.72 | 0.47 | **0.45** | 1.35x (ANE) |
| MLP-INF-004 | Deep MLP 4-layer BS=32 | 1.05 | 0.91 | 0.80 | **0.75** | 1.40x (ANE) |
| MLP-INF-005 | FFN 768->3072->768 BS=32 | 4.35 | 1.77 | **1.58** | 1.84 | 2.75x (O3) |
**Takeaway:** Metal backends are 1.4-2.8x faster than MPSGraph on MLPs. O3 wins the large FFN (2.8x). ANE excels on deep and batched MLPs. MPSGraph is fastest only on small batch sizes.
#### Training
| Benchmark | Description | MPSGraph | Metal O2 | Metal O3 | GPU+ANE | Best vs MPSGraph |
|-----------|-------------|----------|----------|----------|---------|-------------------|
| TRAIN-001 | MLP fwd+bwd BS=32 | **0.69** | 0.91 | 0.90 | 1.00 | ~1.00x |
| TRAIN-003 | Attention fwd+bwd BS=8 | **7.25** | 10.29 | 10.82 | 11.12 | MPSGraph best |
**Takeaway:** MPSGraph leads on training workloads where its backward-pass graph optimization provides an advantage.
### Backend Win Summary
Overall wins across all 67 passing benchmarks:
| Backend | Wins | Best For |
|---------|------|----------|
| **MPSGraph** | 20 | Large GEMMs, convolutions, layer norm, large reductions, training |
| **Metal O2** | 18 | Batch norm, small matmuls, element-wise ops |
| **Metal O3** | 16 | Transpose, reshape, batched ops, softmax, FFN fusion |
| **GPU+ANE** | 13 | Element-wise, MLP inference, transformer-shaped matmuls |
### When to Use Each Backend
| Use Case | Recommended Backend | Why |
|----------|-------------------|-----|
| Large matrix multiply (≥1024x1024) | MPSGraph (default) | Apple's tuned `MPSMatrixMultiplication` |
| Convolution-heavy models (CNNs) | MPSGraph (default) | `MPSCNNConvolution` is highly optimized |
| Training (forward + backward) | MPSGraph (default) | Graph-level backward pass optimization |
| Element-wise heavy workloads | Metal O2/O3 | 1.2-3.4x faster than MPSGraph |
| Batch normalization | Metal O2 | 1.7-1.9x faster custom kernels |
| Transpose / reshape / batched ops | Metal O3 | Pattern fusion + scheduling (2-3x faster) |
| MLP / FFN inference | Metal O3 or ANE | O3 wins large FFN (2.8x), ANE wins deep MLPs |
| Debugging/development | MPSGraph (default) | Broadest compatibility, no compilation |
### Benchmark Categories
| Category | Benchmarks | Description |
|----------|------------|-------------|
| **Matrix Operations** | MAT-DOT-001 to MAT-DOT-008, MAT-BATCH-001 to MAT-BATCH-004, MAT-TR-*, MAT-RSH-* | GEMM, batched GEMM, transpose, reshape |
| **Arithmetic** | ARITH-B-001 to ARITH-B-008, ARITH-BC-*, ARITH-U-* | Binary, broadcast, and unary operations |
| **Reduction** | RED-001 to RED-009 | Sum, max, mean, pooling operations |
| **Convolution** | CONV-001 to CONV-007 | Standard conv2d patterns (ResNet, VGG) |
| **Normalization** | NORM-BN-*, NORM-LN-* | Batch norm, layer norm |
| **Control Flow** | CF-001 to CF-005 | While loops, conditionals |
| **Indexing** | IDX-001 to IDX-007 | Slice, gather, scatter, pad |
| **Model Inference** | MLP-INF-*, CNN-INF-*, XFMR-INF-* | MLP, CNN, Transformer components |
| **Training** | TRAIN-001 to TRAIN-003 | Forward + backward pass benchmarks |
| **Compiler Analysis** | COMP-001 to COMP-005 | Compilation time for various program sizes |
| **Fusion Analysis** | FUSION-001 to FUSION-004 | Fused vs naive execution comparison |
| **Memory** | MEM-001 to MEM-003 | Peak allocation, buffer reuse |
| **Power Efficiency** | PWR-001 to PWR-003 | Throughput per watt estimates |
### Running Benchmarks
```bash
# Build in release mode for accurate measurements
swift build -c release --product benchmark-runner
# Multi-backend comparison (recommended)
.build/release/benchmark-runner --compare -q -c matrix # Matrix operations
.build/release/benchmark-runner --compare -q -c arithmetic # Element-wise ops
.build/release/benchmark-runner --compare -q -c model_transformer # Transformer
.build/release/benchmark-runner --compare -q # All categories
# Single-backend benchmarks
.build/release/benchmark-runner --category matrix
.build/release/benchmark-runner --all
# MLX comparison (requires MLX)
swift build -c release --product mlx-comparison
.build/release/mlx-comparison --quick
.build/release/mlx-comparison --category matrix
# Heterogeneous fusion benchmarks (GPU+MPS+CPU partitioning)
swift build -c release --product HeterogeneousFusionBenchmark
.build/release/HeterogeneousFusionBenchmark
# GPT-2 end-to-end validation (profitability guard + crossover analysis)
swift build -c release --product GPT2EndToEnd
.build/release/GPT2EndToEnd
# Hardware calibration test
swift build -c release --product CalibrateTest
.build/release/CalibrateTest
```
### Benchmark Framework Features
The benchmark framework provides:
- **Timing Statistics**: Mean, std dev, min, max, p95, p99
- **GPU Synchronization**: Accurate timing with Metal command buffer completion
- **Warmup Support**: Configurable warmup iterations to avoid cold-start effects
- **Reproducibility**: Seeded random data generators
- **Multiple Output Formats**: Console, JSON, CSV
```swift
// Running benchmarks programmatically
import MetalHLOBenchmarks
let config = BenchmarkConfig(warmupIterations: 10, measurementIterations: 50)
let runner = try BenchmarkRunner(config: config)
// Run all matrix benchmarks
let results = try runner.run(OperationBenchmarks.matrixBenchmarks())
// Access timing statistics
for result in results {
print("\(result.id): \(result.timing.mean * 1000)ms ± \(result.timing.stdDev * 1000)ms")
}
```
### GPU Utilization Metrics
The framework includes utilities for measuring GPU efficiency:
```swift
import MetalHLOBenchmarks
// Detect hardware and calculate utilization
let calculator = GPUMetricsCalculator()
// After running a 1024x1024 matmul in 2.5ms
let metrics = calculator.calculateMatMulMetrics(
m: 1024, n: 1024, k: 1024,
executionTimeSeconds: 0.0025
)
print(metrics.formatted())
// Hardware: Apple M1
// Compute: 0.86 / 2.60 TFLOPS (33.1% utilization)
// Memory BW: 6.7 / 68.25 GB/s (9.8% utilization)
```
### Fusion Effectiveness Analysis
Measure the benefit of operation fusion:
```swift
import MetalHLOBenchmarks
let runner = FusionAnalysisRunner()
let results = try runner.runAll()
for result in results {
print(result.formatted())
// FUSION-001: Fused=1.23ms, Naive=3.45ms, Speedup=2.80x (expected: 2-3x)
}
```License
Apache License 2.0 - see LICENSE for details.
References
- StableHLO Specification
- OpenXLA/XLA Compiler
- MLX by Apple ML Research
- MPSGraph Documentation
- PJRT C API Specification
- MLIR Language Reference
MetalHLO — Bringing StableHLO to Apple Silicon
Package Metadata
Repository: pedronahum/metalhlo
Default branch: main
README: README.md