零维护

 找回密码
 立即注册
快捷导航
搜索
热搜: 活动 交友 discuz
查看: 105|回复: 1

mmsegmentation源码阅读--FCN(一)

[复制链接]

2

主题

2

帖子

6

积分

新手上路

Rank: 1

积分
6
发表于 2022-11-26 18:06:49 | 显示全部楼层 |阅读模式
阅读本文前,请阅读相关说明,帮助您了解本系列文章
mmsegmentation的安装

mmsegmentation仓库地址:GitHub - open-mmlab/mmsegmentation: OpenMMLab Semantic Segmentation Toolbox and Benchmark.
具体安装细节可参考文档:Prerequisites - MMSegmentation 0.29.0 documentation
# 我的安装
# 创建环境
conda create -n mmseg python=3.7
conda activate mmseg

# 根据自身配置下载pytorch
pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html

# 安装mmsegmentation相关
pip install -U openmim
mim install mmcv-full

# 可手动下载mmsegmentation仓库
cd mmsegmentation
pip install -v -e .
# 安装完成后可使用文档中的测试命令进行验证安装是否成功数据集准备

数据集准备可参考文档:Prepare datasets (本系列文章以ADE20K为例)
ADE20K数据集下载地址:http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip
将下载好的数据集组织成文档中的目录组织形式
Pycharm Debug设置


  • mmsegmentation中的train.py脚本在tools文件夹下,如果直接使用Pycharm进行debug train脚本,会将tools作为默认的工作目录,从而会导致某些路径找不到或者某些包无法导入等问题,所以需要在debug train脚本的时候将工作目录设置为mmdetection的根目录
  • 在train.py中有一个必传参数config,用于指定具体的配置文件。本文以FCN为例,为此将其设置成了默认值。本文使用的配置文件为configs/fcn/fcn_r50-d8_512x512_80k_ade20k.py
def parse_args():
    parser = argparse.ArgumentParser(description='Train a segmentor')
    parser.add_argument('--config', default='configs/fcn/fcn_r50-d8_512x512_80k_ade20k.py', help='train config file path') # 必须指定的参数, 指定具体的配置文件关于配置文件的命名方式可参考文档:Tutorial 1: Learn about Configs
源码阅读

通过调试tools文件夹下的train.py脚本分析模型整个前向过程,本文尽量按照代码的执行顺序进行梳理
tools/train.py

主要用于传递和设置各种参数。与其相关联的文件为mmseg/apis/train.py,其中的train_segmentor函数在tools/train.py中被调用。
数据读取

配置文件configs/fcn/fcn_r50-d8_512x512_80k_ade20k.py中的配置信息如下,有关数据集的配置继承自_base_/datasets/ade20k.py
_base_ = [
    '../_base_/models/fcn_r50-d8.py', '../_base_/datasets/ade20k.py',
    '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
model = dict(
    decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))为此,有关数据集的配置信息存在于_base_/datasets/ade20k.py中,如下
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', reduce_zero_label=True),
    dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(2048, 512),
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    samples_per_gpu=4,
    workers_per_gpu=4,
    train=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='images/training',
        ann_dir='annotations/training',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='images/validation',
        ann_dir='annotations/validation',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='images/validation',
        ann_dir='annotations/validation',
        pipeline=test_pipeline))
