Introducing Metrax: High Performance, Efficient, and Robust Model Evaluation Metrics for JAX

Machine Learning


At Google, when the team migrated from TensorFlow to JAX, JAX didn’t have a built-in metrics library, so the team manually reimplemented the metrics previously provided by TensorFlow. So each team using JAX implemented their own version of precision, F1, RMS error, etc. Creating metrics may seem like a very simple and straightforward topic to some, but it becomes less straightforward when you consider large-scale training and evaluation across a data center-scale distributed computing environment.

Thus was born Metrax’s idea to bring a high-performance library for efficient and robust model evaluation metrics to JAX. Metrax currently provides predefined metrics used to evaluate different types of machine learning models (classification, regression, recommendation, vision, audio, language), providing compatibility and consistency in distributed and scaled training environments. This allows you to focus on evaluating your model. resultrather than (re)implementing different metric definitions. Metrax adds to the ever-evolving ecosystem of JAX-based tools and integrates well with the JAX AI Stack, a suite of tools designed to work together to power your AI tooling needs. Today, Metrax is already used in some of Google’s largest software stacks, including Google Search, YouTube, and the team behind Tunix, Google’s own post-training library.

Metrax’s strengths

Of particular note is the built-in ability to compute the “at K” metric for multiple values ​​of K in parallel. This allows for a more comprehensive and faster evaluation of model performance. For example, you can use PrecisionAtK To determine the accuracy of the model for multiple values ​​of K (K=1, K=8, K=20, and so on), you do not need to make any calls; run the model all in one forward pass. PrecisionAtK Run multiple times for each of these arguments. There are several “at K” metrics you can try. RecallAtK and NDCGAtK. All metrics and their definitions can be found in the documentation here.

The last thing you want to worry about when working on a machine learning research project is whether your metrics are implemented correctly throughout your system. Therefore, having a well-tested metrics library helps the community create less error-prone code and model evaluations.

performance

Metrax leverages some of JAX’s core strengths, including: vmap and jitallows you to do multiple “at K” operations, etc., and allows you to do it in a high-performance way. Due to the nature of metrics, not all metrics provided are “jit-able,” but the goal is to ensure that all metrics are well-written and to demonstrate best practices. In addition to classic metrics such as precision, precision, and recall, the library also features a robust set of NLP-related metrics such as Perplexity, BLEU, and ROUGE, as well as metrics for vision models such as Intersection over Union (IoU), Signal-to-Noise Ratio (SNR), and Structural Samelarity Index (SSIM). You no longer need to vibecode your metrics implementation. Just use Metrax.

Metrax is in operation

Let’s see how to use Metrax in code. Calculating the accuracy metric from the model output looks like this: Note that to pass the prediction and label along with the threshold and calculate the value of the metric, you need to make the following call: compute().

import metrax

# Directly compute the metric state.
metric_state = metrax.Precision.from_model_output(
    predictions=predictions,
    labels=labels,
    threshold=0.5
)

# The result is then readily available by calling compute().
result = metric_state.compute()
result

python

Evaluations are often performed in batches, so you want to be able to iteratively add more information to your collection of metrics. Metrax supports this workflow with a feature called . merge(). This is a great function to use within your evaluation loop when aggregating metrics during training runs. Note that you are still calling compute() Once you are ready to get the final value.

# Iteratively merging precision metrics
for labels_b, predictions_b, weights_b in zip(labels_batched, predictions_batched, sample_weights_batched):
    batch_metric_state = metrax.Precision.from_model_output(
        predictions=predictions_b,
        labels=labels_b
    )
    metric_state = metric_state.merge(batch_metric_state)

result = metric_state.compute()
result

python

See this notebook for a complete example. This notebook shows other ways you can use Metrax, including scaling to multiple devices and integrating with Flax NNX, a modeling library that abstracts some of the implementation details of building AI models.

contribute

Metrax is developed on GitHub and happily accepts contributions from the community. Some of the metrics currently available were actually added by community contributors. A big shout out to GitHub users @nikolasavic3 and @Mrigankkh for their efforts. So if you have more metrics you’d like to add, please submit a pull request and work with our development team to include it in Metrax. For more information, visit github.com/google/metrax.

Also, be sure to check out the other libraries in the JAX ecosystem (jaxstack.ai). There you can find libraries that integrate well with Metrax, as well as additional content about building machine learning models.



Source link

Leave a Reply

Your email address will not be published. Required fields are marked *