I am trying to implement a discriminative loss function for instance segmentation of images based on this paper: https://arxiv.org/pdf/1708.02551.pdf (This link is just for the readers' reference; I don't expect anyone to read it to help me out!)
My problem: Once I move from a simple loss function to a more complicated one (like you see in the attached code snippet), the loss function zeroes out after the first epoch. I checked the weights, and almost all of them seem to hover closely around -300. They are not exactly identical, but very close to each other (differing only in the decimal places).
Relevant code that implements the discriminative loss function:
def regDLF(y_true, y_pred):
    global alpha
    global beta
    global gamma
    global delta_v
    global delta_d
    global image_height
    global image_width
    global nDim
    y_true = tf.reshape(y_true, [image_height*image_width])
    X = tf.reshape(y_pred, [image_height*image_width, nDim])
    uniqueLabels, uniqueInd = tf.unique(y_true)
    numUnique = tf.size(uniqueLabels)
    Sigma = tf.unsorted_segment_sum(X, uniqueInd, numUnique)
    ones_Sigma = tf.ones((tf.shape(X)[0], 1))
    ones_Sigma = tf.unsorted_segment_sum(ones_Sigma,uniqueInd, numUnique)
    mu = tf.divide(Sigma, ones_Sigma)
    Lreg = tf.reduce_mean(tf.norm(mu, axis = 1))
    T = tf.norm(tf.subtract(tf.gather(mu, uniqueInd), X), axis = 1)
    T = tf.divide(T, Lreg)
    T = tf.subtract(T, delta_v)
    T = tf.clip_by_value(T, 0, T)
    T = tf.square(T)
    ones_Sigma = tf.ones_like(uniqueInd, dtype = tf.float32)
    ones_Sigma = tf.unsorted_segment_sum(ones_Sigma,uniqueInd, numUnique)
    clusterSigma = tf.unsorted_segment_sum(T, uniqueInd, numUnique)
    clusterSigma = tf.divide(clusterSigma, ones_Sigma)
    Lvar = tf.reduce_mean(clusterSigma, axis = 0)
    mu_interleaved_rep = tf.tile(mu, [numUnique, 1])
    mu_band_rep = tf.tile(mu, [1, numUnique])
    mu_band_rep = tf.reshape(mu_band_rep, (numUnique*numUnique, nDim))
    mu_diff = tf.subtract(mu_band_rep, mu_interleaved_rep)
    mu_diff = tf.norm(mu_diff, axis = 1)
    mu_diff = tf.divide(mu_diff, Lreg)
    mu_diff = tf.subtract(2*delta_d, mu_diff)
    mu_diff = tf.clip_by_value(mu_diff, 0, mu_diff)
    mu_diff = tf.square(mu_diff)
    numUniqueF = tf.cast(numUnique, tf.float32)
    Ldist = tf.reduce_mean(mu_diff)        
    L = alpha * Lvar + beta * Ldist + gamma * Lreg
    return L
Question: I know it's hard to understand what the code does without reading the paper, but I have a couple questions:
Is there something glaringly wrong with the loss function defined above?
Anyone has a general idea as to why the loss function could zero out after the first epoch?
Thank you very much for your time and help!