My problem is this: during pre-processing I want to apply a function randomly selected from a set of functions to dataset examples using the tf.data.Dataset and tf.function API.
Specifically, my data are 3D volumes and I wish to apply a rotation from a set of 24 predefined rotation functions. I would like to write this code within a tf.function so this limits the use of packages like numpy and list indexing.
For example, I would like to do something like this:
import tensorflow as tf
@tf.function
def func1(tensor):
# Apply some rotation here
...
@tf.function
def func2(tensor):
...
...
@tf.function
def func24(tensor):
...
@tf.function
def apply(tensor):
list_of_funcs = [func1, func2, ..., func24]
# Randomly sample from 0-23
a = tf.random.uniform([1], minval=0, maxval=23, dtype=tf.int32)
return list_of_funcs[a](tensor)
However I cannot index the list_of_funcs as TypeError: list indices must be integers or slices, not Tensor. Additionally, I cannot collect these functions (AFAIK) into a tf.Tensor and use tf.gather.
So my question: how can I reasonably and neatly sample from these functions in a tf.function?