API
Patterns
Einops.ArrowPattern
— TypeArrowPattern{L,R}
A pair of tuples representing the left and right sides of a rearrange
/reduce
/repeat
pattern. These tuples are stored as type parameters, such that the pattern is known at compile time.
An instance ArrowPattern{L,R}()
gets shown as L --> R
, as -->
is used for construction.
Einops.:-->
— Function-->
Create an ArrowPattern
from a left and right tuple. Non-tuple elements are automatically wrapped in a single-element tuple.
Examples
julia> pattern1 = (:a, :b, :c) --> (:c, (:b, :a)) # nested tuple
(:a, :b, :c) --> (:c, (:b, :a))
julia> typeof(pattern1)
ArrowPattern{(:a, :b, :c), (:c, (:b, :a))}
julia> pattern2 = :a --> (1, :a) # single-element autoconversion
(:a,) --> (1, :a)
julia> typeof(pattern2)
ArrowPattern{(:a,), (1, :a)}
julia> (:a, ..) --> :a # exported ellipsis notation
(:a, EllipsisNotation.Ellipsis()) --> (:a,)
Einops.@einops_str
— Macro@einops_str -> Union{ArrowPattern,Tuple}
For parity with Python implementation.
Examples
julia> einops"a 1 b c -> (c b) a"
(:a, 1, :b, :c) --> ((:c, :b), :a)
julia> einops"embed token (head batch) -> (embed head) token batch"
(:embed, :token, (:head, :batch)) --> ((:embed, :head), :token, :batch)
julia> einops"i j, j k -> i k" # for einsum
((:i, :j), (:j, :k)) --> (:i, :k)
julia> einops"a b _ d" # for parse_shape
(:a, :b, -, :d)
julia> einops"i j * k" # for pack/unpack
(:i, :j, *, :k)
parse_shape
Einops.parse_shape
— Functionparse_shape(x, pattern)
Capture the shape of an array in a pattern by naming dimensions using Symbol
s, and -
to ignore dimensions, and ...
to ignore any number of dimensions.
For proper type inference, the pattern needs to be passed as Val(pattern)
when an ellipsis is present.
Examples
julia> parse_shape(rand(2,3,4), (:a, :b, -))
(a = 2, b = 3)
julia> parse_shape(rand(2,3), (-, -))
NamedTuple()
julia> parse_shape(rand(2,3,4,5), (:first, :second, :third, :fourth))
(first = 2, second = 3, third = 4, fourth = 5)
julia> parse_shape(rand(2,3,4), Val((:a, :b, ..)))
(a = 2, b = 3)
rearrange
Einops.rearrange
— Functionrearrange(array::AbstractArray, left --> right; context...)
rearrange(arrays, left --> right; context...)
Rearrange the axes of x
according to the pattern specified by left --> right
.
Can always be expressed as a reshape
+ permutedims
+ reshape
.
Examples
julia> x = rand(2,3,5);
julia> y = rearrange(x, (:a, :b, :c) --> (:c, :b, :a));
julia> size(y)
(5, 3, 2)
julia> y == permutedims(x, (3,2,1))
true
julia> z = rearrange(x, (:a, :b, :c) --> (:a, (:c, :b)));
julia> size(z)
(2, 15)
julia> z == reshape(permutedims(x, (1,3,2)), 2,5*3)
true
reduce
Base.reduce
— Functionreduce(f::Function, x::AbstractArray, left --> right; context...)
Reduce an array over the dimensions specified by the pattern, using e.g. sum
, prod
, minimum
, maximum
, any
, all
, or Statistics.mean
.
f
must accept a dims::Tuple{Vararg{Int}}
keyword argument, allowing for reduction over specific dimensions. This should reduce the specified dimensions to singletons, but not drop them.
This method is not meant for binary reduction operations like +
, *
, min
, max
, |
, &
, etc., as would be expected from Base.reduce
. Also note that Python's min
and max
are available in Julia as minimum
and maximum
respectively.
Examples
julia> x = randn(64, 32, 35);
julia> y = reduce(sum, x, (:c, :b, :t) --> (:c, :b));
julia> size(y)
(64, 32)
julia> y == dropdims(sum(x, dims=3), dims=3)
true
julia> using Statistics: mean
julia> z = reduce(mean, x, (:c, :b, (:t5, :t)) --> ((:t5, :c), :b), t5=5);
julia> size(z)
(320, 32)
julia> z == reshape(permutedims(dropdims(mean(reshape(x, 64,32,5,7), dims=4), dims=4), (3,1,2)), 320,32)
true
repeat
Base.repeat
— Functionrepeat(x::AbstractArray, left --> right; context...)
Repeat elements of x
along specified axes.
Examples
julia> x = rand(2,3);
julia> y = repeat(x, (:a, :b) --> (:a, :b, 1, :r), r=2);
julia> size(y)
(2, 3, 1, 2)
julia> y == reshape(repeat(x, 1,1,2), 2,3,1,2)
true
julia> z = repeat(x, (:a, :b) --> (:a, (:b, :r)), r=2);
julia> size(z)
(2, 6)
julia> z == reshape(repeat(x, 1,1,2), 2,6)
true
einsum
Einops.einsum
— Functioneinsum(arrays..., (left --> right))
Compute the einsum operation specified by the pattern.
Examples
julia> x, y = rand(2,3), rand(3,4);
julia> einsum(x, y, ((:i, :j), (:j, :k)) --> (:i, :k)) == x * y
true
pack
and unpack
Einops.pack
— Functionpack(unpacked_arrays, pattern)
Pack a vector of arrays into a single array according to the pattern.
Examples
julia> inputs = [rand(2,3,5), rand(2,3,7,5), rand(2,3,7,9,5)];
julia> packed_array, packed_shapes = pack(inputs, (:i, :j, *, :k));
julia> size(packed_array)
(2, 3, 71, 5)
julia> packed_shapes
3-element Vector{NTuple{N, Int64} where N}:
()
(7,)
(7, 9)
Einops.unpack
— Functionunpack(packed_array, packed_shapes, pattern)
Unpack a single array into a vector of arrays according to the pattern.
Examples
julia> inputs = [rand(2,3,5), rand(2,3,7,5), rand(2,3,7,9,5)];
julia> inputs == unpack(pack(inputs, (:i, :j, *, :k))..., (:i, :j, *, :k))
true
julia> packed_array = rand(2,3,16);
julia> packed_shapes = [(), (7,), (4, 2)];
julia> unpack(packed_array, packed_shapes, (:i, :j, *)) .|> size
3-element Vector{Tuple{Int64, Int64, Vararg{Int64}}}:
(2, 3)
(2, 3, 7)
(2, 3, 4, 2)