Federated Learning without Refactoring Overhead with NVIDIA FLARE

Machine Learning


Federated learning (FL) is no longer a research curiosity. This is a practical response to the severe constraint that most valuable data cannot be moved. Regulatory boundaries, data sovereignty rules, and organizational risk tolerance routinely prevent centralized aggregation. On the other hand, sheer data gravity makes even permitted transfers slow, expensive, and vulnerable at scale.

The latest version of NVIDIA FLARE addresses this reality with a federated compute runtime that moves training logic to the data while leaving the raw data intact. In high-stakes environments, centrally aggregating data is often not possible or practical, so modern federated platforms must handle data separation., compliance, and Privacy-enhancing technology As a first class requirement.

It’s not the concept of FL that has historically slowed adoption, but the developer experience. Many projects die after a pilot if the path from “train local scripts” to “run jobs across federated sites” requires deep refactoring, new class hierarchies, or brittle configuration.

The evolution of the FLARE API aims to do just that. Eliminate refactoring overhead by splitting the work into two concrete steps that clearly correspond to how teams will actually build and ship ML systems.

  • Step 1 (Client API): Convert your existing local training script to a federated client in 5-6 lines of code without changing your training loop structure.
  • Step 2 (Job Recipe): Select an FL workflow and bind it to your client training script, swapping only the execution environment to run the same job across simulation, PoC, and production environments.

“No data copy” as a system requirement

In regulated and sensitive environments, “centralizing data sets” is becoming increasingly impossible. practical federated computing The platform must support:

  • No data copy: The data remains local, and only model updates (or equivalent signals) move it.
  • Compliance attitude: Deployment and governance controls that support sovereignty and audit requirements.
  • Technologies that enhance privacy: Multiple layers of defense (examples include homomorphic encryption, differential privacy, and confidential computing).

Refactoring cliff: Why FL projects stall

Typically, the team will hit one of two cliffs after the pilot.

  • Code cliff: Converting working PyTorch/TensorFlow/Lightning training to FL may require invasive restructuring such as new abstractions, messaging glue, and framework-specific scaffolding.
  • Life cycle cliff: Even if the simulation is working, moving to PoC and production will trigger job redefinition, reconfiguration, and rewriting with environment-specific branches.

FLARE flattens both cliffs by standardizing the workflow into two steps.

  1. Coordinating scripts (Client API)
  2. Execute as a portable job (job recipe)

The intended experience is to explicitly combine these so that you can quickly move from zero to operational federation jobs.

Step 1: Convert local training scripts to federated clients (client API)

Who it’s for: Practitioners and ML engineers who want to use existing training code and minimize variance.

The mental model is intentionally simple.

  1. Initialize the client runtime
  2. Loop while running job
  3. Receive the current global model
  4. Train locally (code)
  5. Send back updated weights and metrics

FLARE’s client API is designed to minimize code changes and avoids forcing heavy “doer/learner” inheritance. Use FLModel structures or simple data exchange to communicate with the runtime.

Example 1a: Converting PyTorch to FLARE

Below is a specific pattern that can be applied to many scripts. Key touchpoints include: flare.init(), flare.receive()load the model weights, and flare.send() Weights and metrics have been updated.

The local training code is shown on the left and the federated version is shown on the right. import, flare.init(), receive(), send().

train.py

# train.py

import torch
import torchvision
import torchvision.transforms as transforms

from model import Net

batch_size = 4
epochs = 1
lr = 0.01
model = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
transform = transforms.Compose(
   [
       transforms.ToTensor(),
       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
   ]
)

train_dataset = torchvision.datasets.CIFAR10(
   root="/tmp/data/cifar10", transform=transform, download=True, train=True
)

trainloader = torch.utils.data.DataLoader(
   train_dataset, batch_size=batch_size, shuffle=True
)

model.to(device)

for epoch in range(epochs):
   running_loss = 0.0

   for i, batch in enumerate(trainloader):
       images, labels = batch[0].to(device), batch[1].to(device)

       optimizer.zero_grad()

       predictions = model(images)
       cost = loss(predictions, labels)
       cost.backward()
       optimizer.step()

       running_loss += cost.cpu().detach().numpy() / batch_size

       if i % 3000 == 2999:
           print(
               f"Epoch: {epoch + 1}/{epochs}, batch: {i + 1}, Loss: {running_loss / 3000}"
           )
           running_loss = 0.0

   print(
       f"Epoch: {epoch + 1}/{epochs}, batch: {i + 1}, Loss: {running_loss / (i + 1)}"
   )

print("Finished Training")

torch.save(model.state_dict(), "./cifar_net.pth")

client.py

# client.py

# 1. Import client API
import nvflare.client as flare
import torch
import torchvision
import torchvision.transforms as transforms

from model import Net

batch_size = 4
epochs = 1
lr = 0.01
model = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
transform = transforms.Compose(
   [
       transforms.ToTensor(),
       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
   ]
)

train_dataset = torchvision.datasets.CIFAR10(
   root="/tmp/data/cifar10", transform=transform, download=True, train=True
)

trainloader = torch.utils.data.DataLoader(
   train_dataset, batch_size=batch_size, shuffle=True
)

