Onion
Documentation for Onion.
Onion.AdaLN
Onion.Attention
Onion.Bottleneck
Onion.CrossFrameIPA
Onion.DecoderBlock
Onion.DyT
Onion.EncoderBlock
Onion.FSQ
Onion.FlexibleUNet
Onion.Framemover
Onion.GaussianFourierProjection
Onion.IPAblock
Onion.RMSNorm
Onion.ResidualBlock
Onion.RoPE
Onion.StarGLU
Onion.TimeEmbedding
Onion.TransformerBlock
Onion.chunk
Onion.cross_att_padding_mask
Onion.glut
Onion.reverse_tuple
Onion.sample_uniform_causal_chunk_mask
Onion.self_att_padding_mask
Onion.unchunk
Onion.AdaLN
— TypeAdaLN(dim::Int, cond_dim::Int)
Adaptive Layer Normalization.
aln = AdaLN(5, 3)
h = randn(Float32, 5,10,1)
cond = randn(Float32, 3,1)
h = aln(h, cond)
Onion.Attention
— TypeAttention(dim::Int, n_heads::Int, n_kv_heads=n_heads; qkv_bias=false)
Attention layer that supports both self-attention and cross-attention (as in Llama3).
Self-attention example
dim = 64
n_heads = 8
n_kv_heads = 4
attn = Attention(dim, n_heads, n_kv_heads)
output = attn(x) # Self-attention
Cross-attention example
output = attn(query, key, value) # Cross-attention
Onion.Bottleneck
— TypeBottleneck(channels::Int; time_emb=false, emb_dim=256, dropout=0.0, activation=relu)
A bottleneck block for UNet architecture with optional time embeddings and dropout.
Arguments
channels::Int
: Number of input and output channelstime_emb=false
: Whether to use time embeddingsemb_dim=256
: Dimension of time embeddingsdropout=0.0
: Dropout probability (0.0 means no dropout)activation=relu
: Activation function to use
Examples
bn = Bottleneck(256, time_emb=true, emb_dim=256, dropout=0.2)
h = randn(Float32, 8, 8, 256, 1)
t = randn(Float32, 256, 1)
h = bn(h, t)
Onion.CrossFrameIPA
— TypeCrossFrameIPA(dim::Int, ipa; ln = Flux.LayerNorm(dim))
Constructs a layer that takes one embedding, and two sets of frames. Runs layernorm on the embedding, and then makes a cross-attention IPA call with one embedding but two frames. Useful for self-conditioning where two sets of frames need to communicate with each other.
Onion.DecoderBlock
— TypeDecoderBlock(in_channels::Int, out_channels::Int; time_emb=false, emb_dim=256, dropout=0.0, activation=relu)
A decoder block for UNet architecture with optional time embeddings and dropout.
Arguments
in_channels::Int
: Number of input channelsout_channels::Int
: Number of output channelstime_emb=false
: Whether to use time embeddingsemb_dim=256
: Dimension of time embeddingsdropout=0.0
: Dropout probability (0.0 means no dropout)activation=relu
: Activation function to use
Examples
dec = DecoderBlock(256, 128, time_emb=true, emb_dim=256, dropout=0.1)
h = randn(Float32, 8, 8, 256, 1)
skip = randn(Float32, 16, 16, 128, 1)
t = randn(Float32, 256, 1)
h = dec(h, skip, t)
Onion.DyT
— MethodDyT(dim::Integer; init_alpha::T = 0.5f0)
Make a Dynamic Tanh (DyT) layer for normalizing the input tensor.
See Transformers without Normalization for more details.
Onion.EncoderBlock
— TypeEncoderBlock(in_channels::Int, out_channels::Int; time_emb=false, emb_dim=256, dropout=0.0, activation=relu)
An encoder block for UNet architecture with optional time embeddings and dropout.
Arguments
in_channels::Int
: Number of input channelsout_channels::Int
: Number of output channelstime_emb=false
: Whether to use time embeddingsemb_dim=256
: Dimension of time embeddingsdropout=0.0
: Dropout probability (0.0 means no dropout)activation=relu
: Activation function to use
Examples
enc = EncoderBlock(3, 64, time_emb=true, emb_dim=256, dropout=0.1)
h = randn(Float32, 32, 32, 3, 1)
t = randn(Float32, 256, 1)
skip, h = enc(h, t)
Onion.FSQ
— TypeFSQ(l, chunk_size)
Finite Scalar Quantization. l
is the number of quantization levels. For a sequence with d
channels, the codebook size would be l^d
. chunk_size
is the number of channels that get combined/separated when chunk
/unchunk
are called.
Onion.FlexibleUNet
— TypeFlexibleUNet(;
in_channels=3,
out_channels=3,
depth=3,
base_channels=64,
channel_multipliers=[1, 2, 4],
time_embedding=false,
num_classes=0,
embedding_dim=128,
time_emb_dim=256,
dropout=0.0,
dropout_depth=0,
activation=relu
)
A flexible UNet architecture with configurable depth and channel dimensions. Supports optional time and class embeddings for diffusion models and conditional generation.
Arguments
in_channels=3
: Number of input channelsout_channels=3
: Number of output channelsdepth=3
: Number of encoder/decoder blocksbase_channels=64
: Base channel dimension (multiplied at each level)channel_multipliers=[1, 2, 4]
: Multipliers for channel dimensions at each leveltime_embedding=false
: Whether to use time embeddingsnum_classes=0
: Number of class labels for conditional generationembedding_dim=128
: Dimension for class embeddingstime_emb_dim=256
: Dimension for time embeddingsdropout=0.0
: Dropout probability to apply to inner layersdropout_depth=0
: Number of layers to apply dropout to, starting from the innermost layers (0 means no dropout). Maximum value is 1+depth (bottleneck + all encoding/decoding levels)activation=relu
: Activation function to use throughout the network
Examples
# Basic model without dropout
model = FlexibleUNet(
in_channels=3,
out_channels=3,
depth=4,
base_channels=32,
channel_multipliers=[1, 2, 4, 8],
time_embedding=true
)
# Model with dropout applied to the 3 innermost layers
model = FlexibleUNet(
in_channels=3,
out_channels=3,
depth=4,
base_channels=32,
channel_multipliers=[1, 2, 4, 8],
time_embedding=true,
dropout=0.2,
dropout_depth=3
)
x = randn(Float32, 32, 32, 3, 1)
t = randn(Float32, 1)
labels = [5]
y = model(x, t, labels)
Onion.Framemover
— TypeFramemover(dim::Int; init_gain = 0.1f0)
Differentiable rigid body updates (AF2-style).
Onion.GaussianFourierProjection
— TypeGaussianFourierProjection(embed_dim::Int, scale::T=32.0f0)
Creates a Gaussian Fourier feature projection for time embeddings. Used in diffusion models.
Arguments
embed_dim::Int
: Embedding dimension. Should be even.scale::T=32.0f0
: Scaling factor for the random weights.
Onion.IPAblock
— TypeIPAblock(dim::Int, ipa; ln1 = Flux.LayerNorm(dim), ln2 = Flux.LayerNorm(dim), ff = StarGLU(dim, 3dim))
For use with Invariant Point Attention, either from InvariantPointAttention.jl or MessagePassingIPA.jl. If ipablock.ipa
is from InvariantPointAttention.jl, then call ipablock(frames, x; pair_feats = nothing, cond = nothing, mask = 0, kwargs...)
If ipablock.ipa
is from MessagePassingIPA.jl, then call ipablock(g, frames, x, pair_feats; cond = nothing)
Pass in cond
if you're using eg. AdaLN
that takes a second argument.
Onion.RMSNorm
— TypeRMSNorm(dim::Int; eps::T=1f-5)
Root Mean Square Layer Normalization. As used in Llama3.
Onion.ResidualBlock
— TypeResidualBlock(channels::Int; kernel_size=3, time_emb=false, emb_dim=256, dropout=0.0, activation=relu)
A ResNet-style residual block with optional time embeddings, dropout, and configurable activation.
Arguments
channels::Int
: Number of input and output channelskernel_size=3
: Size of convolutional kerneltime_emb=false
: Whether to use time embeddingsemb_dim=256
: Dimension of time embeddingsdropout=0.0
: Dropout probability (0.0 means no dropout)activation=relu
: Activation function to use (e.g., relu, swish, etc.)
Examples
# Basic block with dropout
rb = ResidualBlock(64, dropout=0.1)
# Block with time embeddings and custom activation
rb = ResidualBlock(64, time_emb=true, emb_dim=256, dropout=0.1, activation=swish)
# Usage
h = randn(Float32, 32, 32, 64, 1)
t = randn(Float32, 256, 1)
h = rb(h, t)
Onion.RoPE
— TypeRoPE(dim::Int, max_length; theta::T=10000f0)
Rotary Position Embeddings (as in Llama3).
dim = 64
n_heads = 8
n_kv_heads = 4
seqlen = 10
t = TransformerBlock(dim, n_heads, n_kv_heads)
h = randn(Float32, dim, seqlen, 1)
rope = RoPE(dim ÷ n_heads, 1000)
h = t(h, 1, rope[1:seqlen]) #Note the subsetting to match seqlen
Onion.StarGLU
— TypeStarGLU(dim::Int, ff_hidden_dim::Int; act=Flux.swish)
Gated Linear Unit with flexible activation function (default: swish
, making it a SwiGLU layer as used in Llama3).
l = StarGLU(6, 8)
h = randn(Float32, 6, 10, 1)
h = l(h)
Onion.TimeEmbedding
— TypeTimeEmbedding(embed_dim::Int, num_classes::Int, embedding_dim::Int)
Creates time and optional class embeddings for diffusion models.
Arguments
embed_dim::Int
: Output dimension for time embeddingsnum_classes::Int
: Number of classes for conditional generationembedding_dim::Int
: Dimension for class embeddings
Examples
time_emb = TimeEmbedding(256, 10, 128)
t = randn(Float32, 16)
labels = rand(1:10, 16)
h = time_emb(t, labels)
Onion.TransformerBlock
— TypeTransformerBlock(dim::Int, n_heads::Int, n_kv_heads::Int = n_heads, ff_hidden_dim = 4 * dim; norm_eps=1f-5, qkv_bias=false)
TransformerBlock{Attention,FeedForward,AttentionNorm,FeedForwardNorm}
Transformer block for GQAttention (as in Llama3). No KV caching (see Jjama3.jl for KV caching).
dim = 64
n_heads = 8
n_kv_heads = 4
seqlen = 10
rope = RoPE(dim ÷ n_heads, 1000)
t = TransformerBlock(dim, n_heads, n_kv_heads)
h = randn(Float32, dim, seqlen, 1)
#Use without a mask:
h = t(h, 1, rope[1:seqlen])
#Use with a causal mask:
mask = Onion.causal_mask(h)
h = t(h, 1, rope[1:seqlen], mask)
Onion.chunk
— Methodchunk(x, q::FSQ, chunk_size)
Make a long quantized sequence shorter and wider (to make it more transformer-friendly). x
may have a batch dimension. Contiguous chunks of chunk_size
are recoded as a single integer in the product space q.l^chunk_size
`.
Onion.cross_att_padding_mask
— Methodcross_att_padding_mask(padmask, other_dim; T = Float32)
Takes a sequence-level padmask
and a dimension other_dim
and returns a cross-attention mask that is length-by-other_dim-by-batch. This prevents information flow from padded key
positions to any query
positions (but ignores padding in the query
positions, because nothing should flow out of those).
Onion.glut
— Methodglut(t::AbstractArray, d::Int, pos::Int)
glut(t::Real, d::Int, pos::Int) = t
glut
adds dimensions to the middle. The resulting array will have d
dimensions. pos
is where to add the dimensions. pos=0
adds dims to the start, pos=1
after the first element, etc. If t
is scalar, it is returned unmodified (because scalars don't need to match dims to broadcast).
Typically when broadcasting x .* t
, you would call something like glut(t, ndims(x), 1)
.
Onion.reverse_tuple
— Methodreverse_tuple(t::Tuple)
Helper function that reverses the order of elements in a tuple. Use for maintaining type stability when reversing the order of skip connections.
Onion.sample_uniform_causal_chunk_mask
— Methodsample_uniform_causal_chunk_mask(x, chunk_size)
Generate a mask of all the "chunks" towards the end of the sequence, separately for each batch. The mask dims will be length-by-batch, but contiguous chunks of chunk_size
will be always be masked together.
Onion.self_att_padding_mask
— Methodself_att_padding_mask(padmask; T = Float32)
Takes a sequence-level padmask
(ie. length-by-batch, where 0 indicates a padded position) and returns a (non-causal) self-attention mask that is length-by-length-by-batch and which prevents information flow from padded positions to unpadded positions.
Onion.unchunk
— Methodunchunk(x, q::FSQ)
Take a sequence that has been chunk
ed, and expand it back to the original length. x == unchunk(chunk(x,q),q)
should be true.