MLX doesn't natively support pipeline parallel training. There are a few instances where the design assumptions of MLX go against what is required by pipeline parallel backprop:

Apple Silicon

Apple Silicon devices have recently started to be used to run AI models locally. The hardware profile of the Mac (specifically, a relatively higher memory capacity and memory bandwidth) lend itself well to inferencing LLMs.

MLX is Apple's ML framework optimized for Apple silicon with native understanding of the unified memory model, and support for communication over Thunderbolt. Last year, I was working as a research scientist at EXO Labs, who are building infrastructure to make local inference clusters painless.

LoRA

In addition to holding the model parameters in memory training requires also storing gradients and optimizer states, adding additional memory overhead (in the case of Adam, expect a 4x memory overhead for training versus inference). LoRA reduces the memory footprint by reducing the number of trainable parameters by multiple orders of magnitude. In our case with DeepSeek, we reduce to ~150M trainable parameters versus 617B params for the base model.

Pipeline Parallelism

Models are often still too large to fit on a single device, therefore requiring a form of model-parallelism to train - splitting the parameters so each device stores a fraction of the full model.

Model parallelism refers to parallelism where only a fraction of the model is stored on each device, compared with data parallelism where each device stores the entire model. Model parallelism reduces the per-device memory usage. We want to be able to finetune a large model such as DeepSeek (671GB), requiring multiple devices with a total memory capacity over 671GB.

We'll be using pipeline parallelism: exploiting the sequential nature of language models to split up the model's layers into multiple stages $g_i$. Each device holds one stage, with each stage executed sequentially. Applying all stages in order recovers the original model function:

$$ \begin{align*} y_0 &= g_0(x) \\ y_i &= g_i(y_{i-1}), \quad i=1, ..., p-1 \\ \implies y = y_{p-1} = m(x) &= g_{p-1}(g_{p-2}(\ldots g_1(g_0(x))\ldots)) \end{align*} $$

During inference only one device will be active at a time. Since devices always execute sequentially, multiple requests can easily be overlapped.

Overlapping inference requests
Figure 1: Overlapping 3 inference requests executed on 4 machines.

Pipeline Parallel Training

When running the model at inference, data only flows forwards through the model. However during training a backwards pass (backpropagation step) is used to compute gradients. This literally passes data in the opposite direction through the model from the output back to the input: therefore the order in which devices execute is reversed. This makes the scheduling of multiple batches through the model much more complicated. A number of different scheduling algorithms were designed to handle this complexity, such as 1F1B, GPipe, Zero Bubble Pipeline Parallelism.

Different scheduling algorithms
Figure 2: Different scheduling algorithms for pipeline parallelism. Diagram credits: Xinyi Wan

This challenge of scheduling batches is not the focus of this post. The schedule used here is very simple, with only a single batch processed at once as shown in the diagram below. As we'll see, there is a lot of complexity to performing the backwards pass itself using MLX in a pipeline parallel setting (PPP is not natively supported by MLX). We instead need to roll our own per-stage backwards pass logic to perform the distributed backpropagation.

Naive pipeline schedule
Figure 3: A naive pipeline schedule is adopted for this post.

Gradients in Pipeline Parallel

As mentioned before, the model is split into $p$ stages executed across $p$ devices:

$$y_i = g_i(y_{i-1})$$

In the forwards pass, device $i$ receives the output $y_{i-1}$ from the previous device, and performs $y_i = g_i(y_{i-1})$. The result $y_i$ is then sent to the next device $i+1$.

The model $m$ is parameterized by the weight set $\theta$ - the goal of the backwards pass is to calculate gradients for the weights $\frac{\partial L}{\partial \theta}$. Stage $i$ is parameterized by a subset of $\theta$, $g_i = g_i(\cdot; \theta_i)$. At the start of stage $i$'s backwards pass, it receives intermediate gradients $dy_i = \frac{\partial L}{\partial y_i}$ from rank $i+1$. This is used to calculate two quantities: (1) the gradients of parameters $\frac{\partial L}{\partial \theta_i}$, used by the optimizer and (2) intermediate gradients at the pipeline boundary $dy_{i-1}$ that will be passed to rank $i-1$.

$$ \begin{align} \frac{\partial L}{\partial \theta_i} &= \frac{\partial L}{\partial y_i} \cdot \frac{\partial g_i(y_{i-1};\theta_i)}{\partial \theta_i} \tag{1} \\ \frac{\partial L}{\partial y_{i-1}} &= \frac{\partial L}{\partial y_i} \cdot \frac{\partial g_i(y_{i-1};\theta_i)}{\partial y_{i-1}} \tag{2} \end{align} $$

