Skip to content

Commit 82b28bb

Browse files
committed
refactor and udpate
1 parent 5c2e359 commit 82b28bb

24 files changed

Lines changed: 2670 additions & 719 deletions

data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .build import build_loader

data/build.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# --------------------------------------------------------
2+
# RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
3+
# Github source: https://github.com/DingXiaoH/RepVGG
4+
# Licensed under The MIT License [see LICENSE for details]
5+
# The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer)
6+
# --------------------------------------------------------
7+
import torch
8+
import numpy as np
9+
import torch.distributed as dist
10+
from torchvision import datasets, transforms
11+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
12+
from timm.data import Mixup
13+
from timm.data import create_transform
14+
try:
15+
from timm.data.transforms import str_to_pil_interp as _pil_interp
16+
except:
17+
from timm.data.transforms import _pil_interp
18+
from .cached_image_folder import CachedImageFolder
19+
from .samplers import SubsetRandomSampler
20+
21+
22+
def build_loader(config):
23+
config.defrost()
24+
dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
25+
config.freeze()
26+
print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
27+
dataset_val, _ = build_dataset(is_train=False, config=config)
28+
print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
29+
30+
num_tasks = dist.get_world_size()
31+
global_rank = dist.get_rank()
32+
if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
33+
indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
34+
sampler_train = SubsetRandomSampler(indices)
35+
else:
36+
sampler_train = torch.utils.data.DistributedSampler(
37+
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
38+
)
39+
40+
if dataset_val is None:
41+
sampler_val = None
42+
else:
43+
indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size()) #TODO
44+
sampler_val = SubsetRandomSampler(indices)
45+
46+
data_loader_train = torch.utils.data.DataLoader(
47+
dataset_train, sampler=sampler_train,
48+
batch_size=config.DATA.BATCH_SIZE,
49+
num_workers=config.DATA.NUM_WORKERS,
50+
pin_memory=config.DATA.PIN_MEMORY,
51+
drop_last=True,
52+
)
53+
54+
if dataset_val is None:
55+
data_loader_val = None
56+
else:
57+
data_loader_val = torch.utils.data.DataLoader(
58+
dataset_val, sampler=sampler_val,
59+
batch_size=config.DATA.TEST_BATCH_SIZE,
60+
shuffle=False,
61+
num_workers=config.DATA.NUM_WORKERS,
62+
pin_memory=config.DATA.PIN_MEMORY,
63+
drop_last=False
64+
)
65+
66+
# setup mixup / cutmix
67+
mixup_fn = None
68+
mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
69+
if mixup_active:
70+
mixup_fn = Mixup(
71+
mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
72+
prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
73+
label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
74+
75+
return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
76+
77+
78+
def build_dataset(is_train, config):
79+
if config.DATA.DATASET == 'imagenet':
80+
transform = build_transform(is_train, config)
81+
prefix = 'train' if is_train else 'val'
82+
if config.DATA.ZIP_MODE:
83+
ann_file = prefix + "_map.txt"
84+
prefix = prefix + ".zip@/"
85+
dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
86+
cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
87+
else:
88+
import torchvision
89+
print('use raw ImageNet data')
90+
dataset = torchvision.datasets.ImageNet(root=config.DATA.DATA_PATH, split='train' if is_train else 'val', transform=transform)
91+
nb_classes = 1000
92+
93+
elif config.DATA.DATASET == 'cf100':
94+
mean = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
95+
std = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]
96+
if is_train:
97+
transform = transforms.Compose([
98+
transforms.RandomCrop(32, padding=4),
99+
transforms.RandomHorizontalFlip(),
100+
transforms.ToTensor(),
101+
transforms.Normalize(mean, std)
102+
])
103+
dataset = datasets.CIFAR100(root=config.DATA.DATA_PATH, train=True, download=True, transform=transform)
104+
else:
105+
transform = transforms.Compose(
106+
[transforms.ToTensor(),
107+
transforms.Normalize(mean, std)])
108+
dataset = datasets.CIFAR100(root=config.DATA.DATA_PATH, train=False, download=True, transform=transform)
109+
nb_classes = 100
110+
111+
else:
112+
raise NotImplementedError("We only support ImageNet and CIFAR-100 now.")
113+
114+
return dataset, nb_classes
115+
116+
117+
def build_transform(is_train, config):
118+
resize_im = config.DATA.IMG_SIZE > 32
119+
if is_train:
120+
# this should always dispatch to transforms_imagenet_train
121+
122+
if config.AUG.PRESET is None:
123+
transform = create_transform(
124+
input_size=config.DATA.IMG_SIZE,
125+
is_training=True,
126+
color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
127+
auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
128+
re_prob=config.AUG.REPROB,
129+
re_mode=config.AUG.REMODE,
130+
re_count=config.AUG.RECOUNT,
131+
interpolation=config.DATA.INTERPOLATION,
132+
)
133+
print('=============================== original AUG! ', config.AUG.AUTO_AUGMENT)
134+
if not resize_im:
135+
# replace RandomResizedCropAndInterpolation with
136+
# RandomCrop
137+
transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
138+
139+
elif config.AUG.PRESET.strip() == 'raug15':
140+
from train.randaug import RandAugPolicy
141+
transform = transforms.Compose([
142+
transforms.RandomResizedCrop(config.DATA.IMG_SIZE),
143+
transforms.RandomHorizontalFlip(),
144+
RandAugPolicy(magnitude=15),
145+
transforms.ToTensor(),
146+
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
147+
])
148+
print('---------------------- RAND AUG 15 distortion!')
149+
150+
elif config.AUG.PRESET.strip() == 'weak':
151+
transform = transforms.Compose([
152+
transforms.RandomResizedCrop(config.DATA.IMG_SIZE),
153+
transforms.RandomHorizontalFlip(),
154+
transforms.ToTensor(),
155+
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
156+
])
157+
elif config.AUG.PRESET.strip() == 'none':
158+
transform = transforms.Compose([
159+
transforms.Resize(config.DATA.IMG_SIZE, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
160+
transforms.CenterCrop(config.DATA.IMG_SIZE),
161+
transforms.ToTensor(),
162+
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
163+
])
164+
else:
165+
raise ValueError('???' + config.AUG.PRESET)
166+
print(transform)
167+
return transform
168+
169+
t = []
170+
if resize_im:
171+
if config.TEST.CROP:
172+
size = int((256 / 224) * config.DATA.TEST_SIZE)
173+
t.append(transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
174+
# to maintain same ratio w.r.t. 224 images
175+
)
176+
t.append(transforms.CenterCrop(config.DATA.TEST_SIZE))
177+
else:
178+
# default for testing
179+
t.append(transforms.Resize(config.DATA.TEST_SIZE, interpolation=_pil_interp(config.DATA.INTERPOLATION)))
180+
t.append(transforms.CenterCrop(config.DATA.TEST_SIZE))
181+
t.append(transforms.ToTensor())
182+
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
183+
trans = transforms.Compose(t)
184+
return trans

0 commit comments

Comments
 (0)