# 2. Initialize FLARE
flare.init()

# At each round while FLARE is running
while flare.is_running():
   # 3. Receive the global model
   input_model = flare.receive()

   # 4. Load global model
   model.load_state_dict(input_model.params)
   model.to(device)

   for epoch in range(epochs):
       running_loss = 0.0

       for i, batch in enumerate(trainloader):
           images, labels = batch[0].to(device), batch[1].to(device)

           optimizer.zero_grad()

           predictions = model(images)
           cost = loss(predictions, labels)
           cost.backward()
           optimizer.step()

           running_loss += cost.cpu().detach().numpy() / batch_size

           if i % 3000 == 2999:
               print(
                   f"Epoch: {epoch + 1}/{epochs}, batch: {i + 1}, Loss: {running_loss / 3000}"
               )
               running_loss = 0.0

       print(
           f"Epoch: {epoch + 1}/{epochs}, batch: {i + 1}, Loss: {running_loss / (i + 1)}"
       )

   print("Finished Training")

   torch.save(model.state_dict(), "./cifar_net.pth")

   # 5. Send back the updated model
   output_model = flare.FLModel(
       params=model.cpu().state_dict(),
       meta={"NUM_STEPS_CURRENT_ROUND": len(trainloader) * epochs},
   )
   flare.send(output_model)

Example 1b: PyTorch Lightning client Lightning integration remains the same

Lightning integration remains intact The intent is to receive, train, and send updates to the global model, but expose it in a way that’s appropriate for Lightning. That means importing the Lightning client adapter and patching the trainer.

The typical flow is import, patch, (optional) validation, and train as usual.

# lightning_client.py
import pytorch_lightning as pl
from pytorch_lightning import Trainer

import nvflare.client.lightning as flare  # Lightning Client API  

from model import LitNet
from data import CIFAR10DataModule
def main():
   model = LitNet()
   dm = CIFAR10DataModule()

   trainer = Trainer(max_epochs=1, accelerator="gpu", devices=1)

   # Patch trainer to participate in FL
   flare.patch(trainer)

   while flare.is_running():
       # Optional: validate current global model (useful for server-side selection flows)
       trainer.validate(model, datamodule=dm)

       # Train starting from received global model (handled internally after patch)
       trainer.fit(model, datamodule=dm)


if __name__ == "__main__":
   main()

Takeaway: Lightning users don’t need to drop into custom federated messaging. You can correctly participate in FL rounds while maintaining the trainer abstraction.

Step 2: Package the federation job and run it anywhere (job recipe)

Target audience: Data scientists and application teams that require code-first job definitions that are stable across environments.

After completing step 1, you have a complete federated client script. Step 2 makes it a federated job that can run repeatedly and move through its lifecycle cleanly.

Job recipes are designed to replace JSON-based job configurations with Python-based job definitions.

  • Code first: Define complete FL jobs in Python instead of complex configuration files
  • Write once, run anywhere: The same recipe is executed in simulator, PoC or production environment
  • Speed ​​of implementation: Move from experiment to deployment without changing code structure

Example 2a: Run a FedAvg recipe in simulation

An important connection is that the recipe references the client training script you created in step 1, e.g. train_script="client.py"), run within the environment.

# job.py
from nvflare.app_common.workflows.job import FedAvgRecipe
from nvflare.job_config import SimEnv  # exact import path can vary by NVFlare version

from model import SimpleNetwork

def main():
   n_clients = 3
   num_rounds = 5
   batch_size = 32

   recipe = FedAvgRecipe(
       name="hello-pt",
       min_clients=n_clients,
       num_rounds=num_rounds,
       model=SimpleNetwork(),
       train_script="client.py",  # <-- Step A script
       train_args=f"--batch_size {batch_size} --epochs 1",
   )

   env = SimEnv(num_clients=n_clients, num_threads=n_clients)
   recipe.execute(env=env)

if __name__ == "__main__":
   main()

This is the actual “write once” concept. Once the recipe correctly references the client script, the rest is a matter of execution.

Example 2b: Swap environments to move from simulation to real world.

Job recipes formalize incremental workflows by exchanging execution environments.

  1. SimEnv (simulation): Easy development, quick debugging
  2. PocEnv (proof of concept): Local runtime, multi-process, realistic testing
  3. ProdEnv (production environment): Distributed deployment on a secure and scalable infrastructure

Start

  • Start with scripts you already trust.
  • Step 1: Add client API handshake (or patch Lightning Trainer).
  • Step 2: Wrap this into a job recipe and run it first in simulation, then PoC, then swap environments and run production.

flare in the news

FLARE is being leveraged in real-world deployments, from Eli Lilly TuneLab’s Federated Learning Platform (built using NVFlare by Rhino Federated Computing), to Taiwan MOHW’s National Healthcare Federated Learning Initiative, to Tri-labs’ (Sandia/LANL/LLNL) federated AI pilot across sensitive datasets.

Go further

Start with scripts you already trust. Add a minimal FLARE client handshake (receive → train → send). When you’re ready, scale from a single-node simulation to a multisite deployment.



Source link