7 min read

Fine-Tune a Transformer Model for Grammar Correction

Learn how to train a Transformer model called T5 to be your very own grammar corrector
Fine-Tune a Transformer Model for Grammar Correction

In this article we'll discuss how to train a state-of-the-art Transformer model to perform grammar correction. We'll use a model called T5, which currently outperforms the human baseline on the General Language Understanding Evaluation (GLUE) benchmark – making it one of the most powerful NLP models in existence. T5 was created by Google AI and released to the world for anyone to download and use.

We'll use my very own Python package called Happy Transformer for this tutorial. Happy Transformer is built on top of Hugging Face's Transformers library and makes it easy to implement and train transformer models with just a few lines of code. So, no complex understanding of NLP or Python is required to understand the contents of this tutorial, even though we'll be training what is known to be one of the most capable AI models in the world.

Scroll down to "Pretrained Model" section to learn how to download and use a grammar correction model I trained and uploaded to Hugging Face's model distribution network.

Installation

Happy Transformer is available on PyPI, and thus can be pip installed.

pip install happytransformer

Model

T5 comes in several different sizes, and we'll use the base model, which has 220 million parameters. The largest available model has 11 billion parameters, while the smallest has 60 million.

T5 is a text-to-text model, meaning given text, it generated a standalone piece of text based on the input. Thus, we'll import a class called HappyTextToText from Happy Transformer, which we'll use to load the model. We'll provide the model type (T5) to the first position parameter and the model name (t5-base) to the second.

from happytransformer import HappyTextToText

happy_tt = HappyTextToText("T5", "t5-base")

Data Collection

We'll use a well-known dataset called JFLEG to train the model. According to its description on its Hugging Face page, it is "a gold standard benchmark for developing and evaluating GEC systems with respect to fluency" (GEC stands for grammar error correction.) In addition, its paper currently has 106 citations, according to Google Scholar, which shows that it is indeed respected within the NLP community [1]. It is under a CC BY-NC-SA 4.0 license, which means that you must provide attribution, not use it for commercial purposes and apply the same license to any derivations*.

*This is not legal advice. Read the full license for more information

The dataset is available on Hugging Face's datasets distribution network and can be accessed using their Datasets library. Since this library is a dependency for Happy Transformer, we do not need to install it and can go straight to importing a function called load_dataset from the library.  

from datasets import load_dataset

The dataset's id is "jfleg" and has two splits ‌"validation" and "test." We'll the validation set for training and the test set for evaluation.

train_dataset = load_dataset("jfleg", split='validation[:]')

eval_dataset = load_dataset("jfleg", split='test[:]')

Data Examination  

We just successfully downloaded the dataset. Let's now explore it by iterating over some cases. Both the train and eval datasets are structured the same way and have two features, sentences and corrections. The sentence feature contains a single string for each case, while the correction feature contains a list of 4 human-generated corrections.

for case in train_dataset["corrections"][:2]:
  print(case)
  print(case[0])
  print("--------------------------------------------------------")

Result:

