AlphaTensor — Brief Overview

Some Thoughts and Info on DeepMind’s AlphaTensor

Hassaan Naeem
7 min readOct 19, 2022
Photo by Compare Fibre on Unsplash

Having just read the AlphaTensor paper, and the blogpost on it, I found it quite interesting, and enjoyed reading both.

Unsure as to why some people are giving it flak for not being as groundbreaking as made out to be, the authors clearly state that their focus was on the algorithmic level of abstraction.

This post will serve as a breakdown (for myself, and anyone interested), mostly of the tensor decomposition problem. I would suggest reading the paper, specifically for the more engineering focused aspects of it.

Matrix Multiplication

Anyone reading this knows (or should know) that standard 2⨯2 matrix multiplication is as follows:

2x2 Matrix Multiplication (from: CueMath)

There are 8 multiplications and 4 additions.

A Note of Tensor Decomposition

The AlphaTensor approach is to make a game (TensorGame) out of tensor decomposition. This is a cool lecture by Prof. Oeding via the Simons Institute explaining tensor decomposition, for anyone who cares.

Tensor (rank) decomposition is in and of itself an NP-hard problem. But the opposite, namely, using the rank-one factors of a tensor to recreate it, is fairly straightforward. It is just that minimizing the number of rank-one factors is the tricky bit. If one can reduce the number of these rank-one factors, then one can reduce the algorithmic complexity required to multiply two matrices. This is what AlphaTensor tackles. In doing so, AlphaTensor finds a fairly large set of algorithms to conduct this decomposition, and beats out the Strassen algorithm in terms of practical (see Other section) complexity on given matrices.

In the paper, the authors use a matrix multiplication tensor (MMT), 𝒯ₙ , denoting the multiplication of two 𝑛 ⨯ 𝑛 matrices. It is defined as follows:

Decomposition of MMT (Fawzi et. al., 2022)

Here each of the 𝐮⁽ⁿ⁾ (n is just replacing r due to Medium’s lack of support for math), represents a rank one tensor (aka a vector). The ⊗ is a tensor product, which with vectors would be the outer product, and would be carried out as shown below:

Outer Product of Vectors (from: Wikipedia)

Now it fairly straight forward to imagine breaking up a rank-three tensor into a sum of the outer product of the rank-one factors as follows:

Kolda et al., 2008

