Skip to content

Commit eb5e2cf

Browse files
author
bscuser
committed
Add save and load knn functions
1 parent 94f1554 commit eb5e2cf

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

dislib/classification/knn/base.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -160,15 +160,15 @@ def save_model(self, filepath, overwrite=True, save_format="json"):
160160
Format used to save the models.
161161
Examples
162162
--------
163-
>>> from dislib.recommendation import ALS
163+
>>> from dislib.classification import KNeighborsClassifier
164164
>>> import numpy as np
165165
>>> import dislib as ds
166166
>>> data = np.array([[0, 0, 5], [3, 0, 5], [3, 1, 2]])
167-
>>> ratings = csr_matrix(data)
167+
>>> y_data = np.array([2, 1, 1, 2, 0])
168168
>>> train = ds.array(x=ratings, block_size=(1, 1))
169-
>>> als = ALS(tol=0.01, random_state=666, n_f=5, verbose=False)
170-
>>> als.fit(train)
171-
>>> als.save_model("model_als")
169+
>>> knn = KNeighborsClassifier()
170+
>>> knn.fit(train)
171+
>>> knn.save_model("./model_KNN")
172172
"""
173173

174174
# Check overwrite
@@ -207,22 +207,22 @@ def load_model(self, filepath, load_format="json"):
207207
Format used to load the model.
208208
Examples
209209
--------
210-
>>> from dislib.regression import LinearRegression
210+
>>> from dislib.clasiffication import KNeighborsClassifier
211211
>>> import numpy as np
212212
>>> import dislib as ds
213213
>>> x_data = np.array([[1, 2], [2, 0], [3, 1], [4, 4], [5, 3]])
214-
>>> y_data = np.array([2, 1, 1, 2, 4.5])
214+
>>> y_data = np.array([2, 1, 1, 2, 0])
215215
>>> x_test_m = np.array([[3, 2], [4, 4], [1, 3]])
216216
>>> bn, bm = 2, 2
217217
>>> x = ds.array(x=x_data, block_size=(bn, bm))
218218
>>> y = ds.array(x=y_data, block_size=(bn, 1))
219219
>>> test_data_m = ds.array(x=x_test_m, block_size=(bn, bm))
220-
>>> reg = LinearRegression()
221-
>>> reg.fit(x, y)
222-
>>> reg.save_model("./model_LR")
223-
>>> reg_loaded = LinearRegression()
224-
>>> reg_loaded.load_model("./model_LR")
225-
>>> pred = reg_loaded.predict(test_data).collect()
220+
>>> knn = KNeighborsClassifier()
221+
>>> knn.fit(x, y)
222+
>>> knn.save_model("./model_KNN")
223+
>>> knn_loaded = KNeighborsClassifier()
224+
>>> knn_loaded.load_model("./model_KNN")
225+
>>> pred = knn_loaded.predict(test_data).collect()
226226
"""
227227
# Load model
228228
if load_format == "json":

0 commit comments

Comments
 (0)