|
18 | 18 | },
|
19 | 19 | {
|
20 | 20 | "cell_type": "code",
|
21 |
| - "execution_count": 1, |
| 21 | + "execution_count": null, |
22 | 22 | "metadata": {},
|
23 | 23 | "outputs": [],
|
24 | 24 | "source": [
|
|
34 | 34 | "\n",
|
35 | 35 | "from tqdm import tqdm as pbar\n",
|
36 | 36 | "from torch.utils.tensorboard import SummaryWriter\n",
|
37 |
| - "from models import *" |
| 37 | + "from cifar10_models import *" |
38 | 38 | ]
|
39 | 39 | },
|
40 | 40 | {
|
|
46 | 46 | },
|
47 | 47 | {
|
48 | 48 | "cell_type": "code",
|
49 |
| - "execution_count": 2, |
| 49 | + "execution_count": null, |
50 | 50 | "metadata": {},
|
51 | 51 | "outputs": [],
|
52 | 52 | "source": [
|
|
65 | 65 | " transforms.ToTensor(),\n",
|
66 | 66 | " transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])\n",
|
67 | 67 | " \n",
|
68 |
| - "# transform_validation = transforms.Compose([transforms.ToTensor(),\n", |
69 |
| - "# transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])\n", |
| 68 | + " transform_validation = transforms.Compose([transforms.ToTensor(),\n", |
| 69 | + " transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])\n", |
70 | 70 | " \n",
|
71 | 71 | " transform_validation = transforms.Compose([transforms.ToTensor()])\n",
|
72 | 72 | " \n",
|
73 | 73 | " trainset = torchvision.datasets.CIFAR10(root=params['path'], train=True, transform=transform_train)\n",
|
74 | 74 | " testset = torchvision.datasets.CIFAR10(root=params['path'], train=False, transform=transform_validation)\n",
|
75 | 75 | " \n",
|
76 |
| - " trainloader = torch.utils.data.DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, num_workers=params['num_workers'])\n", |
77 |
| - " testloader = torch.utils.data.DataLoader(testset, batch_size=params['batch_size'], shuffle=False, num_workers=params['num_workers'])\n", |
| 76 | + " trainloader = torch.utils.data.DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, num_workers=4)\n", |
| 77 | + " testloader = torch.utils.data.DataLoader(testset, batch_size=params['batch_size'], shuffle=False, num_workers=4)\n", |
78 | 78 | " return trainloader, testloader"
|
79 | 79 | ]
|
80 | 80 | },
|
|
87 | 87 | },
|
88 | 88 | {
|
89 | 89 | "cell_type": "code",
|
90 |
| - "execution_count": 3, |
| 90 | + "execution_count": null, |
91 | 91 | "metadata": {},
|
92 | 92 | "outputs": [],
|
93 | 93 | "source": [
|
94 | 94 | "def train_model(model, params):\n",
|
95 | 95 | " \n",
|
96 | 96 | " writer = SummaryWriter('runs/' + params['description'])\n",
|
97 | 97 | " model = model.to(params['device'])\n",
|
98 |
| - " optimizer = optim.SGD(model.parameters(), lr=params['max_lr'], weight_decay=params['weight_decay'], momentum=0.9, nesterov=True)\n", |
| 98 | + " optimizer = optim.AdamW(model.parameters())\n", |
99 | 99 | " total_updates = params['num_epochs']*len(params['train_loader'])\n",
|
100 |
| - " scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_updates, eta_min=params['min_lr'])\n", |
101 | 100 | " \n",
|
102 | 101 | " criterion = nn.CrossEntropyLoss()\n",
|
103 | 102 | " best_accuracy = test_model(model, params)\n",
|
104 | 103 | " best_model = copy.deepcopy(model.state_dict())\n",
|
105 | 104 | " \n",
|
106 | 105 | " for epoch in pbar(range(params['num_epochs'])):\n",
|
107 |
| - " scheduler.step()\n", |
108 |
| - " \n", |
109 | 106 | " # Each epoch has a training and validation phase\n",
|
110 | 107 | " for phase in ['train', 'validation']:\n",
|
111 | 108 | " \n",
|
|
153 | 150 | " \n",
|
154 | 151 | " # Write best weights to disk\n",
|
155 | 152 | " if epoch % params['check_point'] == 0 or epoch == params['num_epochs']-1:\n",
|
156 |
| - " torch.save(best_model, params['state_dict_path'] + params['description'] + '.pt')\n", |
| 153 | + " torch.save(best_model, params['description'] + '.pt')\n", |
157 | 154 | " \n",
|
158 | 155 | " final_accuracy = test_model(model, params)\n",
|
159 | 156 | " writer.add_text('Final_Accuracy', str(final_accuracy), 0)\n",
|
|
169 | 166 | },
|
170 | 167 | {
|
171 | 168 | "cell_type": "code",
|
172 |
| - "execution_count": 4, |
| 169 | + "execution_count": null, |
173 | 170 | "metadata": {},
|
174 | 171 | "outputs": [],
|
175 | 172 | "source": [
|
|
203 | 200 | },
|
204 | 201 | {
|
205 | 202 | "cell_type": "code",
|
206 |
| - "execution_count": 5, |
| 203 | + "execution_count": null, |
207 | 204 | "metadata": {},
|
208 |
| - "outputs": [ |
209 |
| - { |
210 |
| - "data": { |
211 |
| - "text/plain": [ |
212 |
| - "<All keys matched successfully>" |
213 |
| - ] |
214 |
| - }, |
215 |
| - "execution_count": 5, |
216 |
| - "metadata": {}, |
217 |
| - "output_type": "execute_result" |
218 |
| - } |
219 |
| - ], |
| 205 | + "outputs": [], |
220 | 206 | "source": [
|
221 |
| - "model = resnet18()\n", |
222 |
| - "model.load_state_dict(torch.load('/tmp/checkpoint_12000.pth'))" |
| 207 | + "model = resnet18()" |
223 | 208 | ]
|
224 | 209 | },
|
225 | 210 | {
|
|
231 | 216 | },
|
232 | 217 | {
|
233 | 218 | "cell_type": "code",
|
234 |
| - "execution_count": 6, |
| 219 | + "execution_count": null, |
235 | 220 | "metadata": {},
|
236 |
| - "outputs": [ |
237 |
| - { |
238 |
| - "name": "stdout", |
239 |
| - "output_type": "stream", |
240 |
| - "text": [ |
241 |
| - "Using cuda:2\n" |
242 |
| - ] |
243 |
| - } |
244 |
| - ], |
| 221 | + "outputs": [], |
245 | 222 | "source": [
|
246 | 223 | "# Train on cuda if available\n",
|
247 |
| - "device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')\n", |
| 224 | + "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", |
248 | 225 | "print(\"Using\", device)"
|
249 | 226 | ]
|
250 | 227 | },
|
251 | 228 | {
|
252 | 229 | "cell_type": "code",
|
253 |
| - "execution_count": 7, |
| 230 | + "execution_count": null, |
254 | 231 | "metadata": {},
|
255 | 232 | "outputs": [],
|
256 | 233 | "source": [
|
257 |
| - "data_params = {'path': '/raid/data/pytorch_dataset/cifar10',\n", |
258 |
| - " 'batch_size': 256, 'num_workers': 4}\n", |
| 234 | + "data_params = {'path': '/raid/data/pytorch_dataset/cifar10', 'batch_size': 256}\n", |
259 | 235 | "\n",
|
260 | 236 | "train_loader, validation_loader = make_dataloaders(data_params)\n",
|
261 | 237 | "\n",
|
262 |
| - "train_params = {'description': 'ResNet18',\n", |
| 238 | + "train_params = {'description': 'Test',\n", |
263 | 239 | " 'num_epochs': 300,\n",
|
264 |
| - " 'max_lr': 5e-2, 'min_lr': 1e-5, 'weight_decay': 1e-3,\n", |
265 | 240 | " 'check_point': 50, 'device': device,\n",
|
266 |
| - " 'state_dict_path': 'trained_models/',\n", |
267 | 241 | " 'train_loader': train_loader, 'validation_loader': validation_loader}"
|
268 | 242 | ]
|
269 | 243 | },
|
270 | 244 | {
|
271 | 245 | "cell_type": "code",
|
272 |
| - "execution_count": 8, |
| 246 | + "execution_count": null, |
273 | 247 | "metadata": {},
|
274 | 248 | "outputs": [],
|
275 | 249 | "source": [
|
276 |
| - "# train_model(model, train_params)" |
| 250 | + "train_model(model, train_params)" |
277 | 251 | ]
|
278 | 252 | },
|
279 | 253 | {
|
280 | 254 | "cell_type": "code",
|
281 |
| - "execution_count": 9, |
| 255 | + "execution_count": null, |
282 | 256 | "metadata": {},
|
283 |
| - "outputs": [ |
284 |
| - { |
285 |
| - "name": "stderr", |
286 |
| - "output_type": "stream", |
287 |
| - "text": [ |
288 |
| - "100%|██████████| 40/40 [00:01<00:00, 23.87it/s]\n" |
289 |
| - ] |
290 |
| - }, |
291 |
| - { |
292 |
| - "data": { |
293 |
| - "text/plain": [ |
294 |
| - "0.7538" |
295 |
| - ] |
296 |
| - }, |
297 |
| - "execution_count": 9, |
298 |
| - "metadata": {}, |
299 |
| - "output_type": "execute_result" |
300 |
| - } |
301 |
| - ], |
| 257 | + "outputs": [], |
302 | 258 | "source": [
|
303 | 259 | "test_model(model, train_params)"
|
304 | 260 | ]
|
|
327 | 283 | "name": "python",
|
328 | 284 | "nbconvert_exporter": "python",
|
329 | 285 | "pygments_lexer": "ipython3",
|
330 |
| - "version": "3.7.4" |
| 286 | + "version": "3.7.5" |
331 | 287 | }
|
332 | 288 | },
|
333 | 289 | "nbformat": 4,
|
|
0 commit comments