As an example in PyTorch, see the code below. It is just a reproduction of Figure 1 in the paper, which is the Strassen algorithm in rank-one tensor form. The entries of the tensor are just an element of {0, 1}. If we have a 𝒯ₙ = [[[1., 0., 0., 0.], [0., 1., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [1., 0., 0., 0.], [0., 1., 0., 0.]], [[0., 0., 1., 0.], [0., 0., 0., 1.] [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]]], we can take the rank-one tensors (vectors) which are stored below in U, V, and W. If we then take the outer product of U, V, and W, and sum up across all the indices, we retrieve the original 𝒯ₙ.

MMT from rank-one factors

Multiplying Matrices Via the MMT

With that out of the way, we can move on to matrix multiplication. This is fairly straightforward. Given two matrices A and B, their product C = AB can be computed via the 𝒯ₙ as follows:

Calculating C = AB (Fawzi et. al., 2022)

As one can see, the number of multiplications is now strictly equal to R. All the combinations in between can just be thought of as a map, which give a coefficient (one of {-1, 0, 1}) to the additive terms.

Below is some PyTorch code showing it in action, again with the same Strassen factors. The resulting matrix C is what it should be.

Calculating C = AB in PyTorch

Mathematical Aside (Can Skip)

The reason why this works in the first place, is that firstly, a Tensor is a multilinear map, meaning it is linear in each of its slots. Secondly, a matrix is a linear map, and hence multiplying two matrices is also a multilinear map or bilinear.

If one has a vector space V and its dual V* we can then have a co-vector 𝜔 ∈ V* and a vector v ∈ V, and after having chosen a basis can write them as 𝜔 = 𝜔ᵢεⁱ and v = vⁱeᵢ (Einstein summation in use). We can then have that 𝜔 ≘ (𝜔₁, … , 𝜔ₘ), and v ≘ (v¹, … , v)ᵀ (correspond to a row and column vectors, respectively).

We can then take ϕ ∈ End(V) which as a vector space is isomorphic to T¹₁V = V* ⨯ V → K , then ϕ = ϕˢₜ eₛ ⊗ εᵗ (End is an endomorphism of V, K is the field over which V is taken, and T¹₁V is the (1,1)-tensor over V). We can then have that ϕ ≘ ((ϕ¹₁, …, ϕ¹ₘ), (ϕ²₁, …, ϕ²ₘ), … , (ϕᵐ₁, …, ϕᵐₘ)), voila a matrix!

If we have two matrices ϕ, ψ ∈ End(V), then we know that (ϕ ○ ψ) ∈ End(V). We can have that (ϕ ○ ψ)ˢₜ = … = ψᵐₜ · ϕˢₘ = ϕˢₘ · ψᵐₜ = a matrix (where · is the scalar multiplication in K). Hence, matrix multiplication is just the notational matrix rule for finding the representing matrix for the composition of two linear maps.

Construction of the tensor is basis free. As the saying goes, a gentleman only chooses a basis if he must. AlphaTensor makes use of this, and the impact on performance can be seen in the Ablations section in the Supplementary Info.

Reinforcement Learning

I won’t go into too much detail for the rest, as it is standard RL with network and model tweaks for the task at hand. Just take a peek at Figure 2 in the paper.

Task

This then becomes a standard RL problem, where the environment is a single agent game. Game state Sₜ at time t is just a tensor which is initially set to 𝒯ₙ and subsequently updated by subtracting from it 𝐮⁽ⁿ⁾ ⊗ v⁽ⁿ⁾ ⊗ w⁽ⁿ⁾ which the agent has chosen. The agents goal is to reach the zero tensor in minimal steps.

Agent And Net

The authors use AlphaTensor, a take on AlphaZero. The neural net takes in the input state Sₜ and outputs a policy and value. The policy gives a distribution over possible actions and the value gives an estimated cumulative reward over that distribution, and sampling (Monte Carlo MCTS) is used to choose the action to take, note that the action space is gigantic.

The net is transformer based. They use a mixed training strategy, by using both previously played games, and what the authors call synthetic demonstrations.

The 𝒯ₙ discussed earlier is matrix multiplication in the standard basis. As briefly discussed in the Aside section, a tensor is an abstract object, and is need of no basis, the multilinear map can be calculated in whatever basis chosen to do so. The authors make use of this, by sampling a random change of basis before playing each game, spicing things up a bit.

Other

The Strassen algorithm uses only 7 multiplications and 18 additions for a 2⨯2 matrix multiplication, which is the same as AlphaTensor. This is an improvement if multiplication is costlier than addition. I would also presume that precision in computer related arithmetic on floats (or rather non-integers) would become a big issue in practice with all the additional additions. Nonetheless, in the related works section of the paper the authors stress that the works focus is on the algorithmic level of abstraction.

Asymptotically, there is also an O(n²·³⁷) algo, but it is practically of no use due to behemoth size constants within the above complexity.

AlphaTensor made improvements over certain n⨯n or n⨯m matrix multiplication, with the headliner being 4⨯4 matrix multiplication in ℤ ∕ 2ℤ (the ring of integers modulo 2 i.e {-1, 0, 1}, or characteristic 2 finite field). This speed up which they mention is 10–20% faster than regular matrix multiplication, not the nested Strassen. The AlphaTensor algos show great improvements on the TPU as shown in Figure 5 of the paper.

The discovered algos, like the Strassen before can be used recursively to multiply matrices of arbitrary size, and here again the number of rank-one factors dictate complexity.

That’s all, if you spot an errors, point them out.

--

--