The details of differentiating the layer itself is hidden by the $\partial g_i$ term, which will be handled by the ML framework. Our attention is required at the boundaries between pipeline stages for how partial gradients are passed from other devices into the backpropagation algorithm.

Computation stage information flow
Figure 4: Information flow for a single computation stage

We now know what computations and communications are needed to perform pipeline parallel training. So, how do we do with this MLX?

value_and_grad

MLX's execution model is different to PyTorch. MLX is lazily executed: as operations are called, MLX builds a graph of intended computations without executing anything. The eval() function is then used to tell MLX to perform the computation. This allows MLX to have a full picture of all operations to execute, making kernel scheduling and fusion more efficient.

To perform backprop, the MLX core library provides the value_and_grad higher-order function. If $f$ is a function that performs the model forwards pass and returns a scalar loss, then:

$$ \text{value\_and\_grad} : (f : (m, X, Y) \mapsto L) \mapsto \left(df: (m, X, Y) \mapsto \left(L, \frac{\partial L}{\partial \theta}\right)\right) $$

In other words, value_and_grad takes a computational graph that computes the forwards and loss and builds from it a graph computing the forwards, loss, and backwards pass. A typical usage of value_and_grad would look as follows:

def loss_fn(model, X, y):
    y = model(X)
    return mlx.core.crossentropy_loss(yhat, y)

# loss_and_grad is a function for the combined forwards & backwards pass for a given batch
loss_and_grad = mlx.nn.value_and_grad(loss_fn)

# build computational graph
loss, grads = loss_and_grad(model, X, y)

mx.eval(loss) # forwards pass
mx.eval(grads) # backwards pass

# or, for efficiency, both forwards & backwards can be performed at the same time
# mx.eval(loss, grads)

Our challenge now is to perform the backwards pass pipeline parallel - involving changing the value_and_grad to handle gradients at the boundaries between devices.

Backpropagating from partial derivatives $dy_i$

The model loss only exists on the final pipeline stage, as this is the only rank that realizes the final output of the model. Backprop on the final rank is simple, starting from the final loss value_and_grad can be used. But other pipeline stages don't have a loss - instead they receive a partial gradient $dy_i = \frac {\partial L} {\partial y_i}$, and must use this as the starting point for backpropagation.

It's fairly straightforward how this would look PyTorch. Gradients are stored in the .grad property of a tensor. To use the $dy_i$ values for backprop, set the gradient of $y_i$ and start backprop from there. PyTorch exposes this using .backward(dy), so a single stage would look as follows:

# y_i: output of this stage
# x:   input tensor for this stage (requires_grad=True)

dy_i = torch.empty_like(y_i)
dist.recv(dy_i, src=i+1)     # get dy_i from next stage

y_i.backward(dy_i)           # backprop locally
dy_prev = x.grad             # gradient w.r.t input

dist.send(dy_prev, dst=i-1)  # send dy_{i-1} to previous stage

However given that value_and_grad is the idiomatic way to compute a backward pass in MLX, it's non-obvious how the backwards pass can be performed starting from derivatives $dy_i$ received from stage $i$.

The Solution: Auxiliary Loss Function

Given that MLX requires a scalar loss as a starting point to compute gradients with value_and_grad, can a loss function be constructed with the desired properties? When differentiating backwards from this auxiliary loss $\hat L$, the gradient of $y_i$ (already computed from the forwards pass) must be $dy_i$ (the partial derivatives received from rank $i+1$). This can be used to find the correct form for auxiliary loss:

$$ \begin{align*} \frac{\partial \hat L}{\partial y_i} &= dy_i = \frac{\partial L}{\partial y_i} \\ \implies \hat L &= y_i \cdot dy_i \end{align*} $$

So what is this saying? Backpropagating from the auxiliary loss of a dot product between $y_i$ and $dy_i$ has the effect of setting the correct gradient for $y_i$. Of course, the backprop algorithm won't stop at $y_i$ - it continues to use the $y_i$ gradients to exhaustively compute gradients for the entire computational graph - in this case, the whole pipeline stage. This includes the gradient for the input to the stage $x_i = y_{i-1}$. This gradient $dy_{i-1}$ will then be passed to the preceding rank.

Capturing intermediate gradient $dy_{i-1}$

nn.value_and_grad computes gradients with respect to the model parameters $\frac{\partial L}{\partial \theta}$. This should suffice - all we need is the parameter gradients in order to perform an optimizer step, right? Actually, for pipeline parallel training we require also computing gradients w.r.t. the stage input, $dy_{i-1} = \frac{\partial L}{\partial y_{i-1}}$, in order to pass this to rank $i-1$.

