I am trying to optimize a model with the following two loss functions
def loss_1(pred, weights, logits):
weighted_sparse_ce = kls.SparseCategoricalCrossentropy(from_logits=True)
policy_loss = weighted_sparse_ce(pred, logits, sample_weight=advantages)
and
def loss_2(y_pred, y):
return kls.mean_squared_error(y_pred, y)
however, because TensorFlow 2 expects loss function to be of the form
def fn(y_pred, y_true):
...
I am using a work-around for loss_1 where I pack pred and weights into a single tensor before passing to loss_1 in the call to model.fit and then unpack them in loss_1. This is inelegant and nasty because pred and weights are of different data types and so this requires an additional cast, pack, un-pack and un-cast each time I call model.fit.
Furthermore, I am aware of the sample_weight argument to fit, which is kind of like the solution to this question. This might be a workable solution were it not for the fact that I am using two loss functions and I only want the sample_weight applied to one of them. Also, even if this were a solution, would it not be generalizable to other types of custom loss functions.
All that being said, my question, said concisely, is:
What is the best way to create a loss function with an arbitrary number of arguments in TensorFlow 2?
Another thing I have tried is passing a tf.tuple but that also seems to violate TensorFlow's desires for a loss function input.