From 5e1667962d96c44df663ec22f8406963f03fa9c4 Mon Sep 17 00:00:00 2001 From: Huen Oh Date: Wed, 15 Mar 2017 12:01:01 +0900 Subject: [PATCH 1/2] correct : typo _change_ont_hot_label() -> _change_one_hot_label() --- dataset/mnist.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dataset/mnist.py b/dataset/mnist.py index bdea3ea4..a2f7d528 100644 --- a/dataset/mnist.py +++ b/dataset/mnist.py @@ -79,7 +79,7 @@ def init_mnist(): pickle.dump(dataset, f, -1) print("Done!") -def _change_ont_hot_label(X): +def _change_one_hot_label(X): T = np.zeros((X.size, 10)) for idx, row in enumerate(T): row[X[idx]] = 1 @@ -114,8 +114,8 @@ def load_mnist(normalize=True, flatten=True, one_hot_label=False): dataset[key] /= 255.0 if one_hot_label: - dataset['train_label'] = _change_ont_hot_label(dataset['train_label']) - dataset['test_label'] = _change_ont_hot_label(dataset['test_label']) + dataset['train_label'] = _change_one_hot_label(dataset['train_label']) + dataset['test_label'] = _change_one_hot_label(dataset['test_label']) if not flatten: for key in ('train_img', 'test_img'): From 28b06ff63435c4c08b1266ce35e736c8d7bffd15 Mon Sep 17 00:00:00 2001 From: huen oh Date: Fri, 17 Mar 2017 13:54:32 +0900 Subject: [PATCH 2/2] fix : Input to numerical_gradient() in main() 1. The input tuple of 'X' and 'Y' set should be transposed. 2. The 'grad', return value of numerical_gradient should be transposed It is intuitionally right. check with other function such as f(x0, x1) = x0**2 + x1*2 It doesn't work properly with the original code --- ch04/gradient_2d.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ch04/gradient_2d.py b/ch04/gradient_2d.py index cf222ebe..1cf2ef27 100644 --- a/ch04/gradient_2d.py +++ b/ch04/gradient_2d.py @@ -59,7 +59,8 @@ def tangent_line(f, x): X = X.flatten() Y = Y.flatten() - grad = numerical_gradient(function_2, np.array([X, Y]) ) + grad = numerical_gradient(function_2, np.array([X, Y]).T) + grad = grad.T plt.figure() plt.quiver(X, Y, -grad[0], -grad[1], angles="xy",color="#666666")#,headwidth=10,scale=40,color="#444444")