LogitSamplers
Documentation for LogitSamplers.
LogitSamplers.GumbelNoiseLogitSamplers.Min_pLogitSamplers.TemperatureLogitSamplers.Top_nσLogitSamplers.Top_pkLogitSamplers.logitsampleLogitSamplers.logitsample_categorical
LogitSamplers.GumbelNoise — Type
GumbelNoise(; scaling=true, rng=Random.default_rng())A logit transform that adds Gumbel noise to the logits.
Taking the argmax of the result is equivalent to sampling from the logit distribution.
LogitSamplers.Min_p — Type
Min_p(pbase)A logit transform that masks out logits below pbase times the maximum probability.
See: https://arxiv.org/pdf/2407.01082
LogitSamplers.Temperature — Type
Temperature(T)A logit transform that scales (divides) the logits by a temperature parameter.
LogitSamplers.Top_nσ — Type
Top_nσ(n)A logit transform that masks out logits below n standard deviations of the maximum logit.
Top-nσ is temperature-invariant, i.e. the candidate set does not change with temperature.
See: https://arxiv.org/pdf/2411.07641
LogitSamplers.Top_pk — Type
Top_pk(p, k)A logit transform that masks out logits outside the top p cumulative probability or top k logits.
LogitSamplers.logitsample — Method
logitsample([rng], logits; dims=:)Sample indices from a logit distribution using the Gumbel argmax trick.
See also logitsample_categorical.
Examples
julia> logitsample([-Inf, -10.0])
2
julia> logits = [-Inf -10
30 10];
julia> logitsample(logits)
CartesianIndex(2, 1)
julia> logitsample(logits, dims=1)
1×2 Matrix{CartesianIndex{2}}:
CartesianIndex(2, 1) CartesianIndex(2, 2)
julia> logitsample(logits, dims=2)
2×1 Matrix{CartesianIndex{2}}:
CartesianIndex(1, 2)
CartesianIndex(2, 1)LogitSamplers.logitsample_categorical — Method
logitsample_categorical([rng], logits; dims::Int=1)Sample indices from a logit distribution using the Gumbel argmax trick, and return the corresponding indices over the specified dimension.
See also logitsample.
Examples
julia> logitsample_categorical([-Inf, -10.0])
1-element Vector{Int64}:
2
julia> logits = [-Inf -10
30 10];
julia> logitsample_categorical(logits) # dims=1 by default
1×2 Matrix{Int64}:
2 2
julia> logitsample_categorical(logits, dims=1)
1×2 Matrix{Int64}:
2 2
julia> logitsample_categorical(logits, dims=2)
2×1 Matrix{Int64}:
2
1