Federated finetuning of a ViT#

View on GitHub

This example shows how to use Flower’s Simulation Engine to federate the finetuning of a Vision Transformer (ViT-Base-16) that has been pretrained on ImageNet. To keep things simple we’ll be finetuning it to Oxford Flower-102 datasset, creating 20 partitions using Flower Datasets. We’ll be finetuning just the exit head of the ViT, this means that the training is not that costly and each client requires just ~1GB of VRAM (for a batch size of 32 images).

Running the example#

If you haven’t cloned the Flower repository already you might want to clone code example and discard the rest. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/vit-finetune . && rm -rf flower && cd vit-finetune

This will create a new directory called vit-finetune containing the following files:

-- README.md         <- Your're reading this right now
-- main.py           <- Main file that launches the simulation
-- client.py         <- Contains Flower client code and ClientApp
-- server.py         <- Contains Flower server code and ServerApp
-- model.py          <- Defines model and train/eval functions
-- dataset.py        <- Downloads, partitions and processes dataset
-- pyproject.toml    <- Example dependencies, installable using Poetry
-- requirements.txt  <- Example dependencies, installable using pip

Installing Dependencies#

Project dependencies (such as torch and flwr) are defined in pyproject.toml and requirements.txt. We recommend Poetry to install those dependencies and manage your virtual environment (Poetry installation) or pip, but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.

Poetry#

poetry install
poetry shell

pip#

With an activated environemnt, install the dependencies for this example:

pip install -r requirements.txt

Run with start_simulation()#

Running the example is quite straightforward. You can control the number of rounds --num-rounds (which defaults to 20).

python main.py

Running the example as-is on an RTX 3090Ti should take ~15s/round running 5 clients in parallel (plus the global model during centralized evaluation stages) in a single GPU. Note that more clients could fit in VRAM, but since the GPU utilization is high (99%-100%) we are probably better off not doing that (at least in this case).

You can adjust the client_resources passed to start_simulation() so more/less clients run at the same time in the GPU. Take a look at the Documentation for more details on how you can customise your simulation.

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.07             Driver Version: 535.161.07   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3090 Ti     Off | 00000000:0B:00.0 Off |                  Off |
| 44%   74C    P2             441W / 450W |   7266MiB / 24564MiB |    100%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A    173812      C   python                                     1966MiB |
|    0   N/A  N/A    174510      C   ray::ClientAppActor.run                    1056MiB |
|    0   N/A  N/A    174512      C   ray::ClientAppActor.run                    1056MiB |
|    0   N/A  N/A    174513      C   ray::ClientAppActor.run                    1056MiB |
|    0   N/A  N/A    174514      C   ray::ClientAppActor.run                    1056MiB |
|    0   N/A  N/A    174516      C   ray::ClientAppActor.run                    1056MiB |
+---------------------------------------------------------------------------------------+

Run with Flower Next (preview)#

flower-simulation \
    --client-app=client:app \
    --server-app=server:app \
    --num-supernodes=20 \
    --backend-config='{"client_resources": {"num_cpus":4, "num_gpus":0.25}}'