上节课介绍了 Pytorch 内置的 nn.DataParallel 类,做多 GPU 并行训练的方法。当使用 nn.DataParallel 训练模型时,每个 Batch 的数据会被平均分散到各个 GPU 上进行计算。默认情况下,DataParallel 会将第一个 GPU 作为主 GPU,它负责接收输入数据,并将其分发给其他 GPU 进行计算,此外,它还负责收集其他 GPU 计算的结果,并执行一些汇总操作,例如计算梯度的平均值,更新参数后再同步给其他 GPU。

所以,第一块GPU的负荷会远远大于其他 GPU,它可能会成为整个训练过程的瓶颈。为了防止第一块 GPU 超限报错,而其他 GPU 有大量闲置的情况,我们可以手动分配数据,把更多数据分配到非主 GPU 上计算,而主 GPU 只处理较少的数据,将更多资源留给分发和汇总操作。

代码示例

1、定义 BalancedDataParallel 类

我们定义一个 BalancedDataParallel 类,来处理多 GPU 负载均衡的问题,这个类的来源是 transformer-xl 的源码。

内容不可见,请联系管理员开通权限。

2、调用逻辑改动

内容不可见,请联系管理员开通权限。

3、观察GPU显存占用情况

内容不可见,请联系管理员开通权限。

好的,那到目前为止,我们才真正的把代码讲完了,下节课,就可以用 Kaggle 的免费 GPU 资源,来做模型训练了。

本文链接:http://ichenhua.cn/edu/note/677

版权声明:本文为「陈华编程」原创课程讲义,请给与知识创作者起码的尊重,未经许可不得传播或转售!