Train BERT Classifier in Browser with TensorflowJS

In this tutorial you will learn how to setup a BERT model for TensorflowJS and train a simple spam classifier on top of BERT (transfer learning) within the browser. Therefore we gonna setup a model from HuggingFace to make it compatible with TensorflowJS and train it on a spam/ham dataset twice (one time in Python and one time in the browser).

Setup BERT Model

First we import the packages:

We gonna setup a Tiny BERT model to save resources for browser usage. We add input layers to the model for the input that's needed for BERT. It's important to freeze the bert weigths, because we don't want to retrain them. As mentioned at the beginning we want to train the spam classifier twice. For the last layer we need to differentiate between our browser based training and our training as part of a Python script. For setting up the model for the browser we add a flatten layer and add the classification layer later with TFJS. The Python script model can already have a classification layer. Save the model in the SavedModel format.

Before we gonna train the model within the browser we check, if the model will give us good results (this step is optional). Therefore we use a small spam/ham dataset that you can find here. To make the dataset work with BERT we use the encoding function from transformers.

Now we can train the model and evaluate the model on the test dataset. For training and testing we get an accuracy of around 97% - 99%. Pretty good - let's train the model within the browser!

Convert SavedModel into TFJS Model

Converting the model from SavedModel format to TFJS format can be done with the tensorflowjs_converter (version 2.8.2 used). The model will be converted into the graph model format.

Train BERT in Browser

The code is written in TypeScript and can be used with any frontend framework that can run the TensorflowJS library. I used NextJS for setting everything up. (Note: I used TensorflowJS version 2.8.1 and TensorflowJS Converter version 2.8.2)

First we need to make sure that we copied the converted model into an accessible folder, convert the above used vocab.txt into JSON format and make it accessible as well. We also need to setup our own tokenizer. A good example you can find here.

Now it's necessary to load our converted model, the tokenizer and to add some preprocessing functionality:

It's time to put all puzzle pieces together and add the training functionality. As mentioned at the beginning we don't retrain BERT and only use it as a frozen model. Therefore we add a function to get the raw output from the BERT layer for preprocessed inputs. After that, we feed those results into a classification layer. Remember, we didn't setup a classification layer for the TensorflowJS model so do it right here. As we added a flatten layer to the BERT model, the output is a 2 dimensional tensor with the shape of length of all examples and 128 * 128.

Finally we can train the model in the browser. We just need to load the spam/ham dataset, preprocess the data and let the model train! As you see we hit the 97 - 99% accuracy here as well.

Conclusion

It's totally possible to train a model on top of BERT within the browser even though it's not an out-of-the-box feature from TFJS and you have to put some effort in it. This tutorial can be used as the starting point for example for training more advanced text classifiers or personalized chatbots within the browser.

Link to GitHub Repository

Demo