LogitSamplers

Documentation for LogitSamplers.

LogitSamplers.Min_pType
Min_p(pbase)

A logit transform that masks out logits below pbase times the maximum probability.

See: https://arxiv.org/pdf/2407.01082

source
LogitSamplers.Top_nσType
Top_nσ(n)

A logit transform that masks out logits below n standard deviations of the maximum logit.

Top-nσ is temperature-invariant, i.e. the candidate set does not change with temperature.

See: https://arxiv.org/pdf/2411.07641

source
LogitSamplers.Top_pkType
Top_pk(p, k)

A logit transform that masks out logits outside the top p cumulative probability or top k logits.

source
LogitSamplers.logitsampleFunction
logitsample([rng], logits; dims=:)

Sample an index from a logit distribution using the Gumbel argmax trick.

Examples

julia> logitsample([-Inf, -10.0])
2

julia> logitsample([-Inf -10.0; 20 -Inf])
CartesianIndex(2, 1)

julia> logitsample([-Inf -10.0; 20 -Inf], dims=1)
1×2 Matrix{Int64}:
 2  1
source