Skip to content

Commit b46f8d3

Browse files
committed
Simplify training code
1 parent 48ec280 commit b46f8d3

10 files changed

+50
-138
lines changed

.ipynb_checkpoints/CIFAR10-checkpoint.ipynb

Lines changed: 25 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
},
1919
{
2020
"cell_type": "code",
21-
"execution_count": 1,
21+
"execution_count": null,
2222
"metadata": {},
2323
"outputs": [],
2424
"source": [
@@ -34,7 +34,7 @@
3434
"\n",
3535
"from tqdm import tqdm as pbar\n",
3636
"from torch.utils.tensorboard import SummaryWriter\n",
37-
"from models import *"
37+
"from cifar10_models import *"
3838
]
3939
},
4040
{
@@ -46,7 +46,7 @@
4646
},
4747
{
4848
"cell_type": "code",
49-
"execution_count": 2,
49+
"execution_count": null,
5050
"metadata": {},
5151
"outputs": [],
5252
"source": [
@@ -65,16 +65,16 @@
6565
" transforms.ToTensor(),\n",
6666
" transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])\n",
6767
" \n",
68-
"# transform_validation = transforms.Compose([transforms.ToTensor(),\n",
69-
"# transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])\n",
68+
" transform_validation = transforms.Compose([transforms.ToTensor(),\n",
69+
" transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])\n",
7070
" \n",
7171
" transform_validation = transforms.Compose([transforms.ToTensor()])\n",
7272
" \n",
7373
" trainset = torchvision.datasets.CIFAR10(root=params['path'], train=True, transform=transform_train)\n",
7474
" testset = torchvision.datasets.CIFAR10(root=params['path'], train=False, transform=transform_validation)\n",
7575
" \n",
76-
" trainloader = torch.utils.data.DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, num_workers=params['num_workers'])\n",
77-
" testloader = torch.utils.data.DataLoader(testset, batch_size=params['batch_size'], shuffle=False, num_workers=params['num_workers'])\n",
76+
" trainloader = torch.utils.data.DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, num_workers=4)\n",
77+
" testloader = torch.utils.data.DataLoader(testset, batch_size=params['batch_size'], shuffle=False, num_workers=4)\n",
7878
" return trainloader, testloader"
7979
]
8080
},
@@ -87,25 +87,22 @@
8787
},
8888
{
8989
"cell_type": "code",
90-
"execution_count": 3,
90+
"execution_count": null,
9191
"metadata": {},
9292
"outputs": [],
9393
"source": [
9494
"def train_model(model, params):\n",
9595
" \n",
9696
" writer = SummaryWriter('runs/' + params['description'])\n",
9797
" model = model.to(params['device'])\n",
98-
" optimizer = optim.SGD(model.parameters(), lr=params['max_lr'], weight_decay=params['weight_decay'], momentum=0.9, nesterov=True)\n",
98+
" optimizer = optim.AdamW(model.parameters())\n",
9999
" total_updates = params['num_epochs']*len(params['train_loader'])\n",
100-
" scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_updates, eta_min=params['min_lr'])\n",
101100
" \n",
102101
" criterion = nn.CrossEntropyLoss()\n",
103102
" best_accuracy = test_model(model, params)\n",
104103
" best_model = copy.deepcopy(model.state_dict())\n",
105104
" \n",
106105
" for epoch in pbar(range(params['num_epochs'])):\n",
107-
" scheduler.step()\n",
108-
" \n",
109106
" # Each epoch has a training and validation phase\n",
110107
" for phase in ['train', 'validation']:\n",
111108
" \n",
@@ -153,7 +150,7 @@
153150
" \n",
154151
" # Write best weights to disk\n",
155152
" if epoch % params['check_point'] == 0 or epoch == params['num_epochs']-1:\n",
156-
" torch.save(best_model, params['state_dict_path'] + params['description'] + '.pt')\n",
153+
" torch.save(best_model, params['description'] + '.pt')\n",
157154
" \n",
158155
" final_accuracy = test_model(model, params)\n",
159156
" writer.add_text('Final_Accuracy', str(final_accuracy), 0)\n",
@@ -169,7 +166,7 @@
169166
},
170167
{
171168
"cell_type": "code",
172-
"execution_count": 4,
169+
"execution_count": null,
173170
"metadata": {},
174171
"outputs": [],
175172
"source": [
@@ -203,23 +200,11 @@
203200
},
204201
{
205202
"cell_type": "code",
206-
"execution_count": 5,
203+
"execution_count": null,
207204
"metadata": {},
208-
"outputs": [
209-
{
210-
"data": {
211-
"text/plain": [
212-
"<All keys matched successfully>"
213-
]
214-
},
215-
"execution_count": 5,
216-
"metadata": {},
217-
"output_type": "execute_result"
218-
}
219-
],
205+
"outputs": [],
220206
"source": [
221-
"model = resnet18()\n",
222-
"model.load_state_dict(torch.load('/tmp/checkpoint_12000.pth'))"
207+
"model = resnet18()"
223208
]
224209
},
225210
{
@@ -231,74 +216,45 @@
231216
},
232217
{
233218
"cell_type": "code",
234-
"execution_count": 6,
219+
"execution_count": null,
235220
"metadata": {},
236-
"outputs": [
237-
{
238-
"name": "stdout",
239-
"output_type": "stream",
240-
"text": [
241-
"Using cuda:2\n"
242-
]
243-
}
244-
],
221+
"outputs": [],
245222
"source": [
246223
"# Train on cuda if available\n",
247-
"device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')\n",
224+
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
248225
"print(\"Using\", device)"
249226
]
250227
},
251228
{
252229
"cell_type": "code",
253-
"execution_count": 7,
230+
"execution_count": null,
254231
"metadata": {},
255232
"outputs": [],
256233
"source": [
257-
"data_params = {'path': '/raid/data/pytorch_dataset/cifar10',\n",
258-
" 'batch_size': 256, 'num_workers': 4}\n",
234+
"data_params = {'path': '/raid/data/pytorch_dataset/cifar10', 'batch_size': 256}\n",
259235
"\n",
260236
"train_loader, validation_loader = make_dataloaders(data_params)\n",
261237
"\n",
262-
"train_params = {'description': 'ResNet18',\n",
238+
"train_params = {'description': 'Test',\n",
263239
" 'num_epochs': 300,\n",
264-
" 'max_lr': 5e-2, 'min_lr': 1e-5, 'weight_decay': 1e-3,\n",
265240
" 'check_point': 50, 'device': device,\n",
266-
" 'state_dict_path': 'trained_models/',\n",
267241
" 'train_loader': train_loader, 'validation_loader': validation_loader}"
268242
]
269243
},
270244
{
271245
"cell_type": "code",
272-
"execution_count": 8,
246+
"execution_count": null,
273247
"metadata": {},
274248
"outputs": [],
275249
"source": [
276-
"# train_model(model, train_params)"
250+
"train_model(model, train_params)"
277251
]
278252
},
279253
{
280254
"cell_type": "code",
281-
"execution_count": 9,
255+
"execution_count": null,
282256
"metadata": {},
283-
"outputs": [
284-
{
285-
"name": "stderr",
286-
"output_type": "stream",
287-
"text": [
288-
"100%|██████████| 40/40 [00:01<00:00, 23.87it/s]\n"
289-
]
290-
},
291-
{
292-
"data": {
293-
"text/plain": [
294-
"0.7538"
295-
]
296-
},
297-
"execution_count": 9,
298-
"metadata": {},
299-
"output_type": "execute_result"
300-
}
301-
],
257+
"outputs": [],
302258
"source": [
303259
"test_model(model, train_params)"
304260
]
@@ -327,7 +283,7 @@
327283
"name": "python",
328284
"nbconvert_exporter": "python",
329285
"pygments_lexer": "ipython3",
330-
"version": "3.7.4"
286+
"version": "3.7.5"
331287
}
332288
},
333289
"nbformat": 4,

0 commit comments

Comments
 (0)