Skip to content

Run on Google Colab #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
__pycache__
*/__pycache__
*/__pycache__
.vscode/settings.json
14 changes: 7 additions & 7 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

def parse_arguments():
parser = argparse.ArgumentParser(description='Parameters to train your model.')
parser.add_argument('--imgs_folder', default='./data/DUTS/DUTS-TE/DUTS-TE-Image', help='Path to folder containing images', type=str)
parser.add_argument('--model_path', default='/home/tarasha/Projects/sairajk/saliency/SOD_2/models/0.7_wbce_w0-1_w1-1.12/best_epoch-138_acc-0.9107_loss-0.1300.pt', help='Path to model', type=str)
parser.add_argument('--imgs_folder', default='/content/test', help='Path to folder containing images', type=str)
parser.add_argument('--model_path', default='/content/drive/My Drive/Colab Notebooks/models/model_epoch-008_mae-0.1553_loss-0.3755.pth', help='Path to model', type=str)
parser.add_argument('--use_gpu', default=True, help='Whether to use GPU or not', type=bool)
parser.add_argument('--img_size', default=256, help='Image size to be used', type=int)
parser.add_argument('--bs', default=24, help='Batch Size for testing', type=int)
Expand Down Expand Up @@ -59,9 +59,9 @@ def run_inference(args):
pred_masks_round = np.squeeze(pred_masks.round().cpu().numpy(), axis=(0, 1))

print('Image :', batch_idx)
cv2.imshow('Input Image', img_np)
cv2.imshow('Generated Saliency Mask', pred_masks_raw)
cv2.imshow('Rounded-off Saliency Mask', pred_masks_round)
cv2.imwrite('/content/test/'+'Input Image'+ str(batch_idx) + '.png', img_np)
cv2.imwrite('/content/test/'+'Generated Saliency Mask'+ str(batch_idx) + '.png', pred_masks_raw)
cv2.imwrite('/content/test/'+'Rounded-off Saliency Mask'+ str(batch_idx) + '.png', pred_masks_round)

key = cv2.waitKey(0)
if key == ord('q'):
Expand Down Expand Up @@ -101,5 +101,5 @@ def calculate_mae(args):

if __name__ == '__main__':
rt_args = parse_arguments()
calculate_mae(rt_args)
run_inference(rt_args)
#calculate_mae(rt_args)
run_inference(rt_args)
154 changes: 154 additions & 0 deletions pyramid_train&infer.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "pyramid train.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true,
"mount_file_id": "1v4ZKD8onuSrchfAXnvz1RE5IyNpn1Qe6",
"authorship_tag": "ABX9TyNthtxYesufB5XHFxbMXj2E",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/ParthPatel-ES/Pyramid-SD_PyTorch/blob/master/pyramid_train.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eartF9D1Mpvc"
},
"source": [
"**Pyramid-SD** Train"
]
},
{
"cell_type": "code",
"metadata": {
"id": "fdmwMZUEMf48",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "d9483748-2036-4d1b-a375-a280315d156e"
},
"source": [
"!git clone https://github.com/parth15041995/Pyramid-SD_PyTorch.git"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Cloning into 'Pyramid-SD_PyTorch'...\n",
"remote: Enumerating objects: 79, done.\u001b[K\n",
"remote: Counting objects: 100% (79/79), done.\u001b[K\n",
"remote: Compressing objects: 100% (66/66), done.\u001b[K\n",
"remote: Total 79 (delta 24), reused 27 (delta 8), pack-reused 0\u001b[K\n",
"Unpacking objects: 100% (79/79), done.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "6GjSmDGuMocn",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "2c334305-476f-4caf-c8cf-0e1b0fc670e3"
},
"source": [
"cd /content/Pyramid-SD_PyTorch"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"/content/Pyramid-SD_PyTorch\n"
],
"name": "stdout"
}
]
},
{
"source": [
"Download the DUTS Dataset "
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"metadata": {
"id": "iDN207WOYEyF"
},
"source": [
"!cp '/content/drive/My Drive/DUTS.zip' /content/Pyramid-SD_PyTorch"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3wGdelkJLURT"
},
"source": [
"import zipfile\n",
"with zipfile.ZipFile('DUTS.zip', 'r') as zip_ref:\n",
" zip_ref.extractall('./data')"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "4eSjb56GL62g"
},
"source": [
"%%shell\n",
"python train.py --epochs 1 --n_worker 192"
],
"execution_count": null,
"outputs": []
},
{
"source": [
"Download pre-trained model. Link in the README.md"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"metadata": {
"id": "pFe9TS-KsUxb"
},
"source": [
"%%shell\n",
"python inference.py --model_path '/content/drive/MyDrive/Colab Notebooks/models/best-model_epoch-204_mae-0.0505_loss-0.1370.pth' --imgs_folder '/content/Pyramid-SD_PyTorch/data/DUTS/DUTS-TE'"
],
"execution_count": null,
"outputs": []
}
]
}
19 changes: 11 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from src.dataloader import SODLoader
from src.model import SODModel
from src.loss import EdgeSaliencyLoss

