Federated Learning with Flower and PyTorch: A Step-by-Step Tutorial
This article provides a comprehensive guide to federated learning using the Flower framework and PyTorch. It will walk you through the process of building a federated learning system, explaining each step in detail.
Introduction to Federated Learning
Federated learning is a machine learning approach that enables training models on decentralized data residing on multiple devices or organizations. This is particularly useful when data privacy is a concern, as the data never leaves the device.
In this tutorial, we'll use Flower and PyTorch to build a federated learning system. We'll start by using PyTorch for the model training pipeline and data loading. Then, we'll federate the PyTorch-based pipeline using Flower.
Step 0: Preparation
Before diving into the code, let's ensure we have all the necessary tools and libraries installed.
Installing Dependencies
We need to install the necessary packages for PyTorch (torch and torchvision) and Flower (flwr):
Read also: Comprehensive Flower Arranging
pip install torch torchvision flwrGPU Acceleration
It's possible to switch to a runtime that has GPU acceleration enabled (on Google Colab: Runtime > Change runtime type > Hardware accelerator: GPU > Save). If the runtime has GPU acceleration enabled, you should see the output "Training on cuda," otherwise, it'll say "Training on cpu."
Loading the Data
Federated learning can be applied to many different types of tasks across different domains. In this tutorial, we introduce federated learning by training a simple convolutional neural network (CNN) on the popular CIFAR-10 dataset. CIFAR-10 can be used to train image classifiers that distinguish between images from ten different classes.
We simulate having multiple datasets from multiple organizations (also called the "cross-silo" setting in federated learning) by splitting the original CIFAR-10 dataset into multiple partitions. Each partition will represent the data from a single organization. We're doing this purely for experimentation purposes; in the real world, there's no need for data splitting because each organization already has its own data (so the data is naturally partitioned).
Each organization will act as a client in the federated learning system. So, having ten organizations participate in a federation means having ten clients connected to the federated learning server.
Let's now load the CIFAR-10 training and test set, partition them into ten smaller datasets (each split into training and validation set), and wrap the resulting partitions by creating a PyTorch DataLoader for each of them:
Read also: Comprehensive Review: Forest Vista Care
We now have a list of ten training sets and ten validation sets (trainloaders and valloaders) representing the data of ten different organizations. Each trainloader/valloader pair contains 4500 training examples and 500 validation examples. There's also a single testloader (we did not split the test set). Again, this is only necessary for building research or educational systems; actual federated learning systems have their data naturally distributed across multiple partitions.
Step 1: Centralized Training with PyTorch
Next, we're going to use PyTorch to define a simple convolutional neural network. This introduction assumes basic familiarity with PyTorch, so it doesn't cover the PyTorch-related aspects in full detail. If you want to dive deeper into PyTorch, we recommend DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ.
Defining the Model
We use the simple CNN described in the PyTorch tutorial:
import torchimport torch.nn as nnimport torch.nn.functional as Fclass Net(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = torch.flatten(x, 1) # flatten all dimensions except batch x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return xTraining and Testing Functions
Let's continue with the usual training and test functions:
import torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import DataLoaderfrom tqdm import tqdmDEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def train(net, trainloader, epochs): """Train the network on the training set.""" criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters()) net.train() for _ in range(epochs): for images, labels in tqdm(trainloader): images, labels = images.to(DEVICE), labels.to(DEVICE) optimizer.zero_grad() loss = criterion(net(images), labels) loss.backward() optimizer.step()def test(net, testloader): """Validate the network on the entire test set.""" criterion = nn.CrossEntropyLoss() correct, loss = 0, 0.0 net.eval() with torch.no_grad(): for images, labels in tqdm(testloader): images, labels = images.to(DEVICE), labels.to(DEVICE) outputs = net(images) loss += criterion(outputs, labels).item() _, predicted = torch.max(outputs.data, 1) correct += (predicted == labels).sum().item() accuracy = correct / len(testloader.dataset) return loss, accuracyTraining the Model
We now have all the basic building blocks we need: a dataset, a model, a training function, and a test function. Let's put them together to train the model on the dataset of one of our organizations (trainloaders[0]). This simulates the reality of most machine learning projects today: each organization has its own data and trains models only on this internal data:
Read also: Federated Hermes Internship Program
net = Net().to(DEVICE)trainloader = trainloaders[0]valloader = valloaders[0]train(net, trainloader, epochs=5)loss, accuracy = test(net, valloader)print(f"Loss: {loss:.2f}, Accuracy: {accuracy:.2f}")Training the simple CNN on our CIFAR-10 split for 5 epochs should result in a test set accuracy of about 41%, which is not good, but at the same time, it doesn't really matter for the purposes of this tutorial. The intent was just to show a simplistic centralized training pipeline that sets the stage for what comes next - federated learning!
Step 2: Federated Learning with Flower
Step 1 demonstrated a simple centralized training pipeline. All data was in one place (i.e., a single trainloader and a single valloader). Next, we'll simulate a situation where we have multiple datasets in multiple organizations and where we train a model over these organizations using federated learning.
Updating Model Parameters
In federated learning, the server sends the global model parameters to the client, and the client updates the local model with the parameters received from the server. It then trains the model on the local data (which changes the model parameters locally) and sends the updated/changed model parameters back to the server (or, alternatively, it sends just the gradients back to the server, not the full model parameters).
We need two helper functions to update the local model with parameters received from the server and to get the updated model parameters from the local model: set_parameters and get_parameters. The following two functions do just that for the PyTorch model above.
def get_parameters(net): return [val.cpu().numpy() for _, val in net.state_dict().items()]def set_parameters(net, parameters): params_dict = zip(net.state_dict().keys(), parameters) state_dict = dict((k, torch.Tensor(v)) for k, v in params_dict) net.load_state_dict(state_dict, strict=True)The details of how this works are not really important here (feel free to consult the PyTorch documentation if you want to learn more). In essence, we use state_dict to access PyTorch model parameter tensors. The parameter tensors are then converted to/from a list of NumPy ndarray's (which Flower knows how to serialize/deserialize):
Implementing a Flower Client
With that out of the way, let's move on to the interesting part. Federated learning systems consist of a server and multiple clients. In Flower, we create clients by implementing subclasses of flwr.client.Client or flwr.client.NumPyClient. We use NumPyClient in this tutorial because it is easier to implement and requires us to write less boilerplate.
To implement the Flower client, we create a subclass of flwr.client.NumPyClient and implement the three methods get_parameters, fit, and evaluate:
get_parameters: Return the current local model parametersfit: Receive model parameters from the server, train the model parameters on the local data, and return the (updated) model parameters to the serverevaluate: Receive model parameters from the server, evaluate the model parameters on the local data, and return the evaluation result to the server
We mentioned that our clients will use the previously defined PyTorch components for model training and evaluation. Let's see a simple Flower client implementation that brings everything together:
import flwr as flimport numpy as npfrom typing import Dictclass FlowerClient(fl.client.NumPyClient): def __init__(self, net, trainloader, valloader): self.net = net self.trainloader = trainloader self.valloader = valloader def get_parameters(self, config): return get_parameters(self.net) def fit(self, parameters, config): set_parameters(self.net, parameters) train(self.net, self.trainloader, epochs=1) return get_parameters(self.net), len(self.trainloader), {} def evaluate(self, parameters, config): set_parameters(self.net, parameters) loss, accuracy = test(self.net, self.valloader) return loss, len(self.valloader), {"accuracy": float(accuracy)}Our class FlowerClient defines how local training/evaluation will be performed and allows Flower to call the local training/evaluation through fit and evaluate. Each instance of FlowerClient represents a single client in our federated learning system. Federated learning systems have multiple clients (otherwise, there's not much to federate), so each client will be represented by its own instance of FlowerClient. If we have, for example, three clients in our workload, then we'd have three instances of FlowerClient. Flower calls FlowerClient.fit on the respective instance when the server selects a particular client for training (and FlowerClient.evaluate for evaluation).
Using the Virtual Client Engine
In this notebook, we want to simulate a federated learning system with 10 clients on a single machine. This means that the server and all 10 clients will live on a single machine and share resources such as CPU, GPU, and memory. Having 10 clients would mean having 10 instances of FlowerClient in memory. Doing this on a single machine can quickly exhaust the available memory resources, even if only a subset of these clients participates in a single round of federated learning.
In addition to the regular capabilities where server and clients run on multiple machines, Flower, therefore, provides special simulation capabilities that create FlowerClient instances only when they are actually necessary for training or evaluation. To enable the Flower framework to create clients when necessary, we need to implement a function called client_fn that creates a FlowerClient instance on demand. Flower calls client_fn whenever it needs an instance of one particular client to call fit or evaluate (those instances are usually discarded after use, so they should not keep any local state). Clients are identified by a client ID, or short cid. The cid can be used, for example, to load different local data partitions for different clients, as can be seen below:
def client_fn(cid: str) -> fl.client.Client: """Create a Flower client representing a single organization.""" # Load model net = Net().to(DEVICE) # Load data (CIFAR-10) # Note: each client gets a different trainloader/valloader, so each client # will train and evaluate on their own unique data partition trainloader = trainloaders[int(cid)] valloader = valloaders[int(cid)] # Create a Flower client return FlowerClient(net, trainloader, valloader)Starting the Training
We now have the class FlowerClient which defines client-side training/evaluation and client_fn which allows Flower to create FlowerClient instances whenever it needs to call fit or evaluate on one particular client. The last step is to start the actual simulation using flwr.simulation.start_simulation.
The function start_simulation accepts a number of arguments, amongst them the client_fn used to create FlowerClient instances, the number of clients to simulate (num_clients), the number of federated learning rounds (num_rounds), and the strategy. The strategy encapsulates the federated learning approach/algorithm, for example, Federated Averaging (FedAvg).
Flower has a number of built-in strategies, but we can also use our own strategy implementations to customize nearly all aspects of the federated learning approach. For this example, we use the built-in FedAvg implementation and customize it using a few basic parameters. The last step is the actual call to start_simulation which - you guessed it - starts the simulation:
import flwr as flstrategy = fl.server.strategy.FedAvg( fraction_fit=1.0, fraction_evaluate=0.5, min_fit_clients=10, min_evaluate_clients=5, min_available_clients=10,)fl.simulation.start_simulation( client_fn=client_fn, num_clients=10, config=fl.server.ServerConfig(num_rounds=5), strategy=strategy,)Behind the Scenes
So how does this work? How does Flower execute this simulation?
When we call start_simulation, we tell Flower that there are 10 clients (num_clients=10). Flower then goes ahead an asks the FedAvg strategy to select clients. FedAvg knows that it should select 100% of the available clients (fraction_fit=1.0), so it goes ahead and selects 10 random clients (i.e., 100% of 10).
Flower then asks the selected 10 clients to train the model. When the server receives the model parameter updates from the clients, it hands those updates over to the strategy (FedAvg) for aggregation. The strategy aggregates those updates and returns the new global model, which then gets used in the next round of federated learning.
Aggregating Custom Metrics
Flower can automatically aggregate losses returned by individual clients, but it cannot do the same for metrics in the generic metrics dictionary (the one with the accuracy key). Metrics dictionaries can contain very different kinds of metrics and even key/value pairs that are not metrics at all, so the framework does not (and can not) know how to handle these automatically.
As users, we need to tell the framework how to handle/aggregate these custom metrics, and we do so by passing metric aggregation functions to the strategy. The strategy will then call these functions whenever it receives fit or evaluate metrics from clients. The two possible functions are fit_metrics_aggregation_fn and evaluate_metrics_aggregation_fn.
Let's create a simple weighted averaging function to aggregate the accuracy metric we return from evaluate:
from typing import List, Tupledef weighted_average(metrics: List[Tuple[int, Dict]]) -> Dict: # Multiply accuracy of each client by number of examples used accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] examples = [num_examples for num_examples, _ in metrics] # Aggregate and return custom metric (weighted average) return {"accuracy": sum(accuracies) / sum(examples)}The only thing left to do is to tell the strategy to call this function whenever it receives evaluation metric dictionaries from the clients:
import flwr as flstrategy = fl.server.strategy.FedAvg( fraction_fit=1.0, fraction_evaluate=0.5, min_fit_clients=10, min_evaluate_clients=5, min_available_clients=10, evaluate_metrics_aggregation_fn=weighted_average,)fl.simulation.start_simulation( client_fn=client_fn, num_clients=10, config=fl.server.ServerConfig(num_rounds=5), strategy=strategy,)We now have a full system that performs federated training and federated evaluation. It uses the weighted_average function to aggregate custom evaluation metrics and calculates a single accuracy metric across all clients on the server side.
tags: #flower #federated #learning #tutorial

