Networked Federated Multi-Task Learning

Many important application domains generate distributed collections of heterogeneous local datasets. These local datasets are often related via an intrinsic network structure that arises from domain-speciﬁc notions of similarity between local datasets. Different notions of similarity are induced by spatio-temporal proximity, statistical dependencies or functional relations. We use this network structure to adaptively pool similar local datasets into nearly homogenous training sets for learning tailored models. Our main conceptual contribution is to formulate networked federated learning using the concept of generalized total variation (GTV) minimization as a regularizer. This formulation is highly ﬂexible and can be combined with almost any parametric model including Lasso or deep neural networks. We unify and considerably extend some well-known approaches to federated multi-task learning. Our main algorithmic contribution is a novel federated learning algorithm which is well suited for distributed computing environments such as edge computing over wireless networks. This algorithm is robust against model misspeciﬁcation and numerical errors arising from limited computational resources including processing time or wireless channel bandwidth. As our main technical contribution, we offer precise conditions on the local models as well on their network structure such that our algorithm learns nearly optimal local models. Our analysis reveals an interesting interplay between the (information) geometry of local models and the (cluster-) geometry of their network.


Introduction
Many important application domains generate distributed collections of local datasets that are related via an intrinsic network structure [1]. Such a network structure can arise from spatio-temporal proximity, statistical dependencies or functional relations [2,3,4]. The network structure of data might also arise from a distributed computing infrastructure that generates the data [5]. In what follows we assume that the network structure of the data is known. The challenge of learning the network structure from raw data is beyond the scope of this paper (see Section 6 for future research directions).
Two timely application domains facing distributed networked data are the high-precision management of pandemics and the Internet of Things (IoT) [6,7]. Here, local datasets are generated either by smartphones and wearables or by IoT devices [8,9]. These local datasets are related via physical contact networks, social networks [10], co-morbidity networks [11], or communication networks [12].
Federated learning (FL) is a recent main thread of machine learning which studies collaborative learning from distributed data [13,14,15]. FL methods provide privacy protection as they do not require exchange of raw data that might be sensitive [16,17]. Moreover, FL offers robustness against malicious data perturbation due to its intrinsic averaging or aggregation over large collections of (mostly benign) datasets [18].

Related Work
Similar to [19], we frame FL as a multi-task learning problem that is solved using regularized empirical risk minimization (RERM). For each local dataset we obtain a separate learning task that amounts to finding an optimal choice for the weights of a local model. These individual learning tasks are coupled by a known network structure which we represent as a weighted undirected "empirical graph" (see Section 2). In contrast to [20], we do not use a probabilistic model for the network structure of the individual tasks. To capture the intrinsic cluster structure of networked data, we construct the regularization term using the concept of generalized total variation (GTV) which unifies widely-used measures for the variation of node attributes (see Section 2.2).
Our approach unifies and considerably extends some recent approaches to distributed FL including the network Lasso (nLasso) [21,22,23,24] and FL methods that use the graph Laplacian quadratic form as regularizer [19,25]. While the methods in [22,23,24,26,27] are limited to generalized linear models, the proposed FL algorithm can be combined with non-linear parametrized models including graphical Lasso or deep neural networks [28,29]. In contrast to [25], we do not require access to the entire networked data but only a small subset of local datasets.
The main algorithmic device behind our method is a primal-dual approach to solving large-scale convex optimization problems [30,31,32]. We obtain a novel networked FL algorithm as the direct application of this primal-dual approach to GTV minimization (see Section 3). This generic primal-dual approach is appealing for FL applications due to its scalability and robustness against modelling errors or faults in the computational infrastructure [33,34,35,36]. The iterations of primaldual methods typically involve smaller (atomic) optimization problems. Given finite computational resources these atomic optimization problems can only be solved to some non-zero optimization error. The effect of inexact updates in primal-dual methods has been analyzed recently [37].
A large body of existing work has studied computational and statistical aspects of RERM in multi-task learning (see, e.g., [38,39,40,41]). Investigations of statistical properties of RERM have mainly focused on (group-) sparsity as the regularization principle [42,43,44]. However, we are not aware of studies about the interplay between estimation error and the network cluster structure in the context of FL. We close this gap by providing a precise characterization of network structures and available data such that RERM is successful (see Section 4).
Our approach relies on the clustering assumption that local datasets that form a network cluster have similar statistical properties and, in turn, similar optimizers for model weights. What sets our approach apart from existing methods for clustered FL [45,46] is that we exploit a known network structure to pool local datasets. This pooling is guided by the network structure and aimed at obtaining sufficiently large training sets for learning the local model weights.

