I need to perform matrix multiplication on two 4D arrays (m & n) with dimensions of 2x2x2x2 and 2x3x2x2 for m & n respectively, which should result in a 2x3x2x2 array. After a lot of research (mostly on this site) it seems this can be done efficiently with either np.einsum or np.tensordot, but I am unable to replicate the answer I am getting from Matlab (verified by hand). I understand how these methods (einsum and tensordot) work when performing matrix multiplication on 2D arrays (clearly explained here), but I cannot get the axes indexes correct for the 4D arrays. Clearly I’m missing something! My actual problem deals with two 23x23x3x3 arrays of complex numbers but my test arrays are:
a = np.array([[1, 7], [4, 3]]) 
b = np.array([[2, 9], [4, 5]]) 
c = np.array([[3, 6], [1, 0]]) 
d = np.array([[2, 8], [1, 2]]) 
e = np.array([[0, 0], [1, 2]])
f = np.array([[2, 8], [1, 0]])
m = np.array([[a, b], [c, d]])              # (2,2,2,2)
n = np.array([[e, f, a], [b, d, c]])        # (2,3,2,2)
I realise the complex numbers may present further issues, but for now, I am just trying to understand how the indexxing works with einsum & tensordot. The answer I’m chasing is this 2x3x2x2 array:
+----+-----------+-----------+-----------+
|    | 0         | 1         | 2         |
+====+===========+===========+===========+
|  0 | [[47 77]  | [[22 42]  | [[44 40]  |
|    |  [31 67]] |  [27 74]] |  [33 61]] |
+----+-----------+-----------+-----------+
|  1 | [[42 70]  | [[24 56]  | [[41 51]  |
|    |  [10 19]] |  [ 6 20]] |  [ 6 13]] |
+----+-----------+-----------+-----------+
and my closest attempt is by using np.tensordot:
mn = np.tensordot(m,n, axes=([1,3],[0,2]))
which gives me a 2x2x3x2 array with correct numbers but not in the right order:
+----+-----------+-----------+
|    | 0         | 1         |
+====+===========+===========+
|  0 | [[47 77]  | [[31 67]  |
|    |  [22 42]  |  [24 74]  |
|    |  [44 40]] |  [33 61]] |
+----+-----------+-----------+
|  1 | [[42 70]  | [[10 19]  |
|    |  [24 56]  |  [ 6 20]  |
|    |  [41 51]] |  [ 6 13]] |
+----+-----------+-----------+
I’ve also tried to implement some of the solutions from here but have not had any luck.
Any ideas on how I might improve this would be greatly appreciated, thanks