Questions
How do you build a custom layer in Kras/TensorFlow?
The Scenario
You are an ML engineer at an e-commerce company. You are building a new recommendation engine that will recommend products to users based on their past behavior.
The model takes two inputs: a user ID and a product ID. It then looks up an embedding for the user and an embedding for the product, and it uses these embeddings to predict whether the user will purchase the product.
You need to implement a custom layer that can learn the embeddings for the users and products.
The Challenge
Explain how you would build a custom Keras layer to learn the user and product embeddings. What are the key methods that you would need to implement, and how would you make the layer serializable so that it can be saved and loaded with the model?
A junior engineer might try to implement the embeddings as standalone `tf.Variable` objects. This would be difficult to integrate into a Keras model, and it would not be serializable. They might also not be aware of the `get_config` method, which is needed to make a custom layer serializable.
A senior engineer would know that the correct way to implement the embeddings is to create a custom Keras layer. They would be able to explain how to use the `__init__`, `build`, and `call` methods to create the layer, and they would know how to use the `get_config` method to make the layer serializable.
Step 1: Why a Custom Layer?
Before we dive into the code, let’s compare a custom layer with a Lambda layer.
| Feature | Custom Layer | Lambda Layer |
|---|---|---|
| Trainable Weights | Yes, you can create and manage trainable weights using self.add_weight(). | No, Lambda layers cannot have their own trainable weights. |
| Serialization | Yes, you can make a custom layer serializable by implementing get_config. | No, Lambda layers are not easily serializable, especially if they contain complex logic. |
| Reusability | Yes, you can easily reuse a custom layer in multiple models. | No, Lambda layers are defined inline and are not easily reusable. |
| Complexity | More complex to implement than a Lambda layer. | Very easy to implement for simple, stateless operations. |
For our use case, a custom layer is the best choice. We need to create and manage trainable weights for the user and product embeddings, and we need the layer to be serializable so that we can save and load the model.
Step 2: Building the Custom Layer
Here’s how we can build a custom layer to learn the user and product embeddings:
import tensorflow as tf
class RecommenderNet(tf.keras.layers.Layer):
def __init__(self, num_users, num_products, embedding_dim):
super(RecommenderNet, self).__init__()
self.num_users = num_users
self.num_products = num_products
self.embedding_dim = embedding_dim
def build(self, input_shape):
self.user_embedding = self.add_weight(
"user_embedding",
shape=[self.num_users, self.embedding_dim],
)
self.product_embedding = self.add_weight(
"product_embedding",
shape=[self.num_products, self.embedding_dim],
)
def call(self, inputs):
user_id, product_id = inputs
user_vec = tf.nn.embedding_lookup(self.user_embedding, user_id)
product_vec = tf.nn.embedding_lookup(self.product_embedding, product_id)
dot_product = tf.reduce_sum(user_vec * product_vec, axis=1)
return tf.nn.sigmoid(dot_product)
def get_config(self):
config = super(RecommenderNet, self).get_config()
config.update({
"num_users": self.num_users,
"num_products": self.num_products,
"embedding_dim": self.embedding_dim,
})
return configStep 3: Using the Custom Layer in a Model
Once we have defined the custom layer, we can use it in a Keras model just like any other layer.
user_id_input = tf.keras.layers.Input(shape=(1,), dtype=tf.int32)
product_id_input = tf.keras.layers.Input(shape=(1,), dtype=tf.int32)
output = RecommenderNet(num_users, num_products, embedding_dim)([user_id_input, product_id_input])
model = tf.keras.Model(inputs=[user_id_input, product_id_input], outputs=output) Practice Question
You want to be able to save and load your custom layer with a Keras model. Which method do you need to implement?