Every minute spent training a deep learning model is a minute not doing something else, and in today’s fast-paced world of research, that minute is worth a lot. Facebook published a paper this morning detailing its personal approach to this problem. The company says it has managed to reduce the training time of a ResNet-50 deep learning model on ImageNet from 29 hours to one.
Facebook managed to reduce training time so dramatically by distributing training in larger “minibatches” across a greater number of GPUs. In the previous benchmark case, batches of 256 images were spread across eight GPUs. But today’s work involves batch sizes of 8,192 images distributed across 256 GPUs.
Most people don’t have 256 GPUs lying around, but big tech companies and well-funded research groups do. Being able to scale training across so many GPUs to reduce training time, without a dramatic loss in accuracy, is a big deal.
The team slowed down learning rates at the beginning stages of the training process to overcome some of the difficulties that made large batch sizes previously infeasible. Without getting too lost in the details, stochastic gradient descent is used to train the ResNet-50 model.
One of the key variables in stochastic gradient descent is the learning rate — the degree by which weights change during the training process. The way this variable changes as minibatch size changes is the key to optimizing effectively.
Machine learning developers spend their days dealing with compromises. Greater accuracy often requires larger data sets that demand additional training time and compute resources. In this vein, it would be possible to prioritize accuracy or speed to achieve more impressive results, but training a model with poor accuracy in 20 seconds isn’t super valuable.
Unlike most research projects, Facebook’s AI Research (FAIR) and Applied Machine Learning (AML) teams worked side by side on increasing minibatch sizes. From here the groups plan to investigate some of the additional questions generated from today’s work.
“This work throws out more questions than it answers,” said Pieter Noordhuis, a member of Facebook’s AML team. “There’s a tipping point beyond 8,000 images where error rates go up again and we don’t know why.”
Facebook used Caffe2, its open source deep learning framework, and its Big Basin GPU servers for this experiment. Additional information from Facebook is available here if you want to dig more deeply into the details.