Just for completeness I am adding just a bit on top of the answer of benjaminplanche. If your custom layer AttentionLayer have any initial parameter that configure its behaviour you need to implement the get_config method of the class. Otherwise it will fail to load. I am writing this because I had a lot of troubles on how to load custom layers with arguments, so I'll leave it here.
For example, a dummy implementation of your layer:
class AttentionLayer(Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, input_shape):
return super().build(input_shape)
def call(self, x):
# Implementation about how to look with attention!
return x
def compute_output_shape(self, input_shape):
return input_shape
This will load with any of the approaches detailed in benjaminplanche's answer, i.e. using the custom_objects={'AttentionLayer': AttentionLayer}. However if your layer have some arguments the loading would fail.
Imagine the init method of your class have 2 paramters:
class AttentionLayer(Layer):
def __init__(self, param1, param2, **kwargs):
self.param1 = param1
self.param2 = param2
super().__init__(**kwargs)
Then, when you load it with:
model = load_model('my_model.h5', custom_objects={'AttentionLayer': AttentionLayer})
It would throw this error:
Traceback (most recent call last):
File "/path/to/file/cstm_layer.py", line 62, in <module>
h = AttentionLayer()(x)
TypeError: __init__() missing 2 required positional arguments: 'param1' and 'param2'
In order to solve it you need to implement the get_config method in your custom layer class. An example:
class AttentionLayer(Layer):
def __init__(self, param1, param2, **kwargs):
self.param1 = param1
self.param2 = param2
super().__init__(**kwargs)
# ...
def get_config(self):
# For serialization with 'custom_objects'
config = super().get_config()
config['param1'] = self.param1
config['param2'] = self.param2
return config
So when you save the model, the saving routine will call the get_config and will serialize the inner state of your custom layer, i.e., the self.params. And when you load it, the loader will know how to initialize the inner state of your custom layer.