class MyDense(tf.keras.layers.Layer):
def __init__(self, units, activation=None, **kwargs):
super().__init__(**kwargs)
self.units = units
self.activation = tf.keras.activations.get(activation) # accepts standard strings like "relu", "selu", ...
def build(self, batch_input_shape):# it's role is to create the layer's variables by calling the add_weight() method for each weight
self.kernel = self.add_weight( # the build method is called the first time the layer is used.
name="kernel", shape=[batch_input_shape[-1], self.units],
initializer="he_normal") # Keras will know the shape of this layer's inputs, and will pass the build() method which is often necessary to create some weights
self.bias = self.add_weight(
name="bias", shape=[self.units], initializer="zeros")
super().build(batch_input_shape) # must be at the end
def call(self, X):
return self.activation(X @ self.kernel + self.bias)
def compute_output_shape(self, batch_input_shape):
return tf.TensorShape(batch_input_shape.as_list()[:-1] + [self.units])
# If you want to be able to save() method and load it using the keras.models.load_model() function you must implement the get config() method for both layers and models
def get_config(self): # to save hyperparameters
base_config = super().get_config()
return {**base_config, "units": self.units,
"activation": tf.keras.activations.serialize(self.activation)}