随着深度学习模型越来越大,单卡训练逐渐成为过去式(或者仅用于调试),越来越多的训练代码需要多卡分布式训练乃至于多机多卡分布式训练。各种分布式方案中,PyTorch自带的DistributedDataParallel(简称DDP)开箱即用,是主流方案之一。
本文记录关于PyTorch分布式训练DDP中的find_unused_parameters参数含义,本文主要的参考资料为pytorch分布式训练论文《PyTorch distributed: experiences on accelerating data parallel training》(VLDB 2020)和pytorch关于分布式训练的一系列官方文档。
一个Toy Example
PyTorch的分布式训练方案的核心是对各个显卡之中的模型的梯度求平均,因此我们首先构造一个可以方便地控制梯度的模型。
from torch.nn import Module
from torch import nn
import torch
class Net(Module):
def __init__(self):
super(Net, self).__init__()
self.layer = nn.Linear(1, 1, bias=False)
self.layer.weight.data.zero_()
self.layer.weight.data += 1
def forward(self, x):
return self.layer(x)
net = Net()
BS = 6
input = torch.zeros(BS, 1).to(torch.float32) + 1
output = net(input)
target = torch.zeros_like(input)
loss = (0.5 * (output - target) ** 2).sum()
loss.backward()
print(f'grad {net.layer.weight.grad.item()}')
这个模型的梯度就是它的batch size,改变BS参数多运行几次就能看出来了。
这样,我们就可以控制在哪个参数上有梯度、梯度是多少了。
设置find_unused_parameters的作用
接下来我们写一份分布式训练的代码,它会根据该节点的rank(序号)调用第一层或者第二层的参数。
from torch.nn import Module
from torch import nn
import torch
class Net(Module):
def __init__(self):
super(Net, self).__init__()
self.layer1 = nn.Linear(1, 1, bias=False)
self.layer2 = nn.Linear(1, 1, bias=False)
self.layer1.weight.data.zero_()
self.layer1.weight.data += 1
self.layer2.weight.data.zero_()
self.layer2.weight.data += 1
def forward(self, x, rank=0):
layer = f'layer{1 if rank % 2 == 1 else 2}'
return getattr(self, layer)(x)
import torch.distributed as dist
dist.init_process_group('gloo')
rank = int(dist.get_rank())
find_unused_parameters = False
net = nn.parallel.DistributedDataParallel(Net(), find_unused_parameters=find_unused_parameters)
input = torch.zeros(rank + 1, 1).to(torch.float32) + 1
output = net(input, rank=rank)
target = torch.zeros_like(input)
loss = (0.5 * (output - target) ** 2).sum()
loss.backward()
print(f'rank {rank}, layer 1, grad {net.module.layer1.weight.grad}')
print(f'rank {rank}, layer 2, grad {net.module.layer2.weight.grad}')
代码保存为a.py,执行命令 torchrun --standalone --nnodes=1 --nproc_per_node=4 a.py ,得到的输出结果表明反向传播之后的梯度情况为:
这说明结果是不对的,代码里存在bug。根据DistributedDataParallel的要求,反向传播之后每个模型的参数的梯度应该是一样的。
将代码中的find_unused_parameters参数改为True,再执行一次,可以看到结果为:
这次结果就对了,每个节点的梯度都一样了,可以进行optimizer.step操作。
结果理解
大体上来说,当4个结点共同进行分布式训练时,只有当4个结点都计算得到了参数梯度之后,DistributedDataParallel才会对梯度进行平均。所以,当find_unused_parameters=False时,每一层的参数只得到了两份梯度,并不会触发梯度平均,所以我们看到的是图1的样子。
当设置find_unused_parameters=True时,DistributedDataParallel会跟踪每个节点的计算图,标记那些没用梯度的参数,并将其梯度视为0,然后再进行梯度平均,就得到了图2的结果。
一张图概括,当find_unused_parameters=False时,如果某个参数的梯度没有n份(n为分布式训练的节点总数),这个参数的梯度将不会被平均(每个节点的梯度都不一样,将导致各个节点的参数发散);当find_unused_parameters=True时,会执行下图的流程。
这样一来,或许就能够明白find_unused_parameters参数的含义了,就是: 寻找没用到的参数。很简单直白的翻译。
提醒
pytorch的DistributedDataParallel训练模式的隐藏要求是每个节点的梯度必须相同,这样才能保证每个节点的模型参数相同。如果模型中存在部分模块在特定数据中不会被启用的情况,并且是默认的find_unused_parameters=False的设置,将导致训练过程无意义,各个节点的模型不一致。
更好的解决方案
find_unused_parameters=True的设置会带来额外的运行时开销(而且还不小)。
一种更好的办法是构建一个相同的计算图,用0和1这些选择变量来执行选择操作,这样就不用设置find_unused_parameters参数了。例如:
from torch.nn import Module
from torch import nn
import torch
class Net(Module):
def __init__(self):
super(Net, self).__init__()
self.layer1 = nn.Linear(1, 1, bias=False)
self.layer2 = nn.Linear(1, 1, bias=False)
self.layer1.weight.data.zero_()
self.layer1.weight.data += 1
self.layer2.weight.data.zero_()
self.layer2.weight.data += 1
def forward(self, x, rank=0):
layer1_coeff = 1 if rank % 2 == 1 else 0
layer2_coeff = 1 - layer1_coeff
return self.layer1(x) * layer1_coeff + self.layer2(x) * layer2_coeff
import torch.distributed as dist
dist.init_process_group('gloo')
rank = int(dist.get_rank())
find_unused_parameters = False
net = nn.parallel.DistributedDataParallel(Net(), find_unused_parameters=find_unused_parameters)
input = torch.zeros(rank + 1, 1).to(torch.float32) + 1
output = net(input, rank=rank)
target = torch.zeros_like(input)
loss = (0.5 * (output - target) ** 2).sum()
loss.backward()
print(f'rank {rank}, layer 1, grad {net.module.layer1.weight.grad}')
print(f'rank {rank}, layer 2, grad {net.module.layer2.weight.grad}')
得到的结果与find_unused_parameters=True是一样的。
扩展知识
熟悉timm的朋友都知道,timm里面有个DropPath层,对应一个drop_path函数。里面用的也是类似的技巧,用乘以0或者1来表示选择,而不能直接用Python的if-else来选择是否调用某个模块。
总结
分布式训练的本质是多个节点协同训练,为了实现这种协同,多个节点的计算结构必须得是一致的。只要稍有差别,就很难处理,而且会带来很大的运行时开销。
本文暂时没有评论,来添加一个吧(●'◡'●)