import tensorflow as tf
class CustomRNNLayer(tf.keras.layers.Layer):
def __init__(self, hidden_units, activation='tanh', **kwargs):
super(CustomRNNLayer, self).__init__(**kwargs)
self.hidden_units = hidden_units
self.activation = tf.keras.activations.get(activation)
def build(self, input_shape):
self.W_xh = self.add_weight(shape=(input_shape[-1], self.hidden_units),
initializer='glorot_uniform',
name='W_xh')
self.W_hh = self.add_weight(shape=(self.hidden_units, self.hidden_units)
, initializer='orthogonal',
name='W_hh')
self.b_h = self.add_weight(shape=(self.hidden_units,),
initializer='zeros',
name='b_h')
self.built = True
def call(self, inputs, initial_state=None):
if initial_state is None:
initial_state = tf.zeros((tf.shape(inputs)[0], self.hidden_units))
sequence_length = tf.shape(inputs)[1]
hidden_states = []
h = initial_state
for t in range(sequence_length):
x = inputs[:, t, :]
h = self.activation(
tf.matmul(x, self.W_xh) + tf.matmul(h, self.W_hh) + self.b_h
)
hidden_states.append(h)
return tf.stack(hidden_states, axis=1)