import math

def parse_arguments():
parser = argparse.ArgumentParser(description='Parameters to train your model.')
Expand All @@ -24,9 +24,9 @@ def parse_arguments():
parser.add_argument('--wd', default=0., help='L2 Weight decay', type=float)
parser.add_argument('--img_size', default=256, help='Image size to be used for training', type=int)
parser.add_argument('--aug', default=True, help='Whether to use Image augmentation', type=bool)
parser.add_argument('--n_worker', default=2, help='Number of workers to use for loading data', type=int)
parser.add_argument('--n_worker', default=2, help='Number of thread to use for loading data(uses RAM)', type=int)
parser.add_argument('--test_interval', default=2, help='Number of epochs after which to test the weights', type=int)
parser.add_argument('--save_interval', default=None, help='Number of epochs after which to save the weights. If None, does not save', type=int)
parser.add_argument('--save_interval', default=2, help='Number of epochs after which to save the weights. If None, does not save', type=int)
parser.add_argument('--save_opt', default=False, help='Whether to save optimizer along with model weights or not', type=bool)
parser.add_argument('--log_interval', default=250, help='Logging interval (in #batches)', type=int)
parser.add_argument('--res_mod', default=None, help='Path to the model to resume from', type=str)
Expand Down Expand Up @@ -119,7 +119,9 @@ def train(self):
ca_act_reg))

# Validation
if epoch % self.test_interval == 0 or epoch % self.save_interval == 0:

if math.fmod(epoch, self.test_interval) == 0 or math.fmod(epoch, self.save_interval) == 0:

te_avg_loss, te_acc, te_pre, te_rec, te_mae = self.test()
mod_chkpt = {'epoch': epoch,
'test_mae' : float(te_mae),
Expand All @@ -140,21 +142,22 @@ def train(self):

# Save the best model
if te_mae < best_test_mae:

best_test_mae = te_mae
torch.save(mod_chkpt, self.model_path + 'weights/best-model_epoch-{:03}_mae-{:.4f}_loss-{:.4f}.pth'.
torch.save(mod_chkpt, self.model_path + '/weights/best-model_epoch-{:03}_mae-{:.4f}_loss-{:.4f}.pth'.
format(epoch, best_test_mae, te_avg_loss))
if self.save_opt:
torch.save(opt_chkpt, self.model_path + 'optimizers/best-opt_epoch-{:03}_mae-{:.4f}_loss-{:.4f}.pth'.
torch.save(opt_chkpt, self.model_path + '/optimizers/best-opt_epoch-{:03}_mae-{:.4f}_loss-{:.4f}.pth'.
format(epoch, best_test_mae, te_avg_loss))
print('Best Model Saved !!!\n')
continue

# Save model at regular intervals
if self.save_interval is not None and epoch % self.save_interval == 0:
torch.save(mod_chkpt, self.model_path + 'weights/model_epoch-{:03}_mae-{:.4f}_loss-{:.4f}.pth'.
torch.save(mod_chkpt, self.model_path + '/weights/model_epoch-{:03}_mae-{:.4f}_loss-{:.4f}.pth'.
format(epoch, te_mae, te_avg_loss))
if self.save_opt:
torch.save(opt_chkpt, self.model_path + 'optimizers/opt_epoch-{:03}_mae-{:.4f}_loss-{:.4f}.pth'.
torch.save(opt_chkpt, self.model_path + '/optimizers/opt_epoch-{:03}_mae-{:.4f}_loss-{:.4f}.pth'.
format(epoch, best_test_mae, te_avg_loss))
print('Model Saved !!!\n')
continue
Expand Down