A clear introduction to testing machine learning projects using standard libraries such as Pytest and Pytest-cov.


Testing is a critical element of software development, but in my experience, it is widely neglected in machine learning projects: many people know that they should test their code, but not many know how to do so and actually test it.
This guide will introduce you to the basics of testing different parts of a machine learning pipeline, focusing on fine-tuning BERT for text classification on the IMDb dataset and using industry standard libraries such as: pytest
and pytest-cov
for test.
I highly recommend following the code in this Github repository:
Below is a brief overview of the project:
bert-text-classification/
├── src/
│ ├── data_loader.py
│ ├── evaluation.py
│ ├── main.py
│ ├── trainer.py
│ └── utils.py
├── tests/
│ ├── conftest.py
│ ├── test_data_loader.py
│ ├── test_evaluation.py
│ ├── test_main.py
│ ├── test_trainer.py
│ └── test_utils.py
├── models/
│ └── imdb_bert_finetuned.pth
├── environment.yml
├── requirements.txt
├── README.md
└── setup.py
A common approach is to split your code into multiple parts.
src:
Contains the main files used to load the dataset, train and evaluate the model.tests:
It contains different Python scripts, most of the time one test file per script. I personally use the following convention:XXX.py
The corresponding test script is calledtest_XXX.py
and,tests
folder.
for example, evaluation.py
In the file, test_evaluation.py
File.
Note: In the test folder, conftest.py
file. This file is not strictly a test feature, but it contains some configuration information for tests, in particular fixtures
We will explain this in more detail later.
You can just read this article, but I learn better by being proactive, so I highly recommend cloning the repository and starting to play around with the code. To do so, you'll need to clone the github repository, create an environment, and get the model.
# clone github repo
git clone https://github.com/FrancoisPorcher/awesome-ai-tutorials/tree/main# enter corresponding folder
cd MLOps/how_to_test/
# create environment
conda env create -f environment.yml
conda activate how_to_test
You will also need a model to run the evaluation. To reproduce the results, run the main file. Training takes 2-20 minutes (depending on whether you have CUDA, MPS and CPU).
python src/main.py
If you don’t want to fine-tune BERT (but I highly recommend you fine-tune BERT yourself), you can take the standard version of BERT and add a linear layer to get two classes with the following command:
from transformers import BertForSequenceClassificationmodel = BertForSequenceClassification.from_pretrained(
"bert-base-uncased", num_labels=2
)
You're all set!
Let's write some tests:
But first, a quick introduction to Pytest.
pytest
is a mature, industry-standard testing framework that makes it easy to write tests.
Great stuff pytest
It allows you to test at different levels of granularity: a single function, a script, or the entire project. Learn how to implement all three options.
What does the test look like?
Tests are functions that test the behavior of other functions. By convention, if you want to test a function, foo
call the test function test_foo
.
Next, we define some tests to verify that the function we are testing is working as expected.
Let's take an example to make the idea clearer.
In data_loader.py
This script uses some pretty standard functions: clean_text
will strip uppercase letters and spaces. It is defined as follows:
def clean_text(text: str) -> str:
"""
Clean the input text by converting it to lowercase and stripping whitespace.Args:
text (str): The text to clean.
Returns:
str: The cleaned text.
"""
return text.lower().strip()
I want to make sure this function works properly, test_data_loader.py
You can write functions in a file test_clean_text
from src.data_loader import clean_textdef test_clean_text():
# test capital letters
assert clean_text("HeLlo, WoRlD!") == "hello, world!"
# test spaces removed
assert clean_text(" Spaces ") == "spaces"
# test empty string
assert clean_text("") == ""
Note that you use the function assert
Here, the claim is True
If nothing happens, False
, AssertionError
will be raised.
Now let's invoke the test by running the following command in the terminal:
pytest tests/test_data_loader.py::test_clean_text
This terminal command means that you are running your tests using pytest. test_data_loader.py
The script is tests
Suppose you want to run only one test in a folder. test_clean_text
.
If the test passes, you will see the following result:
What happens if I don't pass the test?
In this example, test_clean_text
This works:
def clean_text(text: str) -> str:
# return text.lower().strip()
return text.lower()
Now the function will no longer remove the spaces and the test will fail. If you run the test again, you will see this:
This time we know why the test failed. Great!
Why do we need to test a single function?
Yes, testing takes time. For such a small project, evaluating the entire IMDb dataset can already take several minutes. Sometimes you just want to test a single behavior without retesting the entire codebase every time.
Now let's move to the next level of granularity: testing scripts.
How can I test the entire script?
Now, let's make it more complicated. data_loader.py
Add a script tokenize_text
The function takes as input string
,or list of string
Outputs a tokenized version of the input.
# src/data_loader.py
import torch
from transformers import BertTokenizerdef clean_text(text: str) -> str:
"""
Clean the input text by converting it to lowercase and stripping whitespace.
Args:
text (str): The text to clean.
Returns:
str: The cleaned text.
"""
return text.lower().strip()
def tokenize_text(
text: str, tokenizer: BertTokenizer, max_length: int
) -> Dict[str, torch.Tensor]:
"""
Tokenize a single text using the BERT tokenizer.
Args:
text (str): The text to tokenize.
tokenizer (BertTokenizer): The tokenizer to use.
max_length (int): The maximum length of the tokenized sequence.
Returns:
Dict[str, torch.Tensor]: A dictionary containing the tokenized data.
"""
return tokenizer(
text,
padding="max_length",
truncation=True,
max_length=max_length,
return_tensors="pt",
)
To help us understand a bit more what this function does, let's try it out with an example.
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
txt = ["Hello, @! World! qwefqwef"]
tokenize_text(txt, tokenizer=tokenizer, max_length=16)
This will output the following result:
{'input_ids': tensor([[ 101, 7592, 1010, 1030, 999, 2088, 999, 1053, 8545, 2546, 4160, 8545,2546, 102, 0, 0]]),
'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])}
max_length
The : is the maximum length of a sequence. In this case, we chose 16, but we can see that the length of the sequence is 14. So we can see that the last two tokens are padded.input_ids
: Each token is transformed into an associated ID that is a world that is part of the vocabulary. Note: token 101 is a tokenCLS
token_id 102 is the tokenSEP
These two tokens indicate the beginning and end of a sentence. For more information, read the paper “Attention is all your need”.token_type_ids
: It doesn't really matter. If you put two sequences as input, the second sentence will have a value of one.attention_mask
: This tells the model which tokens it should pay attention to in the self-attention mechanism. Because the sentence is padded, the attention mechanism does not need to pay attention to the last two tokens, which have 0s there.
So, test_tokenize_text
the function is, tokenize_text
The function works properly:
def test_tokenize_text():
"""
Test the tokenize_text function to ensure it correctly tokenizes text using BERT tokenizer.
"""
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")# Example input texts
txt = ["Hello, @! World!",
"Spaces "]
# Tokenize the text
max_length = 128
res = tokenize_text(text=txt, tokenizer=tokenizer, max_length=max_length)
# let's test that the output is a dictionary and that the keys are correct
assert all(key in res for key in ["input_ids", "token_type_ids", "attention_mask"]), "Missing keys in the output dictionary."
# let's check the dimensions of the output tensors
assert res["input_ids"].shape[0] == len(txt), "Incorrect number of input_ids."
assert res['input_ids'].shape[1] == max_length, "Incorrect number of tokens."
# let's check that all the associated tensors are pytorch tensors
assert all(isinstance(res[key], torch.Tensor) for key in res), "Not all values are PyTorch tensors."
Now let's run a full test. test_data_loader
The .py file currently has two functions:
test_tokenize_text
test_clean_text
You can run a full test from the terminal using this command:
pytest tests/test_data_loader.py
And you should get a result like this:
Congratulations! Now you know how to test an entire script, let's go to the last level and test your entire code base.
How do I test my entire codebase?
Continuing with the same reasoning, you can write other tests per script and they should have a similar structure.
├── tests/
│ ├── conftest.py
│ ├── test_data_loader.py
│ ├── test_evaluation.py
│ ├── test_main.py
│ ├── test_trainer.py
│ └── test_utils.py
Note that some variables in these test functions are constants. For example, tokenizer
The ones we use are the same for all scripts. Pytest
There is a better way to deal with this Fixtures.
Fixtures are a way to set up context or state before a test runs and clean up after it runs, providing a mechanism for managing test dependencies and injecting reusable code into tests.
The fixture is @pytest.fixture
Decorator.
A tokenizer is a good example of a fixture that we can use. To do so, we need to writeconftest.py
The file is tests
folder:
import pytest
from transformers import BertTokenizer@pytest.fixture()
def bert_tokenizer():
"""Fixture to initialize the BERT tokenizer."""
return BertTokenizer.from_pretrained("bert-base-uncased")
And now, test_data_loader.py
In the file you can call fixtures bert_tokenizer
In the discussion of test_tokenize_text.
def test_tokenize_text(bert_tokenizer):
"""
Test the tokenize_text function to ensure it correctly tokenizes text using BERT tokenizer.
"""
tokenizer = bert_tokenizer# Example input texts
txt = ["Hello, @! World!",
"Spaces "]
# Tokenize the text
max_length = 128
res = tokenize_text(text=txt, tokenizer=tokenizer, max_length=max_length)
# let's test that the output is a dictionary and that the keys are correct
assert all(key in res for key in ["input_ids", "token_type_ids", "attention_mask"]), "Missing keys in the output dictionary."
# let's check the dimensions of the output tensors
assert res["input_ids"].shape[0] == len(txt), "Incorrect number of input_ids."
assert res['input_ids'].shape[1] == max_length, "Incorrect number of tokens."
# let's check that all the associated tensors are pytorch tensors
assert all(isinstance(res[key], torch.Tensor) for key in res), "Not all values are PyTorch tensors."
Fixtures are a very powerful and versatile tool – if you want to learn more about fixtures the official documentation is the go-to resource, but at least for now there is enough tooling to cover most ML testing.
Let’s run the entire codebase by executing the following command from your terminal:
pytest tests
And then the following message appears:
Congratulations!
In the previous section, we learned how to test our code. In large projects, coverage
Of tests. That is, how much of the code is tested.
pytest-cov
It's a plugin pytest
Generate a test coverage report.
That being said, don't be fooled by the coverage percentage – achieving 100% coverage doesn't mean your code is bug-free – it's just a tool to help you identify which parts of your code need further testing.
To generate a coverage report from the terminal, run the following command:
pytest --cov=src --cov-report=html tests/
And we get:
Let's see how to read it:
- statementThe total number of executable statements in your code. It counts all executable lines of code, including conditional statements, loops, function calls, etc.
- do not have: This shows the number of statements that were not executed during the test run. These are lines of code that were not covered by any tests.
- coverage: The percentage of the total number of statements that were executed during the test. It is calculated by dividing the number of statements executed by the total number of statements.
- Exclusion: This refers to lines of code that are explicitly excluded from coverage measurement. This is useful for ignoring code that is not relevant for test coverage, such as debug statements.
Coverage is main.py
If the file is at 0%, this is normal. test_main.py
File.
We also see that only 19% are present. evaluation
It shows you the code under test, so you know where to focus first.
Congratulations, you succeeded!
Thanks for reading! Before you leave:
For more great tutorials, check out this roundup of AI tutorials on Github.
yesGet my articles delivered to your inbox. Subscribe here.
If you want access to premium Medium articles, you'll need a $5/month membership. In my linkyou will support me with a portion of your fee at no extra cost to you.