API
Patterns
Einops.ArrowPattern
— TypeArrowPattern{L,R}
A pair of tuples representing the left and right sides of a rearrange
/reduce
/repeat
pattern. These tuples are stored as type parameters, such that the pattern is known at compile time.
An instance ArrowPattern{L,R}()
gets shown as L --> R
, as -->
is used for construction.
Einops.:-->
— Function-->
Create an ArrowPattern
from a left and right tuple. Non-tuple elements are automatically wrapped in a single-element tuple.
Examples
julia> pattern1 = (:a, :b, :c) --> (:c, (:b, :a)) # nested tuple
(:a, :b, :c) --> (:c, (:b, :a))
julia> typeof(pattern1)
ArrowPattern{(:a, :b, :c), (:c, (:b, :a))}
julia> pattern2 = :a --> (1, :a) # single-element autoconversion
(:a,) --> (1, :a)
julia> typeof(pattern2)
ArrowPattern{(:a,), (1, :a)}
julia> (:a, ..) --> :a # exported ellipsis notation
(:a, EllipsisNotation.Ellipsis()) --> (:a,)
Einops.@einops_str
— Macro@einops_str -> Union{ArrowPattern,Tuple}
For parity with Python implementation.
Examples
julia> einops"a 1 b c -> (c b) a"
(:a, 1, :b, :c) --> ((:c, :b), :a)
julia> einops"embed token (head batch) -> (embed head) token batch"
(:embed, :token, (:head, :batch)) --> ((:embed, :head), :token, :batch)
julia> einops"i j, j k -> i k" # for einsum
((:i, :j), (:j, :k)) --> (:i, :k)
julia> einops"a b _ d" # for parse_shape
Val{(:a, :b, -, :d)}()
julia> einops"i j * k" # for pack/unpack
Val{(:i, :j, *, :k)}()
parse_shape
Einops.parse_shape
— Functionparse_shape(x, pattern)
Capture the shape of an array in a pattern by naming dimensions using Symbol
s, and -
to ignore dimensions, and ...
to ignore any number of dimensions.
For proper type inference, the pattern needs to be passed as Val(pattern)
when an ellipsis is present. This is done automatically when using @einops_str
.
Examples
julia> parse_shape(rand(2,3,4), (:a, :b, -))
(a = 2, b = 3)
julia> parse_shape(rand(2,3), (-, -))
NamedTuple()
julia> parse_shape(rand(2,3,4,5), einops"first second third fourth")
(first = 2, second = 3, third = 4, fourth = 5)
julia> parse_shape(rand(2,3,4), Val((:a, :b, ..)))
(a = 2, b = 3)
rearrange
Einops.rearrange
— Functionrearrange(array::AbstractArray, left --> right; context...)
rearrange(arrays, left --> right; context...)
Rearrange the axes of x
according to the pattern specified by left --> right
.
Can always be expressed as a reshape
+ permutedims
+ reshape
.
Examples
julia> x = rand(2,3,5);
julia> y = rearrange(x, (:a, :b, :c) --> (:c, :b, :a));
julia> size(y)
(5, 3, 2)
julia> y == permutedims(x, (3,2,1))
true
julia> z = rearrange(x, (:a, :b, :c) --> (:a, (:c, :b)));
julia> size(z)
(2, 15)
julia> z == reshape(permutedims(x, (1,3,2)), 2,5*3)
true
reduce
Base.reduce
— Functionreduce(f::Function, x::AbstractArray, left --> right; context...)
Reduce an array over the dimensions specified by the pattern, using e.g. sum
, prod
, minimum
, maximum
, any
, all
, or Statistics.mean
.
f
must accept a dims::Tuple{Vararg{Int}}
keyword argument, allowing for reduction over specific dimensions. This should reduce the specified dimensions to singletons, but not drop them.
This method is not meant for binary reduction operations like +
, *
, min
, max
, |
, &
, etc., as would be expected from Base.reduce
. Also note that Python's min
and max
equivalents are available in Julia as minimum
and maximum
respectively.
Examples
julia> x = randn(64, 32, 35);
julia> y = reduce(sum, x, (:c, :b, :t) --> (:c, :b));
julia> size(y)
(64, 32)
julia> y == dropdims(sum(x, dims=3), dims=3)
true
julia> using Statistics: mean
julia> z = reduce(mean, x, (:c, :b, (:t5, :t)) --> ((:t5, :c), :b), t5=5);
julia> size(z)
(320, 32)
julia> z == reshape(permutedims(dropdims(mean(reshape(x, 64,32,5,7), dims=4), dims=4), (3,1,2)), 320,32)
true
repeat
Base.repeat
— Functionrepeat(x::AbstractArray, left --> right; context...)
Repeat elements of x
along specified axes.
Examples
julia> x = rand(2,3);
julia> y = repeat(x, (:a, :b) --> (:a, :b, 1, :r), r=2);
julia> size(y)
(2, 3, 1, 2)
julia> y == reshape(repeat(x, 1,1,2), 2,3,1,2)
true
julia> z = repeat(x, (:a, :b) --> (:a, (:b, :r)), r=2);
julia> size(z)
(2, 6)
julia> z == reshape(repeat(x, 1,1,2), 2,6)
true
einsum
Einops.einsum
— Functioneinsum(arrays..., (left --> right))
Compute the einsum operation specified by the pattern.
Examples
julia> x, y = rand(2,3), rand(3,4);
julia> einsum(x, y, ((:i, :j), (:j, :k)) --> (:i, :k)) == x * y
true
pack
and unpack
Einops.pack
— Functionpack(unpacked_arrays, pattern)
Pack a vector of arrays into a single array according to the pattern.
Examples
julia> inputs = [rand(2,3,5), rand(2,3,7,5), rand(2,3,7,9,5)];
julia> packed_array, packed_shapes = pack(inputs, (:i, :j, *, :k));
julia> size(packed_array)
(2, 3, 71, 5)
julia> packed_shapes
3-element Vector{NTuple{N, Int64} where N}:
()
(7,)
(7, 9)
Einops.unpack
— Functionunpack(packed_array, packed_shapes, pattern)
Unpack a single array into a vector of arrays according to the pattern.
Examples
julia> inputs = [rand(2,3,5), rand(2,3,7,5), rand(2,3,7,9,5)];
julia> inputs == unpack(pack(inputs, (:i, :j, *, :k))..., (:i, :j, *, :k))
true
julia> packed_array = rand(2,3,16);
julia> packed_shapes = [(), (7,), (4, 2)];
julia> unpack(packed_array, packed_shapes, (:i, :j, *)) .|> size
3-element Vector{Tuple{Int64, Int64, Vararg{Int64}}}:
(2, 3)
(2, 3, 7)
(2, 3, 4, 2)