在数据集配置文件中,使用的数据集为ADE20K数据集,与之相关的类在mmseg/datasets/ade.py
该类为ADE20KDataset, 其继承自CustomDataset,此类位于mmseg/datasets/custom.py
CustomDataset继承自Pytorch的Dataset类.
继承体系: ADE20KDataset --> CustomDataset --> Dataset
_getitem_方法进行了一系列的数据处理, 该函数在CustomDataset类中
class CustomDataset(Dataset):

    CLASSES = None

    PALETTE = None

    def __init__(self,
                 pipeline,
                 img_dir,
                 img_suffix='.jpg',
                 ann_dir=None,
                 seg_map_suffix='.png',
                 split=None,
                 data_root=None,
                 test_mode=False,
                 ignore_index=255,
                 reduce_zero_label=False,
                 classes=None,
                 palette=None,
                 gt_seg_map_loader_cfg=None,
                 file_client_args=dict(backend='disk')):
        self.pipeline = Compose(pipeline)  # 初始化为Compose类, pipeline即为配置文件中train_pipeline的一些列操作用于处理数据
        self.img_dir = img_dir
        self.img_suffix = img_suffix
        self.ann_dir = ann_dir
        self.seg_map_suffix = seg_map_suffix
        self.split = split
        self.data_root = data_root
        self.test_mode = test_mode
        self.ignore_index = ignore_index
        self.reduce_zero_label = reduce_zero_label
        self.label_map = None
        self.CLASSES, self.PALETTE = self.get_classes_and_palette(
            classes, palette)
        self.gt_seg_map_loader = LoadAnnotations(
        ) if gt_seg_map_loader_cfg is None else LoadAnnotations(
            **gt_seg_map_loader_cfg)

        self.file_client_args = file_client_args
        self.file_client = mmcv.FileClient.infer_client(self.file_client_args)

        if test_mode:
            assert self.CLASSES is not None, \
                '`cls.CLASSES` or `classes` should be specified when testing'

        # join paths if data_root is specified
        if self.data_root is not None:
            if not osp.isabs(self.img_dir):
                self.img_dir = osp.join(self.data_root, self.img_dir)
            if not (self.ann_dir is None or osp.isabs(self.ann_dir)):
                self.ann_dir = osp.join(self.data_root, self.ann_dir)
            if not (self.split is None or osp.isabs(self.split)):
                self.split = osp.join(self.data_root, self.split)

        # load annotations
        self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,
                                               self.ann_dir,
                                               self.seg_map_suffix, self.split)

    def __len__(self):

        return len(self.img_infos)

    def __getitem__(self, idx):

        if self.test_mode:
            return self.prepare_test_img(idx)
        else:  # 满足
            return self.prepare_train_img(idx)

    def prepare_train_img(self, idx):

        img_info = self.img_infos[idx]  # 获取图像信息
        ann_info = self.get_ann_info(idx)  # 获取图像标注信息
        results = dict(img_info=img_info, ann_info=ann_info)  # 初始化result字典 后续会被不断更新
        self.pre_pipeline(results)  # 向字典中添加一系列信息
        return self.pipeline(results)  # 调用Compose类的__call__函数进行数据处理

    def pre_pipeline(self, results):  # 向字典中添加一系列信息

        results['seg_fields'] = []
        results['img_prefix'] = self.img_dir
        results['seg_prefix'] = self.ann_dir
        if self.custom_classes:  # 不满足
            results['label_map'] = self.label_map
Compose类位于mmseg/datasets/pipelines/compose.py,如下
class Compose(object):

    def __init__(self, transforms):
        assert isinstance(transforms, collections.abc.Sequence)
        self.transforms = []  # List, 每个元素为数据集配置文件中train_pipeline的各项操作
        for transform in transforms:
            if isinstance(transform, dict):
                transform = build_from_cfg(transform, PIPELINES)
                self.transforms.append(transform)
            elif callable(transform):
                self.transforms.append(transform)
            else:
                raise TypeError('transform must be callable or a dict')

    def __call__(self, data):

        for t in self.transforms:   # 遍历每个数据操作
            # 执行t数据操作,  分别为LoadImageFromFile,LoadAnnotations,Resize,RandomCrop,RandomFlip,PhotoMetricDistortion, Normalize,Pad,DefaultFormatBundle,Collect
            # 分别调用各个类的__call__函数, 更新data中的信息
            data = t(data)
            if data is None:  # 一般不满足
                return None
        return dataLoadImageFromFile类位于mmseg/datasets/pipelines/loading.py,如下
