How to Perform Word Prediction With Transformer Models
I think therefore I [MASK]
Masked word prediction is a fundamental task for Transformer models. For example, BERT was pre-trained by using a combination of masked word prediction and next sentence prediction . Although this task may seem simple, a deep understanding of language is required to complete it, making it an appealing choice for pre-training large language models.
This tutorial will cover how to implement state-of-the-art NLP Transformer models to perform masked word prediction. We'll use a Python library called "Happy Transformer," which I'm the lead maintainer of. Happy Transformer is built on top of Hugging Face's Transformer's library, and allows programmers to implement and train Transformer models with just a few lines of code.
Happy Transformer is available on PyPI for download.
pip install happytransformer
We will import the HappyWordPrediction class to begin predicting masks.
from happytransformer import HappyWordPrediction
We will now instantiate a HappyWordPrediction class with the default setting. By default, it uses a model called "distilbert-base-uncased."
happy_wp = HappyWordPrediction()
To start performing word prediction, we must use happy_wp's predict_mask() method. This method has three inputs:
text: A string that contains a single substring equal to [MASK]. This substring will be the model's target when predict_mask() is called.
targets: A list of potential tokens. For example, a token may be a single word or a symbol such as a period.
top_k: the number of results that are returned.
text = "Artificial intelligence is going to [MASK] the world." result = happy_wp.predict_mask(text, top_k=2)
result: [WordPredictionResult(token='change', score=0.1882605403661728), WordPredictionResult(token='conquer', score=0.15199477970600128)]
As you can see, the method outputs a list of dataclass objects with parameters "token" and "score." Let's analyze the output a little more
print(result) print(result.token) print(result.score)
The code above shows how to output the to extract the result for the top result. We can easily switch the "0" with "1" to get the second result.
Predicting with targets is simple. Provide a list of strings to the predict_mask()'s targets parameter.
text = "Natural [MASK] processing" targets = ["language", "python"] result = happy_wp.predict_mask(text, targets=targets) print(result.token)
When creating a HappyWordPrediction object, provide a "model_type" and a "model_name" to change the model that will be used. The model type is the kind of model in all caps, such as "BERT," "ROBERTA," "ALBERT" etc. Then, for the model name, copy the name of the model found on this webpage.
happy_wp_albert = HappyWordPrediction("ALBERT", "albert-base-v2") happy_wp_roberta = HappyWordPrediction("ROBERTA", "roberta-base")
You can continue to use the predict_mask() method how you were before.
I'm going to briefly introduce training a word prediction model. This deserves a full tutorial on its own, but between this blog post and the official documentation, you should be well on your way to finetuning a custom model.
To train a HappyWordPrediction object, call the method "train()" and provide a path to a text file as the only position parameter.
You may modify the training settings by using a class called WPTrainArgs. After creating an object using the WPTrainArgs class, provide the object to HappyWordPrediction.train()'s args parameter.
from happytransformer import WPTrainArgs args = WPTrainArgs(num_train_epochs=1) happy_wp.train("train.txt", args=args)
Visit this webpage for more details about different training parameters.
And that's it! Subscribe to my YouTube channel for more content like this. Also, check out Happy Transformer's GitHub page. There is a Happy Transformer Discord group you can join to learn more about it and to network with other Transformer enthusiasts.
Support Happy Transformer by giving it a star 🌟🌟🌟
Book a Call
We may be able to help you or your company with your next NLP project. Feel free to book a free 15 minute call with us.