Flowfusion
Documentation for Flowfusion.
Flowfusion.DistInterpolatingDiscreteFlowFlowfusion.DistNoisyInterpolatingDiscreteFlowFlowfusion.GuideFlowfusion.MaskedStateFlowfusion.NoisyInterpolatingDiscreteFlowFlowfusion.apply_tangent_coordinatesFlowfusion.batchFlowfusion.cmask!Flowfusion.denseFlowfusion.endslicesFlowfusion.genFlowfusion.maskFlowfusion.onehotFlowfusion.tangent_guideFlowfusion.unhotFlowfusion.unmaskFlowfusion.unwrap
Flowfusion.DistInterpolatingDiscreteFlow — TypeDistInterpolatingDiscreteFlow(D::UnivariateDistribution)D controls the schedule. Note: both training and inference expect the model to output logits, unlike the other InterpolatingDiscreteFlow (where the user needs a manual softmax for inference).
Flowfusion.DistNoisyInterpolatingDiscreteFlow — TypeDistNoisyInterpolatingDiscreteFlow(; D1=Beta(2,2), D2=Beta(2,2), ωu=0.2, dummy_token=nothing)Convex 3-way path over {X0, Uniform, X1} with distribution-backed schedules: κ₁(t) = cdf(D1,t) κ̃₂(t) = cdf(D2,t) κ₂(t) = ωu * (1 - κ₁(t)) * κ̃₂(t) # uniform amplitude scaled by ωu ∈ [0,1) κ₃(t) = 1 - κ₁(t) - κ₂(t)
Derivatives: dκ₂(t) = ωu * ( -(dκ₁) * κ̃₂ + (1 - κ₁) * dκ̃₂ ), dκ₃ = -(dκ₁ + dκ₂)
ωu directly controls the uniform noise amount; set ωu=0 for no-uniform, ωu→1 for max-gated uniform. Note: both training and inference expect the model to output logits, unlike the other NoisyInterpolatingDiscreteFlow (where the user needs a manual softmax for inference).
Flowfusion.Guide — TypeGuide(H::AbstractArray)Wrapping a model prediction in Guide instructs the solver that the prediction points to X1 from the current state, instead of being a prediction of X1 itself. Used for ManifoldStates where the prediction is a tangent
Flowfusion.MaskedState — TypeMaskedState(S::State, cmask, lmask)Wraps a State with a conditioning mask (cmask) and a loss mask (lmask).
Conditioning mask behavior:
The typical use is that it makes sense, during training, to construct the conditioning mask on the training observation, X1. During inference, the conditioning mask (and conditioned-upon state) has to be present onX1`. This dictates the behavior of the masking:
- When
bridge()is called, the mask, and the state wherecmask=1, are inherited fromX1. - When
gen()is called, the state and mask will be propogated fromX0through all of theXts.
Loss mask behavior:
- Where
lmask=0, that observation (where the shape/size of the observation is determined by the difference in dimensions between the mask and the state) is not included in the loss.
Flowfusion.NoisyInterpolatingDiscreteFlow — MethodNoisyInterpolatingDiscreteFlow(κ₁, κ₂, dκ₁, dκ₂, dummy_token)
NoisyInterpolatingDiscreteFlow(noise; K = 1, dummy_token = nothing) - Uses default cosine schedule, where `noise` is the maximum amplitude of the uniform noise component.
NoisyInterpolatingDiscreteFlow() - Uses default cosine schedule and noise = 0.2.A convex mixture of X0, uniform noise, and X1. Equation 10 in https://arxiv.org/pdf/2407.15595 Compared to InterpolatingDiscreteFlow, it encourages the model to make multiple switches during inference. κ₁, κ₂ are the schedules for target token interpolation and uniform noise probability. dκ₁, dκ₂ are the derivatives of κ₁, κ₂. Defaults to using a cosine schedule. K=2 will resolve the discrete states later than K=1. If K>1 things might break if your X0 is not the dummy_token (also called the masked token) which should be passed to NoisyInterpolatingDiscreteFlow.
Flowfusion.apply_tangent_coordinates — Methodapply_tangent_coordinates(Xt::ManifoldState, ξ; retraction_method=default_retraction_method(Xt.M))returns X̂₁ where each point is the result of retracting Xt by the corresponding tangent coordinate vector ξ.
Flowfusion.batch — Methodbatch(Xs::Vector{T}; dims_from_end = 1)Doesn't handle padding. Add option to pad if batching along dims that don't have the same length.
Flowfusion.cmask! — Methodcmask!(Xt_state, X1_state, cmask)
cmask!(Xt, X1)Applies, in place, a conditioning mask, where only elements (or slices) of Xt where cmask is 1 are noised. When cmask is 0, the elements are forced to be equal to X1.
Flowfusion.dense — Methoddense(X::DiscreteState; T = Float32)Converts X to an appropriate dense representation. If X is a DiscreteState, then X is converted to a CategoricalLikelihood with default eltype Float32. If X is a "onehot" CategoricalLikelihood then X is converted to a fully dense one.
Flowfusion.endslices — Methodendslices(a,m)Returns a view of a where slices specified by m are selected. m can be multidimensional, but the dimensions of m must match the last dimensions of a. For example, if m is a boolean array, then size(a)[ndims(a)-ndims(m):end] == size(m).
Flowfusion.gen — Methodgen(P, X0, model, steps; tracker=Returns(nothing), midpoint = false)Constructs a sequence of (stochastic) bridges between X0 and the predicted X̂₁ under the process P. P, X0, can also be tuples where the Nth element of P will be used for the Nth elements of X0 and model. model is a function that takes t (scalar) and Xₜ (optionally a tuple) and returns hat (a UState, a flat tensor with the right shape, or a tuple of either if you're combining processes). If X0 is a MaskedState, then anything in X̂₁ will be conditioned on X0 where the conditioning mask X0.cmask is 1.
Flowfusion.mask — Methodmask(X, Y)If Y is a MaskedState, mask(X, Y) returns a MaskedState with the content of X where elements of Y.cmask are 1, and Y where Y.cmask is 0. cmask and lmask are inherited from Y. If Y is not a MaskedState, mask(X, Y) returns X.
Flowfusion.onehot — Methodonehot(X)Rerturns a state where X.state is a onehot array.
Flowfusion.tangent_guide — Methodtangent_guide(Xt::ManifoldState, X1::ManifoldState)Computes the coordinate vector (in the default basis) pointing from Xt to X1.
Flowfusion.unhot — Methodunhot(X)Returns a state where X.state is not onehot.
Flowfusion.unmask — Methodunmask(X)
unmask(X) = X, unless X is a MaskedState, in which case X.S is returned.
Flowfusion.unwrap — Methodunwrap(X)Returns the underlying state or dist of X (X.state if X is a State, X.dist if X is a StateLikelihood, and X.S.state if X is a MaskedState, etc). Unlike tensor(X) this does not flatten the state.