class LoadImageFromFile(object):

    def __init__(self,
                 to_float32=False,
                 color_type='color',
                 file_client_args=dict(backend='disk'),
                 imdecode_backend='cv2'):
        self.to_float32 = to_float32
        self.color_type = color_type
        self.file_client_args = file_client_args.copy()
        self.file_client = None
        self.imdecode_backend = imdecode_backend

    def __call__(self, results):

        if self.file_client is None: # 满足
            self.file_client = mmcv.FileClient(**self.file_client_args)

        if results.get('img_prefix') is not None: # 满足
            filename = osp.join(results['img_prefix'],
                                results['img_info']['filename'])   # 组合出图像的路径
        else:
            filename = results['img_info']['filename']
        img_bytes = self.file_client.get(filename)
        img = mmcv.imfrombytes(
            img_bytes, flag=self.color_type, backend=self.imdecode_backend)  # 读取图像, ndarray类型 shape [H, W, C]
        if self.to_float32:
            img = img.astype(np.float32)

        results['filename'] = filename  # 图像路径
        results['ori_filename'] = results['img_info']['filename']  # 图像名称
        results['img'] = img  # 读入的图像
        results['img_shape'] = img.shape  # 图像尺寸,后续会被更新
        results['ori_shape'] = img.shape  # 图像原始尺寸
        # Set initial values for default meta_keys  初始化一些key,后续会被更新
        results['pad_shape'] = img.shape  # 图像尺寸,后续会被更新
        results['scale_factor'] = 1.0 # 图像scale_factor,后续会被更新
        num_channels = 1 if len(img.shape) < 3 else img.shape[2]
        results['img_norm_cfg'] = dict(  # norm参数,后续会被更新
            mean=np.zeros(num_channels, dtype=np.float32),
            std=np.ones(num_channels, dtype=np.float32),
            to_rgb=False)
        return resultsLoadAnnotations类位于mmseg/datasets/pipelines/loading.py,如下
class LoadAnnotations(object):

    def __init__(self,
                 reduce_zero_label=False,
                 file_client_args=dict(backend='disk'),
                 imdecode_backend='pillow'):
        self.reduce_zero_label = reduce_zero_label
        self.file_client_args = file_client_args.copy()
        self.file_client = None
        self.imdecode_backend = imdecode_backend

    def __call__(self, results):

        if self.file_client is None:  # 满足
            self.file_client = mmcv.FileClient(**self.file_client_args)

        if results.get('seg_prefix', None) is not None:  # 满足
            filename = osp.join(results['seg_prefix'],
                                results['ann_info']['seg_map'])  # 组合出标注的路径
        else:
            filename = results['ann_info']['seg_map']
        img_bytes = self.file_client.get(filename)
        gt_semantic_seg = mmcv.imfrombytes(
            img_bytes, flag='unchanged',
            backend=self.imdecode_backend).squeeze().astype(np.uint8)  # 读取标注,ndarray类型 shape [H, W]
        # modify if custom classes
        if results.get('label_map', None) is not None:  # 不满足
            # Add deep copy to solve bug of repeatedly
            # replace `gt_semantic_seg`, which is reported in
            # https://github.com/open-mmlab/mmsegmentation/pull/1445/
            gt_semantic_seg_copy = gt_semantic_seg.copy()
            for old_id, new_id in results['label_map'].items():
                gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
        # reduce zero_label
        if self.reduce_zero_label:  # 该参数对于不同数据集会有不同的设置 对于ADE20K默认为True,因为0代表背景,但是不包含在ADE20K的150个类别中
            # avoid using underflow conversion
            gt_semantic_seg[gt_semantic_seg == 0] = 255  # reduce_zero_label转换
            gt_semantic_seg = gt_semantic_seg - 1
            gt_semantic_seg[gt_semantic_seg == 254] = 255
        results['gt_semantic_seg'] = gt_semantic_seg  # 添加标注
        results['seg_fields'].append('gt_semantic_seg')
        return resultsResize类位于mmseg/datasets/pipelines/transforms.py,如下
