|

Step by Step Guide to Build and Compare FedAvg and FedProx Federated Learning on Non-IID CIFAR-10 with NVIDIA FLARE

In this tutorial, we construct a sophisticated federated studying experiment with NVIDIA FLARE. We evaluate FedAvg and FedProx on a non-IID CIFAR-10 setup, the place consumer knowledge is cut up utilizing a Dirichlet distribution to simulate practical label imbalance throughout federated websites. We use the NVFlare Job API to outline and launch federated jobs, whereas the Client API handles native coaching, mannequin trade, and communication between simulated shoppers and the server. Finally, we run each algorithms on the identical partitioned dataset and visualize how their international mannequin accuracy evolves throughout communication rounds.

!pip set up -q "nvflare>=2.5" torch torchvision matplotlib
import os, glob, shutil, sys
import numpy as np
import torch, torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
NUM_SITES    = 3
NUM_ROUNDS   = 5
LOCAL_EPOCHS = 1
ALPHA        = 0.3
MAX_SAMPLES  = 4000
BATCH_SIZE   = 64
LR           = 0.01
DATA_ROOT    = "/tmp/nvflare/knowledge"
RESULTS_DIR  = "/tmp/nvflare/outcomes"
os.makedirs(DATA_ROOT, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
torchvision.datasets.CIFAR10(root=DATA_ROOT, practice=True,  obtain=True)
torchvision.datasets.CIFAR10(root=DATA_ROOT, practice=False, obtain=True)

We set up the required libraries and import the principle packages wanted for the federated studying experiment. We outline the experiment settings, together with the variety of shoppers, coaching rounds, batch dimension, studying fee, and non-IID Dirichlet alpha worth. We additionally create the required knowledge and outcomes folders, then obtain CIFAR-10 as soon as, so that each one simulated shoppers can reuse the identical dataset safely.

CLIENT_SCRIPT = r'''
import argparse, os, csv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.practical as F
from torch.utils.knowledge import DataLoader, Subset
import torchvision
import torchvision.transforms as T
import nvflare.consumer as flare
class Net(nn.Module):
   """Small CNN for CIFAR-10 (no batchnorm -> clear state_dict for FedAvg)."""
   def __init__(self):
       tremendous().__init__()
       self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
       self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
       self.pool  = nn.MaxPool2d(2, 2)
       self.fc1   = nn.Linear(64 * 8 * 8, 128)
       self.fc2   = nn.Linear(128, 10)
   def ahead(self, x):
       x = self.pool(F.relu(self.conv1(x)))
       x = self.pool(F.relu(self.conv2(x)))
       x = x.flatten(1)
       x = F.relu(self.fc1(x))
       return self.fc2(x)
def dirichlet_partition(labels, num_sites, alpha, seed=42):
   """Deterministic non-IID label-skew partition. All consumer processes use the
   identical seed, so that they independently agree on the identical international cut up."""
   rng = np.random.default_rng(seed)
   num_classes = int(labels.max()) + 1
   site_idx = [[] for _ in vary(num_sites)]
   for c in vary(num_classes):
       idx_c = np.the place(labels == c)[0]
       rng.shuffle(idx_c)
       props = rng.dirichlet([alpha] * num_sites)
       cuts  = (np.cumsum(props) * len(idx_c)).astype(int)[:-1]
       for s, half in enumerate(np.cut up(idx_c, cuts)):
           site_idx[s].prolong(half.tolist())
   return [np.array(s) for s in site_idx]
@torch.no_grad()
def consider(mannequin, loader, machine):
   mannequin.eval()
   appropriate = whole = 0
   for x, y in loader:
       x, y = x.to(machine), y.to(machine)
       pred = mannequin(x).argmax(1)
       appropriate += (pred == y).sum().merchandise()
       whole   += y.dimension(0)
   return appropriate / whole
'''

We create the client-side coaching script as a separate Python file so NVFlare can import and run it inside every simulated consumer course of. We outline a small CNN mannequin for CIFAR-10 classification and add a deterministic Dirichlet partitioning operate to create non-IID consumer datasets. We additionally outline an analysis operate that measures the worldwide mannequin’s accuracy on the shared CIFAR-10 take a look at set.

CLIENT_SCRIPT += r'''
def principal():
   p = argparse.ArgumentParser()
   p.add_argument("--num_sites", sort=int, default=3)
   p.add_argument("--alpha", sort=float, default=0.3)
   p.add_argument("--local_epochs", sort=int, default=1)
   p.add_argument("--mu", sort=float, default=0.0)
   p.add_argument("--max_samples", sort=int, default=4000)
   p.add_argument("--batch_size", sort=int, default=64)
   p.add_argument("--lr", sort=float, default=0.01)
   p.add_argument("--data_root", sort=str, default="/tmp/nvflare/knowledge")
   p.add_argument("--results_dir", sort=str, default="/tmp/nvflare/outcomes")
   p.add_argument("--tag", sort=str, default="fedavg")
   args = p.parse_args()
   machine = "cuda" if torch.cuda.is_available() else "cpu"
   tf = T.Compose([T.ToTensor(),
                   T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
   train_set = torchvision.datasets.CIFAR10(args.data_root, practice=True,  obtain=False, remodel=tf)
   test_set  = torchvision.datasets.CIFAR10(args.data_root, practice=False, obtain=False, remodel=tf)
   flare.init()
   site_name = flare.get_site_name()
   site_id   = int(site_name.cut up("-")[-1]) - 1
   labels   = np.array(train_set.targets)
   my_idx   = dirichlet_partition(labels, args.num_sites, args.alpha)[site_id]
   if len(my_idx) > args.max_samples:
       my_idx = my_idx[:args.max_samples]
   train_loader = DataLoader(Subset(train_set, my_idx), batch_size=args.batch_size, shuffle=True)
   test_loader  = DataLoader(test_set, batch_size=512, shuffle=False)
   print(f"[{site_name}] mu={args.mu}  native samples={len(my_idx)}", flush=True)
   mannequin     = Net().to(machine)
   optimizer = torch.optim.SGD(mannequin.parameters(), lr=args.lr, momentum=0.9)
   criterion = nn.CrossEntropyLoss()
   whereas flare.is_running():
       input_model = flare.obtain()
       rnd = input_model.current_round
       global_state = {ok: torch.as_tensor(v) for ok, v in input_model.params.gadgets()}
       mannequin.load_state_dict(global_state)
       acc = consider(mannequin, test_loader, machine)
       if site_id == 0:
           with open(os.path.be a part of(args.results_dir, args.tag + ".csv"), "a", newline="") as f:
               csv.author(f).writerow([rnd, acc])
       print(f"[{site_name}] spherical {rnd}: international take a look at acc = {acc:.4f}", flush=True)
       global_w = [w.detach().clone() for w in model.parameters()]
       mannequin.practice()
       steps = 0
       for _ in vary(args.local_epochs):
           for x, y in train_loader:
               x, y = x.to(machine), y.to(machine)
               optimizer.zero_grad()
               loss = criterion(mannequin(x), y)
               if args.mu > 0:
                   prox = sum(((w - g) ** 2).sum() for w, g in zip(mannequin.parameters(), global_w))
                   loss = loss + (args.mu / 2.0) * prox
               loss.backward()
               optimizer.step()
               steps += 1
       out = flare.FLModel(
           params={ok: v.cpu().numpy() for ok, v in mannequin.state_dict().gadgets()},
           metrics={"test_accuracy": acc},
           meta={"NUM_STEPS_CURRENT_ROUND": steps},
       )
       flare.ship(out)
if __name__ == "__main__":
   principal()
'''
with open("client_train.py", "w") as f:
   f.write(CLIENT_SCRIPT)
sys.path.insert(0, os.getcwd())
from client_train import Net

We construct the principle consumer coaching workflow that runs inside each NVFlare website. We initialize the NVFlare Client API, establish the present consumer website, load that consumer’s non-IID knowledge shard, and put together the native mannequin, optimizer, and loss operate. We then obtain international weights from the server, consider the mannequin, practice domestically with both FedAvg or FedProx, and ship the up to date mannequin again for aggregation.

from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob
from nvflare.job_config.script_runner import ScriptRunner
def run_experiment(tag, mu):
   csv_path = os.path.be a part of(RESULTS_DIR, tag + ".csv")
   if os.path.exists(csv_path):
       os.take away(csv_path)
   workspace = f"/tmp/nvflare/{tag}_ws"
   shutil.rmtree(workspace, ignore_errors=True)
   job = FedAvgJob(
       identify=f"cifar10_{tag}",
       n_clients=NUM_SITES,
       num_rounds=NUM_ROUNDS,
       initial_model=Net(),
   )
   args = (f"--num_sites {NUM_SITES} --alpha {ALPHA} --local_epochs {LOCAL_EPOCHS} "
           f"--mu {mu} --max_samples {MAX_SAMPLES} --batch_size {BATCH_SIZE} --lr {LR} "
           f"--data_root {DATA_ROOT} --results_dir {RESULTS_DIR} --tag {tag}")
   for i in vary(NUM_SITES):
       job.to(ScriptRunner(script="client_train.py", script_args=args),
              goal=f"site-{i+1}")
   gpu = "0" if torch.cuda.is_available() else None
   print(f"n===== Running {tag} (mu={mu}) on {'GPU' if gpu else 'CPU'} =====")
   job.simulator_run(workspace, gpu=gpu)
   return workspace
run_experiment("fedavg",  mu=0.0)
run_experiment("fedprox", mu=0.1)

We outline the server-side federated experiment utilizing the NVFlare Job API. We create a FedAvg job, connect the consumer coaching script to every simulated website, and go all experimental arguments, resembling alpha, studying fee, native epochs, and the FedProx mu worth. We then run two experiments on the identical non-IID setup: one with normal FedAvg and one other with FedProx.

def load_curve(tag):
   rounds, accs = [], []
   with open(os.path.be a part of(RESULTS_DIR, tag + ".csv")) as f:
       for line in f:
           r, a = line.strip().cut up(",")
           rounds.append(int(r)); accs.append(float(a))
   order = np.argsort(rounds)
   return np.array(rounds)[order], np.array(accs)[order]
plt.determine(figsize=(7, 4.5))
for tag, label in [("fedavg", "FedAvg"), ("fedprox", "FedProx (mu=0.1)")]:
   r, a = load_curve(tag)
   plt.plot(r, a, marker="o", label=label)
plt.title(f"Global mannequin accuracy on non-IID CIFAR-10 (Dirichlet alpha={ALPHA}, {NUM_SITES} shoppers)")
plt.xlabel("Federated spherical"); plt.ylabel("Global take a look at accuracy")
plt.grid(alpha=0.3); plt.legend(); plt.tight_layout(); plt.present()
for tag in ["fedavg", "fedprox"]:
   attempt:
       ckpt = glob.glob(f"/tmp/nvflare/{tag}_ws/**/FL_global_model.pt", recursive=True)[0]
       obj = torch.load(ckpt, map_location="cpu", weights_only=False)
       state = obj.get("mannequin", obj) if isinstance(obj, dict) else obj
       print(f"[{tag}] ultimate international mannequin checkpoint: {ckpt}")
   besides Exception as e:
       print(f"[{tag}] couldn't find ultimate checkpoint ({e})")

We load the saved accuracy logs for each FedAvg and FedProx and kind them by communication spherical. We plot the worldwide take a look at accuracy curves to visually evaluate how the 2 strategies carry out on non-IID CIFAR-10 knowledge. We additionally attempt to find the ultimate aggregated international mannequin checkpoint from every NVFlare workspace for elective inspection or reuse.

In conclusion, we accomplished an end-to-end federated studying workflow utilizing NVIDIA FLARE, from getting ready non-IID CIFAR-10 consumer shards to coaching and evaluating FedAvg and FedProx. We noticed how the Job API helps us configure server-side orchestration, whereas the Client API lets every simulated consumer obtain international weights, practice domestically, and ship up to date fashions again for aggregation. We additionally tracked international take a look at accuracy throughout rounds and plotted studying curves to perceive how FedProx differs from normal FedAvg beneath heterogeneous knowledge circumstances.


Check out the Full Codes hereAlso, be at liberty to observe us on Twitter and don’t overlook to be a part of our 150k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.

Need to associate with us for selling your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar and so on.? Connect with us

The put up Step by Step Guide to Build and Compare FedAvg and FedProx Federated Learning on Non-IID CIFAR-10 with NVIDIA FLARE appeared first on MarkTechPost.

Similar Posts