MANISH AI
All writings RSS
Deep Dive · Apple Silicon GPUs

Writing GPU Kernels
for Apple Silicon with MLX

A comprehensive guide to Metal Shading Language, Apple's GPU architecture, and how MLX lets you write high-performance compute kernels that run natively on M-series chips — no CUDA required.

⏱ 25 min read ⚙ MLX · Metal · MSL · C++ 🍎 M1 / M2 / M3 / M4
Contents
  1. Why Apple GPUs Are Different
  2. The MLX Software Stack
  3. Metal Shading Language Basics
  4. Thread Hierarchy & Memory Model
  5. Your First GPU Kernel
  6. NVIDIA vs Apple Mapping
  7. Elementwise & Reduction Kernels
  8. Matrix Operations & GEMM
  9. Fused Attention Kernels
  10. C++ Runtime & Dispatch
  11. Performance Tuning
  12. The MLX Engineer's Workflow

Why Apple GPUs Are Different

If you've worked with NVIDIA's CUDA ecosystem, you're accustomed to a clean separation: the CPU lives on one side of a PCIe bus, the GPU on the other, and data constantly ferries between them via expensive DMA transfers. Apple blew that model up.

On M-series chips, the CPU, GPU, Neural Engine, and memory subsystem all share the same physical die. There is no discrete memory pool. There is no bus to cross. The GPU reads from the same LPDDR5X pool that the CPU uses — and does so at extraordinarily high bandwidth (up to 400 GB/s on M3 Ultra). This single architectural decision changes everything about how you write kernels.

🏗

Tile-Based Deferred Rendering

Apple GPUs use TBDR architecture. The screen (or compute output) is divided into tiles processed entirely on-chip before being written back to memory, minimizing bandwidth.

🧠

Unified Memory Architecture

CPU and GPU share the same physical memory. Zero-copy tensor sharing is possible — mx.array lives in memory accessible to both without explicit transfers.

SIMD Groups

Apple's equivalent of CUDA warps. Groups of 32 threads execute in lockstep and can share data via SIMD-group instructions without touching threadgroup memory.

💾

Threadgroup Memory

On-chip scratchpad local to a threadgroup. Fast, software-managed, analogous to CUDA shared memory. Finite size — spill to device memory and latency skyrockets.

Architecture note

Apple Silicon uses a harvested GPU die integrated directly with the CPU fabric. M1 has up to 8 GPU cores; M3 Max reaches 40; M2 Ultra pushes 76. Each GPU core contains multiple SIMD engines, each running 32 lanes wide — the Apple equivalent of a streaming multiprocessor (SM) in NVIDIA terms.

The MLX Software Stack

MLX is Apple's open-source array framework for machine learning on Apple Silicon. It is not a thin wrapper around PyTorch. It is a ground-up reimplementation with lazy evaluation, a functional transformation system, and its own GPU backend built directly on Metal. Understanding the layers helps you know exactly where your kernel code lives.

Python API (mlx.core)
C++ Runtime (Primitives, Graph, Dispatch)
Metal API (Command Queues, Buffers, Pipelines)
MSL Kernels (.metal source files)
Apple GPU (Silicon)

The Python Layer

Users write high-level array operations. MLX builds a computation graph lazily — nothing executes until a value is actually needed.

Python
import mlx.core as mx

# Lazy — no GPU work yet
a = mx.ones((1024, 1024), dtype=mx.float16)
b = mx.ones((1024, 1024), dtype=mx.float16)
c = mx.matmul(a, b)          # Graph node created

# Evaluation triggers GPU dispatch
mx.eval(c)
print(c.shape)               # (1024, 1024)

The C++ Runtime

The majority of MLX — perhaps 60–70% by line count — is C++. The runtime manages tensor lifetimes, constructs and optimises the computation graph, fuses operations into single kernel launches where profitable, and ultimately calls into Metal to dispatch work to the GPU.

C++ — Primitive
// Simplified version of how MLX defines a primitive operation
class MatMul : public Primitive {
public:
  MatMul(Stream stream) : Primitive(stream) {}

  // Forward pass — schedules the Metal kernel
  void eval_gpu(
      const std::vector<array>& inputs,
      std::vector<array>& outputs
  ) override;

  // VJP for automatic differentiation
  std::vector<array> vjp(
      const std::vector<array>& primals,
      const std::vector<array>& cotangents,
      const std::vector<int>& argnums,
      const std::vector<array>& outputs
  ) override;
};

