Writing GPU Kernels
for Apple Silicon with MLX
A comprehensive guide to Metal Shading Language, Apple's GPU architecture, and how MLX lets you write high-performance compute kernels that run natively on M-series chips — no CUDA required.
- Why Apple GPUs Are Different
- The MLX Software Stack
- Metal Shading Language Basics
- Thread Hierarchy & Memory Model
- Your First GPU Kernel
- NVIDIA vs Apple Mapping
- Elementwise & Reduction Kernels
- Matrix Operations & GEMM
- Fused Attention Kernels
- C++ Runtime & Dispatch
- Performance Tuning
- The MLX Engineer's Workflow
Why Apple GPUs Are Different
If you've worked with NVIDIA's CUDA ecosystem, you're accustomed to a clean separation: the CPU lives on one side of a PCIe bus, the GPU on the other, and data constantly ferries between them via expensive DMA transfers. Apple blew that model up.
On M-series chips, the CPU, GPU, Neural Engine, and memory subsystem all share the same physical die. There is no discrete memory pool. There is no bus to cross. The GPU reads from the same LPDDR5X pool that the CPU uses — and does so at extraordinarily high bandwidth (up to 400 GB/s on M3 Ultra). This single architectural decision changes everything about how you write kernels.
Tile-Based Deferred Rendering
Apple GPUs use TBDR architecture. The screen (or compute output) is divided into tiles processed entirely on-chip before being written back to memory, minimizing bandwidth.
Unified Memory Architecture
CPU and GPU share the same physical memory. Zero-copy tensor sharing is possible — mx.array lives in memory accessible to both without explicit transfers.
SIMD Groups
Apple's equivalent of CUDA warps. Groups of 32 threads execute in lockstep and can share data via SIMD-group instructions without touching threadgroup memory.
Threadgroup Memory
On-chip scratchpad local to a threadgroup. Fast, software-managed, analogous to CUDA shared memory. Finite size — spill to device memory and latency skyrockets.
Apple Silicon uses a harvested GPU die integrated directly with the CPU fabric. M1 has up to 8 GPU cores; M3 Max reaches 40; M2 Ultra pushes 76. Each GPU core contains multiple SIMD engines, each running 32 lanes wide — the Apple equivalent of a streaming multiprocessor (SM) in NVIDIA terms.
The MLX Software Stack
MLX is Apple's open-source array framework for machine learning on Apple Silicon. It is not a thin wrapper around PyTorch. It is a ground-up reimplementation with lazy evaluation, a functional transformation system, and its own GPU backend built directly on Metal. Understanding the layers helps you know exactly where your kernel code lives.
The Python Layer
Users write high-level array operations. MLX builds a computation graph lazily — nothing executes until a value is actually needed.
import mlx.core as mx # Lazy — no GPU work yet a = mx.ones((1024, 1024), dtype=mx.float16) b = mx.ones((1024, 1024), dtype=mx.float16) c = mx.matmul(a, b) # Graph node created # Evaluation triggers GPU dispatch mx.eval(c) print(c.shape) # (1024, 1024)
The C++ Runtime
The majority of MLX — perhaps 60–70% by line count — is C++. The runtime manages tensor lifetimes, constructs and optimises the computation graph, fuses operations into single kernel launches where profitable, and ultimately calls into Metal to dispatch work to the GPU.
// Simplified version of how MLX defines a primitive operation class MatMul : public Primitive { public: MatMul(Stream stream) : Primitive(stream) {} // Forward pass — schedules the Metal kernel void eval_gpu( const std::vector<array>& inputs, std::vector<array>& outputs ) override; // VJP for automatic differentiation std::vector<array> vjp( const std::vector<array>& primals, const std::vector<array>& cotangents, const std::vector<int>& argnums, const std::vector<array>& outputs ) override; };
Metal API & Pipelines
Metal is Apple's low-level GPU API — analogous to Vulkan Compute or CUDA Runtime. Kernels are compiled from Metal Shading Language (MSL) source either ahead-of-time (packaged into a .metallib file) or at runtime. Each kernel function is wrapped in a MTLComputePipelineState, set up with its arguments, and then encoded into a MTLCommandBuffer which is committed to a MTLCommandQueue for execution.
MLX pre-compiles all its built-in Metal kernels into a metallib at build time. When you call mx.matmul(), the C++ runtime looks up the correct pipeline state for the given dtypes and dispatches it immediately — no JIT compilation delay during training.
Metal Shading Language Basics
Metal Shading Language (MSL) is a C++14-derived language with Apple-specific extensions for GPU programming. If you know CUDA C++, you'll feel at home within minutes. The core concepts translate almost 1:1, though the syntax differs.
Address Spaces
MSL has explicit address space qualifiers that tell the compiler — and you — where data lives in the memory hierarchy. Understanding these is non-negotiable for writing correct kernels.
device— Global GPU memory (analogous to CUDA's default pointer address space). Largest, slowest, accessible by all threads.threadgroup— On-chip shared memory local to a threadgroup. Fast, limited size (typically 32 KB per threadgroup).constant— Read-only data cached for uniform access. Ideal for kernel parameters and lookup tables.thread— Private to a single thread. Register-class storage; essentially free.
kernel void memory_spaces_demo( device const float* global_input [[buffer(0)]], device float* global_output [[buffer(1)]], constant float& scale [[buffer(2)]], threadgroup float* shared_tile [[threadgroup(0)]], uint gid [[thread_position_in_grid]], uint lid [[thread_position_in_threadgroup]] ) { // Thread-private register float val = global_input[gid] * scale; // Write into on-chip threadgroup memory shared_tile[lid] = val; threadgroup_barrier(mem_flags::mem_threadgroup); // All threads in group have written; safe to read global_output[gid] = shared_tile[lid]; }
Built-In Variables
MSL uses attribute syntax [[...]] to inject hardware-provided values into kernel arguments. These tell each thread exactly where it sits in the execution hierarchy.
[[thread_position_in_grid]]— Absolute position of this thread in the full dispatch grid. The most common index variable.[[thread_position_in_threadgroup]]— Position within the thread's own threadgroup (local index, likethreadIdxin CUDA).[[threadgroup_position_in_grid]]— Which threadgroup this thread belongs to (likeblockIdx).[[threads_per_threadgroup]]— The size of the threadgroup in each dimension.[[threads_per_simdgroup]]— Always 32 on current Apple GPUs.[[simdgroup_index_in_threadgroup]]— Which SIMD group within the threadgroup.[[thread_index_in_simdgroup]]— Lane index within the SIMD group (0–31).
Thread Hierarchy & Memory Model
Apple GPUs execute threads in a three-level hierarchy that directly mirrors NVIDIA's grid / block / warp model — with different names and slightly different semantics.
| NVIDIA CUDA | Apple Metal | Description |
|---|---|---|
| Grid | Grid | The entire dispatch space |
| Thread Block | Threadgroup | Cooperative group with shared on-chip memory |
| Warp (32 threads) | SIMD Group (32 threads) | Smallest lockstep execution unit |
| Shared Memory | Threadgroup Memory | On-chip scratchpad local to a threadgroup |
| L1/L2 Cache | GPU Cache | Hardware-managed, transparent |
| Global Memory | Device Memory (Unified) | Main memory pool, shared with CPU |
| Constant Memory | Constant Buffer | Read-only, broadcast-cached |
| __syncthreads() | threadgroup_barrier() | Synchronisation within a group |
| Warp shuffle | SIMD-group functions | Cross-lane data exchange without memory |
SIMD-Group Intrinsics
One area where MSL genuinely excels over basic CUDA is its rich set of SIMD-group operations built directly into the language spec. These allow data exchange between the 32 lanes of a SIMD group without touching threadgroup memory at all, saving precious bandwidth.
kernel void simd_reduce_sum( device const float* input [[buffer(0)]], device float* output [[buffer(1)]], uint gid [[thread_position_in_grid]], uint lane [[thread_index_in_simdgroup]], uint sg [[simdgroup_index_in_threadgroup]] ) { float val = input[gid]; // Sum across all 32 lanes — no threadgroup memory needed! val = simd_sum(val); // Only lane 0 of each SIMD group writes the result if (lane == 0) { output[sg] = val; } }
Key SIMD-group functions available in MSL:
simd_sum(val)— Reduces all 32 lanes' values to their sumsimd_max(val)/simd_min(val)— Max/min reduction across the SIMD groupsimd_shuffle(val, lane)— Broadcast a specific lane's value to all lanessimd_shuffle_up(val, delta)/simd_shuffle_down(val, delta)— Shift values by delta lanessimd_prefix_inclusive_sum(val)— Prefix scan across the groupsimd_broadcast(val, lane)— Broadcast from lane to all
Your First GPU Kernel
Let's start with the canonical example — element-wise vector addition — and then build progressively more sophisticated kernels. Every kernel author should be able to write this in their sleep.
Element-Wise Addition
// The 'kernel' qualifier marks this as a compute kernel entry point // (as opposed to 'vertex' or 'fragment' for graphics) kernel void vector_add( device const float* a [[buffer(0)]], device const float* b [[buffer(1)]], device float* out [[buffer(2)]], uint id [[thread_position_in_grid]] ) { out[id] = a[id] + b[id]; } // Half-precision variant — important for ML workloads kernel void vector_add_half( device const half* a [[buffer(0)]], device const half* b [[buffer(1)]], device half* out [[buffer(2)]], uint id [[thread_position_in_grid]] ) { out[id] = a[id] + b[id]; }
Notice the key differences from CUDA immediately:
kernelinstead of__global__devicepointer qualifier instead of bare pointer[[buffer(N)]]attribute to bind to Metal buffer slots[[thread_position_in_grid]]instead ofblockIdx.x * blockDim.x + threadIdx.x- No explicit
Nparameter needed — Metal handles bounds automatically if you dispatch correctly
Dispatching from C++
The kernel file is just the compute shader. You still need C++ to compile the pipeline state and dispatch work. Here is a minimal, self-contained example using the raw Metal API — in MLX, the framework handles this for you, but understanding it is essential for writing custom ops.
// Objective-C++ or C++ with Metal framework headers #includevoid launch_vector_add( MTL::Device* device, float* a_host, float* b_host, float* out_host, size_t n ) { // 1. Create Metal buffers (zero-copy on Apple Silicon!) auto* a_buf = device->newBuffer(a_host, n * sizeof(float), MTL::ResourceStorageModeShared); auto* b_buf = device->newBuffer(b_host, n * sizeof(float), MTL::ResourceStorageModeShared); auto* out_buf = device->newBuffer(n * sizeof(float), MTL::ResourceStorageModeShared); // 2. Load compiled .metallib and get the function auto* library = device->newDefaultLibrary(); auto* fn = library->newFunction( NS::String::string("vector_add", NS::UTF8StringEncoding)); // 3. Build pipeline state (cached; do this once at startup) NS::Error* err = nullptr; auto* pso = device->newComputePipelineState(fn, &err); // 4. Create command buffer and encoder auto* queue = device->newCommandQueue(); auto* cmdbuf = queue->commandBuffer(); auto* encoder = cmdbuf->computeCommandEncoder(); // 5. Bind resources and dispatch encoder->setComputePipelineState(pso); encoder->setBuffer(a_buf, 0, 0); encoder->setBuffer(b_buf, 0, 1); encoder->setBuffer(out_buf, 0, 2); MTL::Size threads_per_tg = MTL::Size(256, 1, 1); MTL::Size num_tgs = MTL::Size((n + 255) / 256, 1, 1); encoder->dispatchThreadgroups(num_tgs, threads_per_tg); encoder->endEncoding(); cmdbuf->commit(); cmdbuf->waitUntilCompleted(); // or use completion handlers }
The pipeline state object (MTLComputePipelineState) is expensive to create. Always create it once at initialization and cache it. MLX does this via its internal kernel registry which maps operation signatures to pre-built pipeline states.
NVIDIA vs Apple: Side-by-Side
For anyone migrating from CUDA, here is a direct side-by-side comparison of the same softmax kernel written in both CUDA C++ and Metal Shading Language. The algorithmic logic is identical; only the vocabulary changes.
__global__ void softmax_kernel( const float* __restrict__ input, float* __restrict__ output, int n ) { __shared__ float shmem[256]; int tid = blockIdx.x * blockDim.x + threadIdx.x; int lid = threadIdx.x; float val = (tid < n) ? input[tid] : -INFINITY; shmem[lid] = val; __syncthreads(); // Warp-level max reduction for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (lid < s) shmem[lid] = max(shmem[lid], shmem[lid+s]); __syncthreads(); } float row_max = shmem[0]; float exp_val = __expf(val - row_max); shmem[lid] = exp_val; __syncthreads(); for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (lid < s) shmem[lid] += shmem[lid+s]; __syncthreads(); } float row_sum = shmem[0]; if (tid < n) output[tid] = exp_val / row_sum; }
kernel void softmax_kernel( device const float* input [[buffer(0)]], device float* output [[buffer(1)]], constant uint& n [[buffer(2)]], threadgroup float* shmem [[threadgroup(0)]], uint gid [[thread_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint lane [[thread_index_in_simdgroup]] ) { float val = (gid < n) ? input[gid] : -INFINITY; shmem[lid] = val; threadgroup_barrier(mem_flags::mem_threadgroup); // Threadgroup-level reduction via SIMD intrinsics float row_max = simd_max(val); // Faster than manual loop! // For multi-SIMD-group threadgroups, do a second pass in shmem if (lane == 0) shmem[lid / 32] = row_max; threadgroup_barrier(mem_flags::mem_threadgroup); if (lid == 0) { for (uint i = 1; i < 8; ++i) shmem[0] = max(shmem[0], shmem[i]); } threadgroup_barrier(mem_flags::mem_threadgroup); row_max = shmem[0]; float exp_val = exp(val - row_max); float row_sum = simd_sum(exp_val); if (lane == 0) shmem[lid / 32] = row_sum; threadgroup_barrier(mem_flags::mem_threadgroup); if (lid == 0) { for (uint i = 1; i < 8; ++i) shmem[0] += shmem[i]; } threadgroup_barrier(mem_flags::mem_threadgroup); row_sum = shmem[0]; if (gid < n) output[gid] = exp_val / row_sum; }
The SIMD-group intrinsics (simd_max, simd_sum) replace the inner warp-reduction loops entirely, making the MSL code more expressive at this level. Apple's compiler maps these to efficient hardware instructions on SIMD engines.
Elementwise & Reduction Kernels
Template-Based Elementwise Kernels
MLX uses C++ templates extensively to avoid duplicating kernel code for every dtype. The pattern is: write one templated MSL kernel, instantiate it for every type combination you need.
// Generic elementwise binary op template template <typename T, typename Op> kernel void binary_op( device const T* a [[buffer(0)]], device const T* b [[buffer(1)]], device T* out [[buffer(2)]], uint id [[thread_position_in_grid]] ) { out[id] = Op{}(a[id], b[id]); } // Operation structs — inlined at compile time struct Add { template<typename T> T operator()(T x, T y) { return x + y; } }; struct Mul { template<typename T> T operator()(T x, T y) { return x * y; } }; struct Sub { template<typename T> T operator()(T x, T y) { return x - y; } }; struct Divide { template<typename T> T operator()(T x, T y) { return x / y; } }; struct Maximum{ template<typename T> T operator()(T x, T y) { return max(x, y); } }; // Explicit instantiations for the metallib template [[host_name("binary_add_f32")]] kernel void binary_op<float, Add>(device const float* , device const float* , device float* , uint); template [[host_name("binary_add_f16")]] kernel void binary_op<half, Add>(device const half* , device const half* , device half* , uint); template [[host_name("binary_mul_f32")]] kernel void binary_op<float, Mul>(device const float* , device const float* , device float* , uint);
Vectorised Loads
Apple GPUs support 128-bit vector loads. Processing 4 floats (or 8 halfs) in a single instruction dramatically increases memory throughput. This is critical for memory-bandwidth-bound kernels.
kernel void add_f32_vec4( device const float4* a [[buffer(0)]], device const float4* b [[buffer(1)]], device float4* out [[buffer(2)]], uint id [[thread_position_in_grid]] ) { // Each thread processes 4 floats in one 128-bit load out[id] = a[id] + b[id]; // SIMD vector arithmetic } // For half-precision: 8 elements per thread with float4-equivalent kernel void add_f16_vec8( device const half8* a [[buffer(0)]], device const half8* b [[buffer(1)]], device half8* out [[buffer(2)]], uint id [[thread_position_in_grid]] ) { out[id] = a[id] + b[id]; }
Matrix Operations & Tiled GEMM
General matrix multiplication (GEMM) is the backbone of every deep learning model. A naïve implementation that reads directly from device memory is bandwidth-limited and catastrophically slow. The solution is tiled GEMM: load sub-matrices into fast threadgroup memory, compute a tile of the output, and accumulate.
Device memory bandwidth on M3 Max is around 300 GB/s. Threadgroup memory sits on-chip and is effectively unlimited bandwidth within a threadgroup. A well-tiled GEMM can achieve 15–25× more effective throughput than the naïve version by maximising data reuse in threadgroup memory.
constant uint TILE_SIZE = 16; kernel void matmul_tiled( device const float* A [[buffer(0)]], // [M, K] device const float* B [[buffer(1)]], // [K, N] device float* C [[buffer(2)]], // [M, N] constant uint& M [[buffer(3)]], constant uint& N [[buffer(4)]], constant uint& K [[buffer(5)]], threadgroup float* tileA [[threadgroup(0)]], // TILE x TILE threadgroup float* tileB [[threadgroup(1)]], // TILE x TILE uint2 tgid [[threadgroup_position_in_grid]], uint2 lid [[thread_position_in_threadgroup]] ) { uint row = tgid.y * TILE_SIZE + lid.y; uint col = tgid.x * TILE_SIZE + lid.x; float acc = 0.0f; // Iterate over tiles along the K dimension for (uint t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; ++t) { // Cooperatively load a tile of A into threadgroup memory uint aRow = row; uint aCol = t * TILE_SIZE + lid.x; tileA[lid.y * TILE_SIZE + lid.x] = (aRow < M && aCol < K) ? A[aRow * K + aCol] : 0.0f; // Cooperatively load a tile of B uint bRow = t * TILE_SIZE + lid.y; uint bCol = col; tileB[lid.y * TILE_SIZE + lid.x] = (bRow < K && bCol < N) ? B[bRow * N + bCol] : 0.0f; // Synchronise: all threads must finish loading before compute threadgroup_barrier(mem_flags::mem_threadgroup); // Compute dot product of this tile — all reads from fast shmem for (uint k = 0; k < TILE_SIZE; ++k) { acc += tileA[lid.y * TILE_SIZE + k] * tileB[k * TILE_SIZE + lid.x]; } // Synchronise again before next tile load overwrites shmem threadgroup_barrier(mem_flags::mem_threadgroup); } if (row < M && col < N) C[row * N + col] = acc; }
SIMD Matrix Operations (simdgroup_matrix)
For even higher throughput, MSL exposes hardware matrix multiply-accumulate (MMA) units via simdgroup_matrix. These are Apple's equivalent of CUDA's tensor cores — purpose-built for small matrix FMAs that run at much higher throughput than scalar ALU chains.
// Using Apple's simdgroup_matrix for tensor-core-equivalent throughput // Operates on 8x8 tiles; each SIMD group (32 threads) handles one tile kernel void matmul_simdgroup( device const half* A [[buffer(0)]], device const half* B [[buffer(1)]], device float* C [[buffer(2)]], constant uint& M [[buffer(3)]], constant uint& N [[buffer(4)]], constant uint& K [[buffer(5)]], uint2 tgid [[threadgroup_position_in_grid]] ) { // Accumulator in float32 (mixed precision) simdgroup_matrix<float, 8, 8> acc; simdgroup_matrix_initialize(0, acc); simdgroup_matrix<half, 8, 8> matA; simdgroup_matrix<half, 8, 8> matB; ulong2 origin_A = ulong2(0, tgid.y * 8); ulong2 origin_B = ulong2(tgid.x * 8, 0); for (uint k = 0; k < K; k += 8) { // Load 8x8 tile of A from device memory directly simdgroup_load(matA, A, N, ulong2(k, tgid.y * 8)); simdgroup_load(matB, B, N, ulong2(tgid.x * 8, k)); // Hardware MMA: acc += matA * matB simdgroup_multiply_accumulate(acc, matA, matB, acc); } // Store 8x8 result tile simdgroup_store(acc, C, N, ulong2(tgid.x * 8, tgid.y * 8)); }
Fused Attention Kernels
Flash Attention — the technique of computing multi-head attention without materialising the full N×N attention score matrix in device memory — is the most important fused kernel in modern LLM inference. Here is how you implement it in MSL for Apple GPUs.
The key insight: by processing the sequence in blocks and accumulating a running softmax (tracking the max and sum incrementally), we only ever need to hold one block of Q, K, V in threadgroup memory at a time. Memory usage drops from O(N²) to O(N).
// Flash Attention forward pass — O(N) memory, tiled over sequence constant uint BLOCK_SIZE = 64; // KV block size (tune for your hardware) kernel void flash_attention( device const half* Q [[buffer(0)]], // [B, H, N, D] device const half* K [[buffer(1)]], // [B, H, N, D] device const half* V [[buffer(2)]], // [B, H, N, D] device half* Out [[buffer(3)]], constant uint& N [[buffer(4)]], // seq length constant uint& D [[buffer(5)]], // head dim constant float& scale [[buffer(6)]], // 1/sqrt(D) threadgroup half* kv_tile [[threadgroup(0)]], uint q_idx [[thread_position_in_grid]], uint lane [[thread_index_in_simdgroup]] ) { if (q_idx >= N) return; // Load query vector for this thread's position float q[64]; // D ≤ 64 head dim (stored in registers) for (uint d = 0; d < D; ++d) q[d] = (float)Q[q_idx * D + d]; // Online softmax accumulators (Dao et al. 2022 trick) float m = -INFINITY; // running max of attention logits float s = 0.0f; // running denominator float o[64] = {0}; // running output accumulator // Iterate over KV blocks for (uint kv_start = 0; kv_start < N; kv_start += BLOCK_SIZE) { uint kv_end = min(kv_start + BLOCK_SIZE, N); // Load this KV block into threadgroup memory cooperatively for (uint i = lane; i < (kv_end - kv_start) * D; i += 32) kv_tile[i] = K[(kv_start + i/D) * D + i%D]; threadgroup_barrier(mem_flags::mem_threadgroup); // Compute Q·K^T for this block for (uint j = kv_start; j < kv_end; ++j) { float score = 0.0f; for (uint d = 0; d < D; ++d) score += q[d] * (float)kv_tile[(j-kv_start)*D + d]; score *= scale; // Online softmax update float m_new = max(m, score); float exp_s = exp(score - m_new); float correction = exp(m - m_new); s = correction * s + exp_s; for (uint d = 0; d < D; ++d) o[d] = correction * o[d] + exp_s * (float)kv_tile[(j-kv_start)*D + d]; m = m_new; } threadgroup_barrier(mem_flags::mem_threadgroup); } // Normalise and write output for (uint d = 0; d < D; ++d) Out[q_idx * D + d] = (half)(o[d] / s); }
MLX implements mx.fast.scaled_dot_product_attention() using exactly this pattern. The actual MLX kernel adds support for GQA (grouped-query attention), masking, and BF16, and uses simdgroup_matrix operations for the inner GEMM — but the online softmax structure is identical.
C++ Runtime & Custom Ops
If you want to expose a custom Metal kernel to Python via MLX, you need to write the C++ glue: a Primitive subclass that tells the framework how to evaluate your op on the GPU, how to differentiate through it, and how to fuse it with adjacent ops.
// custom_op.h #include <mlx/mlx.h> class ScaledELU : public mlx::Primitive { public: explicit ScaledELU(float alpha, mlx::Stream s) : mlx::Primitive(s), alpha_(alpha) {} void eval_cpu( const std::vector<mlx::array>& inputs, std::vector<mlx::array>& outputs ) override; void eval_gpu( const std::vector<mlx::array>& inputs, std::vector<mlx::array>& outputs ) override; // VJP for autograd support std::vector<mlx::array> vjp( const std::vector<mlx::array>& primals, const std::vector<mlx::array>& cotangents, const std::vector<int>& argnums, const std::vector<mlx::array>& outputs ) override; private: float alpha_; }; // The Python-visible entry point mlx::array scaled_elu(mlx::array x, float alpha); // custom_op.cpp — GPU evaluation void ScaledELU::eval_gpu( const std::vector<mlx::array>& inputs, std::vector<mlx::array>& outputs ) { auto& x = inputs[0]; auto& out = outputs[0]; out.set_data(mlx::allocator().malloc_or_wait(out.nbytes())); // Get the Metal compute encoder from the MLX stream auto& s = metal::device(stream()); auto kernel = s.get_kernel("scaled_elu_f32"); // cached PSO auto encoder = s.get_command_encoder(stream().index); encoder->setComputePipelineState(kernel); encoder->setBuffer(x.buffer(), x.data_offset() * x.itemsize(), 0); encoder->setBuffer(out.buffer(), out.data_offset() * out.itemsize(), 1); encoder->setBytes(&alpha_, sizeof(float), 2); size_t n = x.size(); encoder->dispatchThreads( MTL::Size(n, 1, 1), MTL::Size(std::min(n, (size_t)1024), 1, 1) ); }
kernel void scaled_elu_f32( device const float* x [[buffer(0)]], device float* out [[buffer(1)]], constant float& alpha [[buffer(2)]], uint id [[thread_position_in_grid]] ) { float val = x[id]; out[id] = (val > 0.0f) ? val : alpha * (exp(val) - 1.0f); } // Also instantiate for half precision kernel void scaled_elu_f16( device const half* x [[buffer(0)]], device half* out [[buffer(1)]], constant float& alpha [[buffer(2)]], uint id [[thread_position_in_grid]] ) { float val = (float)x[id]; out[id] = (half)((val > 0.0f) ? val : alpha * (exp(val) - 1.0f)); }
import mlx.core as mx from my_custom_ops import scaled_elu # pybind11 module x = mx.array([-2.0, -1.0, 0.0, 1.0, 2.0]) # Runs your Metal kernel transparently y = scaled_elu(x, alpha=1.0) mx.eval(y) print(y) # [-0.865, -0.632, 0.0, 1.0, 2.0] # Autograd works via the vjp() you implemented grad_fn = mx.grad(lambda x: mx.sum(scaled_elu(x, alpha=1.0))) dx = grad_fn(x) mx.eval(dx) print(dx) # [0.135, 0.368, 1.0, 1.0, 1.0]
Performance Tuning
Getting correct output from your kernel is step one. Getting maximum performance on Apple Silicon requires understanding the hardware bottlenecks and tuning accordingly.
Occupancy & Threadgroup Size
Occupancy is the ratio of active threadgroups to the maximum a GPU core can hold simultaneously. Higher occupancy lets the GPU hide memory latency by switching between ready threadgroups. The right threadgroup size balances register pressure against the desire for more concurrent threadgroups.
- Start with 256 threads per threadgroup — a safe default that divides evenly into SIMD groups of 32
- For memory-bound kernels, try 128 or even 64 to increase occupancy and latency hiding
- For compute-bound kernels with large register files, 512 may be better
- Use
MTLComputePipelineState.maxTotalThreadsPerThreadgroupto query the hardware limit
Memory Access Patterns
// BAD: Strided access — threads in a SIMD group read non-contiguous addresses // Thread 0 reads [0], Thread 1 reads [N], Thread 2 reads [2*N], ... // This generates N separate cache line loads for one SIMD group kernel void column_access_bad( device const float* mat [[buffer(0)]], constant uint& N [[buffer(1)]], uint id [[thread_position_in_grid]] ) { float val = mat[id * N]; // Column-major — terrible for row-major storage! } // GOOD: Coalesced access — threads read consecutive addresses // Thread 0 reads [0], Thread 1 reads [1], Thread 2 reads [2], ... // All 32 threads in a SIMD group satisfied by 1-2 cache line loads kernel void row_access_good( device const float* mat [[buffer(0)]], uint id [[thread_position_in_grid]] ) { float val = mat[id]; // Sequential — one cache line serves 8+ threads } // FIX strided access: transpose into threadgroup memory first kernel void column_via_transpose( device const float* mat [[buffer(0)]], device float* out [[buffer(1)]], constant uint& N [[buffer(2)]], threadgroup float tile[32][33] [[threadgroup(0)]], // +1 avoids bank conflict uint2 lid [[thread_position_in_threadgroup]], uint2 tgid [[threadgroup_position_in_grid]] ) { // Load row of matrix into tile (coalesced) tile[lid.y][lid.x] = mat[(tgid.y * 32 + lid.y) * N + tgid.x * 32 + lid.x]; threadgroup_barrier(mem_flags::mem_threadgroup); // Write column of tile to output (coalesced after transpose) out[(tgid.x * 32 + lid.y) * N + tgid.y * 32 + lid.x] = tile[lid.x][lid.y]; }
Instruction-Level Tricks
- Use
fma(a, b, c)(fused multiply-add) explicitly — it is a single instruction and avoids rounding between the multiply and add - Prefer
halfoverfloatwherever precision allows — Apple GPUs can process twice as many half-precision operations per clock - Use
precise::sin(),precise::exp()only when you need IEEE accuracy; the default approximations are faster - Unroll small inner loops with
[[unroll]]attribute — the compiler will keep values in registers instead of re-reading memory - Annotate read-only input pointers with
const— enables additional optimisations
// Explicit FMA, unrolled loop, and fast math kernel void optimised_dot( device const half4* a [[buffer(0)]], device const half4* b [[buffer(1)]], device float* c [[buffer(2)]], constant uint& len [[buffer(3)]], uint id [[thread_position_in_grid]] ) { uint base = id * 8; // 8 halfs per thread (two half4 loads) float acc = 0.0f; // Unrolled: compiler keeps a[] and b[] in registers [[unroll]] for (uint i = 0; i < 2; ++i) { half4 av = a[(base/4) + i]; half4 bv = b[(base/4) + i]; // float4 FMA: 4 ops fused into hardware instruction acc = fma((float)av.x, (float)bv.x, acc); acc = fma((float)av.y, (float)bv.y, acc); acc = fma((float)av.z, (float)bv.z, acc); acc = fma((float)av.w, (float)bv.w, acc); } // Cross-SIMD-group reduction c[id] = simd_sum(acc); }
Profiling with Metal Debugger
Apple's Xcode instruments (specifically the Metal Debugger and GPU Timeline) are the primary tools for profiling kernel performance. Key metrics to watch:
- ALU Utilisation — are you compute-bound? If below 80%, you're losing cycles to memory stalls
- Threadgroup Memory Read/Write bandwidth — high numbers indicate heavy on-chip traffic (good for GEMM, watch for conflicts)
- Device Memory Bandwidth — compare against theoretical peak to measure roofline efficiency
- Kernel Duration — track per-kernel timing across model iterations to identify regressions
The MLX Engineer's Workflow
Putting it all together: what does a typical day look like for someone writing MLX kernels? The work is distributed across three layers, and understanding the ratio tells you where to focus your energy.
The Deep MLX Pipeline
The most intellectually interesting work in MLX sits in the graph optimisation and kernel fusion passes — the equivalent of what Triton's autotuner and torch.compile's Inductor do on the NVIDIA side:
Example: Writing a Custom Fused Kernel
Suppose you profile a Llama-style model and discover that two adjacent operations — a silu activation applied to a gate, then an elementwise multiplication with the up-projection — account for disproportionate time due to two separate kernel launches and the associated device memory round-trips. You can fuse them:
// Fuses: out = gate * silu(up) into one kernel, one device memory read per array kernel void silu_gate_fused( device const half* gate [[buffer(0)]], // W_gate * x device const half* up [[buffer(1)]], // W_up * x device half* out [[buffer(2)]], uint4 id [[thread_position_in_grid]] // process 4 elements per thread ) { uint base = id.x * 4; // Load 4 halfs using a single 64-bit load half4 g = *(device const half4*)(gate + base); half4 u = *(device const half4*)(up + base); // SiLU: x * sigmoid(x) = x / (1 + exp(-x)) float4 gf = (float4)g; float4 uf = (float4)u; float4 silu_u = uf * (1.0f / (1.0f + exp(-uf))); // Fused gate multiply *(device half4*)(out + base) = (half4)(gf * silu_u); }
import mlx.core as mx import mlx.nn as nn from my_kernels import silu_gate_fused class LlamaMLP(nn.Module): def __init__(self, dim: int, hidden: int): super().__init__() self.gate_proj = nn.Linear(dim, hidden, bias=False) self.up_proj = nn.Linear(dim, hidden, bias=False) self.down_proj = nn.Linear(hidden, dim, bias=False) def __call__(self, x): gate = self.gate_proj(x) up = self.up_proj(x) # One fused Metal kernel instead of two separate launches hidden = silu_gate_fused(gate, up) return self.down_proj(hidden)
The bottom line: Writing GPU kernels for Apple Silicon means mastering Metal Shading Language as your primary compute shading language, C++ as your runtime glue, and understanding Apple's unified memory model as your performance superpower. The absence of PCIe latency and the presence of insane memory bandwidth create opportunities for kernel designs that simply aren't possible on discrete NVIDIA GPUs.
SIMD-group intrinsics let you do warp-level reductions without touching shared memory. Unified memory lets you build zero-copy data pipelines between CPU preprocessing and GPU inference. And the tile-based GPU architecture rewards blocking patterns that keep working sets on-chip.
As LLM inference increasingly moves to personal devices and Apple's silicon continues to close the gap with discrete GPUs, MLX and MSL kernels are becoming essential knowledge for any serious ML systems engineer working in the Apple ecosystem. The conceptual skills transfer cleanly from CUDA — only the vocabulary changes.