在执行模型前向推理时,往往涉及到一些列数据预处理操作,比如Resize,Normalize等,这些操作通常在CPU上完成,然后CPU将预处理后的图片传送到GPU上执行推理。

由于GPU的运算速度远快于CPU,所以能不能将这些数据预处理操作放到GPU上执行从而加快数据加载的速度呢?

NVIDIA DALI 可以!

NVIDIA DALI (Data Loading Library) 是一个加速数据加载和预处理的库,专为深度学习任务设计。它能将图像和视频的复杂预处理操作(尤其是在模型训练阶段,通常涉及大量的数据增强预处理操作)从 CPU 转移到 GPU 上,从而减少数据加载瓶颈,提升 GPU 的利用率。DALI 支持多种格式(如 JPEG、PNG、TFRecord 等),并能与主流深度学习框架(如 PyTorch 和 TensorFlow)无缝集成,使得数据预处理和模型前向推理可以高效并行进行。

本文介绍DALI的使用方法,以及如何将PyTorch的数据加载器替换成DALI的数据加载器,并测试加速效果。

安装NVIDIA DALI

对于CUDA 11.x,执行如下命令行:

1
pip install --extra-index-url https://pypi.nvidia.com --upgrade nvidia-dali-cuda110

对于CUDA 12.x,执行如下命令行:

1
pip install --extra-index-url https://pypi.nvidia.com --upgrade nvidia-dali-cuda120

搭建PyTorch原生数据加载器

在开始介绍DALI的使用方法之前,首先搭建PyTorch原生的数据加载器,之后再使用DALI进行重写。

S1. 导入必要的包

1
2
3
4
5
6
7
8
9
10
import glob
import os
import cv2
import numpy as np
import imgaug.augmenters as iaa
import torch
from skimage import color, io, transform
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

S2. 定义工具函数

本案例涉及到的数据预处理操作包括Resize和Normalize两个操作,因此这里给出它们的定义。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def load_img(path):
img=cv2.imread(path)
img=cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img

class Resize():
def __init__(self,s):
self.aug=iaa.Sequential([iaa.Resize(size=[s,s], interpolation='nearest')])
def __call__(self,image):
image=np.expand_dims(image, axis=0).astype(np.float32)/255
image=self.aug(images=image)
image=np.squeeze(image,0)

return image

class Normalize(object):
def __call__(self, image):

tmpImg = np.zeros((image.shape[0],image.shape[1],3))
image = image/np.max(image)

tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225

tmpImg = tmpImg.transpose((2, 0, 1))

return torch.from_numpy(np.ascontiguousarray(tmpImg))

class BuildDataset(Dataset):
def __init__(self, imgs_path_list, transforms=None):
self.img_paths = imgs_path_list
self.transforms = transforms

def __len__(self):
return len(self.img_paths)

def __getitem__(self,idx):
image = load_img(self.img_paths[idx])
if self.transforms:
image = self.transforms(image)

return image

S3. 搭建原生PyTorch数据加载流水线

在定义好以上数据预处理函数后,直接使用 torchvision.transforms将它们串联起来即可:

1
2
3
4
transform = transforms.Compose([
Resize(512),
Normalize()
])

S4. 实例化测试

再开始测试之前,首先来准备测试用的数据,这里我选取了13k+张图片进行测试:

1
2
images_files = glob.glob('./imgs/*.png')
len(images_files)# 13087

现在来构建PyTorch原生数据加载器:

1
2
3
4
5
batch_size=8
workers=8

dataset = BuildDataset(images_files, transforms=None)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, shuffle=False, drop_last=False)

执行测试,查看数据加载耗时:

1
2
3
4
s=time.time()
for data in dataloader:
pass
print(time.time()-s)# 64.35477137565613

结果显示,PyTorch原生数据加载器加载这些图片需要64.35s.

搭建DALI数据加载器

S1. 导入必要的包

1
2
3
4
from nvidia.dali.pipeline import Pipeline,pipeline_def
from nvidia.dali import ops, fn
import nvidia.dali.types as types
from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy

S2. 搭建DALI数据加载流水线

为了实现在GPU上的预处理,DALI内置了常用的数据据预处理函数,因此这里只需要调用对应的API即可。

这里,参照PyTorch数据加载器中使用的两个数据预处理操作,用DALI中对应的API进行重写。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
@pipeline_def
def dali_pipeline():
jpegs, _ = fn.readers.file(files=images_files, seed=1234, name="main_reader")
images = fn.decoders.image(jpegs, device="mixed")

