Skip to content

Commit

Permalink
refine swin transformer, fix 1x1, update results
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouyu committed Oct 10, 2023
1 parent b579739 commit b4feb66
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 47 deletions.
2 changes: 2 additions & 0 deletions training/benchmarks/swin_transformer/pytorch/config/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
vendor: str = None
# model name
name: str = "swin_transformer"
cudnn_benchmark: bool = False
cudnn_deterministic: bool = True

# -----------------------------------------------------------------------------
# Data settings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
from .cached_image_folder import CachedImageFolder
from .samplers import SubsetRandomSampler

from driver import dist_pytorch

try:
from torchvision.transforms import InterpolationMode


def _pil_interp(method):
if method == 'bicubic':
return InterpolationMode.BICUBIC
Expand All @@ -31,7 +32,6 @@ def _pil_interp(method):
# default bilinear, do we want to allow nearest?
return InterpolationMode.BILINEAR


import timm.data.transforms as timm_transforms

timm_transforms._pil_interp = _pil_interp
Expand All @@ -43,52 +43,62 @@ def build_loader(config):
# config.defrost()
dataset_train, config.model_num_classes = build_dataset(is_train=True, config=config)
# config.freeze()
print(f"local rank {config.local_rank} / global rank {dist.get_rank()} successfully build train dataset")

# bugfix for single-card training
if dist_pytorch.is_dist_avail_and_initialized():
print(f"local rank {config.local_rank} / global rank {dist.get_rank()} successfully build train dataset")
dataset_val, _ = build_dataset(is_train=False, config=config)
print(f"local rank {config.local_rank} / global rank {dist.get_rank()} successfully build val dataset")

num_tasks = dist.get_world_size()
global_rank = dist.get_rank()
# bugfix for single-card training
if dist_pytorch.is_dist_avail_and_initialized():
print(f"local rank {config.local_rank} / global rank {dist.get_rank()} successfully build val dataset")

num_tasks = dist.get_world_size() if dist_pytorch.is_dist_avail_and_initialized() else 1
global_rank = dist.get_rank() if dist_pytorch.is_dist_avail_and_initialized() else 0
if config.data_zip_mode and config.data_cache_mode == 'part':
indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
indices = np.arange(global_rank, len(dataset_train), num_tasks)
sampler_train = SubsetRandomSampler(indices)
else:
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)

if config.test_sequential:
if config.test_sequential or num_tasks == 1:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
else:
sampler_val = torch.utils.data.distributed.DistributedSampler(
dataset_val, shuffle=config.test_shuffle
)
dataset_val, shuffle=config.test_shuffle)

data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
dataset_train,
sampler=sampler_train,
batch_size=config.train_batch_size,
num_workers=config.data_num_workers,
pin_memory=config.data_pin_memory,
drop_last=True,
)

data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
dataset_val,
sampler=sampler_val,
batch_size=config.train_batch_size,
shuffle=False,
num_workers=config.data_num_workers,
pin_memory=config.data_pin_memory,
drop_last=False
)
drop_last=False)

# setup mixup / cutmix
mixup_fn = None
mixup_active = config.aug_mixup > 0 or config.aug_cutmix > 0. or config.aug_cutmix_minmax is not None
if mixup_active:
mixup_fn = Mixup(
mixup_alpha=config.aug_mixup, cutmix_alpha=config.aug_cutmix, cutmix_minmax=config.aug_cutmix_minmax,
prob=config.aug_mixup_prob, switch_prob=config.aug_mixup_switch_prob, mode=config.aug_mixup_mode,
label_smoothing=config.model_label_smoothing, num_classes=config.model_num_classes)
mixup_fn = Mixup(mixup_alpha=config.aug_mixup,
cutmix_alpha=config.aug_cutmix,
cutmix_minmax=config.aug_cutmix_minmax,
prob=config.aug_mixup_prob,
switch_prob=config.aug_mixup_switch_prob,
mode=config.aug_mixup_mode,
label_smoothing=config.model_label_smoothing,
num_classes=config.model_num_classes)

return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn

