From 3dfc7e3fc64075450d4d3d5cdc5a3320e089ae32 Mon Sep 17 00:00:00 2001 From: AmirMasoud Nourollah Date: Sun, 31 Jul 2022 17:41:40 +0430 Subject: [PATCH] Project refactoring for a new python format string and some simpler syntax. Signed-off-by: AmirMasoud Nourollah --- core/base_dataset.py | 24 +++-- core/base_model.py | 81 ++++++++------- core/base_network.py | 83 ++++++++------- core/logger.py | 39 ++++--- core/praser.py | 71 ++++++------- core/util.py | 73 ++++++------- data/__init__.py | 10 +- data/dataset.py | 132 ++++++++++++------------ data/util/auto_augment.py | 95 ++++++++--------- data/util/mask.py | 32 +++--- eval.py | 8 +- models/guided_diffusion_modules/nn.py | 5 +- models/guided_diffusion_modules/unet.py | 87 ++++++++-------- models/loss.py | 3 +- models/metric.py | 8 +- models/model.py | 124 +++++++++++----------- models/network.py | 54 +++++----- models/sr3_modules/unet.py | 66 ++++++------ preprocess/mirflickr25k_preprocess.py | 38 ++++--- run.py | 60 +++++------ 20 files changed, 550 insertions(+), 543 deletions(-) diff --git a/core/base_dataset.py b/core/base_dataset.py index b0e6a29..c364d18 100755 --- a/core/base_dataset.py +++ b/core/base_dataset.py @@ -9,15 +9,17 @@ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', ] + def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + def make_dataset(dir): if os.path.isfile(dir): - images = [i for i in np.genfromtxt(dir, dtype=np.str, encoding='utf-8')] + images = list(np.genfromtxt(dir, dtype=np.str, encoding='utf-8')) else: images = [] - assert os.path.isdir(dir), '%s is not a valid directory' % dir + assert os.path.isdir(dir), f'{dir} is not a valid directory' for root, _, fnames in sorted(os.walk(dir)): for fname in sorted(fnames): if is_image_file(fname): @@ -26,23 +28,25 @@ def make_dataset(dir): return images + def pil_loader(path): return Image.open(path).convert('RGB') + class BaseDataset(data.Dataset): - def __init__(self, data_root, image_size=[256, 256], loader=pil_loader): + def __init__(self, data_root, image_size=None, loader=pil_loader): + if image_size is None: + image_size = [256, 256] self.imgs = make_dataset(data_root) - self.tfs = transforms.Compose([ - transforms.Resize((image_size[0], image_size[1])), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) - ]) + self.tfs = transforms.Compose([transforms.Resize((image_size[0], image_size[1])), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.201))]) + self.loader = loader def __getitem__(self, index): path = self.imgs[index] - img = self.tfs(self.loader(path)) - return img + return self.tfs(self.loader(path)) # return image def __len__(self): return len(self.imgs) diff --git a/core/base_model.py b/core/base_model.py index a76e370..6bae645 100755 --- a/core/base_model.py +++ b/core/base_model.py @@ -6,11 +6,12 @@ import torch import torch.nn as nn - import core.util as Util + CustomResult = collections.namedtuple('CustomResult', 'name result') -class BaseModel(): + +class BaseModel: def __init__(self, opt, phase_loader, val_loader, metrics, logger, writer): """ init model with basic input, which are from __init__(**kwargs) function in inherited class """ self.opt = opt @@ -24,7 +25,7 @@ def __init__(self, opt, phase_loader, val_loader, metrics, logger, writer): ''' process record ''' self.batch_size = self.opt['datasets'][self.phase]['dataloader']['args']['batch_size'] self.epoch = 0 - self.iter = 0 + self.iter = 0 self.phase_loader = phase_loader self.val_loader = val_loader @@ -33,24 +34,24 @@ def __init__(self, opt, phase_loader, val_loader, metrics, logger, writer): ''' logger to log file, which only work on GPU 0. writer to tensorboard and result file ''' self.logger = logger self.writer = writer - self.results_dict = CustomResult([],[]) # {"name":[], "result":[]} + self.results_dict = CustomResult([], []) # {"name":[], "result":[]} def train(self): while self.epoch <= self.opt['train']['n_epoch'] and self.iter <= self.opt['train']['n_iter']: self.epoch += 1 if self.opt['distributed']: ''' sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch ''' - self.phase_loader.sampler.set_epoch(self.epoch) + self.phase_loader.sampler.set_epoch(self.epoch) train_log = self.train_step() - ''' save logged informations into log dict ''' + ''' save logged informations into log dict ''' train_log.update({'epoch': self.epoch, 'iters': self.iter}) - ''' print logged informations to the screen and tensorboard ''' + ''' print logged informations to the screen and tensorboard ''' for key, value in train_log.items(): self.logger.info('{:5s}: {}\t'.format(str(key), value)) - + if self.epoch % self.opt['train']['save_checkpoint_epoch'] == 0: self.logger.info('Saving the self at the end of epoch {:.0f}'.format(self.epoch)) self.save_everything() @@ -79,26 +80,26 @@ def val_step(self): def test_step(self): pass - + def print_network(self, network): """ print network structure, only work on GPU 0 """ - if self.opt['global_rank'] !=0: + if self.opt['global_rank'] != 0: return - if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel): + if isinstance(network, (nn.DataParallel, nn.parallel.DistributedDataParallel)): network = network.module - s, n = str(network), sum(map(lambda x: x.numel(), network.parameters())) - net_struc_str = '{}'.format(network.__class__.__name__) + net_struc_str = f'{network.__class__.__name__}' self.logger.info('Network structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + self.logger.info(s) def save_network(self, network, network_label): """ save network structure, only work on GPU 0 """ - if self.opt['global_rank'] !=0: + if self.opt['global_rank'] != 0: return - save_filename = '{}_{}.pth'.format(self.epoch, network_label) + save_filename = f'{self.epoch}_{network_label}.pth' save_path = os.path.join(self.opt['path']['checkpoint'], save_filename) - if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel): + if isinstance(network, (nn.DataParallel, nn.parallel.DistributedDataParallel)): network = network.module state_dict = network.state_dict() for key, param in state_dict.items(): @@ -107,54 +108,62 @@ def save_network(self, network, network_label): def load_network(self, network, network_label, strict=True): if self.opt['path']['resume_state'] is None: - return + return self.logger.info('Beign loading pretrained model [{:s}] ...'.format(network_label)) - model_path = "{}_{}.pth".format(self. opt['path']['resume_state'], network_label) - + model_path = f"{self.opt['path']['resume_state']}_{network_label}.pth" if not os.path.exists(model_path): self.logger.warning('Pretrained model in [{:s}] is not existed, Skip it'.format(model_path)) - return + return self.logger.info('Loading pretrained model from [{:s}] ...'.format(model_path)) - if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel): + if isinstance(network, (nn.DataParallel, nn.parallel.DistributedDataParallel)): network = network.module - network.load_state_dict(torch.load(model_path, map_location = lambda storage, loc: Util.set_device(storage)), strict=strict) + network.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: Util.set_device(storage)), + strict=strict) def save_training_state(self): """ saves training state during training, only work on GPU 0 """ - if self.opt['global_rank'] !=0: + if self.opt['global_rank'] != 0: return - assert isinstance(self.optimizers, list) and isinstance(self.schedulers, list), 'optimizers and schedulers must be a list.' + assert isinstance(self.optimizers, list) and isinstance(self.schedulers, + list), 'optimizers and schedulers must be a list.' + state = {'epoch': self.epoch, 'iter': self.iter, 'schedulers': [], 'optimizers': []} + for s in self.schedulers: state['schedulers'].append(s.state_dict()) for o in self.optimizers: state['optimizers'].append(o.state_dict()) - save_filename = '{}.state'.format(self.epoch) + save_filename = f'{self.epoch}.state' save_path = os.path.join(self.opt['path']['checkpoint'], save_filename) torch.save(state, save_path) def resume_training(self): """ resume the optimizers and schedulers for training, only work when phase is test or resume training enable """ - if self.phase!='train' or self. opt['path']['resume_state'] is None: + if self.phase != 'train' or self.opt['path']['resume_state'] is None: return self.logger.info('Beign loading training states'.format()) - assert isinstance(self.optimizers, list) and isinstance(self.schedulers, list), 'optimizers and schedulers must be a list.' - - state_path = "{}.state".format(self. opt['path']['resume_state']) - + assert isinstance(self.optimizers, list) and isinstance(self.schedulers, list), \ + 'optimizers and schedulers must be a list.' + + state_path = f"{self.opt['path']['resume_state']}.state" + if not os.path.exists(state_path): self.logger.warning('Training state in [{:s}] is not existed, Skip it'.format(state_path)) return self.logger.info('Loading training state for [{:s}] ...'.format(state_path)) - resume_state = torch.load(state_path, map_location = lambda storage, loc: self.set_device(storage)) - + resume_state = torch.load(state_path, map_location=lambda storage, loc: self.set_device(storage)) + resume_optimizers = resume_state['optimizers'] resume_schedulers = resume_state['schedulers'] - assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers {} != {}'.format(len(resume_optimizers), len(self.optimizers)) - assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers {} != {}'.format(len(resume_schedulers), len(self.schedulers)) + assert len(resume_optimizers) == len( + self.optimizers), f'Wrong lengths of optimizers {len(resume_optimizers)} != {len(self.optimizers)}' + + assert len(resume_schedulers) == len( + self.schedulers), f'Wrong lengths of schedulers {len(resume_schedulers)} != {len(self.schedulers)}' + for i, o in enumerate(resume_optimizers): self.optimizers[i].load_state_dict(o) for i, s in enumerate(resume_schedulers): @@ -164,8 +173,8 @@ def resume_training(self): self.iter = resume_state['iter'] def load_everything(self): - pass - + pass + @abstractmethod def save_everything(self): raise NotImplementedError('You must specify how to save your networks, optimizers and schedulers.') diff --git a/core/base_network.py b/core/base_network.py index bea2f6d..6dfe04e 100755 --- a/core/base_network.py +++ b/core/base_network.py @@ -1,48 +1,47 @@ import torch.nn as nn -class BaseNetwork(nn.Module): - def __init__(self, init_type='kaiming', gain=0.02): - super(BaseNetwork, self).__init__() - self.init_type = init_type - self.gain = gain - def init_weights(self): - """ - initialize network's weights - init_type: normal | xavier | kaiming | orthogonal - https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 - """ - - def init_func(m): - classname = m.__class__.__name__ - if classname.find('InstanceNorm2d') != -1: - if hasattr(m, 'weight') and m.weight is not None: - nn.init.constant_(m.weight.data, 1.0) - if hasattr(m, 'bias') and m.bias is not None: - nn.init.constant_(m.bias.data, 0.0) - elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): - if self.init_type == 'normal': - nn.init.normal_(m.weight.data, 0.0, self.gain) - elif self.init_type == 'xavier': - nn.init.xavier_normal_(m.weight.data, gain=self.gain) - elif self.init_type == 'xavier_uniform': - nn.init.xavier_uniform_(m.weight.data, gain=1.0) - elif self.init_type == 'kaiming': - nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') - elif self.init_type == 'orthogonal': - nn.init.orthogonal_(m.weight.data, gain=self.gain) - elif self.init_type == 'none': # uses pytorch's default init method - m.reset_parameters() - else: - raise NotImplementedError('initialization method [%s] is not implemented' % self.init_type) - if hasattr(m, 'bias') and m.bias is not None: - nn.init.constant_(m.bias.data, 0.0) - self.apply(init_func) - # propagate to children - for m in self.children(): - if hasattr(m, 'init_weights'): - m.init_weights(self.init_type, self.gain) +class BaseNetwork(nn.Module): + def __init__(self, init_type='kaiming', gain=0.02): + super(BaseNetwork, self).__init__() + self.init_type = init_type + self.gain = gain + + def init_weights(self): + """ + initialize network's weights + init_type: normal | xavier | kaiming | orthogonal + https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 + """ + def init_func(m): + classname = m.__class__.__name__ + if classname.find('InstanceNorm2d') != -1: + if hasattr(m, 'weight') and m.weight is not None: + nn.init.constant_(m.weight.data, 1.0) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if self.init_type == 'normal': + nn.init.normal_(m.weight.data, 0.0, self.gain) + elif self.init_type == 'xavier': + nn.init.xavier_normal_(m.weight.data, gain=self.gain) + elif self.init_type == 'xavier_uniform': + nn.init.xavier_uniform_(m.weight.data, gain=1.0) + elif self.init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif self.init_type == 'orthogonal': + nn.init.orthogonal_(m.weight.data, gain=self.gain) + elif self.init_type == 'none': # uses pytorch's default init method + m.reset_parameters() + else: + raise NotImplementedError(f'initialization method [{self.init_type}] is not implemented') + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) - \ No newline at end of file + self.apply(init_func) + # propagate to children + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights(self.init_type, self.gain) diff --git a/core/logger.py b/core/logger.py index a3b3368..23fa03b 100755 --- a/core/logger.py +++ b/core/logger.py @@ -7,10 +7,12 @@ import core.util as Util -class InfoLogger(): + +class InfoLogger: """ use logging to record log, only work on GPU 0 by judging global_rank """ + def __init__(self, opt): self.opt = opt self.rank = opt['global_rank'] @@ -21,23 +23,27 @@ def __init__(self, opt): self.infologger_ftns = {'info', 'warning', 'debug'} def __getattr__(self, name): - if self.rank != 0: # info only print on GPU 0. + if self.rank != 0: # info only print on GPU 0. def wrapper(info, *args, **kwargs): pass + return wrapper if name in self.infologger_ftns: print_info = getattr(self.logger, name, None) + def wrapper(info, *args, **kwargs): print_info(info, *args, **kwargs) + return wrapper - + @staticmethod def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False): """ set up logger """ l = logging.getLogger(logger_name) - formatter = logging.Formatter( - '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S') - log_file = os.path.join(root, '{}.log'.format(phase)) + formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', + datefmt='%y-%m-%d %H:%M:%S') + + log_file = os.path.join(root, f'{phase}.log') fh = logging.FileHandler(log_file, mode='a+') fh.setFormatter(formatter) l.setLevel(level) @@ -47,11 +53,13 @@ def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False): sh.setFormatter(formatter) l.addHandler(sh) -class VisualWriter(): + +class VisualWriter: """ use tensorboard to record visuals, support 'add_scalar', 'add_scalars', 'add_image', 'add_images', etc. funtion. Also integrated with save results function. """ + def __init__(self, opt, logger): log_dir = opt['path']['tb_logger'] self.result_dir = opt['path']['results'] @@ -61,7 +69,7 @@ def __init__(self, opt, logger): self.writer = None self.selected_module = "" - if enabled and self.rank==0: + if enabled and self.rank == 0: log_dir = str(log_dir) # Retrieve vizualization writer. @@ -77,8 +85,8 @@ def __init__(self, opt, logger): if not succeeded: message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ - "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \ - "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file." + "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \ + "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file." logger.warning(message) self.epoch = 0 @@ -108,16 +116,16 @@ def save_images(self, results): try: names = results['name'] outputs = Util.postprocess(results['result']) - for i in range(len(names)): + for i in range(len(names)): Image.fromarray(outputs[i]).save(os.path.join(result_path, names[i])) except: - raise NotImplementedError('You must specify the context of name and result in save_current_results functions of model.') + raise NotImplementedError( + 'You must specify the context of name and result in save_current_results functions of model.') def close(self): self.writer.close() print('Close the Tensorboard SummaryWriter.') - def __getattr__(self, name): """ If visualization is configured to use: @@ -127,12 +135,14 @@ def __getattr__(self, name): """ if name in self.tb_writer_ftns: add_data = getattr(self.writer, name, None) + def wrapper(tag, data, *args, **kwargs): if add_data is not None: # add phase(train/valid) tag if name not in self.tag_mode_exceptions: tag = '{}/{}'.format(self.phase, tag) add_data(tag, data, self.iter, *args, **kwargs) + return wrapper else: # default action for returning methods defined in this class, set_step() for instance. @@ -147,6 +157,7 @@ class LogTracker: """ record training numerical indicators. """ + def __init__(self, *keys, phase='train'): self.phase = phase self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average']) @@ -165,4 +176,4 @@ def avg(self, key): return self._data.average[key] def result(self): - return {'{}/{}'.format(self.phase, k):v for k, v in dict(self._data.average).items()} + return {f'{self.phase}/{k}': v for k, v in dict(self._data.average).items()} diff --git a/core/praser.py b/core/praser.py index 0106d33..f3ff504 100755 --- a/core/praser.py +++ b/core/praser.py @@ -5,21 +5,24 @@ from datetime import datetime from functools import partial import importlib -from types import FunctionType +from types import FunctionType import shutil -def init_obj(opt, logger, *args, default_file_name='default file', given_module=None, init_type='Network', **modify_kwargs): + + +def init_obj(opt, logger, *args, default_file_name='default file', given_module=None, init_type='Network', + **modify_kwargs): """ finds a function handle with the name given as 'name' in config, and returns the instance initialized with corresponding args. - """ - if opt is None or len(opt)<1: - logger.info('Option is None when initialize {}'.format(init_type)) + """ + if opt is None or len(opt) < 1: + logger.info(f'Option is None when initialize {init_type}') return None - + ''' default format is dict with name key ''' if isinstance(opt, str): opt = {'name': opt} - logger.warning('Config is a str, converts to a dict {}'.format(opt)) + logger.warning(f'Config is a str, converts to a dict {opt}') name = opt['name'] ''' name can be list, indicates the file and class name of function ''' @@ -32,17 +35,17 @@ def init_obj(opt, logger, *args, default_file_name='default file', given_module= module = given_module else: module = importlib.import_module(file_name) - + attr = getattr(module, class_name) kwargs = opt.get('args', {}) kwargs.update(modify_kwargs) ''' import class or function with args ''' - if isinstance(attr, type): + if isinstance(attr, type): ret = attr(*args, **kwargs) - ret.__name__ = ret.__class__.__name__ - elif isinstance(attr, FunctionType): + ret.__name__ = ret.__class__.__name__ + elif isinstance(attr, FunctionType): ret = partial(attr, *args, **kwargs) - ret.__name__ = attr.__name__ + ret.__name__ = attr.__name__ # ret = attr logger.info('{} [{:s}() form {:s}] is created.'.format(init_type, class_name, file_name)) except: @@ -57,6 +60,7 @@ def mkdirs(paths): for path in paths: os.makedirs(path, exist_ok=True) + def get_timestamp(): return datetime.now().strftime('%y%m%d_%H%M%S') @@ -66,22 +70,23 @@ def write_json(content, fname): with fname.open('wt') as handle: json.dump(content, handle, indent=4, sort_keys=False) + class NoneDict(dict): def __missing__(self, key): return None + def dict_to_nonedict(opt): """ convert to NoneDict, which return None for missing key. """ if isinstance(opt, dict): - new_opt = dict() - for key, sub_opt in opt.items(): - new_opt[key] = dict_to_nonedict(sub_opt) + new_opt = {key: dict_to_nonedict(sub_opt) for key, sub_opt in opt.items()} return NoneDict(**new_opt) elif isinstance(opt, list): return [dict_to_nonedict(sub_opt) for sub_opt in opt] else: return opt + def dict2str(opt, indent_l=1): """ dict to string for logger """ msg = '' @@ -94,6 +99,7 @@ def dict2str(opt, indent_l=1): msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' return msg + def parse(args): json_str = '' with open(args.config, 'r') as f: @@ -101,55 +107,42 @@ def parse(args): line = line.split('//')[0] + '\n' json_str += line opt = json.loads(json_str, object_pairs_hook=OrderedDict) - ''' replace the config context using args ''' opt['phase'] = args.phase if args.gpu_ids is not None: opt['gpu_ids'] = [int(id) for id in args.gpu_ids.split(',')] if args.batch is not None: opt['datasets'][opt['phase']]['dataloader']['args']['batch_size'] = args.batch - ''' set cuda environment ''' - if len(opt['gpu_ids']) > 1: - opt['distributed'] = True - else: - opt['distributed'] = False - + opt['distributed'] = len(opt['gpu_ids']) > 1 ''' update name ''' if args.debug: - opt['name'] = 'debug_{}'.format(opt['name']) + opt['name'] = f"debug_{opt['name']}" elif opt['finetune_norm']: - opt['name'] = 'finetune_{}'.format(opt['name']) + opt['name'] = f"finetune_{opt['name']}" else: - opt['name'] = '{}_{}'.format(opt['phase'], opt['name']) - + opt['name'] = f"{opt['phase']}_{opt['name']}" ''' set log directory ''' - experiments_root = os.path.join(opt['path']['base_dir'], '{}_{}'.format(opt['name'], get_timestamp())) - mkdirs(experiments_root) + experiments_root = os.path.join(opt['path']['base_dir'], f"{opt['name']}_{get_timestamp()}") + mkdirs(experiments_root) ''' save json ''' - write_json(opt, '{}/config.json'.format(experiments_root)) - + write_json(opt, f'{experiments_root}/config.json') ''' change folder relative hierarchy ''' opt['path']['experiments_root'] = experiments_root for key, path in opt['path'].items(): if 'resume' not in key and 'base' not in key and 'root' not in key: opt['path'][key] = os.path.join(experiments_root, path) mkdirs(opt['path'][key]) - ''' debug mode ''' if 'debug' in opt['name']: opt['train'].update(opt['debug']) - - ''' code backup ''' + ''' code backup ''' for name in os.listdir('.'): if name in ['config', 'models', 'core', 'slurm', 'data']: - shutil.copytree(name, os.path.join(opt['path']['code'], name), ignore=shutil.ignore_patterns("*.pyc", "__pycache__")) + shutil.copytree(name, os.path.join(opt['path']['code'], name), + ignore=shutil.ignore_patterns("*.pyc", "__pycache__")) + if '.py' in name or '.sh' in name: shutil.copy(name, opt['path']['code']) return dict_to_nonedict(opt) - - - - - diff --git a/core/util.py b/core/util.py index 4838fda..c1f904c 100755 --- a/core/util.py +++ b/core/util.py @@ -26,51 +26,52 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)): else: raise TypeError('Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) if out_type == np.uint8: - img_np = ((img_np+1) * 127.5).round() + img_np = ((img_np + 1) * 127.5).round() # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. return img_np.astype(out_type) + def postprocess(images): - return [tensor2img(image) for image in images] + return [tensor2img(image) for image in images] def set_seed(seed, gl_seed=0): - """ set random seed, gl_seed used in worker_init_fn function """ - if seed >=0 and gl_seed>=0: - seed += gl_seed - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - np.random.seed(seed) - random.seed(seed) + """ set random seed, gl_seed used in worker_init_fn function """ + if seed >= 0 and gl_seed >= 0: + seed += gl_seed + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) - ''' change the deterministic and benchmark maybe cause uncertain convolution behavior. - speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html ''' - if seed >=0 and gl_seed>=0: # slower, more reproducible - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - else: # faster, less reproducible - torch.backends.cudnn.deterministic = False - torch.backends.cudnn.benchmark = True + ''' change the deterministic and benchmark maybe cause uncertain convolution behavior. + speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html ''' + if seed >= 0 and gl_seed >= 0: # slower, more reproducible + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + else: # faster, less reproducible + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True -def set_gpu(args, distributed=False, rank=0): - """ set parameter to gpu or ddp """ - if args is None: - return None - if distributed and isinstance(args, torch.nn.Module): - return DDP(args.cuda(), device_ids=[rank], output_device=rank, broadcast_buffers=True, find_unused_parameters=True) - else: - return args.cuda() - -def set_device(args, distributed=False, rank=0): - """ set parameter to gpu or cpu """ - if torch.cuda.is_available(): - if isinstance(args, list): - return (set_gpu(item, distributed, rank) for item in args) - elif isinstance(args, dict): - return {key:set_gpu(args[key], distributed, rank) for key in args} - else: - args = set_gpu(args, distributed, rank) - return args +def set_gpu(args, distributed=False, rank=0): + """ set parameter to gpu or ddp """ + if args is None: + return None + if distributed and isinstance(args, torch.nn.Module): + return DDP(args.cuda(), device_ids=[rank], output_device=rank, broadcast_buffers=True, + find_unused_parameters=True) + else: + return args.cuda() +def set_device(args, distributed=False, rank=0): + """ set parameter to gpu or cpu """ + if torch.cuda.is_available(): + if isinstance(args, list): + return (set_gpu(item, distributed, rank) for item in args) + elif isinstance(args, dict): + return {key: set_gpu(args[key], distributed, rank) for key in args} + else: + args = set_gpu(args, distributed, rank) + return args diff --git a/data/__init__.py b/data/__init__.py index 3dce15f..d7becba 100755 --- a/data/__init__.py +++ b/data/__init__.py @@ -51,9 +51,9 @@ def define_dataset(logger, opt): dataloder_opt = opt['datasets'][opt['phase']]['dataloader'] valid_split = dataloder_opt.get('validation_split', 0) - + ''' divide validation dataset, valid_split==0 when phase is test or validation_split is 0. ''' - if valid_split > 0.0 or 'debug' in opt['name']: + if valid_split > 0.0 or 'debug' in opt['name']: if isinstance(valid_split, int): assert valid_split < data_len, "Validation set size is configured to be larger than entire dataset." valid_len = valid_split @@ -61,10 +61,10 @@ def define_dataset(logger, opt): valid_len = int(data_len * valid_split) data_len -= valid_len phase_dataset, val_dataset = subset_split(dataset=phase_dataset, lengths=[data_len, valid_len], generator=Generator().manual_seed(opt['seed'])) - - logger.info('Dataset for {} have {} samples.'.format(opt['phase'], data_len)) + + logger.info(f"Dataset for {opt['phase']} have {data_len} samples.") if opt['phase'] == 'train': - logger.info('Dataset for {} have {} samples.'.format('val', valid_len)) + logger.info(f'Dataset for val have {valid_len} samples.') return phase_dataset, val_dataset def subset_split(dataset, lengths, generator): diff --git a/data/dataset.py b/data/dataset.py index a36ee75..8e26282 100755 --- a/data/dataset.py +++ b/data/dataset.py @@ -12,37 +12,41 @@ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', ] + def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + def make_dataset(dir): if os.path.isfile(dir): - images = [i for i in np.genfromtxt(dir, dtype=np.str, encoding='utf-8')] + images = list(np.genfromtxt(dir, dtype=np.str, encoding='utf-8')) else: images = [] - assert os.path.isdir(dir), '%s is not a valid directory' % dir + assert os.path.isdir(dir), f'{dir} is not a valid directory' for root, _, fnames in sorted(os.walk(dir)): for fname in sorted(fnames): if is_image_file(fname): path = os.path.join(root, fname) images.append(path) - return images + def pil_loader(path): return Image.open(path).convert('RGB') + class InpaintDataset(data.Dataset): - def __init__(self, data_root, mask_config={}, data_len=-1, image_size=[256, 256], loader=pil_loader): + def __init__(self, data_root, mask_config=None, data_len=-1, image_size=None, loader=pil_loader): + if mask_config is None: + mask_config = {} + if image_size is None: + image_size = [256, 256] imgs = make_dataset(data_root) - if data_len > 0: - self.imgs = imgs[:int(data_len)] - else: - self.imgs = imgs + self.imgs = imgs[:int(data_len)] if data_len > 0 else imgs self.tfs = transforms.Compose([ - transforms.Resize((image_size[0], image_size[1])), - transforms.ToTensor(), - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5,0.5, 0.5]) + transforms.Resize((image_size[0], image_size[1])), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) self.loader = loader self.mask_config = mask_config @@ -50,29 +54,28 @@ def __init__(self, data_root, mask_config={}, data_len=-1, image_size=[256, 256] self.image_size = image_size def __getitem__(self, index): - ret = {} path = self.imgs[index] img = self.tfs(self.loader(path)) mask = self.get_mask() - cond_image = img*(1. - mask) + mask*torch.randn_like(img) - mask_img = img*(1. - mask) + mask - - ret['gt_image'] = img - ret['cond_image'] = cond_image - ret['mask_image'] = mask_img - ret['mask'] = mask - ret['path'] = path.rsplit("/")[-1].rsplit("\\")[-1] - return ret + cond_image = img * (1.0 - mask) + mask * torch.randn_like(img) + mask_img = img * (1.0 - mask) + mask + return { + 'gt_image': img, + 'cond_image': cond_image, + 'mask_image': mask_img, + 'mask': mask, + 'path': path.rsplit("/")[-1].rsplit("\\")[-1] + } def __len__(self): return len(self.imgs) - def get_mask(self): + def get_mask(self): # sourcery skip: remove-pass-elif if self.mask_mode == 'bbox': mask = bbox2mask(self.image_size, random_bbox()) elif self.mask_mode == 'center': h, w = self.image_size - mask = bbox2mask(self.image_size, (h//4, w//4, h//2, w//2)) + mask = bbox2mask(self.image_size, (h // 4, w // 4, h // 2, w // 2)) elif self.mask_mode == 'irregular': mask = get_irregular_mask(self.image_size) elif self.mask_mode == 'free_form': @@ -86,51 +89,49 @@ def get_mask(self): else: raise NotImplementedError( f'Mask mode {self.mask_mode} has not been implemented.') - return torch.from_numpy(mask).permute(2,0,1) + return torch.from_numpy(mask).permute(2, 0, 1) class UncroppingDataset(data.Dataset): - def __init__(self, data_root, mask_config={}, data_len=-1, image_size=[256, 256], loader=pil_loader): + def __init__(self, data_root, mask_config=None, data_len=-1, image_size=None, loader=pil_loader): + if mask_config is None: + mask_config = {} + if image_size is None: + image_size = [256, 256] imgs = make_dataset(data_root) - if data_len > 0: - self.imgs = imgs[:int(data_len)] - else: - self.imgs = imgs - self.tfs = transforms.Compose([ - transforms.Resize((image_size[0], image_size[1])), - transforms.ToTensor(), - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5,0.5, 0.5]) - ]) + self.imgs = imgs[:int(data_len)] if data_len > 0 else imgs + self.tfs = transforms.Compose([transforms.Resize((image_size[0], image_size[1])), transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) + self.loader = loader self.mask_config = mask_config self.mask_mode = self.mask_config['mask_mode'] self.image_size = image_size def __getitem__(self, index): - ret = {} path = self.imgs[index] img = self.tfs(self.loader(path)) mask = self.get_mask() - cond_image = img*(1. - mask) + mask*torch.randn_like(img) - mask_img = img*(1. - mask) + mask - - ret['gt_image'] = img - ret['cond_image'] = cond_image - ret['mask_image'] = mask_img - ret['mask'] = mask - ret['path'] = path.rsplit("/")[-1].rsplit("\\")[-1] - return ret + cond_image = img * (1.0 - mask) + mask * torch.randn_like(img) + mask_img = img * (1.0 - mask) + mask + return { + 'gt_image': img, + 'cond_image': cond_image, + 'mask_image': mask_img, + 'mask': mask, + 'path': path.rsplit("/")[-1].rsplit("\\")[-1] + } def __len__(self): return len(self.imgs) - def get_mask(self): + def get_mask(self): # sourcery skip: remove-pass-elif if self.mask_mode == 'manual': mask = bbox2mask(self.image_size, self.mask_config['shape']) - elif self.mask_mode == 'fourdirection' or self.mask_mode == 'onedirection': + elif self.mask_mode in ['fourdirection', 'onedirection']: mask = bbox2mask(self.image_size, random_cropping_bbox(mask_mode=self.mask_mode)) elif self.mask_mode == 'hybrid': - if np.random.randint(0,2)<1: + if np.random.randint(0, 2) < 1: mask = bbox2mask(self.image_size, random_cropping_bbox(mask_mode='onedirection')) else: mask = bbox2mask(self.image_size, random_cropping_bbox(mask_mode='fourdirection')) @@ -139,38 +140,33 @@ def get_mask(self): else: raise NotImplementedError( f'Mask mode {self.mask_mode} has not been implemented.') - return torch.from_numpy(mask).permute(2,0,1) + return torch.from_numpy(mask).permute(2, 0, 1) class ColorizationDataset(data.Dataset): - def __init__(self, data_root, data_flist, data_len=-1, image_size=[224, 224], loader=pil_loader): + def __init__(self, data_root, data_flist, data_len=-1, image_size=None, loader=pil_loader): + if image_size is None: + image_size = [224, 224] self.data_root = data_root flist = make_dataset(data_flist) - if data_len > 0: - self.flist = flist[:int(data_len)] - else: - self.flist = flist - self.tfs = transforms.Compose([ - transforms.Resize((image_size[0], image_size[1])), - transforms.ToTensor(), - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5,0.5, 0.5]) - ]) + self.flist = flist[:int(data_len)] if data_len > 0 else flist + self.tfs = transforms.Compose([transforms.Resize((image_size[0], image_size[1])), transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) + self.loader = loader self.image_size = image_size def __getitem__(self, index): - ret = {} - file_name = str(self.flist[index]).zfill(5) + '.png' + file_name = f'{str(self.flist[index]).zfill(5)}.png' + img = self.tfs(self.loader(f'{self.data_root}/color/{file_name}')) - img = self.tfs(self.loader('{}/{}/{}'.format(self.data_root, 'color', file_name))) - cond_image = self.tfs(self.loader('{}/{}/{}'.format(self.data_root, 'gray', file_name))) + cond_image = self.tfs(self.loader(f'{self.data_root}/gray/{file_name}')) - ret['gt_image'] = img - ret['cond_image'] = cond_image - ret['path'] = file_name - return ret + return { + 'gt_image': img, + 'cond_image': cond_image, + 'path': file_name + } def __len__(self): return len(self.flist) - - diff --git a/data/util/auto_augment.py b/data/util/auto_augment.py index f01822c..7628991 100755 --- a/data/util/auto_augment.py +++ b/data/util/auto_augment.py @@ -107,24 +107,23 @@ def transform_matrix_offset_center(matrix, x, y): o_y = float(y) / 2 + 0.5 offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]]) reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]]) - transform_matrix = offset_matrix @ matrix @ reset_matrix - return transform_matrix + return offset_matrix @ matrix @ reset_matrix # transform_matrix def shear_x(img, magnitude): img = np.array(img) magnitudes = np.linspace(-0.3, 0.3, 11) - transform_matrix = np.array([[1, random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]), 0], + transform_matrix = np.array([[1, random.uniform(magnitudes[magnitude], magnitudes[magnitude + 1]), 0], [0, 1, 0], [0, 0, 1]]) transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) affine_matrix = transform_matrix[:2, :2] offset = transform_matrix[:2, 2] img = np.stack([ndimage.interpolation.affine_transform( - img[:, :, c], - affine_matrix, - offset) for c in range(img.shape[2])], axis=2) + img[:, :, c], + affine_matrix, + offset) for c in range(img.shape[2])], axis=2) img = Image.fromarray(img) return img @@ -134,51 +133,53 @@ def shear_y(img, magnitude): magnitudes = np.linspace(-0.3, 0.3, 11) transform_matrix = np.array([[1, 0, 0], - [random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]), 1, 0], + [random.uniform(magnitudes[magnitude], magnitudes[magnitude + 1]), 1, 0], [0, 0, 1]]) transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) affine_matrix = transform_matrix[:2, :2] offset = transform_matrix[:2, 2] img = np.stack([ndimage.interpolation.affine_transform( - img[:, :, c], - affine_matrix, - offset) for c in range(img.shape[2])], axis=2) + img[:, :, c], + affine_matrix, + offset) for c in range(img.shape[2])], axis=2) img = Image.fromarray(img) return img def translate_x(img, magnitude): img = np.array(img) - magnitudes = np.linspace(-150/331, 150/331, 11) + magnitudes = np.linspace(-150 / 331, 150 / 331, 11) transform_matrix = np.array([[1, 0, 0], - [0, 1, img.shape[1]*random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])], + [0, 1, + img.shape[1] * random.uniform(magnitudes[magnitude], magnitudes[magnitude + 1])], [0, 0, 1]]) transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) affine_matrix = transform_matrix[:2, :2] offset = transform_matrix[:2, 2] img = np.stack([ndimage.interpolation.affine_transform( - img[:, :, c], - affine_matrix, - offset) for c in range(img.shape[2])], axis=2) + img[:, :, c], + affine_matrix, + offset) for c in range(img.shape[2])], axis=2) img = Image.fromarray(img) return img def translate_y(img, magnitude): img = np.array(img) - magnitudes = np.linspace(-150/331, 150/331, 11) + magnitudes = np.linspace(-150 / 331, 150 / 331, 11) - transform_matrix = np.array([[1, 0, img.shape[0]*random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])], - [0, 1, 0], - [0, 0, 1]]) + transform_matrix = np.array( + [[1, 0, img.shape[0] * random.uniform(magnitudes[magnitude], magnitudes[magnitude + 1])], + [0, 1, 0], + [0, 0, 1]]) transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) affine_matrix = transform_matrix[:2, :2] offset = transform_matrix[:2, 2] img = np.stack([ndimage.interpolation.affine_transform( - img[:, :, c], - affine_matrix, - offset) for c in range(img.shape[2])], axis=2) + img[:, :, c], + affine_matrix, + offset) for c in range(img.shape[2])], axis=2) img = Image.fromarray(img) return img @@ -186,7 +187,7 @@ def translate_y(img, magnitude): def rotate(img, magnitude): img = np.array(img) magnitudes = np.linspace(-30, 30, 11) - theta = np.deg2rad(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) + theta = np.deg2rad(random.uniform(magnitudes[magnitude], magnitudes[magnitude + 1])) transform_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]]) @@ -194,9 +195,9 @@ def rotate(img, magnitude): affine_matrix = transform_matrix[:2, :2] offset = transform_matrix[:2, 2] img = np.stack([ndimage.interpolation.affine_transform( - img[:, :, c], - affine_matrix, - offset) for c in range(img.shape[2])], axis=2) + img[:, :, c], + affine_matrix, + offset) for c in range(img.shape[2])], axis=2) img = Image.fromarray(img) return img @@ -218,65 +219,57 @@ def equalize(img, magnitude): def solarize(img, magnitude): magnitudes = np.linspace(0, 256, 11) - img = ImageOps.solarize(img, random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) + img = ImageOps.solarize(img, random.uniform(magnitudes[magnitude], magnitudes[magnitude + 1])) return img def posterize(img, magnitude): magnitudes = np.linspace(4, 8, 11) - img = ImageOps.posterize(img, int(round(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])))) + img = ImageOps.posterize(img, int(round(random.uniform(magnitudes[magnitude], magnitudes[magnitude + 1])))) return img def contrast(img, magnitude): magnitudes = np.linspace(0.1, 1.9, 11) - img = ImageEnhance.Contrast(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) + img = ImageEnhance.Contrast(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude + 1])) return img def color(img, magnitude): magnitudes = np.linspace(0.1, 1.9, 11) - img = ImageEnhance.Color(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) + img = ImageEnhance.Color(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude + 1])) return img def brightness(img, magnitude): magnitudes = np.linspace(0.1, 1.9, 11) - img = ImageEnhance.Brightness(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) + img = ImageEnhance.Brightness(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude + 1])) return img def sharpness(img, magnitude): magnitudes = np.linspace(0.1, 1.9, 11) - img = ImageEnhance.Sharpness(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) + img = ImageEnhance.Sharpness(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude + 1])) return img def cutout(org_img, magnitude=None): - - magnitudes = np.linspace(0, 60/331, 11) - + magnitudes = np.linspace(0, 60 / 331, 11) img = np.copy(org_img) mask_val = img.mean() - if magnitude is None: mask_size = 16 else: - mask_size = int(round(img.shape[0]*random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]))) - top = np.random.randint(0 - mask_size//2, img.shape[0] - mask_size) - left = np.random.randint(0 - mask_size//2, img.shape[1] - mask_size) + mask_size = int(round(img.shape[0] * random.uniform(magnitudes[magnitude], magnitudes[magnitude + 1]))) + + top = np.random.randint(0 - mask_size // 2, img.shape[0] - mask_size) + left = np.random.randint(0 - mask_size // 2, img.shape[1] - mask_size) bottom = top + mask_size right = left + mask_size - - if top < 0: - top = 0 - if left < 0: - left = 0 - + top = max(top, 0) + left = max(left, 0) img[top:bottom, left:right, :].fill(mask_val) - img = Image.fromarray(img) - return img @@ -290,16 +283,16 @@ def __call__(self, img): mask_val = img.mean() - top = np.random.randint(0 - self.length//2, img.shape[0] - self.length) - left = np.random.randint(0 - self.length//2, img.shape[1] - self.length) + top = np.random.randint(0 - self.length // 2, img.shape[0] - self.length) + left = np.random.randint(0 - self.length // 2, img.shape[1] - self.length) bottom = top + self.length right = left + self.length - top = 0 if top < 0 else top + top = max(top, 0) left = 0 if left < 0 else top img[top:bottom, left:right, :] = mask_val img = Image.fromarray(img) - return img \ No newline at end of file + return img diff --git a/data/util/mask.py b/data/util/mask.py index 04ca008..b2fe6e5 100755 --- a/data/util/mask.py +++ b/data/util/mask.py @@ -6,33 +6,28 @@ from PIL import Image, ImageDraw -def random_cropping_bbox(img_shape=(256,256), mask_mode='onedirection'): +def random_cropping_bbox(img_shape=(256, 256), mask_mode='onedirection'): h, w = img_shape if mask_mode == 'onedirection': _type = np.random.randint(0, 4) if _type == 0: - top, left, height, width = 0, 0, h, w//2 + top, left, height, width = 0, 0, h, w // 2 elif _type == 1: - top, left, height, width = 0, 0, h//2, w + top, left, height, width = 0, 0, h // 2, w elif _type == 2: - top, left, height, width = h//2, 0, h//2, w + top, left, height, width = h // 2, 0, h // 2, w elif _type == 3: - top, left, height, width = 0, w//2, h, w//2 + top, left, height, width = 0, w // 2, h, w // 2 else: - target_area = (h*w)//2 - width = np.random.randint(target_area//h, w) - height = target_area//width - if h==height: - top = 0 - else: - top = np.random.randint(0, h-height) - if w==width: - left = 0 - else: - left = np.random.randint(0, w-width) + target_area = (h * w) // 2 + width = np.random.randint(target_area // h, w) + height = target_area // width + top = 0 if h == height else np.random.randint(0, h - height) + left = 0 if w == width else np.random.randint(0, w - width) return (top, left, height, width) -def random_bbox(img_shape=(256,256), max_bbox_shape=(128, 128), max_bbox_delta=40, min_margin=20): + +def random_bbox(img_shape=(256, 256), max_bbox_shape=(128, 128), max_bbox_delta=40, min_margin=20): """Generate a random bbox for the mask on a given image. In our implementation, the max value cannot be obtained since we use @@ -63,7 +58,7 @@ def random_bbox(img_shape=(256,256), max_bbox_shape=(128, 128), max_bbox_delta=4 max_bbox_delta = (max_bbox_delta, max_bbox_delta) if not isinstance(min_margin, tuple): min_margin = (min_margin, min_margin) - + img_h, img_w = img_shape[:2] max_mask_h, max_mask_w = max_bbox_shape max_delta_h, max_delta_w = max_bbox_delta @@ -130,6 +125,7 @@ def brush_stroke_mask(img_shape, brush_width=(12, 40), max_loops=4, dtype='uint8'): + # sourcery skip: merge-list-append, move-assign-in-block """Generate free-form mask. The method of generating free-form mask is in the following paper: diff --git a/eval.py b/eval.py index a075799..4c0e25b 100755 --- a/eval.py +++ b/eval.py @@ -7,12 +7,12 @@ parser = argparse.ArgumentParser() parser.add_argument('-s', '--src', type=str, help='Ground truth images directory') parser.add_argument('-d', '--dst', type=str, help='Generate images directory') - + ''' parser configs ''' args = parser.parse_args() fid_score = fid.compute_fid(args.src, args.dst) is_mean, is_std = inception_score(BaseDataset(args.dst), cuda=True, batch_size=8, resize=True, splits=10) - - print('FID: {}'.format(fid_score)) - print('IS:{} {}'.format(is_mean, is_std)) \ No newline at end of file + + print(f'FID: {fid_score}') + print(f'IS:{is_mean} {is_std}') diff --git a/models/guided_diffusion_modules/nn.py b/models/guided_diffusion_modules/nn.py index 60b83c9..b2447bc 100755 --- a/models/guided_diffusion_modules/nn.py +++ b/models/guided_diffusion_modules/nn.py @@ -4,7 +4,7 @@ import math import numpy as np -import torch +import torch import torch.nn as nn @@ -48,7 +48,6 @@ def normalization(channels): return GroupNorm32(32, channels) - def checkpoint(func, inputs, params, flag): """ Evaluate a function without caching intermediate activations, allowing for @@ -135,4 +134,4 @@ def gamma_embedding(gammas, dim, max_period=10000): embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding \ No newline at end of file + return embedding diff --git a/models/guided_diffusion_modules/unet.py b/models/guided_diffusion_modules/unet.py index e4e70ea..be98112 100755 --- a/models/guided_diffusion_modules/unet.py +++ b/models/guided_diffusion_modules/unet.py @@ -13,10 +13,12 @@ gamma_embedding ) + class SiLU(nn.Module): def forward(self, x): return x * torch.sigmoid(x) + class EmbedBlock(nn.Module): """ Any module where forward() takes embeddings as a second argument. @@ -28,6 +30,7 @@ def forward(self, x, emb): Apply the module to `x` given `emb` embeddings. """ + class EmbedSequential(nn.Sequential, EmbedBlock): """ A sequential module that passes embeddings to the children that @@ -36,12 +39,10 @@ class EmbedSequential(nn.Sequential, EmbedBlock): def forward(self, x, emb): for layer in self: - if isinstance(layer, EmbedBlock): - x = layer(x, emb) - else: - x = layer(x) + x = layer(x, emb) if isinstance(layer, EmbedBlock) else layer(x) return x + class Upsample(nn.Module): """ An upsampling layer with an optional convolution. @@ -65,6 +66,7 @@ def forward(self, x): x = self.conv(x) return x + class Downsample(nn.Module): """ A downsampling layer with an optional convolution. @@ -107,16 +109,16 @@ class ResBlock(EmbedBlock): """ def __init__( - self, - channels, - emb_channels, - dropout, - out_channel=None, - use_conv=False, - use_scale_shift_norm=False, - use_checkpoint=False, - up=False, - down=False, + self, + channels, + emb_channels, + dropout, + out_channel=None, + use_conv=False, + use_scale_shift_norm=False, + use_checkpoint=False, + up=False, + down=False, ): super().__init__() self.channels = channels @@ -202,6 +204,7 @@ def _forward(self, x, emb): h = self.out_layers(h) return self.skip_connection(x) + h + class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. @@ -210,12 +213,12 @@ class AttentionBlock(nn.Module): """ def __init__( - self, - channels, - num_heads=1, - num_head_channels=-1, - use_checkpoint=False, - use_new_attention_order=False, + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, ): super().__init__() self.channels = channels @@ -223,7 +226,7 @@ def __init__( self.num_heads = num_heads else: assert ( - channels % num_head_channels == 0 + channels % num_head_channels == 0 ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" self.num_heads = channels // num_head_channels self.use_checkpoint = use_checkpoint @@ -315,6 +318,7 @@ def forward(self, qkv): def count_flops(model, _x, y): return count_flops_attn(model, _x, y) + class UNet(nn.Module): """ The full UNet model with attention and embedding. @@ -343,24 +347,24 @@ class UNet(nn.Module): """ def __init__( - self, - image_size, - in_channel, - inner_channel, - out_channel, - res_blocks, - attn_res, - dropout=0, - channel_mults=(1, 2, 4, 8), - conv_resample=True, - use_checkpoint=False, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=True, - resblock_updown=True, - use_new_attention_order=False, + self, + image_size, + in_channel, + inner_channel, + out_channel, + res_blocks, + attn_res, + dropout=0, + channel_mults=(1, 2, 4, 8), + conv_resample=True, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=True, + resblock_updown=True, + use_new_attention_order=False, ): super().__init__() @@ -544,6 +548,7 @@ def forward(self, x, gammas): h = h.type(x.dtype) return self.out(h) + if __name__ == '__main__': b, c, h, w = 3, 6, 64, 64 timsteps = 100 @@ -556,5 +561,5 @@ def forward(self, x, gammas): attn_res=[8] ) x = torch.randn((b, c, h, w)) - emb = torch.ones((b, )) - out = model(x, emb) \ No newline at end of file + emb = torch.ones((b,)) + out = model(x, emb) diff --git a/models/loss.py b/models/loss.py index e793672..ed86230 100755 --- a/models/loss.py +++ b/models/loss.py @@ -43,6 +43,5 @@ def forward(self, input, target): logpt = logpt * Variable(at) loss = -1 * (1-pt)**self.gamma * logpt - if self.size_average: return loss.mean() - else: return loss.sum() + return loss.mean() if self.size_average else loss.sum() diff --git a/models/metric.py b/models/metric.py index a048db1..93a9b28 100755 --- a/models/metric.py +++ b/models/metric.py @@ -9,6 +9,7 @@ import numpy as np from scipy.stats import entropy + def mae(input, target): with torch.no_grad(): loss = nn.L1Loss() @@ -44,6 +45,7 @@ def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1): inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype) inception_model.eval() up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype) + def get_pred(x): if resize: x = up(x) @@ -58,13 +60,13 @@ def get_pred(x): batchv = Variable(batch) batch_size_i = batch.size()[0] - preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv) + preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv) # Now compute the mean kl-div split_scores = [] for k in range(splits): - part = preds[k * (N // splits): (k+1) * (N // splits), :] + part = preds[k * (N // splits): (k + 1) * (N // splits), :] py = np.mean(part, axis=0) scores = [] for i in range(part.shape[0]): @@ -72,4 +74,4 @@ def get_pred(x): scores.append(entropy(pyx, py)) split_scores.append(np.exp(np.mean(scores))) - return np.mean(split_scores), np.std(split_scores) \ No newline at end of file + return np.mean(split_scores), np.std(split_scores) diff --git a/models/model.py b/models/model.py index 9423240..a7496cc 100755 --- a/models/model.py +++ b/models/model.py @@ -3,18 +3,21 @@ from core.base_model import BaseModel from core.logger import LogTracker import copy -class EMA(): + + +class EMA: def __init__(self, beta=0.9999): super().__init__() self.beta = beta + def update_model_average(self, ma_model, current_model): for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): old_weight, up_weight = ma_params.data, current_params.data ma_params.data = self.update_average(old_weight, up_weight) + def update_average(self, old, new): - if old is None: - return new - return old * self.beta + (1 - self.beta) * new + return new if old is None else old * self.beta + (1 - self.beta) * new + class Palette(BaseModel): def __init__(self, networks, losses, sample_num, task, optimizers, ema_scheduler=None, **kwargs): @@ -30,7 +33,7 @@ def __init__(self, networks, losses, sample_num, task, optimizers, ema_scheduler self.EMA = EMA(beta=self.ema_scheduler['ema_decay']) else: self.ema_scheduler = None - + ''' networks can be a list, and must convert by self.set_device function if using multiple GPU. ''' self.netG = self.set_device(self.netG, distributed=self.opt['distributed']) if self.ema_scheduler is not None: @@ -39,7 +42,7 @@ def __init__(self, networks, losses, sample_num, task, optimizers, ema_scheduler self.optG = torch.optim.Adam(list(filter(lambda p: p.requires_grad, self.netG.parameters())), **optimizers[0]) self.optimizers.append(self.optG) - self.resume_training() + self.resume_training() if self.opt['distributed']: self.netG.module.set_loss(self.loss_fn) @@ -55,7 +58,7 @@ def __init__(self, networks, losses, sample_num, task, optimizers, ema_scheduler self.sample_num = sample_num self.task = task - + def set_input(self, data): ''' must use set_device in tensor ''' self.cond_image = self.set_device(data.get('cond_image')) @@ -64,41 +67,38 @@ def set_input(self, data): self.mask_image = data.get('mask_image') self.path = data['path'] self.batch_size = len(data['path']) - + def get_current_visuals(self, phase='train'): dict = { - 'gt_image': (self.gt_image.detach()[:].float().cpu()+1)/2, - 'cond_image': (self.cond_image.detach()[:].float().cpu()+1)/2, + 'gt_image': (self.gt_image.detach()[:].float().cpu() + 1) / 2, + 'cond_image': (self.cond_image.detach()[:].float().cpu() + 1) / 2 } - if self.task in ['inpainting','uncropping']: - dict.update({ + + if self.task in ['inpainting', 'uncropping']: + dict |= { 'mask': self.mask.detach()[:].float().cpu(), - 'mask_image': (self.mask_image+1)/2, - }) + 'mask_image': (self.mask_image + 1) / 2 + } + if phase != 'train': - dict.update({ - 'output': (self.output.detach()[:].float().cpu()+1)/2 - }) + dict['output'] = (self.output.detach()[:].float().cpu() + 1) / 2 return dict def save_current_results(self): ret_path = [] ret_result = [] for idx in range(self.batch_size): - ret_path.append('GT_{}'.format(self.path[idx])) + ret_path.append(f'GT_{self.path[idx]}') ret_result.append(self.gt_image[idx].detach().float().cpu()) - - ret_path.append('Process_{}'.format(self.path[idx])) + ret_path.append(f'Process_{self.path[idx]}') ret_result.append(self.visuals[idx::self.batch_size].detach().float().cpu()) - - ret_path.append('Out_{}'.format(self.path[idx])) - ret_result.append(self.visuals[idx-self.batch_size].detach().float().cpu()) - - if self.task in ['inpainting','uncropping']: - ret_path.extend(['Mask_{}'.format(name) for name in self.path]) + ret_path.append(f'Out_{self.path[idx]}') + ret_result.append(self.visuals[idx - self.batch_size].detach().float().cpu()) + if self.task in ['inpainting', 'uncropping']: + ret_path.extend([f'Mask_{name}' for name in self.path]) ret_result.extend(self.mask_image) - self.results_dict = self.results_dict._replace(name=ret_path, result=ret_result) + return self.results_dict._asdict() def train_step(self): @@ -110,7 +110,6 @@ def train_step(self): loss = self.netG(self.gt_image, self.cond_image, mask=self.mask) loss.backward() self.optG.step() - self.iter += self.batch_size self.writer.set_iter(self.epoch, self.iter, phase='train') self.train_metrics.update(self.loss_fn.__name__, loss.item()) @@ -120,14 +119,13 @@ def train_step(self): self.writer.add_scalar(key, value) for key, value in self.get_current_visuals().items(): self.writer.add_images(key, value) - if self.ema_scheduler is not None: - if self.iter > self.ema_scheduler['ema_start'] and self.iter % self.ema_scheduler['ema_iter'] == 0: - self.EMA.update_model_average(self.netG_EMA, self.netG) - + if self.ema_scheduler is not None and self.iter > self.ema_scheduler['ema_start'] and self.iter % \ + self.ema_scheduler['ema_iter'] == 0: + self.EMA.update_model_average(self.netG_EMA, self.netG) for scheduler in self.schedulers: scheduler.step() return self.train_metrics.result() - + def val_step(self): self.netG.eval() self.val_metrics.reset() @@ -135,21 +133,22 @@ def val_step(self): for val_data in tqdm.tqdm(self.val_loader): self.set_input(val_data) if self.opt['distributed']: - if self.task in ['inpainting','uncropping']: - self.output, self.visuals = self.netG.module.restoration(self.cond_image, y_t=self.cond_image, - y_0=self.gt_image, mask=self.mask, sample_num=self.sample_num) - else: - self.output, self.visuals = self.netG.module.restoration(self.cond_image, sample_num=self.sample_num) + self.output, self.visuals = self.netG.module.restoration(self.cond_image, y_t=self.cond_image, + y_0=self.gt_image, mask=self.mask, + sample_num=self.sample_num) if self.task in [ + 'inpainting', 'uncropping'] else self.netG.module.restoration(self.cond_image, + sample_num=self.sample_num) + + elif self.task in ['inpainting', 'uncropping']: + self.output, self.visuals = self.netG.restoration(self.cond_image, y_t=self.cond_image, + y_0=self.gt_image, mask=self.mask, + sample_num=self.sample_num) + else: - if self.task in ['inpainting','uncropping']: - self.output, self.visuals = self.netG.restoration(self.cond_image, y_t=self.cond_image, - y_0=self.gt_image, mask=self.mask, sample_num=self.sample_num) - else: - self.output, self.visuals = self.netG.restoration(self.cond_image, sample_num=self.sample_num) - + self.output, self.visuals = self.netG.restoration(self.cond_image, sample_num=self.sample_num) + self.iter += self.batch_size self.writer.set_iter(self.epoch, self.iter, phase='val') - for met in self.metrics: key = met.__name__ value = met(self.gt_image, self.output) @@ -158,7 +157,6 @@ def val_step(self): for key, value in self.get_current_visuals(phase='val').items(): self.writer.add_images(key, value) self.writer.save_images(self.save_current_results()) - return self.val_metrics.result() def test(self): @@ -168,18 +166,20 @@ def test(self): for phase_data in tqdm.tqdm(self.phase_loader): self.set_input(phase_data) if self.opt['distributed']: - if self.task in ['inpainting','uncropping']: - self.output, self.visuals = self.netG.module.restoration(self.cond_image, y_t=self.cond_image, - y_0=self.gt_image, mask=self.mask, sample_num=self.sample_num) - else: - self.output, self.visuals = self.netG.module.restoration(self.cond_image, sample_num=self.sample_num) + self.output, self.visuals = self.netG.module.restoration(self.cond_image, y_t=self.cond_image, + y_0=self.gt_image, mask=self.mask, + sample_num=self.sample_num) if self.task in [ + 'inpainting', 'uncropping'] else self.netG.module.restoration(self.cond_image, + sample_num=self.sample_num) + + elif self.task in ['inpainting', 'uncropping']: + self.output, self.visuals = self.netG.restoration(self.cond_image, y_t=self.cond_image, + y_0=self.gt_image, mask=self.mask, + sample_num=self.sample_num) + else: - if self.task in ['inpainting','uncropping']: - self.output, self.visuals = self.netG.restoration(self.cond_image, y_t=self.cond_image, - y_0=self.gt_image, mask=self.mask, sample_num=self.sample_num) - else: - self.output, self.visuals = self.netG.restoration(self.cond_image, sample_num=self.sample_num) - + self.output, self.visuals = self.netG.restoration(self.cond_image, sample_num=self.sample_num) + self.iter += self.batch_size self.writer.set_iter(self.epoch, self.iter, phase='test') for met in self.metrics: @@ -190,12 +190,10 @@ def test(self): for key, value in self.get_current_visuals(phase='test').items(): self.writer.add_images(key, value) self.writer.save_images(self.save_current_results()) - test_log = self.test_metrics.result() - ''' save logged informations into log dict ''' + ''' save logged informations into log dict ''' test_log.update({'epoch': self.epoch, 'iters': self.iter}) - - ''' print logged informations to the screen and tensorboard ''' + ''' print logged informations to the screen and tensorboard ''' for key, value in test_log.items(): self.logger.info('{:5s}: {}\t'.format(str(key), value)) @@ -207,7 +205,7 @@ def load_networks(self): netG_label = self.netG.__class__.__name__ self.load_network(network=self.netG, network_label=netG_label, strict=False) if self.ema_scheduler is not None: - self.load_network(network=self.netG_EMA, network_label=netG_label+'_ema', strict=False) + self.load_network(network=self.netG_EMA, network_label=f'{netG_label}_ema', strict=False) def save_everything(self): """ load pretrained model and training state. """ @@ -217,5 +215,5 @@ def save_everything(self): netG_label = self.netG.__class__.__name__ self.save_network(network=self.netG, network_label=netG_label) if self.ema_scheduler is not None: - self.save_network(network=self.netG_EMA, network_label=netG_label+'_ema') + self.save_network(network=self.netG_EMA, network_label=f'{netG_label}_ema') self.save_training_state() diff --git a/models/network.py b/models/network.py index 24c2bd3..c01a107 100755 --- a/models/network.py +++ b/models/network.py @@ -5,6 +5,8 @@ import numpy as np from tqdm import tqdm from core.base_network import BaseNetwork + + class Network(BaseNetwork): def __init__(self, unet, beta_schedule, module_name='sr3', **kwargs): super(Network, self).__init__(**kwargs) @@ -12,7 +14,7 @@ def __init__(self, unet, beta_schedule, module_name='sr3', **kwargs): from .sr3_modules.unet import UNet elif module_name == 'guided_diffusion': from .guided_diffusion_modules.unet import UNet - + self.denoise_fn = UNet(**unet) self.beta_schedule = beta_schedule @@ -28,7 +30,7 @@ def set_new_noise_schedule(self, device=torch.device('cuda'), phase='train'): timesteps, = betas.shape self.num_timesteps = int(timesteps) - + gammas = np.cumprod(alphas, axis=0) gammas_prev = np.append(1., gammas[:-1]) @@ -46,14 +48,14 @@ def set_new_noise_schedule(self, device=torch.device('cuda'), phase='train'): def predict_start_from_noise(self, y_t, t, noise): return ( - extract(self.sqrt_recip_gammas, t, y_t.shape) * y_t - - extract(self.sqrt_recipm1_gammas, t, y_t.shape) * noise + extract(self.sqrt_recip_gammas, t, y_t.shape) * y_t - + extract(self.sqrt_recipm1_gammas, t, y_t.shape) * noise ) def q_posterior(self, y_0_hat, y_t, t): posterior_mean = ( - extract(self.posterior_mean_coef1, t, y_t.shape) * y_0_hat + - extract(self.posterior_mean_coef2, t, y_t.shape) * y_t + extract(self.posterior_mean_coef1, t, y_t.shape) * y_0_hat + + extract(self.posterior_mean_coef2, t, y_t.shape) * y_t ) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, y_t.shape) return posterior_mean, posterior_log_variance_clipped @@ -61,7 +63,7 @@ def q_posterior(self, y_0_hat, y_t, t): def p_mean_variance(self, y_t, t, clip_denoised: bool, y_cond=None): noise_level = extract(self.gammas, t, x_shape=(1, 1)).to(y_t.device) y_0_hat = self.predict_start_from_noise( - y_t, t=t, noise=self.denoise_fn(torch.cat([y_cond, y_t], dim=1), noise_level)) + y_t, t=t, noise=self.denoise_fn(torch.cat([y_cond, y_t], dim=1), noise_level)) if clip_denoised: y_0_hat.clamp_(-1., 1.) @@ -73,31 +75,30 @@ def p_mean_variance(self, y_t, t, clip_denoised: bool, y_cond=None): def q_sample(self, y_0, sample_gammas, noise=None): noise = default(noise, lambda: torch.randn_like(y_0)) return ( - sample_gammas.sqrt() * y_0 + - (1 - sample_gammas).sqrt() * noise + sample_gammas.sqrt() * y_0 + + (1 - sample_gammas).sqrt() * noise ) @torch.no_grad() def p_sample(self, y_t, t, clip_denoised=True, y_cond=None): model_mean, model_log_variance = self.p_mean_variance( y_t=y_t, t=t, clip_denoised=clip_denoised, y_cond=y_cond) - noise = torch.randn_like(y_t) if any(t>0) else torch.zeros_like(y_t) + noise = torch.randn_like(y_t) if any(t > 0) else torch.zeros_like(y_t) return model_mean + noise * (0.5 * model_log_variance).exp() @torch.no_grad() def restoration(self, y_cond, y_t=None, y_0=None, mask=None, sample_num=8): b, *_ = y_cond.shape - assert self.num_timesteps > sample_num, 'num_timesteps must greater than sample_num' - sample_inter = (self.num_timesteps//sample_num) - + + sample_inter = self.num_timesteps // sample_num y_t = default(y_t, lambda: torch.randn_like(y_cond)) ret_arr = y_t - for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): + for i in tqdm(reversed(range(self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): t = torch.full((b,), i, device=y_cond.device, dtype=torch.long) y_t = self.p_sample(y_t, t, y_cond=y_cond) if mask is not None: - y_t = y_0*(1.-mask) + mask*y_t + y_t = y_0 * (1.0 - mask) + mask * y_t if i % sample_inter == 0: ret_arr = torch.cat([ret_arr, y_t], dim=0) return y_t, ret_arr @@ -106,9 +107,9 @@ def forward(self, y_0, y_cond=None, mask=None, noise=None): # sampling from p(gammas) b, *_ = y_0.shape t = torch.randint(1, self.num_timesteps, (b,), device=y_0.device).long() - gamma_t1 = extract(self.gammas, t-1, x_shape=(1, 1)) + gamma_t1 = extract(self.gammas, t - 1, x_shape=(1, 1)) sqrt_gamma_t2 = extract(self.gammas, t, x_shape=(1, 1)) - sample_gammas = (sqrt_gamma_t2-gamma_t1) * torch.rand((b, 1), device=y_0.device) + gamma_t1 + sample_gammas = (sqrt_gamma_t2 - gamma_t1) * torch.rand((b, 1), device=y_0.device) + gamma_t1 sample_gammas = sample_gammas.view(b, -1) noise = default(noise, lambda: torch.randn_like(y_0)) @@ -116,28 +117,30 @@ def forward(self, y_0, y_cond=None, mask=None, noise=None): y_0=y_0, sample_gammas=sample_gammas.view(-1, 1, 1, 1), noise=noise) if mask is not None: - noise_hat = self.denoise_fn(torch.cat([y_cond, y_noisy*mask+(1.-mask)*y_0], dim=1), sample_gammas) - loss = self.loss_fn(mask*noise, mask*noise_hat) + noise_hat = self.denoise_fn(torch.cat([y_cond, y_noisy * mask + (1. - mask) * y_0], dim=1), sample_gammas) + return self.loss_fn(mask * noise, mask * noise_hat) else: noise_hat = self.denoise_fn(torch.cat([y_cond, y_noisy], dim=1), sample_gammas) - loss = self.loss_fn(noise, noise_hat) - return loss + return self.loss_fn(noise, noise_hat) # gaussian diffusion trainer class def exists(x): return x is not None + def default(val, d): if exists(val): return val return d() if isfunction(d) else d -def extract(a, t, x_shape=(1,1,1,1)): + +def extract(a, t, x_shape=(1, 1, 1, 1)): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) + # beta_schedule function def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac): betas = linear_end * np.ones(n_timestep, dtype=np.float64) @@ -146,6 +149,7 @@ def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac): linear_start, linear_end, warmup_time, dtype=np.float64) return betas + def make_beta_schedule(schedule, n_timestep, linear_start=1e-6, linear_end=1e-2, cosine_s=8e-3): if schedule == 'quad': betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, @@ -166,8 +170,8 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-6, linear_end=1e-2, 1, n_timestep, dtype=np.float64) elif schedule == "cosine": timesteps = ( - torch.arange(n_timestep + 1, dtype=torch.float64) / - n_timestep + cosine_s + torch.arange(n_timestep + 1, dtype=torch.float64) / + n_timestep + cosine_s ) alphas = timesteps / (1 + cosine_s) * math.pi / 2 alphas = torch.cos(alphas).pow(2) @@ -177,5 +181,3 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-6, linear_end=1e-2, else: raise NotImplementedError(schedule) return betas - - diff --git a/models/sr3_modules/unet.py b/models/sr3_modules/unet.py index 28844a0..9f7f216 100755 --- a/models/sr3_modules/unet.py +++ b/models/sr3_modules/unet.py @@ -3,19 +3,20 @@ from torch import nn from inspect import isfunction + class UNet(nn.Module): def __init__( - self, - in_channel=6, - out_channel=3, - inner_channel=32, - norm_groups=32, - channel_mults=(1, 2, 4, 8, 8), - attn_res=(8), - res_blocks=3, - dropout=0, - with_noise_level_emb=True, - image_size=128 + self, + in_channel=6, + out_channel=3, + inner_channel=32, + norm_groups=32, + channel_mults=(1, 2, 4, 8, 8), + attn_res=(8), + res_blocks=3, + dropout=0, + with_noise_level_emb=True, + image_size=128 ): super().__init__() @@ -41,21 +42,24 @@ def __init__( is_last = (ind == num_mults - 1) use_attn = (now_res in attn_res) channel_mult = inner_channel * channel_mults[ind] - for _ in range(0, res_blocks): + for _ in range(res_blocks): downs.append(ResnetBlocWithAttn( - pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn)) + pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, + dropout=dropout, with_attn=use_attn)) feat_channels.append(channel_mult) pre_channel = channel_mult if not is_last: downs.append(Downsample(pre_channel)) feat_channels.append(pre_channel) - now_res = now_res//2 + now_res = now_res // 2 self.downs = nn.ModuleList(downs) self.mid = nn.ModuleList([ - ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, + ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, + norm_groups=norm_groups, dropout=dropout, with_attn=True), - ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, + ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, + norm_groups=norm_groups, dropout=dropout, with_attn=False) ]) @@ -64,43 +68,33 @@ def __init__( is_last = (ind < 1) use_attn = (now_res in attn_res) channel_mult = inner_channel * channel_mults[ind] - for _ in range(0, res_blocks+1): + for _ in range(res_blocks + 1): ups.append(ResnetBlocWithAttn( - pre_channel+feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, - dropout=dropout, with_attn=use_attn)) + pre_channel + feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel, + norm_groups=norm_groups, + dropout=dropout, with_attn=use_attn)) pre_channel = channel_mult if not is_last: ups.append(Upsample(pre_channel)) - now_res = now_res*2 + now_res = now_res * 2 self.ups = nn.ModuleList(ups) self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups) def forward(self, x, time): - t = self.noise_level_mlp(time) if exists( - self.noise_level_mlp) else None - + t = self.noise_level_mlp(time) if exists(self.noise_level_mlp) else None feats = [] for layer in self.downs: - if isinstance(layer, ResnetBlocWithAttn): - x = layer(x, t) - else: - x = layer(x) + x = layer(x, t) if isinstance(layer, ResnetBlocWithAttn) else layer(x) feats.append(x) - for layer in self.mid: - if isinstance(layer, ResnetBlocWithAttn): - x = layer(x, t) - else: - x = layer(x) - + x = layer(x, t) if isinstance(layer, ResnetBlocWithAttn) else layer(x) for layer in self.ups: if isinstance(layer, ResnetBlocWithAttn): x = layer(torch.cat((x, feats.pop()), dim=1), t) else: x = layer(x) - return self.final_conv(x) @@ -123,7 +117,7 @@ def __init__(self, in_channels, out_channels, use_affine_level=False): super(FeatureWiseAffine, self).__init__() self.use_affine_level = use_affine_level self.noise_func = nn.Sequential( - nn.Linear(in_channels, out_channels*(1+self.use_affine_level)) + nn.Linear(in_channels, out_channels * (1 + self.use_affine_level)) ) def forward(self, x, noise_embed): @@ -236,7 +230,7 @@ def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dr def forward(self, x, time_emb): x = self.res_block(x, time_emb) - if(self.with_attn): + if (self.with_attn): x = self.attn(x) return x diff --git a/preprocess/mirflickr25k_preprocess.py b/preprocess/mirflickr25k_preprocess.py index 2433fba..511f909 100755 --- a/preprocess/mirflickr25k_preprocess.py +++ b/preprocess/mirflickr25k_preprocess.py @@ -3,15 +3,16 @@ from sklearn.model_selection import train_test_split import cv2 + def convert_abl(ab, l): """ convert AB and L to RGB """ l = np.expand_dims(l, axis=3) lab = np.concatenate([l, ab], axis=3) - if len(lab.shape)==4: - image_color, image_l = [], [] + if len(lab.shape) == 4: + image_color, image_l = [], [] for _color, _l in zip(lab, l): - out = cv2.cvtColor(_color.astype('uint8'), cv2.COLOR_LAB2RGB) - out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR) + out = cv2.cvtColor(_color.astype('uint8'), cv2.COLOR_LAB2RGB) + out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR) image_color.append(out) image_l.append(cv2.cvtColor(_l.astype('uint8'), cv2.COLOR_GRAY2RGB)) image_color = np.array(image_color) @@ -21,35 +22,38 @@ def convert_abl(ab, l): image_l = cv2.cvtColor(l.astype('uint8'), cv2.COLOR_GRAY2RGB) return image_color, image_l + def load_data(home): - ab1 = np.load(os.path.join(home,"ab/ab", "ab1.npy")) + ab1 = np.load(os.path.join(home, "ab/ab", "ab1.npy")) ab2 = np.load(os.path.join(home, "ab/ab", "ab2.npy")) - ab3 = np.load(os.path.join(home,"ab/ab", "ab3.npy")) + ab3 = np.load(os.path.join(home, "ab/ab", "ab3.npy")) ab = np.concatenate([ab1, ab2, ab3], axis=0) - l = np.load(os.path.join(home,"l/gray_scale.npy")) + l = np.load(os.path.join(home, "l/gray_scale.npy")) return ab, l + if __name__ == '__main__': - home = './' # path saved .npy + home = './' # path saved .npy flist_save_path = './flist' - image_save_path = './images' # images save path + image_save_path = './images' # images save path all_color, all_l = load_data(home) image_color, image_l = convert_abl(all_color, all_l) - - color_save_path, gray_save_path = '{}/color'.format(image_save_path), '{}/gray'.format(image_save_path) + + color_save_path, gray_save_path = f'{image_save_path}/color', f'{image_save_path}/gray' + os.makedirs(color_save_path, exist_ok=True) os.makedirs(gray_save_path, exist_ok=True) for i in range(image_color.shape[0]): - cv2.imwrite('{}/{}.png'.format(color_save_path, str(i).zfill(5)), image_color[i]) + cv2.imwrite(f'{color_save_path}/{str(i).zfill(5)}.png', image_color[i]) for i in range(image_l.shape[0]): - cv2.imwrite('{}/{}.png'.format(gray_save_path, str(i).zfill(5)), image_l[i]) - + cv2.imwrite(f'{gray_save_path}/{str(i).zfill(5)}.png', image_l[i]) + os.makedirs(flist_save_path, exist_ok=True) arr = np.random.permutation(25000) - with open('{}/train.flist'.format(flist_save_path), 'w') as f: + with open(f'{flist_save_path}/train.flist', 'w') as f: for item in arr[:24000]: print(str(item).zfill(5), file=f) - with open('{}/test.flist'.format(flist_save_path), 'w') as f: + with open(f'{flist_save_path}/test.flist', 'w') as f: for item in arr[24000:]: - print(str(item).zfill(5), file=f) \ No newline at end of file + print(str(item).zfill(5), file=f) diff --git a/run.py b/run.py index 4ced338..38b80b1 100755 --- a/run.py +++ b/run.py @@ -10,19 +10,20 @@ from data import define_dataloader from models import create_model, define_network, define_loss, define_metric + def main_worker(gpu, ngpus_per_node, opt): """ threads running on each GPU """ if 'local_rank' not in opt: opt['local_rank'] = opt['global_rank'] = gpu if opt['distributed']: torch.cuda.set_device(int(opt['local_rank'])) - print('using GPU {} for training'.format(int(opt['local_rank']))) - torch.distributed.init_process_group(backend = 'nccl', - init_method = opt['init_method'], - world_size = opt['world_size'], - rank = opt['global_rank'], - group_name='mtorch' - ) + print(f"using GPU {int(opt['local_rank'])} for training") + torch.distributed.init_process_group(backend='nccl', + init_method=opt['init_method'], + world_size=opt['world_size'], + rank=opt['global_rank'], + group_name='mtorch' + ) '''set seed and and cuDNN environment ''' torch.backends.cudnn.enabled = True warnings.warn('You have chosen to use cudnn for accleration. torch.backends.cudnn.enabled=True') @@ -30,11 +31,11 @@ def main_worker(gpu, ngpus_per_node, opt): ''' set logger ''' phase_logger = InfoLogger(opt) - phase_writer = VisualWriter(opt, phase_logger) - phase_logger.info('Create the log file in directory {}.\n'.format(opt['path']['experiments_root'])) + phase_writer = VisualWriter(opt, phase_logger) + phase_logger.info(f"Create the log file in directory {opt['path']['experiments_root']}.\n") '''set networks and dataset''' - phase_loader, val_loader = define_dataloader(phase_logger, opt) # val_loader is None if phase is test. + phase_loader, val_loader = define_dataloader(phase_logger, opt) # val_loader is None if phase is test. networks = [define_network(phase_logger, opt, item_opt) for item_opt in opt['model']['which_networks']] ''' set metrics, loss, optimizer and schedulers ''' @@ -42,17 +43,17 @@ def main_worker(gpu, ngpus_per_node, opt): losses = [define_loss(phase_logger, item_opt) for item_opt in opt['model']['which_losses']] model = create_model( - opt = opt, - networks = networks, - phase_loader = phase_loader, - val_loader = val_loader, - losses = losses, - metrics = metrics, - logger = phase_logger, - writer = phase_writer + opt=opt, + networks=networks, + phase_loader=phase_loader, + val_loader=val_loader, + losses=losses, + metrics=metrics, + logger=phase_logger, + writer=phase_writer ) - phase_logger.info('Begin model {}.'.format(opt['phase'])) + phase_logger.info(f"Begin model {opt['phase']}.") try: if opt['phase'] == 'train': model.train() @@ -60,12 +61,13 @@ def main_worker(gpu, ngpus_per_node, opt): model.test() finally: phase_writer.close() - - + + if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-c', '--config', type=str, default='config/colorization_mirflickr25k.json', help='JSON file for configuration') - parser.add_argument('-p', '--phase', type=str, choices=['train','test'], help='Run train or test', default='train') + parser.add_argument('-c', '--config', type=str, default='config/colorization_mirflickr25k.json', + help='JSON file for configuration') + parser.add_argument('-p', '--phase', type=str, choices=['train', 'test'], help='Run train or test', default='train') parser.add_argument('-b', '--batch', type=int, default=None, help='Batch size in every gpu') parser.add_argument('-gpu', '--gpu_ids', type=str, default=None) parser.add_argument('-d', '--debug', action='store_true') @@ -74,19 +76,19 @@ def main_worker(gpu, ngpus_per_node, opt): ''' parser configs ''' args = parser.parse_args() opt = Praser.parse(args) - + ''' cuda devices ''' gpu_str = ','.join(str(x) for x in opt['gpu_ids']) os.environ['CUDA_VISIBLE_DEVICES'] = gpu_str - print('export CUDA_VISIBLE_DEVICES={}'.format(gpu_str)) + print(f'export CUDA_VISIBLE_DEVICES={gpu_str}') ''' use DistributedDataParallel(DDP) and multiprocessing for multi-gpu training''' # [Todo]: multi GPU on multi machine if opt['distributed']: - ngpus_per_node = len(opt['gpu_ids']) # or torch.cuda.device_count() + ngpus_per_node = len(opt['gpu_ids']) # or torch.cuda.device_count() opt['world_size'] = ngpus_per_node - opt['init_method'] = 'tcp://127.0.0.1:'+ args.port + opt['init_method'] = f'tcp://127.0.0.1:{args.port}' mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt)) else: - opt['world_size'] = 1 - main_worker(0, 1, opt) \ No newline at end of file + opt['world_size'] = 1 + main_worker(0, 1, opt)