class Resize(object):
   
    def __init__(self,
                 img_scale=None,
                 multiscale_mode='range',
                 ratio_range=None,
                 keep_ratio=True,
                 min_size=None):
        if img_scale is None:
            self.img_scale = None
        else:
            if isinstance(img_scale, list):
                self.img_scale = img_scale
            else:
                self.img_scale = [img_scale]
            assert mmcv.is_list_of(self.img_scale, tuple)

        if ratio_range is not None:
            # mode 1: given img_scale=None and a range of image ratio
            # mode 2: given a scale and a range of image ratio
            assert self.img_scale is None or len(self.img_scale) == 1
        else:
            # mode 3 and 4: given multiple scales or a range of scales
            assert multiscale_mode in ['value', 'range']

        self.multiscale_mode = multiscale_mode
        self.ratio_range = ratio_range
        self.keep_ratio = keep_ratio
        self.min_size = min_size

    @staticmethod
    def random_sample_ratio(img_scale, ratio_range):

        assert isinstance(img_scale, tuple) and len(img_scale) == 2
        min_ratio, max_ratio = ratio_range
        assert min_ratio <= max_ratio
        ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio  # 在ratio_range中随机选择一个ratio
        scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)  # 获得目标scale尺度范围
        return scale, None
   
    def _random_scale(self, results):
        
        if self.ratio_range is not None:  # 满足
            if self.img_scale is None:
                h, w = results['img'].shape[:2]
                scale, scale_idx = self.random_sample_ratio((w, h),
                                                            self.ratio_range)
            else:  # 满足
                scale, scale_idx = self.random_sample_ratio(
                    self.img_scale[0], self.ratio_range)
        elif len(self.img_scale) == 1:
            scale, scale_idx = self.img_scale[0], 0
        elif self.multiscale_mode == 'range':
            scale, scale_idx = self.random_sample(self.img_scale)
        elif self.multiscale_mode == 'value':
            scale, scale_idx = self.random_select(self.img_scale)
        else:
            raise NotImplementedError

        results['scale'] = scale  # 添加scale信息
        results['scale_idx'] = scale_idx  # None

    def _resize_img(self, results):

        if self.keep_ratio:  # 满足
            if self.min_size is not None:  # 不满足
                # TODO: Now 'min_size' is an 'int' which means the minimum
                # shape of images is (min_size, min_size, 3). 'min_size'
                # with tuple type will be supported, i.e. the width and
                # height are not equal.
                if min(results['scale']) < self.min_size:
                    new_short = self.min_size
                else:
                    new_short = min(results['scale'])

                h, w = results['img'].shape[:2]
                if h > w:
                    new_h, new_w = new_short * h / w, new_short
                else:
                    new_h, new_w = new_short, new_short * w / h
                results['scale'] = (new_h, new_w)

            img, scale_factor = mmcv.imrescale(   # 图像将被resize到scale内且尽可能的大
                results['img'], results['scale'], return_scale=True)
            # the w_scale and h_scale has minor difference
            # a real fix should be done in the mmcv.imrescale in the future
            new_h, new_w = img.shape[:2]   # resize后图像的高 宽
            h, w = results['img'].shape[:2]  # resize前图像的高 宽
            w_scale = new_w / w  # 宽的放缩尺度
            h_scale = new_h / h  # 高的放缩尺度
        else:
            img, w_scale, h_scale = mmcv.imresize(
                results['img'], results['scale'], return_scale=True)
        scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
                                dtype=np.float32)
        results['img'] = img  # 更新图像
        results['img_shape'] = img.shape  # resize后图像尺寸
        results['pad_shape'] = img.shape  # result添加pad_shape,pad后图像尺寸, 后续会被更新
        results['scale_factor'] = scale_factor # result添加scale_factor, ndarray, shape [4,]
        results['keep_ratio'] = self.keep_ratio # result添加keep_ratio

    def _resize_seg(self, results):
        """Resize semantic segmentation map with ``results['scale']``."""
        for key in results.get('seg_fields', []):
            if self.keep_ratio:  # 满足
                gt_seg = mmcv.imrescale(  # 标注将被resize到scale内且尽可能的大
                    results[key], results['scale'], interpolation='nearest')
            else:
                gt_seg = mmcv.imresize(
                    results[key], results['scale'], interpolation='nearest')
            results[key] = gt_seg  # 更新标注

    def __call__(self, results):

        if 'scale' not in results:  # 满足
            self._random_scale(results)
        self._resize_img(results)  # resize图像
        self._resize_seg(results)  # resize标注
        return results
