4 min read

Implement and Train Text Classification Transformer Models

Learn how to implement and train text classification Transformer models like BERT, DistilBERT and more with only a few lines of code
Implement and Train Text Classification Transformer Models

Text classification is debatably the most common application of NLP. And, like for most NLP applications, Transformer models have dominated the field in recent years. In this article, we'll discuss how to implement and train text classification Transformer models. We'll use a library my team created called Happy Transformer. Happy Transformer is built on top of Hugging Face's transformers library and allows programmers to implement and train Transformer models with just a few lines of code.

Pretrained Models

There are 100s of pretrained text classification models you can choose from on Hugging Face's model distribution network. So, I suggest that before you spend too much time worrying about training a model – see if someone has already fine-tuned a model for your particular application. For example, I've already produced content on how to implement pretrained Transformer models for sentiment analysis and hate speech detection. In this tutorial, we'll implement a model called finbert, which was created by a company called Prosus. This model detects the sentiment of financial data.


Happy Transformer is available on PyPI, and thus we can install it with 1 line of code.

pip install happytransformer


Let's import a class called HappyTextClassification, which we'll use to load the model.

from happytransformer import HappyTextClassification 

From here, we can instantiate an object for the model using the HappyTextClassification class.  The first position argument specifies the type of model and is in all caps. For example, "BERT," "ROBERTA," and "ALBERT" are all valid model names.  The second position argument indicates the model's name, which can be found on the model's webpage. The final parameter is a called "num_labels" and specifies the number of classes the model has. In this case, the model has three labels: "positive," "neutral," and "negative."

Important: Do not forget to set num_labels when instantiating a model. Otherwise, an error may occur.

happy_tc = HappyTextClassification("BERT", "ProsusAI/finbert", num_labels=3)


We can ‌begin classifying text with just one line of code with the method "classify_text."

result = happy_tc.classify_text("Tesla's stock just increased by 20%")

Let's print the result so that we can understand it a little better.


Output: TextClassificationResult(label='positive', score=0.929110586643219)

As you can see, ‌the output is a dataclass with two variables: "label" and "score." The label is a string to indicate which class the input was classified into. The "score" variable specifies the probability the model assigned to the answer as a float. We can not isolate these two variables.





Here's another example to show the output for a negative input.

result = happy_tc.classify_text("The price of gold just dropped by 5%")




Training – NLP Sentiment Analysis

Let's now discuss training. We'll train a model to detect the sentiment of text relating to NLP. We'll only use two examples for training – which of course is not enough to robustly train a model. But, it's just for demonstration.

We must create a CSV file with two columns: text and label. The text column contains the text we wish to classify.  The label column contains the label type as an integer that's greater than or equal to 0. Below is a table that gives an example of a training CSV.

Wow I love using BERT for text classification0
I hate NLP1

Here is code to produce the CSV file above:

import csv

cases= [("Wow I love using BERT for text classification", 0), ("I hate NLP", 1)]

with open("train.csv", 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["text", "label"])
        for case in cases:
            writer.writerow([case[0], case[1]])

First off, we'll install a plain version of DistilBERT to use as a starting point. There are others models you can use, such as BERT, ALBERT, RoBERTa and more. Visit Hugging Face's model distribution network for more models.

happy_tc = HappyTextClassification(model_type="DISTILBERT", model_name="distilbert-base-uncased", num_labels=2)

Then, we can simply call the method "train" using our newly instantiated class.


And that's it! We just trained the model. We can now resume using it as we did in the previous section. So, for example, you can now call "happy_tc.classify_text()" as before, and the newly fine-tuned model will be used.

Custom Parameters

We can easily modify the learning parameters, such as the number of epochs, learning rate and more, by using a class called "TCTrainArgs." Let's import TCTrainArgs.

from happytransformer import TCTrainArgs

Now, we can create an object using the TCTrainArgs class to contain the training arguments. A list of parameters you can modify here. Let's increase the default number of training epochs from 3 to 5.

args = TCTrainArgs(num_train_epochs=5)

Let's call happy_tc's train method as before, but this time pass our args object into the method's args parameter.

happy_tc.train("train.csv", args=args)

There we go, we just modified the learning parameters!


HappyTextGeneration objects have a built-in method that allows you to quickly evaluate your model. First, format your data in the same as discussed for training and then call the ".eval()" method. For simplicity, let's use the training file to evaluate.

result = happy_tc.eval("train.csv")

Result: EvalResult(loss=0.2848379611968994)

We can then isolate the loss variable like so:


Output: 0.2848379611968994

I suggest you use a subsection of your overall data for training and another subsection for evaluating. Then, evaluate your model before and after training. If the loss decreases, then that means that your model learned. You may also create a third section of your data for running experiments to find optimal learning parameters – but that's a talk for another time.


And that's it! You just learned how to implement and train text classification Transformer models. I hope you enjoyed this article.  Be sure to subscribe to our newsletter and YouTube channel for more content like this.

Stay happy everyone!


Check out our latest course!  It covers how to create a web app to display GPT-Neo with 100% Python. It also covers how to fine-tune GPT-Neo. Click the link below to learn more.


Code From This Article: