Following up on a previous post on running federated learning on embedded devices such as Raspberry Pi, we now show how Flower can be used to federate machine learning training on Android devices.
The challenges for deploying Flower on Android devices are as follows:
- Training pipelines written in Python-based ML frameworks such as Tensorflow and PyTorch are not easily portable to Android devices, which use Java as one of their primary programming languages.
- An alternate could be to re-write the ML training pipeline using Java based frameworks such as Deeplearning4j. However, this means we need to learn a completely new ML framework in order to federate our workloads, which goes against Flower's mission of making FL deployment easy!
TensorFlow Lite meets Flower
In this post, we will see how we can use TensorFlow Lite (TFLite) to enable federated learning on Android devices. TFLite based development has three key benefits:
- As we will see below, we can reuse much of TensorFlow's syntax for defining our model architecture. This makes it easy to take our existing model architectures written in TF and deploy them in a federated setting.
- Even though TFLite was initially designed for on-device inference, it already has well-documented capabilities to do on-device model personalization. We will reuse this feature to enable local training of models on Android devices.
- TFLite recently released an example of full-fledged on-device ML training. This is currently not incorporated in Flower, but is on the roadmap.
Setup the model definitions
Let's start. TF Model Personalization requires defining two architectures: a Base model and a Head model. Simply put, the Base model is a pre-trained feature extractor (e.g., ResNet50 trained on ImageNet) which is not updated during on-device training. The Head model is like a task-specific classifier which is randomly initialized and trained on the local data. In this example, we will not use any pre-trained Base model and instead train the entire model using FL.
To this end, let us define our Base model as just an Identity layer which does not do any modifications on the input data. For the Head model, we take an existing model architecture defined using the TF Sequential API.
base = tf.keras.Sequential(
[tf.keras.Input(shape=(32, 32, 3)), tf.keras.layers.Lambda(lambda x: x)]
)
base.compile(loss="categorical_crossentropy", optimizer="sgd")
base.save("identity_model", save_format="tf")
head = tf.keras.Sequential(
[
tf.keras.Input(shape=(32, 32, 3)),
tf.keras.layers.Conv2D(6, 5, activation="relu"),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Conv2D(16, 5, activation="relu"),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=120, activation="relu"),
tf.keras.layers.Dense(units=84, activation="relu"),
tf.keras.layers.Dense(units=10, activation="softmax"),
]
)
head.compile(loss="categorical_crossentropy", optimizer="sgd")
Next, we use the TFLite Transfer Converter to convert these model definitions into the .tflite format, which we will import into our Android application later.
base_path = bases.saved_model_base.SavedModelBase("identity_model")
converter = TFLiteTransferConverter(
10, base_path, heads.KerasModelHead(head), optimizers.SGD(1e-3), train_batch_size=32
)
converter.convert_and_save("tflite_model")
The full source code is available in the Flower GitHub repo at examples/android/tflite_convertor/convert_to_tflite.py.
Android Client
Let us move to the Android client. We first fork the TFLite Model Personalization app from here.
Initial Setup
- Go to the Assets directory of the Android project and put the .tflite models from earlier into it.
- For this tutorial, we will also copy the CIFAR10 training data into the Assets directory. In practice however, the training data (e.g., user images) will be stored on the external storage and can be read from there.
Both these steps are currently automated and happen during build time. Please see here for details.
- We will write some glue code for data loading and update the names of the class labels. See here for details.
GRPC Compilation
-
Add GRPC libraries for Java to the build.gradle file and copy the transport.proto file used by Flower in the src/main/proto directory of our Android project. See here for details.
-
Handle the GRPC communication between the server and Android client (e.g., FL instructions, serialized model weights). Please see here for details.
Flower interface:
Finally, we implement the three core methods that enable FL with Flower: fit, evaluate, and get_weights.
public FlowerClient(Context context) {
this.tlModel = new TransferLearningModelWrapper(context);
this.context = context;
}
public ByteBuffer[] getWeights() {
return tlModel.getParameters();
}
public Pair<ByteBuffer[], Integer> fit(ByteBuffer[] weights, int epochs) {
this.local_epochs = epochs;
tlModel.updateParameters(weights);
isTraining.close();
tlModel.train(this.local_epochs);
tlModel.enableTraining((epoch, loss) -> setLastLoss(epoch, loss));
isTraining.block();
return Pair.create(getWeights(), tlModel.getSize_Training());
}
public Pair<Pair<Float, Float>, Integer> evaluate(ByteBuffer[] weights) {
tlModel.updateParameters(weights);
tlModel.disableTraining();
return Pair.create(tlModel.calculateTestStatistics(), tlModel.getSize_Testing());
}
For details, please see here.
FL in action
- Let us first start our FL server. The server uses a modified strategy called FedAvgAndroid which supports the (de)-serialization of model parameters exchanged between the server and the clients.
python server.py
- Build an Android APK or use the one provided in the repo, and install it on an Android smartphone.
- For this tutorial, we have partitioned the CIFAR10 dataset into 10 IID partitions. See the dataset for details. We specify which partition should be loaded on this Android client and click on Load Dataset. You can try loading different data partitions on different clients and see the difference in global model's performance.
- Provide the IP and port number of the FL server and click Setup Connection Channel
- Once the channel is established, press Train Federated and it will start FL (assuming min_clients are available to start the training.)
Future extensions:
- Integrate full-fledged on-device TFLite training with Flower.
- Modularize the Android client code and abstract out the (de)-serialization methods.
- Did I hear you say "why not use C++ and Android NDK"? Well, that is certainly an option for C++ fans. Importantly, Flower's C++ SDK is in the works and it will be groundbreaking for running FL on IoT-scale devices.