Contribution
We propose and study a novel FL algorithm that learns separate (parametrized) models for each local dataset within a collection of networked data. The networked data is represented by a given undirected empirical graph whose edges connect (statistically) similar datasets. Our method exploits the cluster structure of the empirical graph to adaptively pool datapoints for learning model weights.
This paper provides three main contributions.
• As our main conceptual contribution, we show that GTV minimization is a useful framework for federated multi-task learning in networked data. The GTV minimization framework unifies and significantly extends the nLasso and the MOCHA method.
• Our main methodological contribution is a novel family of distributed FL algorithms (see Algorithm 1). This family is obtained by solving GTV minimization with a primal-dual method for structured optimization problems [34]. Each member of this family corresponds A i,j Figure 1: We represent networked data and corresponding models using an undirected empirical graph G = V, E . Each node i ∈ V of the graph carries a local dataset X (i) and model weights w (i) which are scored using a local loss function L (i) X (i) ; w (i) . Two nodes are connected by a weighted edge {i, j} if they carry similar datasets. The amount of similarity is encoded in an edge weight A i,j > 0 (indicated by varying thickness). We rely on a clustering assumption in that the optimal weight vectors for nodes in the same cluster C ⊆ V are in the proximity of a cluster-wise optimal weight vector w (C) . Here, we indicate a partition of the empirical graph into three disjoint clusters C 1 , C 2 , C 3 . Note that our FL method does not require the (typically unknown) partition but rather learns the partition based on the local datasets and network structure of G.
to a particular choices for the local models and the choice of penalty function used to define the GTV. By exploiting the network structure of data, our algorithm can also handle partially observed datasets which is relevant for semi-supervised learning settings. • Our main analytical contribution is an upper bound on the estimation error incurred by GTV-based RERM. This upper bound reveals sufficient conditions on the local datasets and network structure such that our method achieves the performance of an oracle-based method that perfectly knows the true cluster structure of the data network.
Notation. The identity matrix of size n×n is denoted I n , with the subscript omitted if the size n is clear from context. The Euclidean norm of a vector w = (w 1 , . . . , w n ) T is w 2 := n r=1 w 2 r and the 1 norm w 1 := n r=1 |w r |. Given a positive semi-definite matrix M, we define the norm w M := √ w T Mw. It will be convenient to use the notation (1/2τ ) instead of (1/(2τ )). We will need the clipping function T (γ) (w) := γw/ w 2 for w 2 ≥ γ and T (γ) (w) := w otherwise.

Problem Formulation
Networked data can be represented elegantly using an undirected "empirical" graph G = (V, E) as illustrated in Figure 1. Every node i ∈ V of the empirical graph carries a separate local dataset X (i) . The local dataset X (i) might be a single sensor measurement, an entire time series or even a whole collection of videos. Our approach is agnostic towards the details of the data representation such as if datapoints are labelled or unlabelled. A key aspect of our method is that it accesses X (i) only indirectly via a local loss function (see Section 2.1). Another key aspect of our approach is that we use local datasets only at nodes in a (small) training set M = {i 1 , ..., i M } ⊆ V. To compensate for not using local datasets outside M we exploit the cluster structure of G. This is relevant for applications where accessing local datasets is computationally costly.
An undirected edge {i, j} ∈ E indicates that the corresponding local datasets X (i) and X (j) have similar statistical properties. The strength of the similarity is quantified by the edge weight A ij > 0. The neighbourhood of a node i ∈ V is N i := {j ∈ V : {i, j} ∈ E}. It will be convenient to define the head and tail of an undirected edge {i, j} as e + := min{i, j} and e − := max{i, j}, respectively.
In what follows, we will make use of two vector spaces that are naturally associated with an empirical graph G. The space R V×n of networked weight vectors consists of maps w : V → R n : i → w (i) that assign each node i ∈ V a vector w (i) ∈ R n . Another (in some sense dual) space R E×n is obtained by all maps u : E → R n : e → u (e) that assign each edge e ∈ E a vector u (e) ∈ R n . These two spaces are linked via the block-incidence matrix D ∈ R n|E|×n|V| , D e,i = I for i = e + , D e,i = −I for i = e − and D e,i = 0 otherwise. (1) Applying D to a given networked weight vector w ∈ R V×n results in the networked dual weights u ∈ R E×n with u (e) = w (e+) − w (e−) .

