"You're a little scary sometimes, you know that? Brilliant ... but scary." — Ron Weasley
Was Ron referring to Hermione or GPT-2?
This article discusses how to implement a fine-tuned version of GPT-2 to generate Harry Potter fan fiction. According to its description, it was fine-tuned with "the top 100 rated fanficition stories." After experimenting with it, I found that it works quite well – especially after adjusting the text generation algorithm. So, let's get right generating fan fiction for the best-selling book series of all time!
First off, we'll install Happy Transformer. Happy Transformer is an open-source Python library that makes it easy to implement and train Transformer models. I am the lead maintainer of it, and if you want to network with other Transformer model enthusiasts, then I suggest you join Happy Transformer's Discord server.
pip install happytransformer
Now, we're going to import a class called HappyGeneration to load the model.
from happytransformer import HappyGeneration
We need to provide two positional arguments to create a HappyGeneration object: the model type and model name. We're using a GPT-2 model, so our model type is simply "GPT2". The name of the model can be found here on Hugging Face's model distribution network.
happy_gen = HappyGeneration("GPT2", "ceostroff/harry-potter-gpt2-fanfiction")
From here, we can begin generating text with just one line of code using happy_gen's "generate_text" method. This method requires a single position input that is a string. Then, the model attempts to continue whatever input was provided and outputs a Dataclass object.
result = happy_gen.generate_text("Pass the")
The outputted Dataclass contains a single variable called text. This text value is a continuation of whatever text was provided.
Output: GenerationResult(text=' letter to the Ministry, and then go to the Ministry. I\'ll be there shortly. I\'ll be there in a few minutes."Harry nodded and walked out of the office. He was glad that he had a friend in the Ministry, and that')
Now, we can extract the text variable.
Output: letter to the Ministry, and then go to the Ministry. I'll be there shortly. I'll be there in a few minutes."Harry nodded and walked out of the office. He was glad that he had a friend in the Ministry, and that
Let's import the GENSettings class to begin modifying the text generation parameters. By default, an algorithm called "greedy" is used, which is prone to repetition. This algorithm simply repeatedly selects the most likely token, where tokens are typically words and symbols.
from happytransformer import GENSettings
Head over to this webpage to learn about different text generation algorithms we can modify.
We want to prevent repetition. In the above example, notice how "I'll be there" was repeated twice. We can prevent this by setting GENSetting's no_repeat_ngram_size parameter to two. This means that a pair of two tokens (2-gram) cannot repeat more than once. For example, "I am" and "Canada." are both 2-grams. So, if "Canada." occurs within the generated text, then it cannot occur again (note: the period counts as a token).
args = GENSettings(no_repeat_ngram_size=2)
We can now generate text as before, except set generate_text's args parameter to the object we created above.
result = happy_gen.generate_text("Pass the", args=args)
Result: letter to the Ministry, and then go to Gringotts. I'll be back in a few minutes."Harry nodded, then turned back to his friends. "I'll see you in the morning, Hermione."Hermione smiled. She'd
Be sure to check out other algorithms you can use such as beam search, generic sampling and top-k sampling. Also, if you wish to generate longer sequences, then increase GENSettings's max_length parameter (which is by default 50).
Fine-tuning models with Happy Transformer is incredibly easy. For example, let's say you have a text file that contains Harry Potter Fanfiction called "train.txt." Then, you only need the following lines of code to train a model. Here we are using a model called GPT-Neo, which is an open-source version of GPT-3.
from happytransformer import HappyGeneration
happy_gen = HappyGeneration("GPT-NEO", "EleutherAI/gpt-neo-125M")
I suggest you copy this code and modify both the input and the text generation parameters to see if you can generate creative text. Stay happy everyone!
Link to code used in this tutorial:
Check out this new course I created on how to create a web app to display GPT-Neo with 100% Python. It goes into more depth compared to this article on how to implement and train text generation models
Subscribe to my YouTube channel for new videos on NLP.
Special thanks to Caitlin Ostroff for publishing this model. Her Hugging Face profile can be accessed here.
Book a Call
We may be able to help you or your company with your next NLP project. Feel free to book a free 15 minute call with us.