Fused Quantization in LLMs
A practical demonstration of how fusing dequantization and matrix multiplication works.
Introduction: The Need for Speed in Large Language Models
Following up on LLM Quantization blog post, I'll now dive deeper into an optimization technique. As I discussed in the previous guide, quantization is a powerful method for reducing the size and improving the speed of Large Language Models (LLMs).
But simply quantizing isn't enough. The way these quantized operations are performed also matters. This is where fused quantization comes into play. It's about combining multiple steps—like dequantization and matrix multiplication—into a single, optimized operation. While often implemented in highly optimized CUDA or Triton kernels for GPUs, we can demonstrate the underlying principle using standard NumPy on a CPU.
In this post, I'll walk through a NumPy example to show how fusing operations can lead to substantial performance gains, using the MXFP4 which is OpenAI style of quantization.
Understanding MXFP4 Quantization
Before diving into the fused operation, let's quickly review MXFP4. This 4-bit floating-point format is designed to handle the wide dynamic range often found in LLM weights.
- Block-based Quantization: Instead of a single scale for the entire tensor, the data is divided into small blocks (e.g., 32 elements).
- Shared 8-bit Scale: Each block has its own 8-bit scale factor (E8M0 format) that effectively acts as a shared exponent for all values within that block.
- E2M1 Encoding: Individual values within the block are represented using 4 bits (1 sign bit, 2 exponent bits, 1 mantissa bit).
To dequantize an MXFP4 value, you reconstruct its 4-bit representation (using a lookup table) and then multiply it by the block's 8-bit scale.
The Code: NumPy Experiment
I'll use a snippet of Python code (similar to what you'd find in an optimized inference library) that simulates operations on a quantized Mixture-of-Experts (MoE) MLP layer. For demonstration, we'll load pre-quantized weights and scales from a dummy GPT-OSS model.
# Import necessary libraries
import gdown
import numpy as np
# Link to a real GPT-OSS quantized tensor, I copied from one MLP layer.
MXFP4_TENSOR_LINK = 'https://drive.google.com/uc?id=1EMCfy_FWfkpICZ7j6oeINsiRS7jAUoqe'
tensor_npz = 'tensor.npz'
# Download and load the quantized weights and scales
# Note: The Colab environment handles this. We are just showing the code.
# gdown.download(MXFP4_TENSOR_LINK, tensor_npz, quiet=False) # I am just showing the code.
# mlp_tensors = np.load(tensor_npz)
# mlp_weight, mlp_scale = mlp_tensors['arr_0'], mlp_tensors['arr_1']
# print(f'weight tensor: {mlp_weight.shape}, scale tensor: {mlp_scale.shape}')
# Expected output: weight tensor: (128, 2880, 90, 16), scale tensor: (128, 2880, 90)
The mlp_weight tensor has a shape (128, 2880, 90, 16). This means the model has 128 "experts" (common in MoE models). For each expert, there's a matrix of size 2880 x (90 * 16). The 90 represents the number of 32-element blocks (intermediate_dimension / 32 = 2880 / 32 = 90), and 16 indicates that each packed uint8 value in the mlp_weight tensor actually stores *two* 4-bit values (low nibble and high nibble, totaling 32 elements per 16 uint8 values). The mlp_scale tensor (128, 2880, 90) stores one 8-bit scale per block.
Dequantization Utilities
Next, let's look at the core dequantization logic:
def make_fp4_e2m1_lut() -> np.ndarray:
"""Make lookup table for fp4 e2m1."""
lut = np.zeros(16, dtype=np.float32)
for code in range(16):
s = (code >> 3) & 0x1
E = (code >> 1) & 0x3
M = code & 0x1
bias = 1 # E2M1 bias is 1
if E == 0:
val = (M / 2.0) # subnormal values
else:
frac = 1.0 + (M / 2.0)
exp = E - bias
val = np.ldexp(frac, exp)
lut[code] = (-1.0)**s * val
return lut
_FP4_LUT = make_fp4_e2m1_lut()
def e8m0_decode(scales_u8: np.ndarray) -> np.ndarray:
"""Scale u8 dequantize E8M0."""
return np.exp2(scales_u8.astype(np.int16) - 127) # E8M0 has a bias of 127
def mxfp4_dequantize(packed_fp4: np.ndarray, scales_u8: np.ndarray) -> np.ndarray:
"""Dequantize MXFP4 tensor with scale."""
assert packed_fp4.dtype == np.uint8
assert scales_u8.dtype == np.uint8
assert packed_fp4.shape[:-1] == scales_u8.shape, \
f"scales shape {scales_u8.shape} must match packed_fp4.shape[:-1] {packed_fp4.shape[:-1]}"
# Unpack nibbles: each byte holds 2 FP4 values (low nibble first)
low = packed_fp4 & 0x0F # Get the lower 4 bits
high = packed_fp4 >> 4 # Get the upper 4 bits
nibbles = np.concatenate([low, high], axis=-1) # (..., 2*B)
# FP4 decode via LUT
elems = _FP4_LUT[nibbles]
# Decode scales and broadcast
scales = e8m0_decode(scales_u8)[..., None] # Expand last dim for broadcasting
return elems * scales
The mxfp4_dequantize function takes packed uint8 weights and uint8 scales. It first unpacks the 4-bit values (nibbles) from each byte, then uses a pre-computed lookup table (_FP4_LUT) to convert these 4-bit codes into their actual floating-point representations. Finally, it multiplies these values by the decoded 8-bit scale factor to get the full-precision float32 values. This process effectively reconstructs the original, higher-precision weight block.
The Fused Operation: `mxfp4_mlp_matmul_activation`
This is where the "fused" part comes in. This function performs dequantization and matrix multiplication in an interleaved fashion:
def mxfp4_mlp_matmul_activation(
x: np.ndarray,
weight_packed: np.ndarray,
scale_u8: np.ndarray,
expert_idx: int,
bias: np.ndarray | None = None
) -> np.ndarray:
"""Fused Quantizatze and Matmul for a single expert."""
assert weight_packed.dtype == np.uint8 and scale_u8.dtype == np.uint8
assert weight_packed.shape[:-1] == scale_u8.shape
assert 0 <= expert_idx < weight_packed.shape[0]
intermediate = weight_packed.shape[1] # 2880 for GPT-OSS example
x = np.asarray(x).astype(np.float32, copy=False)
assert x.shape[-1] == intermediate, f"expected last dim {intermediate}, got {x.shape[-1]}"
# Output buffer (..., output_dim), initialized to zeros
out_shape = x.shape[:-1] + (intermediate,) # (..., 2880)
y = np.zeros(out_shape, dtype=np.float32)
# Grab the specific expert's weights and scales once
Wp_e = weight_packed[expert_idx] # (intermediate, n_blocks, b) -> (2880, 90, 16)
Sc_e = scale_u8[expert_idx] # (intermediate, n_blocks) -> (2880, 90)
# Process 32-wide input blocks
# For each block j: dequantize (O,32) then y += einsum('...k,ok->...o', x_block, W_block)
for j in range(Wp_e.shape[1]): # Iterate over 90 blocks
# Select the relevant slice of the input activation for this block
x_block = x[..., (j*32):((j+1)*32)] # (..., 32)
# Dequantize only the *current block's* weights (2880x32)
# Wp_e[:, j, :] is (2880, 16), Sc_e[:, j] is (2880,)
W_block = mxfp4_dequantize(Wp_e[:, j, :], Sc_e[:, j]) # Dequantizes to (2880, 32)
# Perform matrix multiplication for this block and accumulate
y += np.einsum('...k,ok->...o', x_block, W_block, optimize=True)
if bias is not None:
y += np.asarray(bias, dtype=np.float32)
return y
In this function, instead of dequantizing the entire mlp_weight tensor at once, I iterate through its blocks. For each block (j in the loop): I extract the corresponding slice of the input activation (x_block), I dequantize only the current block's weights (W_block), and I immediately perform a matrix multiplication (np.einsum) between x_block and W_block, adding the result to the overall output y.
This means that only a small portion of the weights is dequantized and loaded into high-precision memory at any given time, and the multiplication happens right after.
The Comparison: Fused vs. Separate Operations
Now, let's compare the performance using dummy data. We'll generate random input activations and select a few "experts" (since it's an MoE model, only a few experts are active for a given input).
# Create a random activation and pick 4 random expert indices
dummy_activations = np.random.randn(10, 2880).astype(np.float32)
experts = np.random.randint(0, 128, 4) # Select 4 random experts
Scenario 1: Fused Dequantization and Matmul
# Measure performance of the fused operation
# %%time
output_fused = np.zeros((10, 2880), dtype=np.float32)
for expert in experts:
output_fused += mxfp4_mlp_matmul_activation(dummy_activations, mlp_weight, mlp_scale, expert)
print(f"\nOutput shape (fused): {output_fused.shape}")
print(output_fused[0, :5])
Typical output (your times may vary based on CPU/system load):
Output shape (fused): (10, 2880)
[ 83.90713692 29.92747545 -7.78181982 -54.15209579 -121.69489574]
CPU times: user 1.11 s, sys: 802 µs, total: 1.11 s
Wall time: 592 ms
Scenario 2: Per-Expert Dequantization and Matmul
This is a middle-ground approach. Instead of dequantizing the entire 128-expert weight tensor, we only dequantize the weights for one expert at a time, right before its matrix multiplication.
# Measure performance of per-expert dequantization
# %%time
output_per_expert = np.zeros((10, 2880), dtype=np.float32)
for expert in experts:
# Dequantize only the weights for the current expert
mat = mxfp4_dequantize(mlp_weight[expert], mlp_scale[expert])
expert_mat = mat.reshape((2880, -1))
# Perform matrix multiplication
output_per_expert += np.einsum('...k,ok->...o', dummy_activations, expert_mat, optimize=True)
print(f"\nOutput shape (per-expert): {output_per_expert.shape}")
print(output_per_expert[0, :5])
Typical output:
Output shape: (10, 2880)
[ 83.90706062 29.92745209 -7.78180981 -54.15212631 -121.69487953]
CPU times: user 601 ms, sys: 0 ns, total: 601 ms
Wall time: 406 ms
Scenario 3: Separate (Full) Dequantization and Matmul
# Measure performance of separate dequantization and matmul
# %%time
# Step 1: Dequantize ALL weights for ALL experts first
# This creates a very large, full-precision tensor in memory
full_precision_weights = mxfp4_dequantize(mlp_weight, mlp_scale) # Shape: (128, 2880, 2880)
output_separate = np.zeros((10, 2880), dtype=np.float32)
for expert in experts:
# Step 2: Reshape the full-precision expert matrix
expert_mat = full_precision_weights[expert].reshape((2880, -1))
# Step 3: Perform matrix multiplication
output_separate += np.einsum('...k,ok->...o', dummy_activations, expert_mat, optimize=True)
print(f"\nOutput shape (separate): {output_separate.shape}")
print(output_separate[0, :5])
Typical output:
Output shape (separate): (10, 2880)
[ 83.90706062 29.92745209 -7.78180981 -54.15212631 -121.69487953]
CPU times: user 5.83 s, sys: 2.56 s, total: 8.4 s
Wall time: 8.32 s
The Results: Analysis
Comparing the "Wall time" of the three approaches:
- Fused Operation (Block-by-Block): ~592 ms
- Per-Expert Dequantization: ~406 ms
- Separate (Full) Dequantization: ~8.32 s
The fully separate operation is dramatically slower, as expected. Interestingly, in this CPU-based NumPy test, the "Per-Expert" dequantization is slightly faster than the "fully fused" block-by-block approach. This is likely due to NumPy's overhead in the Python loop and its ability to efficiently handle the larger, but still manageable, single-expert matrix multiplication.
However, this result can be misleading. In a highly parallel environment like a GPU, the overhead of launching separate kernels for dequantization and matmul (even per-expert) becomes significant. The fully fused approach, which minimizes data movement and kernel launches by design, would almost certainly outperform the other methods in a real-world, optimized CUDA or Triton implementation. The key takeaway remains: fusing operations to minimize memory traffic is paramount.
Why is Fused Quantization So Much Faster?
While our demonstration uses NumPy and CPU timings, the takeaway is universally applicable to optimized inference for LLMs: reducing redundant data movement and combining computational steps is critical for performance. Fused quantization allows LLMs to run faster and more efficiently, making them more accessible and deployable in real-world applications.
- Reduced Memory Bandwidth: When dequantizing separately, the entire quantized weight tensor must first be converted to a higher precision (e.g., FP32) and stored in memory. This large intermediate tensor then needs to be read again for the matrix multiplication. Fused operations avoid creating this large intermediate tensor; they dequantize small blocks of weights just-in-time and use them immediately for computation. This significantly reduces the amount of data moved between different memory levels (e.g., from main memory to CPU caches or GPU shared memory), which is often a major bottleneck.
- Better Cache Utilization: By operating on smaller blocks of data, fused kernels can keep relevant weights and activations in faster, smaller cache memories for longer. This minimizes costly fetches from slower main memory.
- Elimination of Kernel Launch Overheads: On GPUs, launching separate kernels for dequantization and matrix multiplication incurs overheads. Fusing them means only one kernel launch is needed, reducing this administrative burden.