AI@Edge — Federated Learning: Learn on Private and Partitioned Datasets
Authors: Supriyo Chakraborty, Wendy Chong, Utpal Mangla, Satish Sadagopan, Mudhakar Srivatsa, Mathews Thomas, Dinesh Verma, Shiqiang Wang
In our first blog we discussed some of the challenges in implementing AI/ML at the edge. In our second article we discussed a solution to one of the challenges — how large volumes of data should be transferred to a central location for model training. The solution we discussed was coresets — machine learning algorithms to succinctly summarize the data from the edge devices without losing data quality, while enabling real-time decisions with pre-trained AI models at the point of action, and at the speed of 5G.
However, is it essential to gather data at a centralized location for model training? In this blog we discuss how AI/ML models can learn from data across multiple edges, without sharing the raw data (thereby, offering higher levels of data privacy). Federated Learning algorithms form distributed cohorts of data and create a global model without revealing raw (personal) data. However, before we get into the details let us return to the ambulance use case, to revisit key challenges.
As covered in the previous article centered around the ambulance use case, there are various types of data in large volumes collected for analysis, such as the vital statistics of the patient, the status and location of the ambulance, the video feed from the ambulance to the hospital, etc. To get the best results from the models, it is best to use diverse geographical data that spans multiple edges:
- Not only from city roads, but also highways.
- Various weather conditions such as rain, snow and sunny days. Day and nighttime etc.
- Patient data of the illness for various demographics of age, gender, race, physical conditions, pre-existing medical conditions etc.
- Network data as the ambulance moves through different locations, and thus span multiple vRANs (Radio Access Networks), IP/transport layers and multiple MEC servers (Mobile Edge Computing).
Fusion of data across multiple geographical edges and patients, helps us train the models to improve the overall treatment provided to the patient in transit. However, there are some key points to be considered here:
- How do we ensure the security and privacy of the edge data? This data can be used to train the models at the edge, but private data should not be shared with others in its raw form.
- How can the models running at a given edge, be enhanced by the learnings at other edges and while ensuring the data is not shared? There has to be some level of collaboration between these edge nodes to make this possible.
- How to scale model training in a cloud and multi-edge environment? With some of the training will occur at the edge while others at the central cloud, how does one scale this mode of training while keeping the data secure?
- How to dynamically orchestrate resources across cloud and multi-edges to facilitate model training and scoring in a Telco network cloud? Such solutions should leverage the inherent strengths of 5G networks such as dynamic network slicing, improving the bandwidth or QoS of a slice, or migrating the ambulance traffic to another suitable slice.
Federated learning has recently emerged as a popular implementation for large-scale deep neural network training. It is formulated as a multi-round strategy in which the training of a neural network model is distributed between multiple agents. In each round, a random subset of agents, with local data and computational resources, is selected for training. The selected agents perform model training and share only the parameter (model weights) updates with a centralized parameter server, that facilitates aggregation of the updates. Motivated by privacy concerns, the server is designed to have no visibility into an agents’ local data and training process. The aggregation algorithm used is usually weighted averaging of the agent updates to generate the global model.
To make the most efficient use of the limited computation and communication resources in edge computing systems, researchers at IBM Research, Imperial College London, and Yale University have explored the design of efficient federated learning algorithms [1]. These algorithms track the resource variation and data diversity at the edge, based on which optimal parameter aggregation frequency, gradient sparsity, and model size are determined. The key idea is that model parameters only need to be aggregated by the server when necessary. In cases where different agents have similar data, the amount of exchanged information can be significantly less than cases where agents have diverse data. The optimal degree of information exchange also depends on the trade-off between computation and communication. When there is abundant computation resource (CPU cycles, memory, etc.) but low communication bandwidth, it is beneficial to compute more and communicate less, and vice versa. Various experiments show that our algorithms can significantly reduce the model training time, compared to the standard (non-optimized) federated averaging (FedAvg) algorithm. When combined with model pruning, we can also largely improve the inference time, making it possible to deploy trained models to resource-limited edge devices.
In the ambulance use case, there is both data/resource similarity and diversity. These characteristics make these data points ideal candidates for federated learning models, as described above.
While federated learning has many benefits, it adds a new vulnerability in the form of misbehaving agents. Researchers at IBM Research and Princeton University have explored the possibility of an adversary controlling a small number of agents, with the goal of corrupting the trained model [2]. The adversary’s objective is to cause the jointly trained global model to misclassify a set of chosen inputs with high confidence, i.e., it seeks to poison the global model in a targeted manner. Since the attack is targeted, the adversary also attempts to ensure that the global model converges to a point with good performance on the test or validation data. Such attacks emphasize the need for admitting trusted agents in federated learning and develop model training algorithms that are robust to adversarial attacks [3]. These defenses are being increasingly adapted into the federated learning setup (e.g., in the design of robust model fusion algorithms) to detect and secure the learning process against such attacks.
In summary, federated learning provide a secure model training strategy that can provide insights for patient care, optimal routes, Telco orchestration, etc. for emergency vehicles, such that the training of a neural network model is distributed between multiple agents on the edge. This allows learning across multiple edges, without having to actually share the raw data between them, often resulting in better models for deployment at the edge.
References:
[1] S. Wang, T. Tuor, T. Salonidis, K. K. Leung, C. Makaya, T. He, and K. Chan, “Adaptive federated learning in resource constrained edge computing systems,” IEEE Journal on Selected Areas in Communications, vol. 37, no. 6, pp. 1205–1221, Jun. 2019.
[2] A. Bhagoji, S. Chakraborty, P. Mittal, S. Calo, “Analyzing federated learning through an adversarial lens”, in International Conference on Machine Learning (ICML), 2019.
[3] M.-I. Nicolae, M. Sinn, M. N. Tran, B. Buesser, A. Rawat, M. Wistuba, V. Zantedeschi, N. Baracaldo, B. Chen, H. Ludwig, I. M. Molloy, B. Edwards, “Adversarial Robustness Toolbox v1.0.0”, arXiv preprint, 2018.