I have a class in which I instance a Keras model to perform predictions. This class is organized somewhat like this:
class MyClass():
def __init__(self):
self.model = None
def load(path):
self.model = tf.keras.models.load_model(path_)
def inference(data):
#...
pred = self.model.predict(data)
#...
return pred
I have been trying to run the MyClass.inference method in parallel. I tried it with joblib.Parallel:
from joblib import Parallel, delayed
n_jobs = 8
myobj = MyClass()
myobj.load(<Path_to_model>)
results = Parallel(n_jobs=n_jobs )(delayed(myobj.inference)(d) for d in mydata))
But I get the following error: TypeError: cannot pickle 'weakref' object
Apparently, this is a known issue with Keras (https://github.com/tensorflow/tensorflow/issues/34697), that should have been fixed on TF 2.6.0. But after upgrading tensorflow to 2.6.0, I still get the same error. I even tried tf-nightly, as suggested in the same issue, but it also did not work.
I also tried replacing pickle with dill, by import dill as pickle, but it did not fix it.
The only thing that actually worked is replacing the loky backend in Parallel by threading. However, in one scenario I tried using threading ends up taking pretty much the same time (or a bit slower) as performing the MyClass.inference calls sequentially.
My question is: what are my options here? Is there any way to run a preloaded keras model's predict in parallel, such as with other python libs?