Metal API & Pipelines

Metal is Apple's low-level GPU API — analogous to Vulkan Compute or CUDA Runtime. Kernels are compiled from Metal Shading Language (MSL) source either ahead-of-time (packaged into a .metallib file) or at runtime. Each kernel function is wrapped in a MTLComputePipelineState, set up with its arguments, and then encoded into a MTLCommandBuffer which is committed to a MTLCommandQueue for execution.

MLX tip

MLX pre-compiles all its built-in Metal kernels into a metallib at build time. When you call mx.matmul(), the C++ runtime looks up the correct pipeline state for the given dtypes and dispatches it immediately — no JIT compilation delay during training.

Metal Shading Language Basics

Metal Shading Language (MSL) is a C++14-derived language with Apple-specific extensions for GPU programming. If you know CUDA C++, you'll feel at home within minutes. The core concepts translate almost 1:1, though the syntax differs.

Address Spaces

MSL has explicit address space qualifiers that tell the compiler — and you — where data lives in the memory hierarchy. Understanding these is non-negotiable for writing correct kernels.

MSL — Address Space Demo
kernel void memory_spaces_demo(
    device const float* global_input   [[buffer(0)]],
    device float*       global_output  [[buffer(1)]],
    constant float&     scale          [[buffer(2)]],
    threadgroup float*  shared_tile    [[threadgroup(0)]],
    uint gid [[thread_position_in_grid]],
    uint lid [[thread_position_in_threadgroup]]
) {
    // Thread-private register
    float val = global_input[gid] * scale;

    // Write into on-chip threadgroup memory
    shared_tile[lid] = val;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // All threads in group have written; safe to read
    global_output[gid] = shared_tile[lid];
}

Built-In Variables

MSL uses attribute syntax [[...]] to inject hardware-provided values into kernel arguments. These tell each thread exactly where it sits in the execution hierarchy.

Thread Hierarchy & Memory Model

Apple GPUs execute threads in a three-level hierarchy that directly mirrors NVIDIA's grid / block / warp model — with different names and slightly different semantics.

NVIDIA CUDA Apple Metal Description
Grid Grid The entire dispatch space
Thread Block Threadgroup Cooperative group with shared on-chip memory
Warp (32 threads) SIMD Group (32 threads) Smallest lockstep execution unit
Shared Memory Threadgroup Memory On-chip scratchpad local to a threadgroup
L1/L2 Cache GPU Cache Hardware-managed, transparent
Global Memory Device Memory (Unified) Main memory pool, shared with CPU
Constant Memory Constant Buffer Read-only, broadcast-cached
__syncthreads() threadgroup_barrier() Synchronisation within a group
Warp shuffle SIMD-group functions Cross-lane data exchange without memory

SIMD-Group Intrinsics

One area where MSL genuinely excels over basic CUDA is its rich set of SIMD-group operations built directly into the language spec. These allow data exchange between the 32 lanes of a SIMD group without touching threadgroup memory at all, saving precious bandwidth.

MSL — SIMD-Group Intrinsics
kernel void simd_reduce_sum(
    device const float* input  [[buffer(0)]],
    device float*       output [[buffer(1)]],
    uint gid  [[thread_position_in_grid]],
    uint lane [[thread_index_in_simdgroup]],
    uint sg   [[simdgroup_index_in_threadgroup]]
) {
    float val = input[gid];

    // Sum across all 32 lanes — no threadgroup memory needed!
    val = simd_sum(val);

    // Only lane 0 of each SIMD group writes the result
    if (lane == 0) {
        output[sg] = val;
    }
}

Key SIMD-group functions available in MSL:

Your First GPU Kernel

Let's start with the canonical example — element-wise vector addition — and then build progressively more sophisticated kernels. Every kernel author should be able to write this in their sleep.

Element-Wise Addition

MSL — vector_add.metal
// The 'kernel' qualifier marks this as a compute kernel entry point
// (as opposed to 'vertex' or 'fragment' for graphics)
kernel void vector_add(
    device const float* a   [[buffer(0)]],
    device const float* b   [[buffer(1)]],
    device       float* out [[buffer(2)]],
    uint id [[thread_position_in_grid]]
) {
    out[id] = a[id] + b[id];
}

