2026-06-18·12 min read

Scaling

Notes on model scaling, arithmetic intensity, and accelerator bounds.

Introduction

These are my notes on How to Scale Your Model, written in my own words. This document is only meant to be written, but not read.

1. Bounds

Terminology

There are three ways time passes in serving algorithms:

  1. The computation itself, represented by:

    $T_{math} = \frac{\text{Computation FLOPS}}{\text{Accelerator FLOP/s}}$

  2. The bandwidth within a chip from accelerator memory (high-bandwidth memory) to compute core.

  3. The communication between chips:

    $T_{comms} = \frac{\text{Communication Bytes}}{\text{Memory Bandwidth Bytes/s}}$

To maximize training efficiency, we look at each of these metrics and determine the "limiting reactant". Does the bottleneck lie in computation or in communication?

Arithmetic intensity

One way to determine the bottleneck is via arithmetic intensity (AI):

$$\text{Arithmetic Intensity} = \frac{\text{Computation FLOPS}}{\text{Communication Bytes}}$$

In other words, for every byte moved, how much work are we performing on it? This is the arithmetic intensity of the algorithm. A high AI (not to be confused with artificial intelligence) indicates a bottleneck in compute, since we have a high load of computation per byte moved, while a low AI represents a bottleneck in communication. We find the critical arithmetic intensity by finding the arithmetic intensity of an algorithm that exceeds the accelerator's intensity, which is an indication of being compute bound.

A simple example for AI is finding the dot product between two vectors: $bf16[N] \cdot bf16[N] = bf16[1]$. There are $N$ product operations, $N-1$ summations, and each bf16 consists of 2N bytes: $AI = \frac{N + N - 1}{2N + 2N + 2} \approx \frac{1}{2}$.

*Consider the additional bytes we use to write the result back into memory.

Roofline plots

Roofline plots are used to visualize computation and communication bounds. The x-axis represents AI and the y-axis represents peak FLOPs realized. An important computation not easily identifiable in the book is that the realized FLOPS is the product between bandwidth and AI. The ratio of FLOPS per byte moved multiplied by the number of bytes moved (bandwidth) gives us the total amount of FLOPS realized.

At lower AI, we are communication bound. As we increase the AI, we eventually become compute bound.

This inflection point can be determined by many variables such as dimension size, batch size and bandwidth. A higher bandwidth achieves the inflection point earlier since the compute bound ceiling is fixed for the hardware's peak FLOPs/s. Higher bandwidths scale more steeply in realized FLOPs, hitting this fixed ceiling sooner than lower bandwidths.

A quick note on sharding

Inflection points are dependent on batch size (B) when performing matmul on a single chip. For a [B, D] x [D, F] matmul, the [D, F] weight matrix is reused across all B inputs. The FLOPs done over the fixed bytes loaded from the weight matrix scales only with B. This is considering that we are performing large matmuls where B is comparatively small to D and output dimension. Under this regime, the arithmetic intensity approximately equals to batch size.

Conversely, when we shard the matmul, the inflection point is dependent on the dimension size (D). Batch size increases compute and communication equally, hence B is irrelevant. D, however, increases compute without increasing communication between chips. Importantly, D does not scale communication while scaling computation FLOPs because the reduction only occurs locally within the chips: the result maintains the same shape (X[:, :D//2] @ Y[:D//2, :]).

2. TPUs

Architecture

TPUs are highly specialized for matrix multiplications. They consist of a simple architecture consisting of a TensorCore that communicates with a high bandwidth memory (HBM). Within the TensorCore consist of the matrix multiply unit (MXU), vector unit (VPU), scalar unit and vector memory (VMEM).

The MXU performs large matmuls of size bf8[16, 128] @ bf[128, 128] for most generations. Weight matrices need to be padded to at least size of 128 for such a computation. The VPU performs general operations such as activations and vector operations. The scalar unit acts like a CPU communicating instructions between the VMEM and the MXU. The VMEM is a high bandwidth but lightweight storage that communicates between the HBM and the TensorCore.

If matrices can be stored inside the light storage of the VMEM, lower batch sizes can be used to achieve critical AI due to the higher bandwidth compared to the communication bottleneck of the HBM. The problem, though, is that VMEM has far less storage than HBM.

The aforementioned architecture describes the TensorCore. A TPU chip consists of 2-4 TensorCores attached to a shared HBM (generally). Four TPU chips are arranged in a set called a tray, connected to a CPU host via the PCIe network. Such chips are connected at an even larger scale to their four nearest neighbors via the ICI network in a pod.

Chips within a pod are connected via a structure that enables a maximum distance of $N/2$. Larger pods of size 16x16x16 are called super pods.

Systolic Arrays

I want to briefly mention the beautiful design of the systolic array utilized in the TPU's MXU to perform matmuls efficiently. The systolic array reuses values by "flowing" them through the array and tracking their accumulated values. fleetwood has an excellent animation.

For a traditional matmul of two $N \times N$ square matrices, we perform $N^3$ operations and read $2N^3$ values from memory (each $N^3$ operation requires two values read from memory, though this is the worst-case scenario with no reuse). With a systolic array, we still perform $N^3$ operations but only read each matrix from memory one time, totaling to $2N^2$, a factor of $N$ more efficient.

There are drawbacks to systolic memory, however, such as requiring the whole [128, 128] array to be populated. For a [64, 64] matmul, a significant portion of the array is empty with zero padding, wasting computational resources. Moreover, if we wish to multiply a larger matrix such as [512, 512], we tile the larger matmul into 128x128 chunks. Yet this espouses a memory bandwidth problem of communication between tiles.