CannotWaitForTheseOptimisers
Documentation for CannotWaitForTheseOptimisers.
CannotWaitForTheseOptimisers.AdaptiveGradNormControl
CannotWaitForTheseOptimisers.Apollo
CannotWaitForTheseOptimisers.GradNormControl
CannotWaitForTheseOptimisers.Muon
CannotWaitForTheseOptimisers.NormGrowthCap
CannotWaitForTheseOptimisers.AdaptiveGradNormControl
— TypeAdaptiveGradNormControl(accumulator, τ = 1.0; epsilon = 1e-8, lb = 0.1,
momentum = 0.90, throw = true, clipreportthresh = Inf)
Gradient norm control using exponential moving statistics. Clips gradients when the current norm exceeds mean + τ * std.
CannotWaitForTheseOptimisers.Apollo
— TypeApollo(opt::AdamW = AdamW(), r::Function = dim -> ceil(Int, sqrt(dim)); u = 100, sort_dims = true)
Apollo(η::Real, args...; kw...)
Apollo(arg, rank::Int; kw...)
Apollo(η::Real, rank::Int; kw...)
Apollo optimizer from Zhu et al. (https://arxiv.org/abs/2412.05270). Tracks moments in a low-rank subspace, aiming for Adam-like behavior with minimal additional memory usage. First argument can be an AdamW optimizer, or a learning rate (which will use the default AdamW optimizer with that learning rate). Second argument can be a rank, or a function to compute the rank from the second dimension (or the product of all dims > 1) of the weight matrix (or tensor).
CannotWaitForTheseOptimisers.GradNormControl
— TypeGradNormControl(accumulator, τ = 1.1; epsilon = 1e-8, lb = 0.1, throw = true, scale = true, clipreportthresh = Inf)
NormGrowthCap with additional control, accumulation, and reporting options. accumulator
must be an array of Float64
with two elements, which is where the unscaled and scaled gradient norms are added into, allowing you to monitor the sum of the norms. It is your job to print/reset this.
CannotWaitForTheseOptimisers.Muon
— TypeMuon(opt = AdamW(eta = 0.0003, beta = (0.9,0.95), lambda = 0.01), η = 0.02, μ = 0.95, λ = 0.01, fallback = Returns(false))
Muon(; [opt, eta, mu, lambda, fallback])
Muon - MomentUm Orthogonalized by Newton-schulz (https://github.com/KellerJordan/Muon)
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, in which each 2D parameter's update is replaced with the nearest orthogonal matrix using Newton-Schulz iteration.
Parameters
- Fallback optimizer (
opt
): Optimizer to use for 1D parameters or when thefallback
function returns true - Learning rate (
η == eta
): Amount by which gradients are discounted before updating the weights - Momentum (
μ == mu
): Controls the acceleration of gradient descent in the prominent direction - Weight decay (
λ == lambda
): Controls the strength of $L_2$ regularisation. - Fallback function (
fallback
): Function to control when, in addition to 1D arrays, the fallback optimizer should be used. Will be passed the parameter array and must return a boolean.
Note: Works best with large batch sizes and may not be suitable for fine-tuning. In nanoGPT speedrun experiments, Muon is used for the internal layer >2D weights, and AdamW is used for the 1D weights, embeddings, and heads.
Optimisers.adjust!(optimiser_state, η::Real)
will adjust the fallback optimizer's eta
to η * (opt.eta / eta)
, and Muon's eta
to η
, preserving their ratio, but Optimisers.adjust!(optimiser, eta = η)
will only adjust Muon's learning rate (allowing you to adjust the fallback optimizer's learning rate separately).
CannotWaitForTheseOptimisers.NormGrowthCap
— TypeNormGrowthCap(τ = 1.01; ϵ = 1e-8, lb = 1e-7, throw = true, scale = true)
Gradient norm growth limiter. τ
controls the maximum that the gradient norm can grow from one step to the next, such that if ||dx||/||dx_prev|| > τ
& ||dx|| > lb
, then dx = dx * τ*||dx_prev||/(||dx||+ϵ)
Inspired by Chen et al. and used with Apollo in Zhu et al., but with Optimisers.jl this will apply per-tensor instead of per-model. This implementation also introduces lb
as a hard minimum on the gradient norm threshold, and never rescales grads below this, preventing a tensor from getting "trapped" near zero. This can be a fixed min, or scaled by the square root of the number of parameters in the tensor (with scale = true
).