// Half-precision variant — important for ML workloads
kernel void vector_add_half(
    device const half* a   [[buffer(0)]],
    device const half* b   [[buffer(1)]],
    device       half* out [[buffer(2)]],
    uint id [[thread_position_in_grid]]
) {
    out[id] = a[id] + b[id];
}

Notice the key differences from CUDA immediately:

Dispatching from C++

The kernel file is just the compute shader. You still need C++ to compile the pipeline state and dispatch work. Here is a minimal, self-contained example using the raw Metal API — in MLX, the framework handles this for you, but understanding it is essential for writing custom ops.

C++ — Metal Dispatch
// Objective-C++ or C++ with Metal framework headers
#include 

void launch_vector_add(
    MTL::Device* device,
    float* a_host, float* b_host, float* out_host,
    size_t n
) {
    // 1. Create Metal buffers (zero-copy on Apple Silicon!)
    auto* a_buf   = device->newBuffer(a_host, n * sizeof(float),
                         MTL::ResourceStorageModeShared);
    auto* b_buf   = device->newBuffer(b_host, n * sizeof(float),
                         MTL::ResourceStorageModeShared);
    auto* out_buf = device->newBuffer(n * sizeof(float),
                         MTL::ResourceStorageModeShared);

    // 2. Load compiled .metallib and get the function
    auto* library = device->newDefaultLibrary();
    auto* fn      = library->newFunction(
                         NS::String::string("vector_add",
                         NS::UTF8StringEncoding));

    // 3. Build pipeline state (cached; do this once at startup)
    NS::Error* err = nullptr;
    auto* pso = device->newComputePipelineState(fn, &err);

    // 4. Create command buffer and encoder
    auto* queue   = device->newCommandQueue();
    auto* cmdbuf  = queue->commandBuffer();
    auto* encoder = cmdbuf->computeCommandEncoder();

    // 5. Bind resources and dispatch
    encoder->setComputePipelineState(pso);
    encoder->setBuffer(a_buf,   0, 0);
    encoder->setBuffer(b_buf,   0, 1);
    encoder->setBuffer(out_buf, 0, 2);

    MTL::Size threads_per_tg = MTL::Size(256, 1, 1);
    MTL::Size num_tgs = MTL::Size((n + 255) / 256, 1, 1);
    encoder->dispatchThreadgroups(num_tgs, threads_per_tg);

    encoder->endEncoding();
    cmdbuf->commit();
    cmdbuf->waitUntilCompleted();  // or use completion handlers
}
Performance note

The pipeline state object (MTLComputePipelineState) is expensive to create. Always create it once at initialization and cache it. MLX does this via its internal kernel registry which maps operation signatures to pre-built pipeline states.

NVIDIA vs Apple: Side-by-Side

For anyone migrating from CUDA, here is a direct side-by-side comparison of the same softmax kernel written in both CUDA C++ and Metal Shading Language. The algorithmic logic is identical; only the vocabulary changes.

CUDA C++ — softmax (NVIDIA)
__global__ void softmax_kernel(
    const float* __restrict__ input,
    float* __restrict__ output,
    int n
) {
    __shared__ float shmem[256];

    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    int lid = threadIdx.x;

    float val = (tid < n) ? input[tid] : -INFINITY;
    shmem[lid] = val;
    __syncthreads();

    // Warp-level max reduction
    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (lid < s) shmem[lid] = max(shmem[lid], shmem[lid+s]);
        __syncthreads();
    }
    float row_max = shmem[0];

    float exp_val = __expf(val - row_max);
    shmem[lid] = exp_val;
    __syncthreads();

    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (lid < s) shmem[lid] += shmem[lid+s];
        __syncthreads();
    }
    float row_sum = shmem[0];

    if (tid < n) output[tid] = exp_val / row_sum;
}
MSL — softmax (Apple)
kernel void softmax_kernel(
    device const float* input        [[buffer(0)]],
    device       float* output       [[buffer(1)]],
    constant uint&  n               [[buffer(2)]],
    threadgroup float* shmem        [[threadgroup(0)]],
    uint gid  [[thread_position_in_grid]],
    uint lid  [[thread_position_in_threadgroup]],
    uint lane [[thread_index_in_simdgroup]]
) {
    float val = (gid < n) ? input[gid] : -INFINITY;
    shmem[lid] = val;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Threadgroup-level reduction via SIMD intrinsics
    float row_max = simd_max(val);   // Faster than manual loop!

    // For multi-SIMD-group threadgroups, do a second pass in shmem
    if (lane == 0) shmem[lid / 32] = row_max;
    threadgroup_barrier(mem_flags::mem_threadgroup);
    if (lid == 0) {
        for (uint i = 1; i < 8; ++i) shmem[0] = max(shmem[0], shmem[i]);
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);
    row_max = shmem[0];

    float exp_val = exp(val - row_max);

    float row_sum = simd_sum(exp_val);
    if (lane == 0) shmem[lid / 32] = row_sum;
    threadgroup_barrier(mem_flags::mem_threadgroup);
    if (lid == 0) {
        for (uint i = 1; i < 8; ++i) shmem[0] += shmem[i];
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);
    row_sum = shmem[0];

    if (gid < n) output[gid] = exp_val / row_sum;
}