['So I think we would not be alive if our ancestors did not develop sciences and technologies . ',...

So I think we would not be alive if our ancestors did not develop sciences and technologies .

--------------------------------------------------------

['Not for use with a car . ', 'Do not use in the car . ', 'Car not for use . ', 'Can not use the car . ']

Not for use with a car .

--------------------------------------------------------

Data Preprocessing  

Now, we must process the into the proper format for Happy Transformer. We need to structure both of the training and evaluating data into the same format, which is a CSV file with two columns: input and target. The input column contains grammatically incorrect text, and the target column contains text that is the corrected version of the text from the target column.

Below is code that processes data into the proper format. We must specify the task we wish to perform by adding the same prefix to each input. In this case, we'll use the prefix "grammar: ". This is done because T5 models are able to perform multiple tasks like translation and summarization with a single model, and a unique prefix is used for each task so that the model learns which task to perform. We also need to skip over cases that contain a blank string to avoid errors while fine-tuning.

import csv

def generate_csv(csv_path, dataset):
    with open(csv_path, 'w', newline='') as csvfile:
        writter = csv.writer(csvfile)
        writter.writerow(["input", "target"])
        for case in dataset:
     	    # Adding the task's prefix to input 
            input_text = "grammar: " + case["sentence"]
            for correction in case["corrections"]:
                # a few of the cases contain blank strings. 
                if input_text and correction:
                    writter.writerow([input_text, correction])
                    


generate_csv("train.csv", train_dataset)
generate_csv("eval.csv", eval_dataset)
              

We just generate our training and evaluating data! In total, we generated 3016 training examples and 2988 evaluating examples.


Before Training Evaluating

We'll evaluate the model before and after fine-tuning using a common metric called loss. Loss can be described as how "wrong" the model's predictions are compared to the correct answers. So, if the loss decreases after fine-tuning, then that suggests the model learned. It's important that we use separate data for training and evaluating to show that the model can generalize its obtained knowledge to solve unseen cases.

There are other metrics you can use to evaluate grammar correction models. One of the most popular ones is called GLEU, which you can learn more about here [2]. Loss is the simplest to implement with Happy Transformer so we'll be using it instead.

Let's determine the loss of the model on the evaluating dataset prior to any training. To accomplish this, we'll call happy_tt's eval() method and provide the path to our CSV that contains our evaluating data.

 before_result = happy_tt.eval("eval.csv")

The result is a dataclass object with a single variable called loss, which we can isolate as shown below.

print("Before loss:", before_result.loss)

Result: Before loss: 1.2803919315338135

Training

Let's now train the model. We can do so by calling happy_tt's train() method. For simplicity, we'll use the default parameters other than the batch size which we'll increase to 8. If you experience an out of memory error,  then I suggest you reduce the batch size. You can visit this webpage to learn how to modify various parameters like the learning rate and the number of epochs.

from happytransformer import TTTrainArgs

args = TTTrainArgs(batch_size=8)
happy_tt.train("train.csv", args=args)

After Training Evaluating

Like before, let's determine the model's loss.

before_loss = happy_tt.eval("eval.csv")

print("After loss: ", before_loss.loss)

Result: After loss: 0.451170414686203

There we go, as you can see the loss decrease!  But, now let's evaluate the model in a more qualitative way by providing it with examples.

Inference

Let's now use the model to correct the grammar of examples we'll provide it. To accomplish this, we'll use happy_tt's generate_text() method. We'll also use an algorithm called beam search for the generation. You can view the different text generation parameters you can modify on this webpage, along with different configurations you could use for common algorithms.

from happytransformer import TTSettings

beam_settings =  TTSettings(num_beams=5, min_length=1, max_length=20)

Example 1

example_1 = "grammar: This sentences, has bads grammar and spelling!"
result_1 = happy_tt.generate_text(example_1, args=beam_settings)
print(result_1.text)

Result: This sentence has bad grammar and spelling!

Example 2

example_2 = "grammar: I am enjoys, writtings articles ons AI."

result_2 = happy_tt.generate_text(example_2, args=beam_settings)
print(result_2.text)

Result: I enjoy writing articles on AI.


Next Steps

There are some ways you can potentially improve performance. I suggest transferring some of the evaluating cases to the training data and then optimize the hyperparameters by applying a technique like grid search. You can then include the evaluating cases in the training set to fine-tune a final model using your best set of hyperparameters.

I also suggest that you apply basic data preprocessing. Some of the cases within the dataset contain excess spaces, and if not corrected, the model will produce spaces when not required. So, you can apply the code below to correct the input and output text for both your training and evaluating data.

replacements = [
  (" .", "."), 
  (" ,", ","),
  (" '", "'"),
  (" ?", "?"),
  (" !", "!"),
  (" :", "!"),
  (" ;", "!"),
  (" n't", "n't"),
  (" v", "n't"),
  ("2 0 0 6", "2006"),
  ("5 5", "55"),
  ("4 0 0", "400"),
  ("1 7-5 0", "1750"),
  ("2 0 %", "20%"),
  ("5 0", "50"),
  ("1 2", "12"),
  ("1 0", "10"),
  ('" ballast water', '"ballast water')
]

def remove_excess_spaces(text):
  for rep in replacements:
    text = text.replace(rep[0], rep[1])

  return text

Now, make the following changes at the bottom of the generate_csv() function.

input_text = remove_excess_spaces(input_text)
correction = remove_excess_spaces(correction)
writter.writerow([input_text, correction])

Finally, you can save you model and load it for a an other time as explained on this webpage.

Put Your Skills to the Test

You can fine-tune a grammar correction model and upload it to Hugging Face’s model distribution network to enhance your learning. In particular, I suggest you look into using a newly released dataset by Google for grammar correction called C4_200M Synthetic Dataset for Grammatical Error Correction [3]*. Then, follow this tutorial on how to upload a model to Hugging Face’s model distribution network after you’ve trained a model.

Please email me (eric@vennify.ca) if you publish a model using the techniques discussed in this tutorial. I may publish an article on how to use it.

*Licence: Creative Commons Attribution 4.0 International

Pretrained Model

I published a model on Hugging Face's model distribution network using the dataset and techniques covered in this tutorial. I also applied suggestions within the Next Steps section. I've included code below that demonstrates how to use it.

happy_tt = HappyTextToText("T5", "vennify/t5-base-grammar-correction")

result = happy_tt.generate_text("grammar: I boughts ten apple.", args=beam_settings)
print(result.text)

Result: I bought ten apples.

Conclusion:

I hope you take what you learned and apply it to train your own model to then release to the world through publishing it on Hugging Face's model distribution network. By doing so, you'll be helping many so many people who are eager to implement quality grammar correction models. Or maybe you just scrolled to the bottom of this article to learn how to implement the pretrained model. In any case, hopefully, you learned something useful and stay happy!

Resources

Support Happy Transformer by giving it a star 🌟🌟🌟

Subscribe to my YouTube Channel for an upcoming video on grammar correction.

Join Happy Transformer’s Discord community to network and ask questions to others who have read this article and are passionate about NLP

Code used in this tutorial

References

[1] C. Napoles, K. Sakaguchi, J. Tetreault, JFLEG: A Fluency Corpus and Benchmark for Grammatical Error Correction, EACL 2017

[2] C. Napoles, K. Sakaguchi, M. Post, J. Tetreault, Ground Truth for Grammatical Error Correction Metrics, IJCNLP 2015

[3] F. Stahlberg, S. Kumar, Synthetic Data Generation for Grammatical Error Correction with Tagged Corruption Models, ACL 2021