Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions examples/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random

import torch
from torch import Generator
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

Expand All @@ -11,9 +12,10 @@
PATH = os.path.dirname(os.path.abspath(__file__))


def get_data(params, dataset_name, subset=None):
def get_data(params, dataset_name, device, subset=None):
load_test = 'train_all' in params and params['train_all']
test_dataset = None
device = utils.get_device(device)

# Loading the dataset and creating the data loaders and transforms
transform = transforms.Compose([
Expand Down Expand Up @@ -41,12 +43,12 @@ def get_data(params, dataset_name, subset=None):
dataset = Subset(dataset, random.sample(range(len(dataset)), subset))

if load_test:
train_loader = DataLoader(dataset, batch_size=params['train_batch_size'], shuffle=True)
val_loader = DataLoader(test_dataset, batch_size=params['val_batch_size'], shuffle=False)
train_loader = DataLoader(dataset, batch_size=params['train_batch_size'], shuffle=True, generator=Generator(device=device))
val_loader = DataLoader(test_dataset, batch_size=params['val_batch_size'], shuffle=False, generator=Generator(device=device))
else:
train_dataset, val_dataset = utils.split_dataset(dataset, val_split=params['val_split'])
train_loader = DataLoader(train_dataset, batch_size=params['train_batch_size'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=params['val_batch_size'], shuffle=False)
train_dataset, val_dataset = utils.split_dataset(dataset, device, val_split=params['val_split'])
train_loader = DataLoader(train_dataset, batch_size=params['train_batch_size'], shuffle=True, generator=Generator(device=device))
val_loader = DataLoader(val_dataset, batch_size=params['val_batch_size'], shuffle=False, generator=Generator(device=device))

# Analyze dataset
data_batch = next(iter(train_loader))[0]
Expand Down
14 changes: 9 additions & 5 deletions examples/hebbian.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
PATH = os.path.dirname(os.path.abspath(__file__))


def attach_handlers(run, model, optimizer, learning_rule, trainer, evaluator, train_loader, val_loader, params):
def attach_handlers(run, model, optimizer, learning_rule, trainer, evaluator, train_loader, val_loader, params, device):
# Metrics
UnitConvergence(model[0], learning_rule.norm).attach(trainer.engine, 'unit_conv')
device = utils.get_device(device)

UnitConvergence(model[0], learning_rule.norm, device=device).attach(trainer.engine, 'unit_conv')

# Tqdm logger
pbar = ProgressBar(persist=True, bar_format=config.IGNITE_BAR_FORMAT)
Expand Down Expand Up @@ -107,6 +109,8 @@ def attach_handlers(run, model, optimizer, learning_rule, trainer, evaluator, tr


def main(args: Namespace, params: dict, dataset_name, run_postfix=""):
torch.set_default_device(args.device)
torch.set_default_dtype(torch.float64)
# Creating an identifier for this run
identifier = time.strftime("%Y%m%d-%H%M%S")
run = '{}/heb/{}'.format(dataset_name, identifier)
Expand All @@ -128,7 +132,7 @@ def main(args: Namespace, params: dict, dataset_name, run_postfix=""):
print("Device set to '{}'.".format(device))

# Data loaders
train_loader, val_loader = data.get_data(params, dataset_name, subset=10000)
train_loader, val_loader = data.get_data(params, dataset_name, device, subset=10000)

# Creating the learning rule, optimizer, evaluator and trainer
learning_rule = KrotovsRule(delta=params['delta'], k=params['k'], norm=params['norm'], normalize=False)
Expand All @@ -143,7 +147,7 @@ def init_function(h_model):
h_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(h_optimizer, 'max', verbose=True, patience=5,
factor=0.5)
h_trainer = SupervisedTrainer(model=h_model, optimizer=h_optimizer, criterion=h_criterion, device=device)

# Tqdm logger
h_pbar = ProgressBar(persist=False, bar_format=config.IGNITE_BAR_FORMAT)
h_pbar.attach(h_trainer.engine, metric_names='all')
Expand Down Expand Up @@ -193,7 +197,7 @@ def init_function(h_model):

# Handlers
tb_logger = attach_handlers(run, model, optimizer, learning_rule, trainer, evaluator, train_loader, val_loader,
params)
params, device)

# Running the trainer
trainer.run(train_loader=train_loader, epochs=params['epochs'])
Expand Down
30 changes: 15 additions & 15 deletions pytorch_hebbian/handlers/tensorboard_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import torch
import torchvision
from ignite.contrib.handlers.base_logger import BaseHandler, BaseWeightsScalarHandler, BaseWeightsHistHandler
from ignite.contrib.handlers.base_logger import BaseHandler, BaseWeightsScalarHandler, BaseWeightsHandler
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger
from matplotlib import pyplot as plt

Expand Down Expand Up @@ -52,34 +52,34 @@ def __call__(self, engine, logger, event_name):
)


class WeightsHistHandler(BaseWeightsHistHandler):
class WeightsHistHandler(BaseWeightsHandler):
"""Helper handler to log model's weights as histograms.

Args:
model (torch.nn.Module): model to log weights
tag (str, optional): common title for all produced plots. For example, 'generator'
"""

def __init__(self, model, layer_names=None, tag=None):
def __init__(self, model, tag=None):
super(WeightsHistHandler, self).__init__(model, tag=tag)
self.layer_names = layer_names

def __call__(self, engine, logger, event_name):
if not isinstance(logger, TensorboardLogger):
raise RuntimeError("Handler 'WeightsHistHandler' works only with TensorboardLogger")
if not isinstance(logger, TrainsLogger):
raise RuntimeError("Handler 'WeightsHistHandler' works only with TrainsLogger")

global_step = engine.state.get_event_attrib_value(event_name)
tag_prefix = "{}/".format(self.tag) if self.tag else ""
for name, p in self.model.named_parameters():
if self.layer_names is not None:
if name.split('.')[0] not in self.layer_names:
continue
if p.grad is None:
continue

name = name.replace(".", "/")
logger.writer.add_histogram(
tag="{}weights/{}".format(tag_prefix, name),
values=p.data.detach().cpu().numpy(),
global_step=global_step,
title_name, _, series_name = name.partition(".")

logger.grad_helper.add_histogram(
title="{}weights_{}".format(tag_prefix, title_name),
series=series_name,
step=global_step,
hist_data=p.grad.detach().cpu().numpy(),
)


Expand Down Expand Up @@ -183,7 +183,7 @@ def __call__(self, engine, logger, event_name):
)


class ActivationsHistHandler(BaseWeightsHistHandler):
class ActivationsHistHandler(BaseWeightsHandler):
"""Helper handler to log model's activations as histograms.

Args:
Expand Down
5 changes: 3 additions & 2 deletions pytorch_hebbian/handlers/tqdm_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import warnings

import torch
from ignite.contrib.handlers.base_logger import BaseOutputHandler, BaseLogger
from ignite.contrib.handlers.base_logger import BaseLogger
from ignite.contrib.handlers.mlflow_logger import BaseOutputHandler


class OutputHandler(BaseOutputHandler):
Expand All @@ -27,7 +28,7 @@ def __call__(self, engine, logger, event_name):
if not isinstance(logger, TqdmLogger):
raise RuntimeError("Handler 'OutputHandler' works only with TqdmLogger")

metrics = self._setup_output_metrics(engine)
metrics = self._setup_output_metrics_state_attrs(engine)

global_step = self.global_step_transform(engine, event_name)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_hebbian/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def __init__(self, model: torch.nn.Sequential, learning_rule: Union[LearningRule
optimizer: Optimizer, supervised_from: int = -1, freeze_layers: List[str] = None,
complete_forward: bool = False, single_forward: bool = False,
device: Optional[Union[str, torch.device]] = None):
device = utils.get_device(device)
engine = self.create_hebbian_trainer(model, learning_rule, optimizer, device=device)
self.supervised_from = supervised_from
self.freeze_layers = freeze_layers
Expand Down Expand Up @@ -152,6 +151,7 @@ def _prepare_data(self, inputs, model, layer_index):

x = x.view((x.shape[0], -1))
self.logger.debug("Prepared inputs and weights with shapes {} and {}.".format(list(x.shape), list(w.shape)))

return x, w

def _prepare_data2(self, layer, layer_name):
Expand Down
14 changes: 10 additions & 4 deletions pytorch_hebbian/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from matplotlib import pyplot as plt
from torch.utils.data import random_split

import traceback
import sys
# TODO: find better fix
# https://stackoverflow.com/questions/27147300/matplotlib-tcl-asyncdelete-async-handler-deleted-by-the-wrong-thread
matplotlib.use('Agg')
Expand Down Expand Up @@ -47,10 +49,11 @@ def extract_image_patches(x, kernel_size, stride=(1, 1), dilation=1, padding=0):
return patches.view(-1, kernel_size[0], kernel_size[1])


def split_dataset(dataset, val_split):
def split_dataset(dataset, device, val_split):
val_size = int(val_split * len(dataset))
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
device = get_device(device)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=torch.Generator(device=device))

return train_dataset, val_dataset

Expand Down Expand Up @@ -122,19 +125,22 @@ def get_device(device=None):
if device is None:
if torch.cuda.is_available():
device = 'cuda'
torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.set_default_device('cuda')
torch.set_default_dtype(torch.float64)
else:
device = 'cpu'
elif device == 'cuda':
if torch.cuda.is_available():
device = 'cuda'
torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.set_default_device('cuda')
torch.set_default_dtype(torch.float64)
else:
device = 'cpu'
else:
device = 'cpu'

if device == 'cuda':
logger.info("CUDA device set to '{}'.".format(torch.cuda.get_device_name(0)))


return device
4 changes: 2 additions & 2 deletions requirements_base.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
tensorboard==2.3.0
pytorch-ignite==0.4.*
tensorboard
pytorch-ignite
matplotlib
numpy
Pillow
Expand Down
4 changes: 2 additions & 2 deletions requirements_gpu.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-r requirements_base.txt
--find-links https://download.pytorch.org/whl/torch_stable.html
torch==1.6.0+cu101
torchvision==0.7.0+cu101
torch==2.2.1+cu118
torchvision==0.17.1+cu118
pynvml
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
],
python_requires='>=3.6',
install_requires=[
'torch==1.6.0',
'torchvision==0.7.0',
'tensorboard==2.3.0',
'pytorch-ignite==0.4.*',
'torch',
'torchvision',
'tensorboard',
'pytorch-ignite',
'matplotlib',
'numpy',
'Pillow',
Expand Down