nn.value_and_grad therefore isn't appropriate; it only knows how to compute gradients w.r.t. model parameters. However, nn.value_and_grad is actually a thin wrapper around mlx.core.value_and_grad:

$$ \begin{gathered} \text{mlx.core.value\_and\_grad} : \\[4pt] \left( f : \left(X_1, \ldots, X_n\right) \mapsto L \right) \times \left(N \subseteq \{1, \ldots, n\}\right) \\[4pt] \mapsto \left(df: \left(X_1, \ldots, X_n\right) \mapsto \left(L, \left\{ \frac{\partial L}{\partial X_i} : i \in N \right\} \right)\right) \end{gathered} $$

In other words, mlx.core.value_and_grad allows differentiation w.r.t. arbitrary inputs to the function - it allows us to select which inputs require a gradient to be computed. One of these inputs has to be the model parameters (for parameter gradients), and we can also choose to differentiate w.r.t. the $X$ value. In the below example, argnums=(0,1) is used to select differentiation w.r.t. the first two inputs: model parameters and x.

def loss(params, x, yhat):
    model.update(params)
    y = model(x)
    return mx.crossentropy_loss(yhat, y)

loss, (dtheta, dx) = mx.value_and_grad(local_loss, argnums=(0,1))(
    model.trainable_parameters(), x, yhat
)

Putting it together: backprop for intermediate pipeline stage

We combine the injection of gradients $dy_i$ from the next stage with the calculation of input gradients $dy_{i-1}$ in order to build the full pipeline stage. This code makes the simplifying assumption that we are not the first or last rank: $i \neq 0, p-1$.

dy_loc = mx.distributed.recv_like(y, src=dist.rank + 1) # receive partial derivatives from rank i+1

def local_loss(params, x):
    model.update(params)

    # we don't pass tokens but instead intermediate embeddings x
    y_loc = model(tokens=None, input_embeddings=x)

    return mx.sum(y_loc * mx.stop_gradient(dy_loc))

loss, (g_s, dx) = mx.value_and_grad(local_loss, argnums=(0,1))(
    model.trainable_parameters(), x
)

dx_tok = mx.distributed.send(dx, dest=dist.rank - 1)

mx.stop_gradient enforces that dy_loc is considered as a pure tensor of numbers, and not something that can be backpropagated through. Without stop_gradient, the backprop engine would attempt to send gradients through mx.distributed.recv - causing an error.

Computational Graph & Dependencies

We now know how to build a computational graph for each of the pipeline stages that propagates $y$-values forwards, and partial derivatives $dy$ backwards. So maybe we're ready to call mx.eval() on all the devices and watch the magic happen? Unfortunately, if we just try to eval like this, we get a deadlock and we can't progress. MLX distributed is pretty new, so there was limited tooling to narrow down exactly why this was happening. There must have been some send/recv in the graph that was failing somehow, and it took a lot of work to figure out what's going wrong!

Notice that figure 4 is partitioned into two halves - the forwards pass on the left, and the backwards pass on the right - that both start with a recv and finish with a send. For both sides, the recv operation is a leaf node with no dependencies. The MLX scheduler is therefore free to start with either the recv(x) or recv(dy) operations, as neither depends on the other.

The recv operations in MLX are blocking: if a rank $i$ rank starts with recv(dy), it won't execute anything else until it's received. It's waiting to receive the derivatives, but they depend on later ranks' backwards passes, which depend on rank $i$'s forwards pass. But since rank $i$ is blocked on recv(dy), it can't execute the forwards pass. recv(x) never gets run so the forwards fails to happen.

The solution to this is to artificially insert the dependency between forwards and backwards pass, to ensure the MLX scheduler completes the forwards pass before starting the backwards pass.

dy = mlx.core.depends(dy, tok_y)
Data flow with artificial dependency
Figure 5: An artificial dependency is inserted between $y_i$ and $dy_i$ to enforce compute order.

Result - Llama 3.2 1B

We've built a computational graph on each device that can be executed in parallel and gives the desired outcome - allowing pipeline parallel training of Llama 3.2 1B! It runs stably and most importantly, the loss goes down. Unfortunately, scaling up to a large model (DeepSeek at 671GB) was non-trivial.

Large Model Training

The algorithm as described so far builds one computational graph per rank that encapsulates the forwards and backwards pass, and the dependency between the two. This will be executed by calling a single mx.eval() on all ranks simultaneously. This works great, and was provided the original proof-of-concept for pipeline parallel training of Llama 3.2 1b on 4 devices. When scaled up to DeepSeek (671GB) on 4 devices, the system started crashing unexpectedly.

