I can use tf.matmul(A, B) to do batch matrix multiplication when:
A.shape == (..., a, b)andB.shape == (..., b, c),
where the ... are the same.
But I want an additional broadcasting:
A.shape == (a, b, 2, d)andB.shape == (a, 1, d, c)result.shape == (a, b, 2, c)
I expect the result to be a x b batches of matrix multiplication between (2, d) and (d, c).
How to do this?
Test code:
import tensorflow as tf
import numpy as np
a = 3
b = 4
c = 5
d = 6
x_shape = (a, b, 2, d)
y_shape = (a, d, c)
z_shape = (a, b, 2, c)
x = np.random.uniform(0, 1, x_shape)
y = np.random.uniform(0, 1, y_shape)
z = np.empty(z_shape)
with tf.Session() as sess:
for i in range(b):
x_now = x[:, i, :, :]
z[:, i, :, :] = sess.run(
tf.matmul(x_now, y)
)
print(z)