PyTorch 加速数据读取, 提高 GPU 利用率

网友投稿 1438 2022-08-31

PyTorch 加速数据读取, 提高 GPU 利用率

PyTorch 加速数据读取, 提高 GPU 利用率

PyTorch代码:​​prefetch_generator

使用 prefetch_generator 库在后台加载下一 batch 的数据。 需要安装 prefetch_generator 库

pip install prefetch_generator

原本 PyTorch 默认的 DataLoader 会创建一些 worker 线程来预读取新的数据,但是除非这些线程的数据全部都被清空,这些线程才会读下一批数据。 使用 prefetch_generator,我们可以保证线程不会等待,每个线程都总有至少一个数据在加载。

import torchfrom torch.utils.data import DataLoaderfrom prefetch_generator import BackgroundGeneratorclass PrefetchDataLoader(DataLoader): ''' replace DataLoader with PrefetchDataLoader ''' def __iter__(self): return BackgroundGenerator(super().__iter__())

2 data_prefetcher

使用 data_prefetcher 新开 cuda stream 来拷贝 tensor 到 gpu

默认情况下,PyTorch 将所有涉及到 GPU 的操作(比如内核操作,cpu->gpu,gpu->cpu)都排入同一个 stream(default stream)中,并对同一个流的操作序列化,它们永远不会并行。要想并行,两个操作必须位于不同的 stream 中。

而前向传播位于 default stream 中,因此,要想将下一个 batch 数据的预读取(涉及 cpu->gpu)与当前 batch 的前向传播并行处理,就必须: 1 cpu 上的数据 batch 必须 pinned; 2 预读取操作必须在另一个 stream 上进行

class DataPrefetcher(object): ''' prefetcher = DataPrefetcher(train_loader, device=self.device) batch = prefetcher.next() iter_id = 0 while batch is not None: iter_id += 1 if iter_id >= num_iters: break run_step() batch = prefetcher.next() ''' def __init__(self, loader, device): self.loader = loader self.dataset = loader.dataset self.stream = torch.cuda.Stream() self.next_input = None self.next_target = None self.device = device def __len__(self): return len(self.loader) def preload(self): try: self.next_input, self.next_target = next(self.loaditer) except StopIteration: self.next_input = None self.next_target = None return with torch.cuda.stream(self.stream): self.next_input = self.next_input.cuda(device=self.device, non_blocking=True) self.next_target = self.next_target.cuda(device=self.device, non_blocking=True) def __iter__(self): count = 0 self.loaditer = iter(self.loader) self.preload() while self.next_input is not None: torch.cuda.current_stream().wait_stream(self.stream) input = self.next_input target = self.next_target self.preload() count += 1 yield input,

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:ubuntu下常见问题汇总(累积中。。。)
下一篇:Go语言教程之结构体(go 结构)
相关文章

 发表评论

暂时没有评论,来抢沙发吧~