Pytorch并行计算:nn.parallel.replicate, scatter, gather, parallel_apply

谁说胖子不能爱 提交于 2020-11-14 08:30:42
import torch
import torch.nn as nn
import ipdb


class DataParallelModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.block1 = nn.Linear(10, 20)


    def forward(self, x):
        x = self.block1(x)
        return x

def data_parallel(module, input, device_ids, output_device=None):
    if not device_ids:
        return module(input)

    if output_device is None:
        output_device = device_ids[0]

    replicas = nn.parallel.replicate(module, device_ids)
    print(f"replicas:{replicas}")
	
    inputs = nn.parallel.scatter(input, device_ids)
    print(f"inputs:{type(inputs)}")
    for i in range(len(inputs)):
        print(f"input {i}:{inputs[i].shape}")
		
    replicas = replicas[:len(inputs)]
    outputs = nn.parallel.parallel_apply(replicas, inputs)
    print(f"outputs:{type(outputs)}")
    for i in range(len(outputs)):
        print(f"output {i}:{outputs[i].shape}")
		
    result = nn.parallel.gather(outputs, output_device)
    return result

model = DataParallelModel()
x = torch.rand(16,10)
result = data_parallel(model.cuda(),x.cuda(), [0,1])
print(f"result:{type(result)}")

最后输出为

replicas:[DataParallelModel(
  (block1): Linear(in_features=10, out_features=20, bias=True)
), DataParallelModel(
  (block1): Linear(in_features=10, out_features=20, bias=True)
)]
inputs:<class 'tuple'>
input 0:torch.Size([8, 10])
input 1:torch.Size([8, 10])
outputs:<class 'list'>
output 0:torch.Size([8, 20])
output 1:torch.Size([8, 20])
result: torch.Size([16, 20])

可以看到整个流程如下:

  • replicas: 将模型复制若干份,这里只有两个GPU,所以复制两份
  • scatter: 将输入数据若干等分,这里划分成了两份,会返回一个tuple。因为batch size=16,所以刚好可以划分成8和8,那如果是15怎么办呢?没关系,它会自动划分成8和7,这个你自己可以做实验感受一下。
  • parallel_apply: 现在模型和数据都有了,所以当然就是并行化的计算咯,最后返回的是一个list,每个元素是对应GPU的计算结果。
  • gather:每个GPU计算完了之后需要将结果发送到第一个GPU上进行汇总,可以看到最终的tensor大小是[16,20],这符合预期。

<footer style="color:white;;background-color:rgb(24,24,24);padding:10px;border-radius:10px;"><br> <h3 style="text-align:center;color:tomato;font-size:16px;" id="autoid-2-0-0"><br> <b>MARSGGBO</b><b style="color:white;"><span style="font-size:25px;">♥</span>原创</b><br> <br><br> <br><br> <b style="color:white;"><br> 2019-9-17<p></p> </b><p><b style="color:white;"></b><br> </p></h3><br> </footer>

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!