跳转至

如何设计自己的数据变换⚓︎

在本教程中,我们将介绍MMagic中变换流水线的设计。

The structure of this guide are as follows:

MMagic中的数据流水线⚓︎

按照典型的惯例,我们使用 DatasetDataLoader 来加载多个线程的数据。 Dataset 返回一个与模型的forward方法的参数相对应的数据项的字典。

数据准备流水线和数据集是分开的。通常,一个数据集定义了如何处理标注,而一个数据管道定义了准备一个数据字典的所有步骤。

一个流水线由一连串的操作组成。每个操作都需要一个字典作为输入,并为下一个变换输出一个字典。

这些操作被分为数据加载、预处理和格式化。

在MMagic中,所有数据变换都继承自 BaseTransform。 变换的输入和输出类型都是字典。

数据变换的一个简单示例⚓︎

>>> from mmagic.transforms import LoadPairedImageFromFile
>>> transforms = LoadPairedImageFromFile(
>>>     key='pair',
>>>     domain_a='horse',
>>>     domain_b='zebra',
>>>     flag='color'),
>>> data_dict = {'pair_path': './data/pix2pix/facades/train/1.png'}
>>> data_dict = transforms(data_dict)
>>> print(data_dict.keys())
dict_keys(['pair_path', 'pair', 'pair_ori_shape', 'img_mask', 'img_photo', 'img_mask_path', 'img_photo_path', 'img_mask_ori_shape', 'img_photo_ori_shape'])

一般来说,变换流水线的最后一步必须是 PackInputs. PackInputs 将把处理过的数据打包成一个包含两个字段的字典:inputsdata_samples. inputs 是你想用作模型输入的变量,它可以是 torch.Tensor 的类型, torch.Tensor 的字典,或者你想要的任何类型。 data_samples 是一个 DataSample 的列表. 每个 DataSample 都包含真实值和对应输入的必要信息。

BasicVSR的一个示例⚓︎

下面是一个BasicVSR的流水线示例。

train_pipeline = [
    dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
    dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
    dict(type='SetValues', dictionary=dict(scale=scale)),
    dict(type='PairedRandomCrop', gt_patch_size=256),
    dict(
        type='Flip',
        keys=['img', 'gt'],
        flip_ratio=0.5,
        direction='horizontal'),
    dict(
        type='Flip', keys=['img', 'gt'], flip_ratio=0.5, direction='vertical'),
    dict(type='RandomTransposeHW', keys=['img', 'gt'], transpose_ratio=0.5),
    dict(type='MirrorSequence', keys=['img', 'gt']),
    dict(type='PackInputs')
]

val_pipeline = [
    dict(type='GenerateSegmentIndices', interval_list=[1]),
    dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
    dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
    dict(type='PackInputs')
]

test_pipeline = [
    dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
    dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
    dict(type='MirrorSequence', keys=['img']),
    dict(type='PackInputs')
]

对于每个操作,我们列出了添加/更新/删除的相关字典字段,标记为 '*' 的字典字段是可选的。

Pix2Pix的一个示例⚓︎

下面是一个在aerial2maps数据集上Pix2Pix训练的流水线示例。

source_domain = 'aerial'
target_domain = 'map'

