Flowfusion
Documentation for Flowfusion.
Flowfusion.Guide
Flowfusion.MaskedState
Flowfusion.apply_tangent_coordinates
Flowfusion.batch
Flowfusion.cmask!
Flowfusion.dense
Flowfusion.endslices
Flowfusion.gen
Flowfusion.mask
Flowfusion.onehot
Flowfusion.tangent_guide
Flowfusion.unhot
Flowfusion.unmask
Flowfusion.unwrap
Flowfusion.Guide
— TypeGuide(H::AbstractArray)
Wrapping a model prediction in Guide instructs the solver that the prediction points to X1 from the current state, instead of being a prediction of X1 itself. Used for ManifoldStates where the prediction is a tangent
Flowfusion.MaskedState
— TypeMaskedState(S::State, cmask, lmask)
Wraps a State
with a conditioning mask (cmask
) and a loss mask (lmask
).
Conditioning mask behavior:
The typical use is that it makes sense, during training, to construct the conditioning mask on the training observation, X1
. During inference, the conditioning mask (and conditioned-upon state) has to be present on
X1`. This dictates the behavior of the masking:
- When
bridge()
is called, the mask, and the state wherecmask=1
, are inherited fromX1
. - When
gen()
is called, the state and mask will be propogated fromX0
through all of theXt
s.
Loss mask behavior:
- Where
lmask=0
, that observation (where the shape/size of the observation is determined by the difference in dimensions between the mask and the state) is not included in the loss.
Flowfusion.apply_tangent_coordinates
— Methodapply_tangent_coordinates(Xt::ManifoldState, ξ; retraction_method=default_retraction_method(Xt.M))
returns X̂₁
where each point is the result of retracting Xt
by the corresponding tangent coordinate vector ξ
.
Flowfusion.batch
— Methodbatch(Xs::Vector{T}; dims_from_end = 1)
Doesn't handle padding. Add option to pad if batching along dims that don't have the same length.
Flowfusion.cmask!
— Methodcmask!(Xt_state, X1_state, cmask)
cmask!(Xt, X1)
Applies, in place, a conditioning mask, where only elements (or slices) of Xt
where cmask
is 1 are noised. When cmask
is 0, the elements are forced to be equal to X1
.
Flowfusion.dense
— Methoddense(X::DiscreteState; T = Float32)
Converts X
to an appropriate dense representation. If X
is a DiscreteState
, then X
is converted to a CategoricalLikelihood
with default eltype Float32. If X
is a "onehot" CategoricalLikelihood then X
is converted to a fully dense one.
Flowfusion.endslices
— Methodendslices(a,m)
Returns a view of a
where slices specified by m
are selected. m
can be multidimensional, but the dimensions of m must match the last dimensions of a
. For example, if m
is a boolean array, then size(a)[ndims(a)-ndims(m):end] == size(m)
.
Flowfusion.gen
— Methodgen(P, X0, model, steps; tracker=Returns(nothing), midpoint = false)
Constructs a sequence of (stochastic) bridges between X0
and the predicted X̂₁
under the process P
. P
, X0
, can also be tuples where the Nth element of P
will be used for the Nth elements of X0
and model
. model is a function that takes t
(scalar) and Xₜ
(optionally a tuple) and returns hat
(a UState
, a flat tensor with the right shape, or a tuple of either if you're combining processes). If X0
is a MaskedState
, then anything in X̂₁
will be conditioned on X0
where the conditioning mask X0.cmask
is 1.
Flowfusion.mask
— Methodmask(X, Y)
If Y
is a MaskedState
, mask(X, Y)
returns a MaskedState
with the content of X
where elements of Y.cmask
are 1, and Y
where Y.cmask
is 0. cmask
and lmask
are inherited from Y
. If Y
is not a MaskedState
, mask(X, Y)
returns X
.
Flowfusion.onehot
— Methodonehot(X)
Rerturns a state where X.state
is a onehot array.
Flowfusion.tangent_guide
— Methodtangent_guide(Xt::ManifoldState, X1::ManifoldState)
Computes the coordinate vector (in the default basis) pointing from Xt
to X1
.
Flowfusion.unhot
— Methodunhot(X)
Returns a state where X.state
is not onehot.
Flowfusion.unmask
— Methodunmask(X)
unmask(X) = X
, unless X is a MaskedState
, in which case X.S
is returned.
Flowfusion.unwrap
— Methodunwrap(X)
Returns the underlying state or dist of X
(X.state
if X
is a State
, X.dist
if X
is a StateLikelihood
, and X.S.state
if X
is a MaskedState
, etc). Unlike tensor(X)
this does not flatten the state.