Skip to content

lus105/DeepVisionXplain

Repository files navigation

DeepVisionXplain

ViT and CNN explainability comparison

pytorch lightning hydra wandb

Project Description

Neural network training environment (including various MLOps tools) designed to compare the explainability of CNNs (using Class Activation Maps) and ViTs (using attention rollout). Project is based on DeepTrainer.

Conda installation

# clone project
git clone https://github.com/lus105/DeepVisionXplain.git
# change directory
cd DeepVisionXplain
# update conda
conda update -n base conda
# create conda environment and install dependencies
conda env create -f environment.yaml -n DeepVisionXplain
# activate conda environment
conda activate DeepVisionXplain

Quickstart

Train model with default configuration (check if environment is properly set up):

# train on CPU (mnist dataset)
python src/train.py trainer=cpu
# train on GPU (mnist dataset)
python src/train.py trainer=gpu

Modified ViT architecture

Modified CNN architecture

Two cnn models were trained for experimentation.

full size: efficientnet_v2_s. features.7 -> [1, 1280, 7, 7]
downscaled: efficientnet_v2_s. features.6.0.block.0 -> [1, 960, 14, 14]
full size: mobilenet_v3_large. features.16 -> [1, 960, 7, 7]
downscaled: mobilenet_v3_large. features.13.block.0 -> [1, 672, 14, 14]

Train cnn/vit model:

# train cnn
python src/train.py experiment=train_cnn
# train vit
python src/train.py experiment=train_vit

Train cnn/vit model with hparams search:

# train cnn
python src/train.py hparams_search=cnn_optuna experiment=train_cnn

# train vit
python src/train.py hparams_search=vit_optuna experiment=train_vit

Run explainability segmentation evaluation for all models:

scripts\eval_segmentation.bat

Resources

References