Scaling Distributed Machine Learning with In-Network Aggregation

Distributed ML training is increasingly a network-bound workload. The work considers data-parallel training, where the input data is partitioned across workers, focussing exclusively on widely-used distributed synchronous SGD. The challenge is that data-parallel SGD requires computing the sum of model updates across all workers after every iteration. Each model update has as many parameters as the model itself, and models are growing exponentially. Today, they exceed 32 GB. Performance bottleneck in distributed training is shifting from compute to communication due to:

Today’s ML toolkits implement this communication phase in one of two ways:

Even at 100 Gbps, DeepLight, LSTM, NCF, and BERT spend 17% of their batch time in communication. When GPUs become faster, network performance will be a serious issue event at 100 Gbps.

The work proposes an alternative approach to model update exchange for ML workloads using in-network aggregation (INA). Workers send their model updates over the network where an aggregation primitive sums the updates and distributes only the resulting value.

The fundamental advantage of INA over PS & all-reduce is that:

Even though techniques like gradient compression can reduce the data volume of model updates, it is not necessarily superior to INA.

The proposed system, SwitchML, implements the aggregation primitive in a programmable dataplane switch. They solve the challenges of limited computation & storage capabilities in switches. The system must also tolerate packet loss, since the DNN training jobs are long-running.

SwitchML uses the following techniques:

They benchmark SwitchML against state-of-the-art all-reduce communication libraries Gloo & NCCL. SwitchML outperforms NCCL in aggregated tensor elements per unit of time (ATE/s). Also, they found SwitchML provides significant speedups in training batch-processing speedup compared to NCCL backend of PyTorch & TensorFlow. The four network-bottlenecked DNNs at 100Gbps (DeepLight, LSTM, NCF, and BERT) provide speedups of upto 2.27x over NCCL-RDMA and 5.55x over NCCL-TCP.

Strengths

Future Work