Expand All @@ -100,8 +110,12 @@ def build_dataset(is_train, config):
if config.data_zip_mode:
ann_file = prefix + "_map.txt"
prefix = prefix + ".zip@/"
dataset = CachedImageFolder(config.data_dir, ann_file, prefix, transform,
cache_mode=config.data_cache_mode if is_train else 'part')
dataset = CachedImageFolder(
config.data_dir,
ann_file,
prefix,
transform,
cache_mode=config.data_cache_mode if is_train else 'part')
else:
root = os.path.join(config.data_dir, prefix)
dataset = datasets.ImageFolder(root, transform=transform)
Expand All @@ -119,8 +133,10 @@ def build_transform(is_train, config):
transform = create_transform(
input_size=config.data_img_size,
is_training=True,
color_jitter=config.aug_color_jitter if config.aug_color_jitter > 0 else None,
auto_augment=config.aug_auto_augment if config.aug_auto_augment != 'none' else None,
color_jitter=config.aug_color_jitter
if config.aug_color_jitter > 0 else None,
auto_augment=config.aug_auto_augment
if config.aug_auto_augment != 'none' else None,
re_prob=config.aug_reprob,
re_mode=config.aug_remode,
re_count=config.aug_recount,
Expand All @@ -129,23 +145,26 @@ def build_transform(is_train, config):
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(config.data_img_size, padding=4)
transform.transforms[0] = transforms.RandomCrop(
config.data_img_size, padding=4)
return transform

t = []
if resize_im:
if config.test_crop:
size = int((256 / 224) * config.data_img_size)
t.append(
transforms.Resize(size, interpolation=_pil_interp(config.data_interpolation)),
transforms.Resize(size,
interpolation=_pil_interp(
config.data_interpolation)),
# to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(config.data_img_size))
else:
t.append(
transforms.Resize((config.data_img_size, config.data_img_size),
interpolation=_pil_interp(config.data_interpolation))
)
transforms.Resize(
(config.data_img_size, config.data_img_size),
interpolation=_pil_interp(config.data_interpolation)))

t.append(transforms.ToTensor())
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
Expand Down
14 changes: 10 additions & 4 deletions training/benchmarks/swin_transformer/pytorch/run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,11 @@ def main() -> Tuple[Any, Any]:

dist_pytorch.barrier(config.vendor)
model_driver.event(Event.TRAIN_START)
raw_train_start_time = logger.previous_log_time

epoch = -1
max_accuracy = 0.0

train_start_time = time.time()

for epoch in range(config.train_start_epoch, config.train_epochs):
training_state.epoch = epoch
Expand All @@ -102,10 +103,8 @@ def main() -> Tuple[Any, Any]:

end_training_state = trainer.detect_training_status(training_state)
model_driver.event(Event.TRAIN_END)
raw_train_end_time = logger.previous_log_time

training_state.raw_train_time = (raw_train_end_time -
raw_train_start_time) / 1e+3
training_state.raw_train_time = time.time() - train_start_time

return config, training_state

Expand All @@ -131,5 +130,12 @@ def main() -> Tuple[Any, Any]:
"final_acc5": state.eval_acc5,
"raw_train_time": state.raw_train_time,
"init_time": state.init_time,
"train_no_eval_time": state.no_eval_time,
"pure_training_computing_time": state.pure_compute_time,
"throughput(ips)_raw": state.num_trained_samples / state.raw_train_time,
"throughput(ips)_no_eval":
state.num_trained_samples / state.no_eval_time,
"throughput(ips)_pure_compute":
state.num_trained_samples / state.pure_compute_time,
}
logger.log(Event.FINISHED, message=finished_info, stacklevel=0)
11 changes: 8 additions & 3 deletions training/benchmarks/swin_transformer/pytorch/train/trainer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os
import sys
import time
import datetime

import torch
from torch.types import Device
from timm.utils import accuracy, AverageMeter
from timm.utils import AverageMeter

