如何设计自己的数据变换⚓︎
在本教程中,我们将介绍MMagic中变换流水线的设计。
The structure of this guide are as follows:
MMagic中的数据流水线⚓︎
按照典型的惯例,我们使用 Dataset
和 DataLoader
来加载多个线程的数据。 Dataset
返回一个与模型的forward方法的参数相对应的数据项的字典。
数据准备流水线和数据集是分开的。通常,一个数据集定义了如何处理标注,而一个数据管道定义了准备一个数据字典的所有步骤。
一个流水线由一连串的操作组成。每个操作都需要一个字典作为输入,并为下一个变换输出一个字典。
这些操作被分为数据加载、预处理和格式化。
在MMagic中,所有数据变换都继承自 BaseTransform
。
变换的输入和输出类型都是字典。
数据变换的一个简单示例⚓︎
>>> from mmagic.transforms import LoadPairedImageFromFile
>>> transforms = LoadPairedImageFromFile(
>>> key='pair',
>>> domain_a='horse',
>>> domain_b='zebra',
>>> flag='color'),
>>> data_dict = {'pair_path': './data/pix2pix/facades/train/1.png'}
>>> data_dict = transforms(data_dict)
>>> print(data_dict.keys())
dict_keys(['pair_path', 'pair', 'pair_ori_shape', 'img_mask', 'img_photo', 'img_mask_path', 'img_photo_path', 'img_mask_ori_shape', 'img_photo_ori_shape'])
一般来说,变换流水线的最后一步必须是 PackInputs
.
PackInputs
将把处理过的数据打包成一个包含两个字段的字典:inputs
和 data_samples
.
inputs
是你想用作模型输入的变量,它可以是 torch.Tensor
的类型, torch.Tensor
的字典,或者你想要的任何类型。
data_samples
是一个 DataSample
的列表. 每个 DataSample
都包含真实值和对应输入的必要信息。
BasicVSR的一个示例⚓︎
下面是一个BasicVSR的流水线示例。
train_pipeline = [
dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
dict(type='SetValues', dictionary=dict(scale=scale)),
dict(type='PairedRandomCrop', gt_patch_size=256),
dict(
type='Flip',
keys=['img', 'gt'],
flip_ratio=0.5,
direction='horizontal'),
dict(
type='Flip', keys=['img', 'gt'], flip_ratio=0.5, direction='vertical'),
dict(type='RandomTransposeHW', keys=['img', 'gt'], transpose_ratio=0.5),
dict(type='MirrorSequence', keys=['img', 'gt']),
dict(type='PackInputs')
]
val_pipeline = [
dict(type='GenerateSegmentIndices', interval_list=[1]),
dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
dict(type='PackInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
dict(type='MirrorSequence', keys=['img']),
dict(type='PackInputs')
]
对于每个操作,我们列出了添加/更新/删除的相关字典字段,标记为 '*' 的字典字段是可选的。
Pix2Pix的一个示例⚓︎
下面是一个在aerial2maps数据集上Pix2Pix训练的流水线示例。
source_domain = 'aerial'
target_domain = 'map'
pipeline = [
dict(
type='LoadPairedImageFromFile',
io_backend='disk',
key='pair',
domain_a=domain_a,
domain_b=domain_b,
flag='color'),
dict(
type='TransformBroadcaster',
mapping={'img': [f'img_{domain_a}', f'img_{domain_b}']},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmagic.Resize', scale=(286, 286),
interpolation='bicubic'),
dict(type='mmagic.FixedCrop', crop_size=(256, 256))
]),
dict(
type='Flip',
keys=[f'img_{domain_a}', f'img_{domain_b}'],
direction='horizontal'),
dict(
type='PackInputs',
keys=[f'img_{domain_a}', f'img_{domain_b}', 'pair'])
MMagic中支持的数据变换⚓︎
数据加载⚓︎
Transform | Modification of Results' keys |
---|---|
LoadImageFromFile
|
- add: img, img_path, img_ori_shape, \*ori_img |
RandomLoadResizeBg
|
- add: bg |
LoadMask
|
- add: mask |
GetSpatialDiscountMask
|
- add: discount_mask |
预处理⚓︎
Transform | Modification of Results' keys | |
---|---|---|
Resize
|
- add: scale_factor, keep_ratio, interpolation, backend
- update: specified by keys
|
|
MATLABLikeResize
|
- add: scale, output_shape
- update: specified by keys
|
|
RandomRotation
|
- add: degrees
- update: specified by keys
| |
Flip
|
- add: flip, flip_direction
- update: specified by keys
|
|
RandomAffine
|
- update: specified by keys
|
|
RandomJitter
|
- update: fg (img) | |
ColorJitter
|
- update: specified by keys
|
|
BinarizeImage
|
- update: specified by keys
|
|
RandomMaskDilation
|
- add: img_dilate_kernel_size | |
RandomTransposeHW
|
- add: transpose | |
RandomDownSampling
|
- update: scale, gt (img), lq (img) | |
RandomBlur
|
- update: specified by keys
|
|
RandomResize
|
- update: specified by keys
|
|
RandomNoise
|
- update: specified by keys
|
|
RandomJPEGCompression
|
- update: specified by keys
|
|
RandomVideoCompression
|
- update: specified by keys
|
|
DegradationsWithShuffle
|
- update: specified by keys
|
|
GenerateFrameIndices
|
- update: img_path (gt_path, lq_path) | |
GenerateFrameIndiceswithPadding
|
- update: img_path (gt_path, lq_path) | |
TemporalReverse
|
- add: reverse
- update: specified by keys
|
|
GenerateSegmentIndices
|
- add: interval - update: img_path (gt_path, lq_path) | |
MirrorSequence
|
- update: specified by keys
|
|
CopyValues
|
- add: specified by dst_key
|
|
UnsharpMasking
|
- add: img_unsharp | |
Crop
|
- add: img_crop_bbox, crop_size
- update: specified by keys
|
|
RandomResizedCrop
|
- add: img_crop_bbox
- update: specified by keys
|
|
FixedCrop
|
- add: crop_size, crop_pos
- update: specified by keys
|
|
PairedRandomCrop
|
- update: gt (img), lq (img) | |
CropAroundCenter
|
- add: crop_bbox - update: fg (img), alpha (img), trimap (img), bg (img) | |
CropAroundUnknown
|
- add: crop_bbox
- update: specified by keys
|
|
CropAroundFg
|
- add: crop_bbox
- update: specified by keys
|
|
ModCrop
|
- update: gt (img) | |
CropLike
|
- update: specified by target_key
|
|
GetMaskedImage
|
- add: masked_img | |
GenerateFacialHeatmap
|
- add: heatmap | |
GenerateCoordinateAndCell
|
- add: coord, cell - update: gt (img) | |
Normalize
|
- add: img_norm_cfg
- update: specified by keys
|
|
RescaleToZeroOne
|
- update: specified by keys
|
格式化⚓︎
Transform | Modification of Results' keys |
---|---|
ToTensor
|
update: specified by keys .
|
FormatTrimap
|
- update: trimap |
PackInputs
|
- add: inputs, data_sample - remove: all other keys |
Albumentations⚓︎
MMagic 支持添加 Albumentations 库中的 transformation,请浏览 https://albumentations.ai/docs/getting_started/transforms_and_targets 获取更多 transformation 的信息。
使用 Albumentations 的示例如下:
albu_transforms = [
dict(
type='Resize',
height=100,
width=100,
),
dict(
type='RandomFog',
p=0.5,
),
dict(
type='RandomRain',
p=0.5
),
dict(
type='RandomSnow',
p=0.5,
),
]
pipeline = [
dict(
type='LoadImageFromFile',
key='img',
color_type='color',
channel_order='rgb',
imdecode_backend='cv2'),
dict(
type='Albumentations',
keys=['img'],
transforms=albu_transforms),
dict(type='PackInputs')
]
扩展和使用自定义流水线⚓︎
一个简单的MyTransform示例⚓︎
- 在文件中写入一个新的流水线,例如在
my_pipeline.py
中。它接受一个字典作为输入,并返回一个字典。
import random
from mmcv.transforms import BaseTransform
from mmagic.registry import TRANSFORMS
@TRANSFORMS.register_module()
class MyTransform(BaseTransform):
"""Add your transform
Args:
p (float): Probability of shifts. Default 0.5.
"""
def __init__(self, p=0.5):
self.p = p
def transform(self, results):
if random.random() > self.p:
results['dummy'] = True
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(p={self.p})')
return repr_str
- 在你的配置文件中导入并使用该流水线。
确保导入相对于你的训练脚本所在的位置。
train_pipeline = [
...
dict(type='MyTransform', p=0.2),
...
]
一个翻转变换的示例⚓︎
这里我们以一个简单的翻转变换为例:
import random
import mmcv
from mmcv.transforms import BaseTransform, TRANSFORMS
@TRANSFORMS.register_module()
class MyFlip(BaseTransform):
def __init__(self, direction: str):
super().__init__()
self.direction = direction
def transform(self, results: dict) -> dict:
img = results['img']
results['img'] = mmcv.imflip(img, direction=self.direction)
return results
因此,我们可以实例化一个 MyFlip
对象,用它来处理数据字典。
import numpy as np
transform = MyFlip(direction='horizontal')
data_dict = {'img': np.random.rand(224, 224, 3)}
data_dict = transform(data_dict)
processed_img = data_dict['img']
或者,我们可以在配置文件的数据流水线中使用 MyFlip
变换。
pipeline = [
...
dict(type='MyFlip', direction='horizontal'),
...
]
请注意,如果你想在配置中使用 MyFlip
,你必须确保在程序运行过程中导入包含 MyFlip
的文件。
创建日期: November 27, 2023