RandomCrop类位于mmseg/datasets/pipelines/transforms.py,如下
class RandomCrop(object):

    def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255):
        assert crop_size[0] > 0 and crop_size[1] > 0
        self.crop_size = crop_size
        self.cat_max_ratio = cat_max_ratio
        self.ignore_index = ignore_index

    def get_crop_bbox(self, img):

        margin_h = max(img.shape[0] - self.crop_size[0], 0)
        margin_w = max(img.shape[1] - self.crop_size[1], 0)
        offset_h = np.random.randint(0, margin_h + 1)  # 随机选取crop的左上初始点
        offset_w = np.random.randint(0, margin_w + 1)
        crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]  # 获取裁减区域
        crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]

        return crop_y1, crop_y2, crop_x1, crop_x2  

    def crop(self, img, crop_bbox):

        crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
        img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
        return img

    def __call__(self, results):

        img = results['img']
        crop_bbox = self.get_crop_bbox(img)  # 获取crop区域
        if self.cat_max_ratio < 1.:  # 满足
            # Repeat 10 times
            for _ in range(10):
                seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox)  # crop标注
                labels, cnt = np.unique(seg_temp, return_counts=True)
                cnt = cnt[labels != self.ignore_index]

                # len(cnt) > 1确保crop区域内存在前景  
                # np.max(cnt) / np.sum(cnt) < self.cat_max_ratio 确保前景中的某个类别的像素占所有前景像素比例小于cat_max_ratio
                # 为了保证crop后前景类别区域尽可能多  
                if len(cnt) > 1 and np.max(cnt) / np.sum(cnt) < self.cat_max_ratio:
                    break
                crop_bbox = self.get_crop_bbox(img)  # 重新获取crop区域

        # crop the image
        img = self.crop(img, crop_bbox)  # crop图像
        img_shape = img.shape
        results['img'] = img  # 更新图像
        results['img_shape'] = img_shape  # 更新为crop后的图像尺寸

        # crop semantic seg
        for key in results.get('seg_fields', []):
            results[key] = self.crop(results[key], crop_bbox)  # crop标注

        return resultsRandomFlip类位于mmseg/datasets/pipelines/transforms.py,如下
class RandomFlip(object):

    @deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip')
    def __init__(self, prob=None, direction='horizontal'):
        self.prob = prob
        self.direction = direction
        if prob is not None:
            assert prob >= 0 and prob <= 1
        assert direction in ['horizontal', 'vertical']

    def __call__(self, results):

        if 'flip' not in results:  # 满足
            flip = True if np.random.rand() < self.prob else False
            results['flip'] = flip  # 添加是否flip信息
        if 'flip_direction' not in results:
            results['flip_direction'] = self.direction  # 添加flip方向信息 默认为horizontal
        if results['flip']:
            # flip image
            results['img'] = mmcv.imflip(  # flip图像
                results['img'], direction=results['flip_direction'])

            # flip segs
            for key in results.get('seg_fields', []):
                # use copy() to make numpy stride positive
                results[key] = mmcv.imflip(  # flip标注
                    results[key], direction=results['flip_direction']).copy()
        return results
PhotoMetricDistortion类位于mmseg/datasets/pipelines/transforms.py,如下
class PhotoMetricDistortion(object):

    def __init__(self,
                 brightness_delta=32,
                 contrast_range=(0.5, 1.5),
                 saturation_range=(0.5, 1.5),
                 hue_delta=18):
        self.brightness_delta = brightness_delta
        self.contrast_lower, self.contrast_upper = contrast_range
        self.saturation_lower, self.saturation_upper = saturation_range
        self.hue_delta = hue_delta

    def convert(self, img, alpha=1, beta=0):

        img = img.astype(np.float32) * alpha + beta
        img = np.clip(img, 0, 255)
        return img.astype(np.uint8)

    def brightness(self, img):

        if random.randint(2):  # 返回0或1
            return self.convert(
                img,
                beta=random.uniform(-self.brightness_delta,  # 调整亮度
                                    self.brightness_delta))
        return img

    def contrast(self, img):

        if random.randint(2):
            return self.convert(
                img,
                alpha=random.uniform(self.contrast_lower, self.contrast_upper))  # 调整对比度
        return img

    def saturation(self, img):

        if random.randint(2):
            img = mmcv.bgr2hsv(img)
            img[:, :, 1] = self.convert(
                img[:, :, 1],
                alpha=random.uniform(self.saturation_lower,
                                     self.saturation_upper))  # 调整饱和度
            img = mmcv.hsv2bgr(img)
        return img

    def hue(self, img):

        if random.randint(2):
            img = mmcv.bgr2hsv(img)
            img[:, :,
                0] = (img[:, :, 0].astype(int) +
                      random.randint(-self.hue_delta, self.hue_delta)) % 180  # 调整色相
            img = mmcv.hsv2bgr(img)
        return img

    def __call__(self, results):
      
        img = results['img']  # 获取图像
        # random brightness
        img = self.brightness(img)  # 调整亮度

        # mode == 0 --> do random contrast first
        # mode == 1 --> do random contrast last
        mode = random.randint(2)
        if mode == 1:
            img = self.contrast(img) # 调整对比度

        # random saturation
        img = self.saturation(img)  # 调整饱和度

        # random hue
        img = self.hue(img)  # 调整色相

        # random contrast
        if mode == 0:
            img = self.contrast(img)

        results['img'] = img
        return resultsNormalize类位于mmseg/datasets/pipelines/transforms.py,如下