The SIMD-group intrinsics (simd_max, simd_sum) replace the inner warp-reduction loops entirely, making the MSL code more expressive at this level. Apple's compiler maps these to efficient hardware instructions on SIMD engines.

Elementwise & Reduction Kernels

Template-Based Elementwise Kernels

MLX uses C++ templates extensively to avoid duplicating kernel code for every dtype. The pattern is: write one templated MSL kernel, instantiate it for every type combination you need.

MSL — Templated Elementwise Op
// Generic elementwise binary op template
template <typename T, typename Op>
kernel void binary_op(
    device const T* a   [[buffer(0)]],
    device const T* b   [[buffer(1)]],
    device       T* out [[buffer(2)]],
    uint id [[thread_position_in_grid]]
) {
    out[id] = Op{}(a[id], b[id]);
}

// Operation structs — inlined at compile time
struct Add    { template<typename T> T operator()(T x, T y) { return x + y; } };
struct Mul    { template<typename T> T operator()(T x, T y) { return x * y; } };
struct Sub    { template<typename T> T operator()(T x, T y) { return x - y; } };
struct Divide { template<typename T> T operator()(T x, T y) { return x / y; } };
struct Maximum{ template<typename T> T operator()(T x, T y) { return max(x, y); } };

// Explicit instantiations for the metallib
template [[host_name("binary_add_f32")]]
kernel void binary_op<float,  Add>(device const float* , device const float* , device float* , uint);
template [[host_name("binary_add_f16")]]
kernel void binary_op<half,   Add>(device const half*  , device const half*  , device half*  , uint);
template [[host_name("binary_mul_f32")]]
kernel void binary_op<float,  Mul>(device const float* , device const float* , device float* , uint);

Vectorised Loads

Apple GPUs support 128-bit vector loads. Processing 4 floats (or 8 halfs) in a single instruction dramatically increases memory throughput. This is critical for memory-bandwidth-bound kernels.

MSL — Vectorised float4 Load
kernel void add_f32_vec4(
    device const float4* a   [[buffer(0)]],
    device const float4* b   [[buffer(1)]],
    device       float4* out [[buffer(2)]],
    uint id [[thread_position_in_grid]]
) {
    // Each thread processes 4 floats in one 128-bit load
    out[id] = a[id] + b[id];   // SIMD vector arithmetic
}

// For half-precision: 8 elements per thread with float4-equivalent
kernel void add_f16_vec8(
    device const half8* a   [[buffer(0)]],
    device const half8* b   [[buffer(1)]],
    device       half8* out [[buffer(2)]],
    uint id [[thread_position_in_grid]]
) {
    out[id] = a[id] + b[id];
}

Matrix Operations & Tiled GEMM

General matrix multiplication (GEMM) is the backbone of every deep learning model. A naïve implementation that reads directly from device memory is bandwidth-limited and catastrophically slow. The solution is tiled GEMM: load sub-matrices into fast threadgroup memory, compute a tile of the output, and accumulate.

Memory hierarchy

Device memory bandwidth on M3 Max is around 300 GB/s. Threadgroup memory sits on-chip and is effectively unlimited bandwidth within a threadgroup. A well-tiled GEMM can achieve 15–25× more effective throughput than the naïve version by maximising data reuse in threadgroup memory.

MSL — Tiled GEMM Kernel
constant uint TILE_SIZE = 16;