from driver import Driver, Event, dist_pytorch
from train.training_state import TrainingState
Expand All @@ -27,6 +25,7 @@ def train_one_epoch(self, model, criterion, dataloader, optimizer, epoch, mixup_

model.train()
optimizer.zero_grad()
no_eval_start_time = time.time()

num_steps = len(dataloader)
batch_time = AverageMeter()
Expand All @@ -41,6 +40,9 @@ def train_one_epoch(self, model, criterion, dataloader, optimizer, epoch, mixup_
state.global_steps += 1
samples = samples.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True)
state.num_trained_samples += samples.size(0) * self.config.n_device

pure_compute_start_time = time.time()

if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets)
Expand Down Expand Up @@ -70,6 +72,8 @@ def train_one_epoch(self, model, criterion, dataloader, optimizer, epoch, mixup_
end = time.time()
state.loss = loss_meter.val

state.pure_compute_time += time.time() - pure_compute_start_time

other_state = dict()
if state.global_steps % self.config.gradient_accumulation_steps == 0:
step_end_time = time.time()
Expand All @@ -91,6 +95,7 @@ def train_one_epoch(self, model, criterion, dataloader, optimizer, epoch, mixup_
loss=state.loss)

epoch_time = time.time() - start
state.no_eval_time += time.time() - no_eval_start_time
if config.local_rank == 0:
print("EPOCH {} training takes {}".format(epoch, datetime.timedelta(seconds=int(epoch_time))))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TrainingState:
acc5: float = 0.0
batch_time: float = 0.0
max_accuracy: float = 0.0

eval_loss: float = 0.0
eval_acc1: float = 0.0
eval_acc5: float = 0.0
Expand All @@ -27,6 +27,8 @@ class TrainingState:

init_time = 0
raw_train_time = 0
no_eval_time = 0
pure_compute_time = 0

def status(self):
if self.converged:
Expand Down Expand Up @@ -54,8 +56,9 @@ def to_dict(self, **kwargs):
state_dict[var_name] = value

exclude = [
"eval_loss", "acc1", "acc5", "max_accuracy", "eval_acc1", "eval_acc5", "skipped_steps",
"converged", "init_time", "raw_train_time", "batch_time"
"eval_loss", "acc1", "acc5", "max_accuracy", "eval_acc1",
"eval_acc5", "skipped_steps", "converged", "init_time",
"raw_train_time", "batch_time"
]
for exkey in exclude:
if exkey in state_dict:
Expand Down
9 changes: 6 additions & 3 deletions training/benchmarks/swin_transformer/pytorch/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@
# Written by Ze Liu
# --------------------------------------------------------

import os
import torch
import torch.distributed as dist
from torch._six import inf


def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size()
# bugfix for 1x1 training
if dist.is_available() and dist.is_initialized():
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size()
else:
return tensor
return rt


Expand Down
33 changes: 26 additions & 7 deletions training/nvidia/swin_transformer-pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,29 @@


### 运行情况
| 训练资源 | 配置文件 | 运行时长(s) | 目标精度 | 收敛精度 | Steps数 | 性能(samples/s) |
| -------- | --------------- | ----------- | -------- | -------- | ------- | ---------------- |
| 单机1卡 | config_A100x1x1 | | | | | |
| 单机2卡 | config_A100x1x2 | | | | | |
| 单机4卡 | config_A100x1x4 | | | | | |
| 单机8卡 | config_A100x1x8 | 109571.12 | 81.00 | 81.12 | 187500 | 3505.07 |
| 两机8卡 | config_A100x2x8 | | | | | |
* 通用指标

| 指标名称 | 指标值 | 特殊说明 |
| -------------- | --------------------------------------------- | ------------------------------------------- |
| 任务类别 | Image Classification && Semantic Segmantation | |
| 模型 | swin_transformer | |
| 数据集 | Imagenet2012 1K | |
| 数据精度 | precision,见“性能指标” | 可选fp32/amp/fp16/tf32 |
| 超参修改 | fix_hp,见“性能指标” | 跑满硬件设备评测吞吐量所需特殊超参 |
| 硬件设备简称 | nvidia A100 | |
| 硬件存储使用 | mem,见“性能指标” | 通常称为“显存”,单位为GiB |
| 端到端时间 | e2e_time,见“性能指标” | 总时间+Perf初始化等时间 |
| 总吞吐量 | p_whole,见“性能指标” | 实际训练样本数除以总时间(performance_whole) |
| 训练吞吐量 | p_train,见“性能指标” | 不包含每个epoch末尾的评估部分耗时 |
| **计算吞吐量** | **p_core,见“性能指标”** | 不包含数据IO部分的耗时(p3>p2>p1) |
| 训练结果 | val_loss,见“性能指标” | 验证loss |
| 额外修改项 || |

* 性能指标

| 配置 | precision | fix_hp | e2e_time | p_whole | p_train | p_core | final_acc1 | mem |
| ----------------- | --------- | ------ | -------- | ------- | ------- | ------ | ---------- | --------- |
| A100单机8卡(1x8) | amp | / | 109832 | 3410 | 3481 | 3511 | 81.12 | 28.9/40.0 |
| A100单机8卡(1x8) | amp | bs=384 | | 3457 | 3535 | 3573 | | 37.6/40.0 |
| A100单机单卡(1x1) | amp | bs=384 | | 451 | 457 | 458 | | 36.0/40.0 |
| A100两机8卡(2x8) | amp | bs=384 | | 6733 | 6947 | 7073 | | 39.6/40.0 |

0 comments on commit b4feb66

Please sign in to comment.