class Normalize(object):

    def __init__(self, mean, std, to_rgb=True):
        self.mean = np.array(mean, dtype=np.float32)
        self.std = np.array(std, dtype=np.float32)
        self.to_rgb = to_rgb

    def __call__(self, results):

        results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std,
                                          self.to_rgb)  # 对图像进行normalize
        results['img_norm_cfg'] = dict(  # 更新img_norm_cfg信息
            mean=self.mean, std=self.std, to_rgb=self.to_rgb)
        return resultsPad类位于
class Pad(object):

    def __init__(self,
                 size=None,
                 size_divisor=None,
                 pad_val=0,
                 seg_pad_val=255):
        self.size = size
        self.size_divisor = size_divisor
        self.pad_val = pad_val
        self.seg_pad_val = seg_pad_val
        # only one of size and size_divisor should be valid
        assert size is not None or size_divisor is not None
        assert size is None or size_divisor is None

    def _pad_img(self, results):

        if self.size is not None:  # 满足
            padded_img = mmcv.impad(  # 对img pad 0,pad到crop size
                results['img'], shape=self.size, pad_val=self.pad_val)
        elif self.size_divisor is not None:
            padded_img = mmcv.impad_to_multiple(
                results['img'], self.size_divisor, pad_val=self.pad_val)
        results['img'] = padded_img  # 更新图像
        results['pad_shape'] = padded_img.shape  # pad后图像的shape
        results['pad_fixed_size'] = self.size  # pad到固定尺寸尺寸
        results['pad_size_divisor'] = self.size_divisor  # None

    def _pad_seg(self, results):

        for key in results.get('seg_fields', []):
            results[key] = mmcv.impad(  # 对标注pad 255,pad到crop size
                results[key],
                shape=results['pad_shape'][:2],
                pad_val=self.seg_pad_val)

    def __call__(self, results):
        
        self._pad_img(results)  # pad图像
        self._pad_seg(results)  # pad标注
        return resultsDefaultFormatBundle类位于mmseg/datasets/pipelines/formatting.py,如下
class DefaultFormatBundle(object):

    def __call__(self, results):

        if 'img' in results:  # 满足
            img = results['img']  # 获取图像
            if len(img.shape) < 3:
                img = np.expand_dims(img, -1)
            img = np.ascontiguousarray(img.transpose(2, 0, 1))  # [H,W,C]--> [C,H,W]
            results['img'] = DC(to_tensor(img), stack=True) # 转为tensor, 并封装为DataContainer
        if 'gt_semantic_seg' in results:
            # convert to long
            results['gt_semantic_seg'] = DC(
                to_tensor(results['gt_semantic_seg'][None,
                                                     ...].astype(np.int64)),
                stack=True)  # 新增channel维度,转为tensor, 并封装为DataContainer
        return results
Collect类位于mmseg/datasets/pipelines/formatting.py,如下
class Collect(object):

    def __init__(self,
                 keys,
                 meta_keys=('filename', 'ori_filename', 'ori_shape',
                            'img_shape', 'pad_shape', 'scale_factor', 'flip',
                            'flip_direction', 'img_norm_cfg')):
        self.keys = keys
        self.meta_keys = meta_keys

    def __call__(self, results):

        data = {}
        img_meta = {}  # 用于存储图像一些基本信息
        for key in self.meta_keys:
            img_meta[key] = results[key]
        data['img_metas'] = DC(img_meta, cpu_only=True)  # 被封装进DataContainer
        for key in self.keys:
            data[key] = results[key] # 获取与任务相关的信息
        return data经过上述一系列处理 ,__getitem__方法最终输出一个字典。Dict[str,DataContainer]内部含有3个元素img_metas:Dict[str, *]、img:Tensor [3, H, W]和gt_semantic_seg:Tensor [1,H,W]
回复

使用道具 举报

0

主题

2

帖子

3

积分

新手上路

Rank: 1

积分
3
发表于 2025-3-27 00:25:05 | 显示全部楼层
楼主呀,,,您太有才了。。。
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Archiver| 手机版| 小黑屋| 零维护

GMT+8, 2025-4-8 15:16 , Processed in 0.112188 second(s), 23 queries .

Powered by Discuz! X3.4

Copyright © 2020, LianLian.

快速回复 返回顶部 返回列表