kernel void matmul_tiled(
    device const float* A       [[buffer(0)]],  // [M, K]
    device const float* B       [[buffer(1)]],  // [K, N]
    device       float* C       [[buffer(2)]],  // [M, N]
    constant uint& M           [[buffer(3)]],
    constant uint& N           [[buffer(4)]],
    constant uint& K           [[buffer(5)]],
    threadgroup float* tileA   [[threadgroup(0)]],  // TILE x TILE
    threadgroup float* tileB   [[threadgroup(1)]],  // TILE x TILE
    uint2 tgid [[threadgroup_position_in_grid]],
    uint2 lid  [[thread_position_in_threadgroup]]
) {
    uint row = tgid.y * TILE_SIZE + lid.y;
    uint col = tgid.x * TILE_SIZE + lid.x;

    float acc = 0.0f;

    // Iterate over tiles along the K dimension
    for (uint t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; ++t) {

        // Cooperatively load a tile of A into threadgroup memory
        uint aRow = row;
        uint aCol = t * TILE_SIZE + lid.x;
        tileA[lid.y * TILE_SIZE + lid.x] =
            (aRow < M && aCol < K) ? A[aRow * K + aCol] : 0.0f;

        // Cooperatively load a tile of B
        uint bRow = t * TILE_SIZE + lid.y;
        uint bCol = col;
        tileB[lid.y * TILE_SIZE + lid.x] =
            (bRow < K && bCol < N) ? B[bRow * N + bCol] : 0.0f;

        // Synchronise: all threads must finish loading before compute
        threadgroup_barrier(mem_flags::mem_threadgroup);

        // Compute dot product of this tile — all reads from fast shmem
        for (uint k = 0; k < TILE_SIZE; ++k) {
            acc += tileA[lid.y * TILE_SIZE + k] *
                   tileB[k * TILE_SIZE + lid.x];
        }

        // Synchronise again before next tile load overwrites shmem
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    if (row < M && col < N) C[row * N + col] = acc;
}

SIMD Matrix Operations (simdgroup_matrix)

For even higher throughput, MSL exposes hardware matrix multiply-accumulate (MMA) units via simdgroup_matrix. These are Apple's equivalent of CUDA's tensor cores — purpose-built for small matrix FMAs that run at much higher throughput than scalar ALU chains.

MSL — simdgroup_matrix MMA
// Using Apple's simdgroup_matrix for tensor-core-equivalent throughput
// Operates on 8x8 tiles; each SIMD group (32 threads) handles one tile
kernel void matmul_simdgroup(
    device const half*  A   [[buffer(0)]],
    device const half*  B   [[buffer(1)]],
    device       float* C   [[buffer(2)]],
    constant uint& M       [[buffer(3)]],
    constant uint& N       [[buffer(4)]],
    constant uint& K       [[buffer(5)]],
    uint2 tgid [[threadgroup_position_in_grid]]
) {
    // Accumulator in float32 (mixed precision)
    simdgroup_matrix<float, 8, 8> acc;
    simdgroup_matrix_initialize(0, acc);

    simdgroup_matrix<half, 8, 8> matA;
    simdgroup_matrix<half, 8, 8> matB;

    ulong2 origin_A = ulong2(0, tgid.y * 8);
    ulong2 origin_B = ulong2(tgid.x * 8, 0);

    for (uint k = 0; k < K; k += 8) {
        // Load 8x8 tile of A from device memory directly
        simdgroup_load(matA, A, N, ulong2(k, tgid.y * 8));
        simdgroup_load(matB, B, N, ulong2(tgid.x * 8, k));

        // Hardware MMA: acc += matA * matB
        simdgroup_multiply_accumulate(acc, matA, matB, acc);
    }

    // Store 8x8 result tile
    simdgroup_store(acc, C, N, ulong2(tgid.x * 8, tgid.y * 8));
}

Fused Attention Kernels

Flash Attention — the technique of computing multi-head attention without materialising the full N×N attention score matrix in device memory — is the most important fused kernel in modern LLM inference. Here is how you implement it in MSL for Apple GPUs.

The key insight: by processing the sequence in blocks and accumulating a running softmax (tracking the max and sum incrementally), we only ever need to hold one block of Q, K, V in threadgroup memory at a time. Memory usage drops from O(N²) to O(N).

MSL — Flash Attention (Simplified)
// Flash Attention forward pass — O(N) memory, tiled over sequence
constant uint BLOCK_SIZE = 64;   // KV block size (tune for your hardware)

kernel void flash_attention(
    device const half* Q        [[buffer(0)]],  // [B, H, N, D]
    device const half* K        [[buffer(1)]],  // [B, H, N, D]
    device const half* V        [[buffer(2)]],  // [B, H, N, D]
    device       half* Out      [[buffer(3)]],
    constant uint&  N           [[buffer(4)]],  // seq length
    constant uint&  D           [[buffer(5)]],  // head dim
    constant float& scale       [[buffer(6)]],  // 1/sqrt(D)
    threadgroup half*  kv_tile  [[threadgroup(0)]],
    uint q_idx [[thread_position_in_grid]],
    uint lane  [[thread_index_in_simdgroup]]
) {
    if (q_idx >= N) return;

    // Load query vector for this thread's position
    float q[64];  // D ≤ 64 head dim (stored in registers)
    for (uint d = 0; d < D; ++d)
        q[d] = (float)Q[q_idx * D + d];

    // Online softmax accumulators (Dao et al. 2022 trick)
    float m = -INFINITY;  // running max of attention logits
    float s = 0.0f;       // running denominator
    float o[64] = {0};    // running output accumulator

    // Iterate over KV blocks
    for (uint kv_start = 0; kv_start < N; kv_start += BLOCK_SIZE) {
        uint kv_end = min(kv_start + BLOCK_SIZE, N);

        // Load this KV block into threadgroup memory cooperatively
        for (uint i = lane; i < (kv_end - kv_start) * D; i += 32)
            kv_tile[i] = K[(kv_start + i/D) * D + i%D];
        threadgroup_barrier(mem_flags::mem_threadgroup);

        // Compute Q·K^T for this block
        for (uint j = kv_start; j < kv_end; ++j) {
            float score = 0.0f;
            for (uint d = 0; d < D; ++d)
                score += q[d] * (float)kv_tile[(j-kv_start)*D + d];
            score *= scale;

            // Online softmax update
            float m_new = max(m, score);
            float exp_s = exp(score - m_new);
            float correction = exp(m - m_new);
            s = correction * s + exp_s;
            for (uint d = 0; d < D; ++d)
                o[d] = correction * o[d] +
                       exp_s * (float)kv_tile[(j-kv_start)*D + d];
            m = m_new;
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    // Normalise and write output
    for (uint d = 0; d < D; ++d)
        Out[q_idx * D + d] = (half)(o[d] / s);
}
MLX implementation

MLX implements mx.fast.scaled_dot_product_attention() using exactly this pattern. The actual MLX kernel adds support for GQA (grouped-query attention), masking, and BF16, and uses simdgroup_matrix operations for the inner GEMM — but the online softmax structure is identical.

C++ Runtime & Custom Ops

If you want to expose a custom Metal kernel to Python via MLX, you need to write the C++ glue: a Primitive subclass that tells the framework how to evaluate your op on the GPU, how to differentiate through it, and how to fuse it with adjacent ops.

C++ — Custom MLX Primitive
// custom_op.h
#include <mlx/mlx.h>

class ScaledELU : public mlx::Primitive {
public:
  explicit ScaledELU(float alpha, mlx::Stream s)
      : mlx::Primitive(s), alpha_(alpha) {}

  void eval_cpu(
      const std::vector<mlx::array>& inputs,
      std::vector<mlx::array>& outputs
  ) override;

  void eval_gpu(
      const std::vector<mlx::array>& inputs,
      std::vector<mlx::array>& outputs
  ) override;

  // VJP for autograd support
  std::vector<mlx::array> vjp(
      const std::vector<mlx::array>& primals,
      const std::vector<mlx::array>& cotangents,
      const std::vector<int>& argnums,
      const std::vector<mlx::array>& outputs
  ) override;

private:
  float alpha_;
};

// The Python-visible entry point
mlx::array scaled_elu(mlx::array x, float alpha);

// custom_op.cpp — GPU evaluation
void ScaledELU::eval_gpu(
    const std::vector<mlx::array>& inputs,
    std::vector<mlx::array>& outputs
) {
    auto& x   = inputs[0];
    auto& out = outputs[0];
    out.set_data(mlx::allocator().malloc_or_wait(out.nbytes()));

    // Get the Metal compute encoder from the MLX stream
    auto& s = metal::device(stream());
    auto kernel = s.get_kernel("scaled_elu_f32");  // cached PSO
    auto encoder = s.get_command_encoder(stream().index);

    encoder->setComputePipelineState(kernel);
    encoder->setBuffer(x.buffer(),   x.data_offset() * x.itemsize(),   0);
    encoder->setBuffer(out.buffer(), out.data_offset() * out.itemsize(), 1);
    encoder->setBytes(&alpha_, sizeof(float), 2);

    size_t n = x.size();
    encoder->dispatchThreads(
        MTL::Size(n, 1, 1),
        MTL::Size(std::min(n, (size_t)1024), 1, 1)
    );
}
MSL — scaled_elu.metal
kernel void scaled_elu_f32(
    device const float* x      [[buffer(0)]],
    device       float* out    [[buffer(1)]],
    constant float& alpha     [[buffer(2)]],
    uint id [[thread_position_in_grid]]
) {
    float val = x[id];
    out[id] = (val > 0.0f) ? val : alpha * (exp(val) - 1.0f);
}

// Also instantiate for half precision
kernel void scaled_elu_f16(
    device const half* x     [[buffer(0)]],
    device       half* out   [[buffer(1)]],
    constant float& alpha    [[buffer(2)]],
    uint id [[thread_position_in_grid]]
) {
    float val = (float)x[id];
    out[id]   = (half)((val > 0.0f) ? val : alpha * (exp(val) - 1.0f));
}
Python — Using the Custom Op
import mlx.core as mx
from my_custom_ops import scaled_elu  # pybind11 module

x = mx.array([-2.0, -1.0, 0.0, 1.0, 2.0])

# Runs your Metal kernel transparently
y = scaled_elu(x, alpha=1.0)
mx.eval(y)
print(y)   # [-0.865, -0.632, 0.0, 1.0, 2.0]

# Autograd works via the vjp() you implemented
grad_fn = mx.grad(lambda x: mx.sum(scaled_elu(x, alpha=1.0)))
dx = grad_fn(x)
mx.eval(dx)
print(dx)  # [0.135, 0.368, 1.0, 1.0, 1.0]

Performance Tuning

Getting correct output from your kernel is step one. Getting maximum performance on Apple Silicon requires understanding the hardware bottlenecks and tuning accordingly.

Occupancy & Threadgroup Size

Occupancy is the ratio of active threadgroups to the maximum a GPU core can hold simultaneously. Higher occupancy lets the GPU hide memory latency by switching between ready threadgroups. The right threadgroup size balances register pressure against the desire for more concurrent threadgroups.

Memory Access Patterns

MSL — Coalesced vs Strided Access
// BAD: Strided access — threads in a SIMD group read non-contiguous addresses
// Thread 0 reads [0], Thread 1 reads [N], Thread 2 reads [2*N], ...
// This generates N separate cache line loads for one SIMD group
kernel void column_access_bad(
    device const float* mat [[buffer(0)]],
    constant uint& N       [[buffer(1)]],
    uint id [[thread_position_in_grid]]
) {
    float val = mat[id * N];  // Column-major — terrible for row-major storage!
}

// GOOD: Coalesced access — threads read consecutive addresses
// Thread 0 reads [0], Thread 1 reads [1], Thread 2 reads [2], ...
// All 32 threads in a SIMD group satisfied by 1-2 cache line loads
kernel void row_access_good(
    device const float* mat [[buffer(0)]],
    uint id [[thread_position_in_grid]]
) {
    float val = mat[id];  // Sequential — one cache line serves 8+ threads
}

// FIX strided access: transpose into threadgroup memory first
kernel void column_via_transpose(
    device const float* mat      [[buffer(0)]],
    device       float* out      [[buffer(1)]],
    constant uint& N             [[buffer(2)]],
    threadgroup float tile[32][33] [[threadgroup(0)]],  // +1 avoids bank conflict
    uint2 lid  [[thread_position_in_threadgroup]],
    uint2 tgid [[threadgroup_position_in_grid]]
) {
    // Load row of matrix into tile (coalesced)
    tile[lid.y][lid.x] = mat[(tgid.y * 32 + lid.y) * N + tgid.x * 32 + lid.x];
    threadgroup_barrier(mem_flags::mem_threadgroup);
    // Write column of tile to output (coalesced after transpose)
    out[(tgid.x * 32 + lid.y) * N + tgid.y * 32 + lid.x] = tile[lid.x][lid.y];
}

Instruction-Level Tricks

MSL — Optimisation Attributes
// Explicit FMA, unrolled loop, and fast math
kernel void optimised_dot(
    device const half4* a [[buffer(0)]],
    device const half4* b [[buffer(1)]],
    device       float* c [[buffer(2)]],
    constant uint& len   [[buffer(3)]],
    uint id [[thread_position_in_grid]]
) {
    uint base = id * 8;   // 8 halfs per thread (two half4 loads)
    float acc = 0.0f;

    // Unrolled: compiler keeps a[] and b[] in registers
    [[unroll]] for (uint i = 0; i < 2; ++i) {
        half4 av = a[(base/4) + i];
        half4 bv = b[(base/4) + i];
        // float4 FMA: 4 ops fused into hardware instruction
        acc = fma((float)av.x, (float)bv.x, acc);
        acc = fma((float)av.y, (float)bv.y, acc);
        acc = fma((float)av.z, (float)bv.z, acc);
        acc = fma((float)av.w, (float)bv.w, acc);
    }
    // Cross-SIMD-group reduction
    c[id] = simd_sum(acc);
}

Profiling with Metal Debugger

Apple's Xcode instruments (specifically the Metal Debugger and GPU Timeline) are the primary tools for profiling kernel performance. Key metrics to watch:

The MLX Engineer's Workflow

Putting it all together: what does a typical day look like for someone writing MLX kernels? The work is distributed across three layers, and understanding the ratio tells you where to focus your energy.

C++
~55%
MSL
~35%
Python
~10%

The Deep MLX Pipeline

The most intellectually interesting work in MLX sits in the graph optimisation and kernel fusion passes — the equivalent of what Triton's autotuner and torch.compile's Inductor do on the NVIDIA side:

MLX Graph
Fusion Pass
Kernel Gen
Metal Runtime
Apple GPU

Example: Writing a Custom Fused Kernel

Suppose you profile a Llama-style model and discover that two adjacent operations — a silu activation applied to a gate, then an elementwise multiplication with the up-projection — account for disproportionate time due to two separate kernel launches and the associated device memory round-trips. You can fuse them:

MSL — Fused SiLU Gate (Llama MLP)
// Fuses: out = gate * silu(up) into one kernel, one device memory read per array
kernel void silu_gate_fused(
    device const half* gate [[buffer(0)]],   // W_gate * x
    device const half* up   [[buffer(1)]],   // W_up * x
    device       half* out  [[buffer(2)]],
    uint4 id [[thread_position_in_grid]]        // process 4 elements per thread
) {
    uint base = id.x * 4;

    // Load 4 halfs using a single 64-bit load
    half4 g = *(device const half4*)(gate + base);
    half4 u = *(device const half4*)(up   + base);

    // SiLU: x * sigmoid(x) = x / (1 + exp(-x))
    float4 gf = (float4)g;
    float4 uf = (float4)u;
    float4 silu_u = uf * (1.0f / (1.0f + exp(-uf)));

    // Fused gate multiply
    *(device half4*)(out + base) = (half4)(gf * silu_u);
}
Python — Using the Fused Op in a Model
import mlx.core as mx
import mlx.nn as nn
from my_kernels import silu_gate_fused

class LlamaMLP(nn.Module):
    def __init__(self, dim: int, hidden: int):
        super().__init__()
        self.gate_proj = nn.Linear(dim, hidden, bias=False)
        self.up_proj   = nn.Linear(dim, hidden, bias=False)
        self.down_proj = nn.Linear(hidden, dim, bias=False)

    def __call__(self, x):
        gate = self.gate_proj(x)
        up   = self.up_proj(x)
        # One fused Metal kernel instead of two separate launches
        hidden = silu_gate_fused(gate, up)
        return self.down_proj(hidden)

The bottom line: Writing GPU kernels for Apple Silicon means mastering Metal Shading Language as your primary compute shading language, C++ as your runtime glue, and understanding Apple's unified memory model as your performance superpower. The absence of PCIe latency and the presence of insane memory bandwidth create opportunities for kernel designs that simply aren't possible on discrete NVIDIA GPUs.

SIMD-group intrinsics let you do warp-level reductions without touching shared memory. Unified memory lets you build zero-copy data pipelines between CPU preprocessing and GPU inference. And the tile-based GPU architecture rewards blocking patterns that keep working sets on-chip.

As LLM inference increasingly moves to personal devices and Apple's silicon continues to close the gap with discrete GPUs, MLX and MSL kernels are becoming essential knowledge for any serious ML systems engineer working in the Apple ecosystem. The conceptual skills transfer cleanly from CUDA — only the vocabulary changes.

Written with precision for engineers crossing the CUDA → Metal divide.

MLX · Metal Shading Language · Apple M-Series · Unified Memory Architecture