When all 4 devices eval() at the same time, all ranks other than 0 open a recv() socket immediately, waiting for their $y_i$. Due to the length of the DeepSeek computation, some devices could be waiting up to 20s for their recv. The timeout on recv is short enough that the final device will time out by the time it's sent $y_{p-2}$

The solution here is to split each device's graph into two parts - the forwards and backwards pass - and execute each of these in sequence. For this, we need to understand MLX send/recv tokens:

Send & Recv Tokens

The function signature of mlx.dist.send is:

send(x: array, dst: int) -> array

...why does it return an array? The send operation returns a send token, call this x_tok. This tensor has identical value to the input x, but evaluating x_tok will cause the send operation to be executed. Evaluating just x will not cause any communication.

Two-Stage Graph

Returning the tokens for each send/recv allows fine-grained control over which parts of the computational graph are executed. Each pipeline stage is executed in 4 stages:

  1. Receive $y_{i-1}$ from stage $i-1$
  2. Compute $y_i$ and send to next stage
  3. Receive $dy_i$ from stage $i+1$
  4. Compute $dy_{i-1}$ and send to stage $i-1$
loss, g_s, tokens = build_graph(model, inputs, targets, lengths)

fwd_eval = tokens['y']
fwd_recv_eval = tokens['x']
bwd_eval = tokens['dx']
bwd_recv_eval = tokens['dy']

if dist.rank == dist.size - 1:
    fwd_eval.append(loss)

# forward phase
for stage in range(p-1):
    dist.barrier()
    if dist.rank == stage + 1:
        mx.eval(*fwd_recv_eval)
    if dist.rank == stage:
        mx.eval(*fwd_eval)

# backward phase
for stage in range(p-1, 0, -1):
    dist.barrier()
    if dist.rank == stage - 1:
        mx.eval(*bwd_recv_eval)
    if dist.rank == stage:
        mx.eval(*bwd_eval)

Not only does this alleviate the issue with crashes as each device only has to wait for a single forwards pass to receive its values, but also gives finer control over the scheduling of computations, paving the way towards more elaborate pipeline schedules (see below section on future work).

Results

With finer grain control over the execution of forwards and backwards pass, we are able to successfully finetune DeepSeek! On 2x M3 Ultra Mac Studios, our train throughput is ~30 tokens/second. This is with 50% utilization due to the pipeline schedule used (figure 3); with better scheduling, throughput could theoretically be doubled.

Training loss curve Before and after output comparison
Figure 6: (left) DeepSeek finetune training curve, (right) before/after output demonstrating a successful SFT.

Future Work - Batch Pipeline Scheduling with MLX

Implementing more complicated pipeline schedules such as seen in figure 2 is left to future work. The primary challenge with this will be handling model state (KV, activations) for multiple minibatches simultaneously. Model activations need to be persisted and used for the correct batch in backwards pass.

How would this be implemented in PyTorch? A hook would be registered to save activations during the forwards pass, that could then be reloaded into the model when performing backwards pass.

saved_acts = {}

def save_and_use_activation(name):
    def hook(module, input, output):
        # Save for later
        saved_acts[name] = output.detach()
        return output
    return hook

handle = model.layer1.register_forward_hook(save_activation('layer1'))

MLX doesn't have an equivalent of hooks, so it's hard to see how this might look. The activations are actually part of the forwards computational graph, so I could imagine the activations being returned by the model forward() call.

y, cache, activations = model(x, cache)

Activations can be evaluated as part of the forwards pass, and will still be referred to by the unrealized graph for the backwards pass. The backwards graph is stored and evaluated when the time comes, and the activations should line up. Care will need to be taken over how tensors are released from memory - tensors can't be released too early, but if they don't get freed properly then a memory leak will occur.

Appendix - MLX Distributed Setup

There are a number of hurdles with setting up the networking for MLX distributed clusters. Whilst writing the mlx-train repo I built a few dev tools to simplify the process of setting up a distributed cluster. There are a number of gotcha's, which I'll detail here. This section is primarily intended for the MLX team and people trying to run mlx-train themselves.

mlx.distributed_config

The mlx.distributed_config script written by the MLX team is used to set up the network between the devices. I only used TB ring backend, but this also supports an MPI backend. It strips away the thunderbolt bridge layer, and replaces it with a subnet for every thunderbolt cable. This subnet only has two devices - the two devices on either end of the cable.

mlx.launch

mlx.launch is used to execute a distributed MLX program. It works by opening an SSH connection to each device, running the mlx program remotely on each device, and streaming the output back to the master. Below is a list of challenges I encountered: