To be clear, I am not
- Asking how to prevent gradients from being propagated to certain tensors (in this case you can just set
requires_grad = Falsefor that tensor). - Asking how to prevent gradients from being propagated from an entire tensor (in that case you can just call
tensor.detach(), see this question).
I'm wondering how to forgo gradient computations for some elements of a loss tensor that give a NaN gradient every time -- essentially, to call .detach() for individual elements of a tensor. The way to do this in Tensorflow is using tf.stop_gradients, see this question.
Some context: My neural network computes a distance matrix of its predicted coordinates, as follows. The entries of the distance matrix D are given by d_ij = || coordinates_i - coordinates_j ||. I want to backpropagate through the distance matrix creation step. However, the norm function includes a square root, which is not differentiable at 0 -- and the diagonal of the distance matrix is 0 by construction. Thus I get NaN gradients for the diagonal of the distance matrix. I would like to mask out the gradients on the diagonal of the distance matrix.
Minimal working example:
import torch
def compute_distance_matrix(coordinates):
L = len(coordinates)
gram_matrix = torch.mm(coordinates, torch.transpose(coordinates, 0, 1))
gram_diag = torch.diagonal(gram_matrix, dim1=0, dim2=1)
# gram_diag: L
diag_1 = torch.matmul(gram_diag.unsqueeze(-1), torch.ones(1, L).to(coordinates.device))
# diag_1: L x L
diag_2 = torch.transpose(diag_1, dim0=0, dim1=1)
# diag_2: L x L
distance_matrix = torch.sqrt(diag_1 + diag_2 - (2 * gram_matrix))
return distance_matrix
# In reality, pred_coordinates is an output of the network, but we initialize it here for a minimal working example
L = 10
pred_coordinates = torch.randn(L, 3, requires_grad=True)
true_coordinates = torch.randn(L, 3, requires_grad=False)
obj = torch.nn.MSELoss()
optimizer = torch.optim.Adam([pred_coordinates])
for i in range(500):
pred_distance_matrix = compute_distance_matrix(pred_coordinates)
true_distance_matrix = compute_distance_matrix(true_coordinates)
loss = obj(pred_distance_matrix, true_distance_matrix)
loss.backward()
print(loss.item())
optimizer.step()
gives
1.2868314981460571
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
...