I am using Keras 2.2.4 and I am trying to implement a loss function for pixel-wise classification as described in here but I am having some of the difficulties presented here. I am doing 3D segmentation, therefore my target vector is (b_size, width_x, width_y, width_z, nb_classes). I implemented the following loss function where weight map is the same shape as target and prediction vector:
def dice_xent_loss(y_true, y_pred, weight_map):
    """Adaptation of https://arxiv.org/pdf/1809.10486.pdf for multilabel 
    classification with overlapping pixels between classes. Dec 2018.
    """
    loss_dice = weighted_dice(y_true, y_pred, weight_map)
    loss_xent = weighted_binary_crossentropy(y_true, y_pred, weight_map)
    return loss_dice + loss_xent
def weighted_binary_crossentropy(y_true, y_pred, weight_map):
    return tf.reduce_mean((K.binary_crossentropy(y_true, 
                                                 y_pred)*weight_map)) / (tf.reduce_sum(weight_map) + K.epsilon())
def weighted_dice(y_true, y_pred, weight_map):
    if weight_map is None:
        raise ValueError("Weight map cannot be None")
    if y_true.shape != weight_map.shape:
        raise ValueError("Weight map must be the same size as target vector")
    dice_numerator = 2.0 * K.sum(y_pred * y_true * weight_map, axis=[1,2,3])
    dice_denominator = K.sum(weight_map * y_true, axis=[1,2,3]) + \
                                                             K.sum(y_pred * weight_map, axis=[1,2,3])
    loss_dice = (dice_numerator) / (dice_denominator + K.epsilon())
    h1=tf.square(tf.minimum(0.1,loss_dice)*10-1)
    h2=tf.square(tf.minimum(0.01,loss_dice)*100-1)
    return 1.0 - tf.reduce_mean(loss_dice) + \
            tf.reduce_mean(h1)*10 + \
            tf.reduce_mean(h2)*10
I am compiling the model using sample_weights=temporal as suggested and I am passing the weights to the model.fit as sample_weight=weights. Still I get the following error:
File "overfit_one_case.py", line 153, in <module>
    main()
File "overfit_one_case.py", line 81, in main
   sample_weight_mode="temporal")
 File "/home/igt/anaconda2/envs/niftynet/lib/python2.7/site-packages/keras/engine/training.py", line 342, in compile
sample_weight, mask)
File "/home/igt/anaconda2/envs/niftynet/lib/python2.7/site-packages/keras/engine/training_utils.py", line 404, in weighted
score_array = fn(y_true, y_pred)
TypeError: dice_xent_loss() takes exactly 3 arguments (2 given)
In  training_utils.py Keras is calling my custom loss without any weights. Any idea on how to solve this? Another constraint that I have is that I am trying to do transfer learning on this particular model. Therefore, I cannot add weight_map to the Input layer as suggested here.
 
    