Networked Models
Our goal is to learn, for each node i ∈ V, the weights (parameters) w (i) of a local model that is tailored to the local dataset X (i) . The usefulness of a particular choice for the weights w (i) is measured by a loss function L X (i) ; w (i) . In principle, our approach allows for the use of an arbitrary loss function. However, unless otherwise noted, we tacitly assume each loss function L X (i) ; w (i) to be a convex function of the weight vector w (i) [47]. Specific choices for the local loss function, corresponding to some well-known parametric models, are discussed in Appendix B.
We define networked (model) weights as a map w : V → R n : i → w (i) that assigns each node i ∈ V the a local weight vectors w (i) . The space of all such networked weights is denoted R V×n . To evaluate the quality of given networked weights w ∈ R V×n , we compute the training error Here we used the shorthand L (i) w (i) := L X (i) ; w (i) . Unless noted otherwise, we assume that L (i) · , for i ∈ V are convex functions. Additional restrictions are placed on L (i) · in Section 4 to facilitate the analysis of statistical properties of our proposed algorithm. This algorithm aims at learning networked weights w such that the training error (2) is small.
The criterion (2) alone is insufficient to guide the learning of the weights w (i) since it completely ignores the weights w (i) at nodes i ∈ V\M outside the training set. We therefore need to impose some additional structure on the collection of weight vectors w (i) , for i ∈ V. To this end, we exploit the network structure of the empirical graph G. We assume that the empirical graph can be decomposed into few tight-knit clusters. We then require the weight vectors w (i) to be approximately constant for all nodes i ∈ V belonging to the same cluster.

Generalized Total Variation Minimization
To learn the weights w (i) of the local models, one for each local dataset X (i) . we assume that local datasets forming a tight-knit subset (or cluster) C ⊆ V have similar statistical properties. It is therefore sensible to enforce the weight vectors w (i) to be approximately constant for all nodes i ∈ V in the same cluster C ⊆ V. Thus, we learn a cluster-wise weight vector w (C) by pooling local datasets X (i) with i ∈ C into an intermediate training set.
Let us define the variation of networked weights w ∈ R V×n using the map u : e ∈ E → u (e) ∈ R n which assigns each edge e ∈ E the difference u (e) := w (e+) − w (e−) . Using the block-incidence matrix we can write more compactly u = Dw.
To enforce similar predictors w (i) ≈ w (j) for nodes i, j ∈ V in the same cluster, we require them to have a small generalized TV (GTV) We define the GTV for an arbitrary edge set S ⊆ E as w S : The GTV (3) uses a non-negative penalty function φ(v) ∈ R + which is typically increasing with increasing norm v of the argument v ∈ R n . Our approach allows for different choices for the penalty function φ(·). Two specific choices are φ(v) := v 2 , which is used by nLasso [21], and φ(v) := (1/2) v 2 2 which is used by "MOCHA" [19]. Another recent FL method for networked data uses the choice φ(v) := v 1 [48].
Enforcing a small GTV (3), requires weights w (i) to change only over few edges e ∈ E with small weight A e . To balance between training error (2) and GTV (3), we solvê Note that GTV minimization (4) is an instance of RERM with the GTV as the regularizer. The empirical risk incurred by networked weights w ∈ R V×n is measured by the training error f (w) (2).
We will solve (4) using a primal-dual method lending to our FL algorithm in Section 3. The statistical properties of the solutions to (4) are the subject of Section 4.
The regularization parameter λ > 0 in (4) allows to trade a small GTV || w|| GTV against a small training error f ( w) (2). The choice of λ can be guided by cross validation [49] or by our analysis of the solutions of (4) in Section 4. Loosely speaking, increasing the value of λ results in the solutions of (4) becoming increasingly clustered, i.e., the weight vectors w (i) become constant over increasingly larger subsets of nodes. Choosing λ larger than some critical value, that depends on the local datasets and network structure of G, results in constant weight vectors w (i) for all nodes i ∈ V.
Different choices for the penalty function φ(·) offer different trade-offs between computational complexity and statistical properties of the resulting FL algorithm. Non-smooth TV minimization (φ(u) = u 2 ) is computationally more challenging than methods using graph Laplacian quadratic form (φ(u) = u 2 2 ). Statistically, non-smooth TV minimization can be more accurate on network structures that are challenging for approaches relying on graph Laplacian quadratic form [50,51].

