LogitSamplers

Documentation for LogitSamplers.

LogitSamplers.GumbelNoiseType
GumbelNoise(; scaling=true, rng=Random.default_rng())

A logit transform that adds Gumbel noise to the logits.

Taking the argmax of the result is equivalent to sampling from the logit distribution.

source
LogitSamplers.Min_pType
Min_p(pbase)

A logit transform that masks out logits below pbase times the maximum probability.

See: https://arxiv.org/pdf/2407.01082

source
LogitSamplers.Top_nσType
Top_nσ(n)

A logit transform that masks out logits below n standard deviations of the maximum logit.

Top-nσ is temperature-invariant, i.e. the candidate set does not change with temperature.

See: https://arxiv.org/pdf/2411.07641

source
LogitSamplers.Top_pkType
Top_pk(p, k)

A logit transform that masks out logits outside the top p cumulative probability or top k logits.

source
LogitSamplers.logitsampleMethod
logitsample([rng], logits; dims=:)

Sample indices from a logit distribution using the Gumbel argmax trick.

See also logitsample_categorical.

Examples

julia> logitsample([-Inf, -10.0])
2

julia> logits = [-Inf -10
                   30  10];

julia> logitsample(logits)
CartesianIndex(2, 1)

julia> logitsample(logits, dims=1)
1×2 Matrix{CartesianIndex{2}}:
 CartesianIndex(2, 1)  CartesianIndex(2, 2)

julia> logitsample(logits, dims=2)
2×1 Matrix{CartesianIndex{2}}:
 CartesianIndex(1, 2)
 CartesianIndex(2, 1)
source
LogitSamplers.logitsample_categoricalMethod
logitsample_categorical([rng], logits; dims::Int=1)

Sample indices from a logit distribution using the Gumbel argmax trick, and return the corresponding indices over the specified dimension.

See also logitsample.

Examples

julia> logitsample_categorical([-Inf, -10.0])
1-element Vector{Int64}:
 2

julia> logits = [-Inf -10
                   30  10];

julia> logitsample_categorical(logits) # dims=1 by default
1×2 Matrix{Int64}:
 2  2

julia> logitsample_categorical(logits, dims=1)
1×2 Matrix{Int64}:
 2  2

julia> logitsample_categorical(logits, dims=2)
2×1 Matrix{Int64}:
 2
 1
source