pipeline = [
    dict(
        type='LoadPairedImageFromFile',
        io_backend='disk',
        key='pair',
        domain_a=domain_a,
        domain_b=domain_b,
        flag='color'),
    dict(
        type='TransformBroadcaster',
        mapping={'img': [f'img_{domain_a}', f'img_{domain_b}']},
        auto_remap=True,
        share_random_params=True,
        transforms=[
            dict(
                type='mmagic.Resize', scale=(286, 286),
                interpolation='bicubic'),
            dict(type='mmagic.FixedCrop', crop_size=(256, 256))
        ]),
    dict(
        type='Flip',
        keys=[f'img_{domain_a}', f'img_{domain_b}'],
        direction='horizontal'),
    dict(
        type='PackInputs',
        keys=[f'img_{domain_a}', f'img_{domain_b}', 'pair'])

MMagic中支持的数据变换⚓︎

数据加载⚓︎

Transform Modification of Results' keys
LoadImageFromFile - add: img, img_path, img_ori_shape, \*ori_img
RandomLoadResizeBg - add: bg
LoadMask - add: mask
GetSpatialDiscountMask - add: discount_mask

预处理⚓︎

Transform Modification of Results' keys
Resize - add: scale_factor, keep_ratio, interpolation, backend - update: specified by keys
MATLABLikeResize - add: scale, output_shape - update: specified by keys
RandomRotation - add: degrees - update: specified by keys
Flip - add: flip, flip_direction - update: specified by keys
RandomAffine - update: specified by keys
RandomJitter - update: fg (img)
ColorJitter - update: specified by keys
BinarizeImage - update: specified by keys
RandomMaskDilation - add: img_dilate_kernel_size
RandomTransposeHW - add: transpose
RandomDownSampling - update: scale, gt (img), lq (img)
RandomBlur - update: specified by keys
RandomResize - update: specified by keys
RandomNoise - update: specified by keys
RandomJPEGCompression - update: specified by keys
RandomVideoCompression - update: specified by keys
DegradationsWithShuffle - update: specified by keys
GenerateFrameIndices - update: img_path (gt_path, lq_path)
GenerateFrameIndiceswithPadding - update: img_path (gt_path, lq_path)
TemporalReverse - add: reverse - update: specified by keys
GenerateSegmentIndices - add: interval - update: img_path (gt_path, lq_path)
MirrorSequence - update: specified by keys
CopyValues - add: specified by dst_key
UnsharpMasking - add: img_unsharp
Crop - add: img_crop_bbox, crop_size - update: specified by keys
RandomResizedCrop - add: img_crop_bbox - update: specified by keys
FixedCrop - add: crop_size, crop_pos - update: specified by keys
PairedRandomCrop - update: gt (img), lq (img)
CropAroundCenter - add: crop_bbox - update: fg (img), alpha (img), trimap (img), bg (img)
CropAroundUnknown - add: crop_bbox - update: specified by keys
CropAroundFg - add: crop_bbox - update: specified by keys
ModCrop - update: gt (img)
CropLike - update: specified by target_key
GetMaskedImage - add: masked_img
GenerateFacialHeatmap - add: heatmap
GenerateCoordinateAndCell - add: coord, cell - update: gt (img)
Normalize - add: img_norm_cfg - update: specified by keys
RescaleToZeroOne - update: specified by keys

格式化⚓︎

Transform Modification of Results' keys
ToTensor update: specified by keys.
FormatTrimap - update: trimap
PackInputs - add: inputs, data_sample - remove: all other keys

Albumentations⚓︎

MMagic 支持添加 Albumentations 库中的 transformation,请浏览 https://albumentations.ai/docs/getting_started/transforms_and_targets 获取更多 transformation 的信息。

使用 Albumentations 的示例如下:

albu_transforms = [
   dict(
         type='Resize',
         height=100,
         width=100,
   ),
   dict(
         type='RandomFog',
         p=0.5,
   ),
   dict(
         type='RandomRain',
         p=0.5
   ),
   dict(
         type='RandomSnow',
         p=0.5,
   ),
]
pipeline = [
   dict(
         type='LoadImageFromFile',
         key='img',
         color_type='color',
         channel_order='rgb',
         imdecode_backend='cv2'),
   dict(
         type='Albumentations',
         keys=['img'],
         transforms=albu_transforms),
   dict(type='PackInputs')
]

扩展和使用自定义流水线⚓︎

一个简单的MyTransform示例⚓︎

  1. 在文件中写入一个新的流水线,例如在 my_pipeline.py中。它接受一个字典作为输入,并返回一个字典。
import random
from mmcv.transforms import BaseTransform
from mmagic.registry import TRANSFORMS


@TRANSFORMS.register_module()
class MyTransform(BaseTransform):
    """Add your transform

    Args:
        p (float): Probability of shifts. Default 0.5.
    """

    def __init__(self, p=0.5):
        self.p = p

    def transform(self, results):
        if random.random() > self.p:
            results['dummy'] = True
        return results

    def __repr__(self):

        repr_str = self.__class__.__name__
        repr_str += (f'(p={self.p})')

        return repr_str
  1. 在你的配置文件中导入并使用该流水线。

确保导入相对于你的训练脚本所在的位置。

train_pipeline = [
    ...
    dict(type='MyTransform', p=0.2),
    ...
]

一个翻转变换的示例⚓︎

这里我们以一个简单的翻转变换为例:

import random
import mmcv
from mmcv.transforms import BaseTransform, TRANSFORMS

@TRANSFORMS.register_module()
class MyFlip(BaseTransform):
    def __init__(self, direction: str):
        super().__init__()
        self.direction = direction

    def transform(self, results: dict) -> dict:
        img = results['img']
        results['img'] = mmcv.imflip(img, direction=self.direction)
        return results

因此,我们可以实例化一个 MyFlip 对象,用它来处理数据字典。

import numpy as np

transform = MyFlip(direction='horizontal')
data_dict = {'img': np.random.rand(224, 224, 3)}
data_dict = transform(data_dict)
processed_img = data_dict['img']

或者,我们可以在配置文件的数据流水线中使用 MyFlip 变换。

pipeline = [
    ...
    dict(type='MyFlip', direction='horizontal'),
    ...
]

请注意,如果你想在配置中使用 MyFlip ,你必须确保在程序运行过程中导入包含 MyFlip 的文件。


最后更新: November 27, 2023
创建日期: November 27, 2023