Networked Federated Multi-Task Learning Algorithm
We now discuss our FL algorithm that is obtained by applying an established primal-dual approach to solve GTV minimization (4). Let us first rewrite (4) more compactly (see (1) and (3)), The objective function in of the GTV minimization problem (5) consists of two components. The first component is the training error f (w) (2), which depends on the networked weights w ∈ R V×n . The second component is the scaled GTV g(u) (3) which depends on the dual weights u = Dw, with u (e) = w (e+) − w (e−) , for e ∈ E.
We solve (5) jointly with the dual problem The objective function in (6) is composed of the convex conjugates [47] g * (u) := sup z∈R E×n e∈E u (e) T z (e) − g(z) and f * (w) := sup z∈R V×n i∈V The domain of the dual problem (6) is the dual space R E×n of maps u : E → R n that assign a separate dual vector u (e) to each edge e ∈ E.
The duality between (5) and (6) The strong duality (9) allows to bound the sub-optimality of given networked weights. Indeed, for any given dual variable u, the objective function value −g * ( u)−f * (−D T u) is a lower bound for the optimal value of (5).
We obtain Algorithm 1 by applying a proximal point algorithm to jointly solve (5) and (6). The details of this application are provided in Appendix C. The key components of Algorithm 1 are node-wise and edge-wise updates in step 6 and 10. The results of these local computations are then propagated to the adjacent edges and nodes. Let us next discuss the node-wise and edge-wise updates in detail.
The node-wise updates in step 6 of Algorithm 1 use the primal update operator The primal update operator (10) solves a regularized variant of the local loss L (i) (w (i) ). The regularizer is the squared Euclidean distance to the argument v of the operator and will be used to enforce similar weight vectors at well-connected nodes (belonging to the same cluster). Note that the primal update operator (10) depends on the choice for the loss function. We discuss specific choices for loss functions and corresponding primal update operators in Appendix B.
Algorithm 1 involves dual updates, for each edge e ∈ E, using the dual update operator Here, we used the convex conjugate φ * (v) := sup z∈R n v T z−φ(z) of the GTV penalty function φ(v) (see (3)). The dual update operator (48)  and the local loss function L (i) (w (i) ) (which encapsulates the local dataset X (i) . The dual update step (10) can be carried out in parallel for each edge e ∈ E.
The results of the parallel primal and dual updates in step (6) and (10) are spread to neighbouring (incident) edges and nodes in the steps (3) and (9). For data whose network structure has a bounded maximum node degree, the computational complexity of Algorithm 1 scales linearly with the number of local datasets.
A computationally appealing property of Algorithm 1 is its robustness against errors occurring during the updates in step 6 and 10. . This is important for applications where the update operators (10) and (48) can be evaluated only approximately or where these updates have to be forwarded over imperfect wireless links. We can ensure convergence to an approximate solution of GTV minimization (5) even if we implement the primal update operator (10) only numerically (with sufficient accuracy). Let us denote the perturbed update, e.g., obtained by some numerical method for solving (10), by w k+1 and the exact update by w k+1 . Then, Algorithm 1 converges to a solution of (5) as long as ∞ k=1 w k+1 − w k+1 < ∞ (see [34,Sec. 5]). Algorithm 1 combines the information contained in the local datasets with their network structure to iteratively improve the weight vectorsŵ (i) k for each node i ∈ V.
Step 6 adapts the current weight vectorsŵ (i) k to better fit the labeled local datasets X (i) for i ∈ M. These updates are then propagated beyond the training set via steps 3 and 9 of Algorithm 1.
We can use different stopping criteria for Algorithm 1 such as a fixed number of iterations or requiring a minimum decrease of the objective function (5) during each iteration. We can also mix criteria and stop iterating either after a maximum number of iterations has been reached or when the reduction of the objective function falls below a prescribed threshold.

Statistical Aspects
This section studies the statistical properties of Algorithm 1 by analyzing the exact solutions to (5). Our main analytical contribution is an upper bound on the deviation between solutions of (5) and weight vectors obtained by pooling local datasets using the true underlying cluster structure of G. This upper bound applies only under certain conditions on the local loss functions and the cluster structure of G. These conditions will be formalized as three explicit assumptions.
The first assumption restricts our analysis to loss functions L (i) (·) that are convex and differentiable. This is a rather mild restriction as it includes many popular choices for parametrized models and loss functions (see Appendix B).

Assumption 1 (Smooth and Convex Loss).
For each node i ∈ V, the local loss L (i) v is convex and smooth with a continuous Hessian ∂vr∂vs . There is a positive constant C > 0 such that the eigenvalues of the Hessian ∇ 2 L (i) are lower bounded as The factor 2 in (12) is only for notational convenience. For the special case of linear regression models and squared error loss (see Appendix B), the condition (12) is satisfied with high probability if the features of datapoints are not too strongly correlated.
The idea behind using GTV minimization (5) is that minimizers of loss functions L (i) (·) in the same cluster i ∈ C are all near to a common cluster-wise optimal weight w (C) = argmin v∈R n i∈C Thus, the collections of local loss functions are simultaneously minimized approximately by the piece-wise constant weight vectors Note that in practice we cannot use (13) to learn the model weights w (i) since we do not know the true cluster structure of the empirical graph G. Only if we had an oracle that provides us a true cluster C, we could actually compute the optimal weight (13) for all nodes i ∈ C. The role of (13) is to provide a (practically infeasible) benchmark for assessing the statistical properties of GTV minimization (5). We provide conditions on the (cluster structure of the) empirical graph G such that solutions of (5) are close to the oracle-based optimal weights (13).
Our main analytical result below is an upper bound on the deviation between w and the solution w of GTV minimization (5). This upper bound depends on the discrepancy between the node-wise (local) minimizers of L (i) · , for some i ∈ C, and the cluster-wise optimal weight vector w (C) , we use the norm ∇L (i) w (C) 2 . We now formalize our clustering assumption by requiring the norm ∇L (i) w (C) 2 to be uniformly bounded for all nodes in the empirical graph.
Assumption 2 (Clustering). There is a constant U and a partition P = {C 1 , . . . , C |P| } of the empirical graph G into disjoint clusters such that for each cluster C l ∈ P, Here, we used the optimal weight vector (13) for cluster C l .
In principle, our analysis applies to an arbitrary choice for the partition P. However, the analysis is most useful if the partition is such that the boundary is small in a certain sense. In particular, we focus on partitions such that e∈∂P A e is small. We emphasize that the partition P used in Assumption 2 is only required for the analysis of Algorithm 1. Algorithm 1 itself does not require the specification of the partition P.
Our third assumption requires that the GTV of approximately piece-wise constant (see Assumption 2) cannot significantly exceed the size of the weights on the training set M. This assumption also restricts our analysis to GTV penalty functions φ(·) being a norm [53]. Assumption 3 (Sampling Condition). Consider the GTV minimization (5) with the penalty function φ(·) (3) being a norm and the local loss functions satisfying (12) with C > 1. There are positive constants L > 1 and K > 0 such that The norm w M := i∈M w (i) 2 2 measures the size of the weights w (i) on the training set M.
Assumption 3 ensures that piece-wise weight vectors, with sufficiency large GTV, cannot be arbitrarily small on M. The condition (17) can be verified by the existence of sufficiently large network flows between nodes in the training set M and the boundary edges ∂P that connect nodes in different clusters [50,54]. Using the flow condition [50,Lemma 6], it has been shown that (17) is satisfied with high probability for empirical graphs obtained as realizations of a stochastic block model (SBM) using a certain range of SBM parameters [55].
Theorem 1. Consider local loss functions L (i) (·) that satisfy Assumption 1 and are approximately minimized by a piece-wise constant w according to Assumption 2. We can access the local loss only for nodes in the training set M ⊆ V which is such that Assumption 3 is valid. The estimation error incurred by solutions w of GTV minimization (5) with λ = U L √ M /K satisfies Theorem 1 uses the specific choice λ = U L √ M /K for the GTV regularization parameter λ in (4). This provides a means for guiding the choice of λ based on (estimates for) the constants U, L, K and size M of the training set M. If these constants cannot be determined reliably, it might be more convenient to tune λ via cross-validation techniques [49].
The bound (18) reveals an interesting interplay between the geometry of the loss functions L (i) , via constants C and U in (12) and 15, and the cluster geometry of the empirical graph G, via constants L and K in Assumption 17. It can be shown that (17) holds with prescribed constant L if sufficiently large flows can be routed from cluster boundaries ∂P to the training set. According to (18), ensuring (17) with a larger L allows to tolerate a larger constant U in (15), which means that we can tolerate a larger discrepancy between true minimizers of L (i) and the piece-wise constant weights (14).

Numerical Experiments
This section reports the results of some illustrative numerical experiments to verify the performance of Algorithm 1. We provide the code to reproduce these experiments at https://github.com/ sahelyiyi/FederatedLearning. The experiments have been carried out using a standard desktop computer and revolve around synthetic datasets whose empirical graph is a realization of a SBM [56]. The nodes V are partitioned into two equal-sized clusters C 1 , C 2 .Two nodes in the same cluster are connected by a (unit weight) edge with probability p in . Two nodes in different clusters are connected by a (unit weight) edge with probability p out (typically p out p in ). Each node i ∈ V of G holds a local dataset X (i) consisting of m i = 5 data points x (i,1) , y (i,1) , . . . , x (i,mi) , y (i,mi) with feature vectors x (i,r) ∈ R 2 and scalar labels y (i,r) , for r = 1, . . . , m i . The feature vectors are generated by i.i.d. realizations (draws) from a standard multivariate normal distribution N (0, I 2×2 ). The labels of the datapoints are generated by a linear model Here, we used i.i.d. realizations ε (i,r) of a standard Gaussian random variable N (0, 1). The true underlying weight vector w (i) is piece-wise constant on the partition P = {C 1 , C 2 }.
We assume to have access only to the local datasets X (i) in a training set M of relative size ρ := |M|/|V|. We choose the training set uniformly at random from all subsets of V with prescribed size ρ|V| (using rounding when necessary). Given the local datasets in the training set, we learn weights w (i) using Algorithm 1. We used a fixed number K of iterations (K = 3000 for Figure 2, K = 2000 for Figure 3) and GTV regularization parameter λ = 10 −2 in (5) empirically.
We compute MSE := (1/|V|) i∈V w (i) − w (i) 2 2 for varying relative size ρ of training set, noise strength σ (see (19)) and intra-cluster edge probability p out . In all experiments, the intra-cluster edge probability is fixed to p in = 1/2. The curves and error bars represent the average and standard deviation of MSE values obtained using 10 simulation runs for a given choice of ρ, p out and σ. Figure  2 depicts the results for the noise less case where σ = 0 in (19). Figure 3 depicts how the MSE varies with the noise level σ in (19) for a fixed relative training set size ρ = 0.6.

Conclusion
Using a multi-task learning perspective, we have proposed an studied and efficient method for FL with networks of local datasets. Each local dataset gives rise to a separate learning task which is to learn the weights of a tailored ("personalized") model for each local dataset. These individual tasks are coupled via a known network structure. We formalize this multi-task learning problem as RERM with the GTV of model weights as regularizer .We obtain a computationally and statistically appealing FL algorithm by solving this RERM with an established primal-dual method for large-scale optimization. Future research directions include the joint learning of network structure and networked model weights. We also aim at a more fine-grained convergence analysis of our FL algorithm that takes into account the cluster structure of the empirical graph.