In this post, we try to understand entity embedding applied to categorical variable. Entity embeddings is a way to represent a categorical variable in a new dimension. It was (and still is) mostly used in natural language processing application,
to interpret a word within his context (using surrounding word). But now, it is used for any categorical variable. I saw the idea (to apply embedding to categorical) while studying the fast.ai MOOC
with Jeremy Howard and I found the idea awesome. Since then, I use this idea in a few project, but didn't take the time to plot and study the embedding created in a model.
in the following example, I will use an old model that I created few months ago to predict the price of a flat (you can find the app using the model here).
I will create a neural net to predict the price, with embedding for location (french department) only.
Then, I will plot the embedding to see what is created by the model.
Why ?
Why entity embedding can be useful for categorical variable:
Reduce memory usage and speeds up neural net compared to one-hot encoding (dimensionality is reduced, depending on the choosen size of the embeddings).
Can help to generalize better (especially for high cardinality.
Embedding vectors can be reused (but are trained depending on a specific target).
Can reveal property of the variable (similar value will have close vector values).
Method
I won't explain the technicality behind embedding, because you can find it almost anywhere on the internet, explained by people way better than me. You can look at post that explain word2vec and if you want to understand specific applications to categorical variables, you can look at the paper written by Cheng Guo and Felix Berkhahn "Entity Embeddings of Categorical Variables". I think it is the first known application of entity embeddings for categorical variable, they used it for a Kaggle competition and won the third place.
Dimension of entity embeddings: it can be difficult to find the best size of embeddings. Most of the time, we will use trial and error, but Jeremy Howard propose the following rule of thumb for that matter: embedding size = min(50, number of categories/2). You can find this rule of thumb in the excellent fast.ai MOOC "Deep learning for coders".
Code
I use the following code (associated with custom transformers than you can find on my github):
Result
The result is interesting. We tried to plot the embeddings in two dimensions using TSNE and PCA (the previous plot is based on PCA). Because we trained the embedding with a target on renting price, we can find pattern based on price and location. For example, we can find that all the department related to "ile de France" are really close on the plot.
I was really curious to see if some pattern can be found using embedding on department, in this context. Apparently that is the case, and there is probably some other interesting pattern that I didn't notice.