Inside FlashOptim, a new trick to reduce LLM training memory by 50%

Machine Learning


Training large language models is an expensive endeavor. This is mainly due to the large accelerator memory required for each parameter during the training process. To reduce costs, Databricks researchers introduced: flash optima set of memory optimization techniques designed for popular deep learning optimizers. FlashOptim serves as a drop-in replacement that reduces memory consumption per parameter by more than 50%. It accomplishes this without sacrificing training throughput or model quality. According to the researchers, this efficiency “allows practitioners and researchers with limited hardware to train larger models than previously possible.”

Before we explore how FlashOptim works, it helps to understand why training neural networks requires a lot of hardware. During training, all model parameters introduce a heavy burden of additional variables that must be stored in the GPU’s memory. First, we obtain the parameters themselves, which are the weights of the actual neural network being trained.

Developers rely on mixed-precision training to speed up computations, often using 16-bit floating point numbers to perform forward and backward passes. However, standard methods require keeping high-precision 32-bit master weights in memory to prevent errors when accumulating very small gradient updates. The training system then computes the gradient of every single parameter during the backward pass of backpropagation. The gradient determines the direction and magnitude of the update required. Gradients are stored as 32-bit floating point numbers, so they typically occupy an additional 4 bytes of memory for each parameter.

Third, modern optimizers such as Adam and AdamW track historical statistics to smooth the learning trajectory. Adam maintains two specific state variables for all parameters. They are momentum, which is the moving average of past slopes, and variance, which is the moving average of squared slopes. Typically, both states are maintained with 32-bit precision, so the optimizer alone consumes 8 bytes of memory per parameter. Finally, the model computes intermediate outputs known as “activations” during the forward pass. Because the backward pass requires gradients to be computed, the system must temporarily hold these activations in memory. Unlike weights, gradients, and optimizer state, which scale strictly according to the size of the model, activation memory scales based on the batch size (the number of training samples you feed to the model before updating the weights).

Combining parameters, gradients, and optimizer state, a typical training setup with Adam requires approximately 16 bytes of memory per parameter. This means that if a developer wants to train a language model with 7 billion parameters, they will need to provision at least 112 GB of accelerator memory just to hold the model and its optimization variables. This calculation also does not include the additional memory required to process the data batches.

The deep learning community has developed several workarounds to address these hardware constraints, but each comes with significant tradeoffs. One common method is distributed training using tensor sharding. Framework PyTorch Fully sharded data parallelism Split the memory load across a cluster of multiple GPUs. Although this is standard operating procedure within resource-rich technology organizations, it strictly requires access to a fleet of accelerators. For independent developers, researchers, or small teams working with a single GPU, this approach is physically impossible to implement.

Another option is CPU offloading. GPU memory is expensive and scarce, while host system memory is relatively cheap and plentiful. Offloading techniques temporarily move certain memory-intensive tensors from the GPU to the host machine’s RAM and move them back only when the accelerator is needed for a particular computation. The downside is that transferring gigabytes of data over the PCIe bus creates a large communication bottleneck. This shuffling adds overhead and complexity, ultimately slowing down the training loop.

A third common workaround involves parameter-efficient methods such as: low rank adaptation (LoRA). These techniques freeze most of the original weights instead of updating all parameters in a large model. The system then computes gradients and optimizer states only for a small subset of the original weights, or a small set of new auxiliary weights injected into the architecture. The problem is that intentionally ignoring large parts of the network fundamentally changes the training dynamics. Parameter-efficient fine-tuning is an approximation that does not follow exactly the same learning trajectory as full parameter fine-tuning, which can limit performance on complex tasks.

Databricks researchers took a different route and built FlashOptim as a suite of techniques to compress parameter-related memory directly within common deep learning optimizers. FlashOptim achieves this through improved floating-point splitting, compressed optimizer state quantization, and optimized kernel fusion.

Typically, developers maintain 32-bit master weights in parallel with downcast 16-bit versions used for the actual forward and backward passes. The 16-bit weights store very little information that is not already present in the master weights, so keeping both in memory is very redundant. Previous attempts to split these weights preserved 16 bits of base weight and 16 bits of error correction, but this method wasted valuable data bits trying to cover a huge range of standard floating point numbers.

The Databricks team made a wise observation. In other words, under the nearest rounding rule, the rounding error between a 32-bit master weight and its 16-bit downcast version must fall within a fine and predictable range. Instead of storing a wide range of floating point numbers, FlashOptim’s “improved floating point splitting” technique rescales this small error interval and maps it to the nearest 8-bit integer. Combining the 16-bit base weight with this 8-bit error correction, FlashOptim successfully rebuilds the 24-bit master weight. This innovation reduces the total weight memory requirement from 4 to 3 bytes per parameter with virtually no loss in accuracy.

The second major advance is “enhanced optimizer state quantization.” Traditional attempts to reduce optimizer state simply group numbers together and compress them into 8-bit integers. This linear quantization implicitly assumes that the optimizer values ​​are evenly distributed across the spectrum. However, measurements showed that the optimizer’s state distribution significantly violated this assumption. For example, the variance accumulates a squared slope, producing a highly skewed, heavy-tailed distribution. Forcing these highly skewed numbers into evenly spaced 8-bit bins introduces a large amount of quantization error. Before converting numbers, FlashOptim applies a mathematical trick called a companding function. This compresses extreme values ​​and reshapes the data distribution to be more uniform. After this companding step, the values ​​fit perfectly into the 8-bit bins and the error is significantly reduced. This reduces the optimizer state from 8 bytes per parameter to just 2 bytes, plus a small fraction of the bytes required for the group scaling factor.

FlashOptim packages these technologies into a fused, optimized kernel. Splitting weights, dequantizing states, performing mathematical updates, and all recompression requires moving large amounts of data back and forth. Naively implementing this would create a large memory bandwidth bottleneck. FlashOptim solves this problem by implementing the entire optimizer step as a single fused Triton kernel designed for Nvidia hardware. The GPU ingests compressed data into fast local memory, decompresses it, computes updates, compresses the results, and writes it all in one seamless operation. This allows FlashOptim to reduce memory consumption during training without causing any real slowdown.

To prove the framework’s real-world viability, researchers tested FlashOptim on several standard vision and language benchmarks. This includes: GPT-2 Architecture and performing supervised fine-tuning on a large-scale Llama-3.1-8B model. the other side stochastic gradient descent (SGD), AdamW, and a model trained with FlashOptim for Lion optimizers matched the loss trajectory, convergence rate, and final validation accuracy of a memory-intensive standard model.

During the Llama-3.1-8B tweak test, peak GPU memory was reduced from 175 GB to 113 GB, an overall reduction of 36%. A closer look at the breakdown shows that optimizer memory has decreased by 61% and parameter memory has decreased by 50%. These compressions do not slow down the training process because FlashOptim performs the mathematical operations within a highly efficient fusion kernel. In fact, the optimizer step time during the Llama-3.1 test decreased slightly from 12.5 ms to 11.5 ms.

For developers, FlashOptim’s greatest value may be its simplicity. It provides a drop-in replacement for common optimizers, so developers don’t have to rewrite training loops, change optimization semantics, or devise new tuning strategies. The researchers plan to release FlashOptim as an open source PyTorch library on GitHub.



Source link