I have a tensor a that I would like to first mask using mask and then discard the remaining frames. To ensure the output tensor is of the correct shape, padding should fill in the remaining values at the end. I can assume there is only a single continuous sequence of True's in each row of the mask.
e.g.
a = torch.arange(1,17).reshape(4,4)
# tensor([[ 1, 2, 3, 4],
# [ 5, 6, 7, 8],
# [ 9, 10, 11, 12],
# [13, 14, 15, 16]])
mask = torch.tensor([[False, True, True, False],
[False, True, True, True],
[ True, False, False, False],
[ True, True, True, True]])
# desired output (assuming padding value is 0):
# tensor([[ 2, 3, 0, 0],
# [ 6, 7, 8, 0],
# [ 9, 0, 0, 0],
# [13, 14, 15, 16]])
I can achieve the desired output by applying torch.masked_select followed by torch.nn.functional.pad on each row in a loop but I am struggling to think of a way to do this more efficiently in batches.
I have also looked into starting by using torch.roll and zeroing after appropriate indexes, but this function can only be applied across an entire dimension and not a custom amount of roll per row.