From a2ef278fa2d3a134ba2e3df4cc5baf13c3ce0b7d Mon Sep 17 00:00:00 2001 From: paisano Date: Mon, 26 Feb 2024 07:09:46 -0800 Subject: [PATCH 1/2] Refactor to work with current PyTorch/Ignite --- examples/data.py | 14 +++++---- examples/hebbian.py | 14 +++++---- .../handlers/tensorboard_logger.py | 30 +++++++++---------- pytorch_hebbian/handlers/tqdm_logger.py | 5 ++-- pytorch_hebbian/trainers.py | 2 +- pytorch_hebbian/utils.py | 14 ++++++--- setup.py | 8 ++--- 7 files changed, 50 insertions(+), 37 deletions(-) diff --git a/examples/data.py b/examples/data.py index 68e6be8..a81b474 100644 --- a/examples/data.py +++ b/examples/data.py @@ -3,6 +3,7 @@ import random import torch +from torch import Generator from torch.utils.data import DataLoader, Subset from torchvision import datasets, transforms @@ -11,9 +12,10 @@ PATH = os.path.dirname(os.path.abspath(__file__)) -def get_data(params, dataset_name, subset=None): +def get_data(params, dataset_name, device, subset=None): load_test = 'train_all' in params and params['train_all'] test_dataset = None + device = utils.get_device(device) # Loading the dataset and creating the data loaders and transforms transform = transforms.Compose([ @@ -41,12 +43,12 @@ def get_data(params, dataset_name, subset=None): dataset = Subset(dataset, random.sample(range(len(dataset)), subset)) if load_test: - train_loader = DataLoader(dataset, batch_size=params['train_batch_size'], shuffle=True) - val_loader = DataLoader(test_dataset, batch_size=params['val_batch_size'], shuffle=False) + train_loader = DataLoader(dataset, batch_size=params['train_batch_size'], shuffle=True, generator=Generator(device=device)) + val_loader = DataLoader(test_dataset, batch_size=params['val_batch_size'], shuffle=False, generator=Generator(device=device)) else: - train_dataset, val_dataset = utils.split_dataset(dataset, val_split=params['val_split']) - train_loader = DataLoader(train_dataset, batch_size=params['train_batch_size'], shuffle=True) - val_loader = DataLoader(val_dataset, batch_size=params['val_batch_size'], shuffle=False) + train_dataset, val_dataset = utils.split_dataset(dataset, device, val_split=params['val_split']) + train_loader = DataLoader(train_dataset, batch_size=params['train_batch_size'], shuffle=True, generator=Generator(device=device)) + val_loader = DataLoader(val_dataset, batch_size=params['val_batch_size'], shuffle=False, generator=Generator(device=device)) # Analyze dataset data_batch = next(iter(train_loader))[0] diff --git a/examples/hebbian.py b/examples/hebbian.py index b3c5add..0cb818c 100644 --- a/examples/hebbian.py +++ b/examples/hebbian.py @@ -25,9 +25,11 @@ PATH = os.path.dirname(os.path.abspath(__file__)) -def attach_handlers(run, model, optimizer, learning_rule, trainer, evaluator, train_loader, val_loader, params): +def attach_handlers(run, model, optimizer, learning_rule, trainer, evaluator, train_loader, val_loader, params, device): # Metrics - UnitConvergence(model[0], learning_rule.norm).attach(trainer.engine, 'unit_conv') + device = utils.get_device(device) + + UnitConvergence(model[0], learning_rule.norm, device=device).attach(trainer.engine, 'unit_conv') # Tqdm logger pbar = ProgressBar(persist=True, bar_format=config.IGNITE_BAR_FORMAT) @@ -107,6 +109,8 @@ def attach_handlers(run, model, optimizer, learning_rule, trainer, evaluator, tr def main(args: Namespace, params: dict, dataset_name, run_postfix=""): + torch.set_default_device(args.device) + torch.set_default_dtype(torch.float64) # Creating an identifier for this run identifier = time.strftime("%Y%m%d-%H%M%S") run = '{}/heb/{}'.format(dataset_name, identifier) @@ -128,7 +132,7 @@ def main(args: Namespace, params: dict, dataset_name, run_postfix=""): print("Device set to '{}'.".format(device)) # Data loaders - train_loader, val_loader = data.get_data(params, dataset_name, subset=10000) + train_loader, val_loader = data.get_data(params, dataset_name, device, subset=10000) # Creating the learning rule, optimizer, evaluator and trainer learning_rule = KrotovsRule(delta=params['delta'], k=params['k'], norm=params['norm'], normalize=False) @@ -143,7 +147,7 @@ def init_function(h_model): h_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(h_optimizer, 'max', verbose=True, patience=5, factor=0.5) h_trainer = SupervisedTrainer(model=h_model, optimizer=h_optimizer, criterion=h_criterion, device=device) - + # Tqdm logger h_pbar = ProgressBar(persist=False, bar_format=config.IGNITE_BAR_FORMAT) h_pbar.attach(h_trainer.engine, metric_names='all') @@ -193,7 +197,7 @@ def init_function(h_model): # Handlers tb_logger = attach_handlers(run, model, optimizer, learning_rule, trainer, evaluator, train_loader, val_loader, - params) + params, device) # Running the trainer trainer.run(train_loader=train_loader, epochs=params['epochs']) diff --git a/pytorch_hebbian/handlers/tensorboard_logger.py b/pytorch_hebbian/handlers/tensorboard_logger.py index 4360515..4d258bd 100644 --- a/pytorch_hebbian/handlers/tensorboard_logger.py +++ b/pytorch_hebbian/handlers/tensorboard_logger.py @@ -5,7 +5,7 @@ import numpy as np import torch import torchvision -from ignite.contrib.handlers.base_logger import BaseHandler, BaseWeightsScalarHandler, BaseWeightsHistHandler +from ignite.contrib.handlers.base_logger import BaseHandler, BaseWeightsScalarHandler, BaseWeightsHandler from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger from matplotlib import pyplot as plt @@ -52,7 +52,7 @@ def __call__(self, engine, logger, event_name): ) -class WeightsHistHandler(BaseWeightsHistHandler): +class WeightsHistHandler(BaseWeightsHandler): """Helper handler to log model's weights as histograms. Args: @@ -60,26 +60,26 @@ class WeightsHistHandler(BaseWeightsHistHandler): tag (str, optional): common title for all produced plots. For example, 'generator' """ - def __init__(self, model, layer_names=None, tag=None): + def __init__(self, model, tag=None): super(WeightsHistHandler, self).__init__(model, tag=tag) - self.layer_names = layer_names def __call__(self, engine, logger, event_name): - if not isinstance(logger, TensorboardLogger): - raise RuntimeError("Handler 'WeightsHistHandler' works only with TensorboardLogger") + if not isinstance(logger, TrainsLogger): + raise RuntimeError("Handler 'WeightsHistHandler' works only with TrainsLogger") global_step = engine.state.get_event_attrib_value(event_name) tag_prefix = "{}/".format(self.tag) if self.tag else "" for name, p in self.model.named_parameters(): - if self.layer_names is not None: - if name.split('.')[0] not in self.layer_names: - continue + if p.grad is None: + continue - name = name.replace(".", "/") - logger.writer.add_histogram( - tag="{}weights/{}".format(tag_prefix, name), - values=p.data.detach().cpu().numpy(), - global_step=global_step, + title_name, _, series_name = name.partition(".") + + logger.grad_helper.add_histogram( + title="{}weights_{}".format(tag_prefix, title_name), + series=series_name, + step=global_step, + hist_data=p.grad.detach().cpu().numpy(), ) @@ -183,7 +183,7 @@ def __call__(self, engine, logger, event_name): ) -class ActivationsHistHandler(BaseWeightsHistHandler): +class ActivationsHistHandler(BaseWeightsHandler): """Helper handler to log model's activations as histograms. Args: diff --git a/pytorch_hebbian/handlers/tqdm_logger.py b/pytorch_hebbian/handlers/tqdm_logger.py index 49aa556..9daacd1 100644 --- a/pytorch_hebbian/handlers/tqdm_logger.py +++ b/pytorch_hebbian/handlers/tqdm_logger.py @@ -2,7 +2,8 @@ import warnings import torch -from ignite.contrib.handlers.base_logger import BaseOutputHandler, BaseLogger +from ignite.contrib.handlers.base_logger import BaseLogger +from ignite.contrib.handlers.mlflow_logger import BaseOutputHandler class OutputHandler(BaseOutputHandler): @@ -27,7 +28,7 @@ def __call__(self, engine, logger, event_name): if not isinstance(logger, TqdmLogger): raise RuntimeError("Handler 'OutputHandler' works only with TqdmLogger") - metrics = self._setup_output_metrics(engine) + metrics = self._setup_output_metrics_state_attrs(engine) global_step = self.global_step_transform(engine, event_name) diff --git a/pytorch_hebbian/trainers.py b/pytorch_hebbian/trainers.py index 6adcf66..0f4d760 100644 --- a/pytorch_hebbian/trainers.py +++ b/pytorch_hebbian/trainers.py @@ -74,7 +74,6 @@ def __init__(self, model: torch.nn.Sequential, learning_rule: Union[LearningRule optimizer: Optimizer, supervised_from: int = -1, freeze_layers: List[str] = None, complete_forward: bool = False, single_forward: bool = False, device: Optional[Union[str, torch.device]] = None): - device = utils.get_device(device) engine = self.create_hebbian_trainer(model, learning_rule, optimizer, device=device) self.supervised_from = supervised_from self.freeze_layers = freeze_layers @@ -152,6 +151,7 @@ def _prepare_data(self, inputs, model, layer_index): x = x.view((x.shape[0], -1)) self.logger.debug("Prepared inputs and weights with shapes {} and {}.".format(list(x.shape), list(w.shape))) + return x, w def _prepare_data2(self, layer, layer_name): diff --git a/pytorch_hebbian/utils.py b/pytorch_hebbian/utils.py index ccb281a..d461c94 100644 --- a/pytorch_hebbian/utils.py +++ b/pytorch_hebbian/utils.py @@ -8,6 +8,8 @@ from matplotlib import pyplot as plt from torch.utils.data import random_split +import traceback +import sys # TODO: find better fix # https://stackoverflow.com/questions/27147300/matplotlib-tcl-asyncdelete-async-handler-deleted-by-the-wrong-thread matplotlib.use('Agg') @@ -47,10 +49,11 @@ def extract_image_patches(x, kernel_size, stride=(1, 1), dilation=1, padding=0): return patches.view(-1, kernel_size[0], kernel_size[1]) -def split_dataset(dataset, val_split): +def split_dataset(dataset, device, val_split): val_size = int(val_split * len(dataset)) train_size = len(dataset) - val_size - train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) + device = get_device(device) + train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=torch.Generator(device=device)) return train_dataset, val_dataset @@ -122,13 +125,15 @@ def get_device(device=None): if device is None: if torch.cuda.is_available(): device = 'cuda' - torch.set_default_tensor_type('torch.cuda.FloatTensor') + torch.set_default_device('cuda') + torch.set_default_dtype(torch.float64) else: device = 'cpu' elif device == 'cuda': if torch.cuda.is_available(): device = 'cuda' - torch.set_default_tensor_type('torch.cuda.FloatTensor') + torch.set_default_device('cuda') + torch.set_default_dtype(torch.float64) else: device = 'cpu' else: @@ -136,5 +141,6 @@ def get_device(device=None): if device == 'cuda': logger.info("CUDA device set to '{}'.".format(torch.cuda.get_device_name(0))) + return device diff --git a/setup.py b/setup.py index c543d4f..93659d3 100644 --- a/setup.py +++ b/setup.py @@ -20,10 +20,10 @@ ], python_requires='>=3.6', install_requires=[ - 'torch==1.6.0', - 'torchvision==0.7.0', - 'tensorboard==2.3.0', - 'pytorch-ignite==0.4.*', + 'torch', + 'torchvision', + 'tensorboard', + 'pytorch-ignite', 'matplotlib', 'numpy', 'Pillow', From c388628dcb899ff1047529eea53994838fa12b58 Mon Sep 17 00:00:00 2001 From: paisano Date: Wed, 28 Feb 2024 23:05:51 -0800 Subject: [PATCH 2/2] Update requirements to newest working version. --- requirements_base.txt | 4 ++-- requirements_gpu.txt | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/requirements_base.txt b/requirements_base.txt index 0eac9ad..4ec177c 100644 --- a/requirements_base.txt +++ b/requirements_base.txt @@ -1,5 +1,5 @@ -tensorboard==2.3.0 -pytorch-ignite==0.4.* +tensorboard +pytorch-ignite matplotlib numpy Pillow diff --git a/requirements_gpu.txt b/requirements_gpu.txt index 0852269..03d5d15 100644 --- a/requirements_gpu.txt +++ b/requirements_gpu.txt @@ -1,5 +1,5 @@ -r requirements_base.txt --find-links https://download.pytorch.org/whl/torch_stable.html -torch==1.6.0+cu101 -torchvision==0.7.0+cu101 +torch==2.2.1+cu118 +torchvision==0.17.1+cu118 pynvml