I have a tf.data.Dataset that looks like this:
<BatchDataset shapes: ((None, 256, 256, 3), (None,)), types: (tf.float32, tf.int32)>
The 2nd element (1st if zero indexing) corresponds with a label. I want to cast the 2nd term (labels) to tf.uint8.
How can one use tf.cast when dealing with td.data.Dataset?
Similar Questions
How to convert tf.int64 to tf.float32? is very similar, but is not for a tf.data.Dataset.
Repro
From Image classification from scratch:
curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip
unzip kagglecatsanddogs_5340.zip
Then in Python with tensorflow~=2.4:
import tensorflow as tf
ds = tf.keras.preprocessing.image_dataset_from_directory(
"PetImages", batch_size=32
)
print(ds)