I have an index tensor of size (2, 3):
>>> index = torch.empty(6).random_(0,8).view(2,3)
tensor([[6., 3., 2.],
[3., 4., 7.]])
And a value tensor of size (2, 8):
>>> value = torch.zeros(2,8)
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]])
I want to set the element in value to 1 by the index along dim=-1.** The output should be like:
>>> output
tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
[0., 0., 0., 1., 1., 0., 0., 1.]])
I tried value[range(2), index] = 1 but it triggers an error. I also tried torch.index_fill but it doesn't accept batched indices. torch.scatter requires creating an extra tensor of size 2*8 full of 1, which consumes unnecessary memory and time.