images = fn.resize(images, resize_x=512, resize_y=512,device='gpu', interp_type=types.INTERP_NN)# images=fn.resize(images, size=[512, 512],device='gpu')

# Now apply the normalization with mean and std
images = fn.crop_mirror_normalize(
images,
mean=[0.485*255 , 0.456*255, 0.406*255 ],
std=[0.229*255 , 0.224*255 , 0.225*255 ],
device='gpu',
mirror=0)

return images

S3. 实例化测试

1
2
3
4
5
6
7
8
9
pipe=dali_pipeline(batch_size=batch_size, num_threads=workers, device_id=0)
pipe.build()

it = DALIGenericIterator([pipe], ['data'],size=len(images_files), last_batch_policy=LastBatchPolicy.PARTIAL)

s=time.time()
for i,data in enumerate(it):
pass
print(time.time()-s)# 16.748019218444824

执行测试,结果显示,使用DALI数据加载器加载这些图片需要16.75s,远快于原生PyTorch数据加载器的64.35s.

DALI从内存加载数据

在上面的例子中,数据源来自磁盘上存储的图片。然而,在某些应用场景下,可能涉及到直接从内存读取数据,比如,在针对高分辨率图像进行推理时,由于图像尺寸过大,往往需要先将高分辨率图像读取到内存中,并将其切分成多个小图,然后将这些小图分别使用模型进行预测,最后再将结果合并。

在这种场景下,若使用DALI数据加载器,则必须从内存加载数据。DALI提供了 external_source方法,用于读取其它数据源。

这里,我们仍然使用上面例子中的13k+张图片进行演示,只不过,需要先把这些图片统统读取到内存:

1
2
3
4
5
6
7
8
9
10
images_files = glob.glob('./imgs/*.png')
len(images_files)# 13087

def load_img(path):
img=cv2.imread(path)
img=cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img

img_arrays = [load_img(path) for path in images_files]
img_arrays[0].shape# (640, 640, 3)

接下来,构建DALI数据加载器流水线,和之前的例子不同的是,这里不再使用 fn.readers.filefn.decoders.image读取并解码来自磁盘的图片,而是使用 fn.external_source读取外部数据源,数据源的源头使用 source参数进行指定:

1
2
3
4
5
6
7
8
9
10
11
12
13
@dali.pipeline_def()
def i_pipeline():

images = fn.external_source(source=eii, batch=True, device="cpu")
images = images.gpu() # Transfer to GPU
images = fn.resize(images, resize_x=512, resize_y=512, device='gpu', interp_type=dali.types.INTERP_NN)
images = fn.crop_mirror_normalize(
images,
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
device='gpu'
)
return images

相较于之前的例子,这里的 source参数需要自行编写,它对应一个迭代器,用于每个batch的数据迭代。

在实现这个数据迭代器时,需要特别注意,当剩余的数据量不足一个batch时要进行单独处理,否则会导致数据丢失。这一点针对模型训练来说无关紧要,但是在执行高分辨率图像推理时则会导致错误。

数据迭代器实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class ExternalInputIterator(object):
def __init__(self, batch_size, img_arrays):
self.batch_size = batch_size
self.img_arrays = img_arrays
self.length = len(self.img_arrays)

def __iter__(self):
self.i = 0
return self

def __next__(self):
if self.i >= self.length:
raise StopIteration
batch = []
for _ in range(self.batch_size):
if self.i >= self.length:
# Fill remaining batch with zeros if images are exhausted
break
img = self.img_arrays[self.i]
batch.append(img)
self.i += 1

# Make sure the batch is correctly shaped: [batch_size, H, W, C]
batch = np.array(batch)
return batch

万事俱备,现在执行测试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
batch_size = 10
workers=8
# Instantiate ExternalInputIterator
eii = ExternalInputIterator(batch_size, img_arrays)

# Build and run pipeline
pipe = i_pipeline(batch_size=batch_size, num_threads=workers, device_id=0)
pipe.build()

# Create iterator
it = DALIGenericIterator([pipe], ['data'])

# Iterate over the data
for idx, data in enumerate(it):
print(f"Batch {idx+1} shape:", data[0]['data'].shape)

输出:

1
2
3
4
5
6
Batch 1 shape: torch.Size([10, 3, 512, 512])
Batch 2 shape: torch.Size([10, 3, 512, 512])
Batch 3 shape: torch.Size([10, 3, 512, 512])
Batch 4 shape: torch.Size([10, 3, 512, 512])
Batch 5 shape: torch.Size([10, 3, 512, 512])
...

over,如果本文对你有帮助的话,请不要吝啬您的点赞、收藏与转发,感谢~