Skip to content

Commit 86ab5ce

Browse files
authored
feature/svm (#79)
* fix naive bayes index, add svm support * add `svm` chain * add `svm chain` to pymilo router in `pymilo_func.py` * add svm models' support to pymilo param * add svm chain to `test_pymilo.py` * add LinearSVC test * add LinearSVR test * add NuSVC test * add NuSVR test * add OneClassSVM test * add SVC test * add SVR test * add svm models' test runner * add exported svm model's folder to gitignore * run tests + remove unused imports + remove trailing whitespaces * remove `dual` parameter * support for `numpy.intc` serialization added * transporting `numpy.intc` data structure added * `CHANGELOG.md` updated * add relative import * deep serialize LaberBinarizer ndarray * deep deserialize_ndarray function implemented * deep serialize_ndarray function implemented * refactor pure primitive data type or numpy primitive data type deserialization into a function * add deep ndarray serializer to serialize_tuple * add deep ndarray serializer to serialize_dict * add deep ndarray serializer to serialize function * refactor `deserialize` function to have deep ndarray deserialization * refactor primitive type deserialization * refactor primitive type deserialization * add deep ndarray deserializer to `get_deserialized_dict` function * add & update docstring * strict the condition * add pymilo bypass * handle complex dtype str * add pymilo bypass * change not(a in b) to a not in b * apply `autopep8.sh` * refactor `check_str_in_iterable` function * remove unused import * refactor `isinstance(X,dict) and Y in X` with `check check_str_in_iterable(Y,X)` * update check_str_in_iterable function * apply `autopep8.sh` * update function name * `CHANGELOG.md` updated * user `enumerate(x)` instead of `range(len(x))` * use `isinstance` instead of `type` * `CHANGELOG.md` updated * remove unused import * update IDs * raise error & warning instead of print * refactor `get_homogeneous_type` function * Using a predefined string imported from `pymilo_praram.py` instead. * remove trailing whitespaces
1 parent 06716ac commit 86ab5ce

23 files changed

+524
-68
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,5 @@ gen
104104
/tests/exported_decision_trees
105105
/tests/exported_clusterings
106106
/tests/exported_naive_bayes
107+
/tests/exported_svms
107108
/.VSCodeCounter

CHANGELOG.md

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,31 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
66

77
## [Unreleased]
88
### Added
9+
- `deserialize_primitive_type` function in `GeneralDataStructureTransporter`
10+
- `is_deserialized_ndarray` function in `GeneralDataStructureTransporter`
11+
- `deep_deserialize_ndarray` function in `GeneralDataStructureTransporter`
12+
- `deep_serialize_ndarray` function in `GeneralDataStructureTransporter`
13+
- `SVR` model
14+
- `SVC` model
15+
- `One Class SVM` model
16+
- `NuSVR` model
17+
- `NuSVC` model
18+
- `Linear SVR` model
19+
- `Linear SVC` model
20+
- SVM models test runner
21+
- SVM chain
922
### Changed
23+
- `TreeTransporter` updated
24+
- `get_homogeneous_type` function in `util.py` updated
25+
- `GeneralDataStructureTransporter` updated to use deep ndarray serializer & deserializer
26+
- `check_str_in_iterable` updated
27+
- `Label Binarizer` Transporter updated
28+
- `Function` Transporter updated
29+
- `CFNode` Transporter updated
30+
- `Bisecting Tree` Transporter updated
31+
- Tests config modified
32+
- SVM params initialized in `pymilo_param`
33+
- SVM support added to `pymilo_func.py`
1034
- `SUPPORTED_MODELS.md` updated
1135
- `README.md` updated
1236
## [0.5] - 2024-01-31
@@ -19,7 +43,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
1943
- `Bernoulli Naive Bayes` model declared as `BernoulliNB` model
2044
- `Categorical Naive Bayes` model declared as `CategoricalNB` model
2145
- Naive Bayes models test runner
22-
- Naive Bayes chain
46+
- Naive Bayes chain
2347
### Changed
2448
- `Transport` function of `AbstractTransporter` updated
2549
- fix the order of `CFNode` fields serialization in `CFNodeTransporter`

SUPPORTED_MODELS.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
* [Decision Trees](#scikit-learn-trees)
1111
* [Clustering Models](#scikit-learn-clustering)
1212
* [Naive Bayes](#scikit-learn-naivebayes)
13+
* [Support Vector Machine](#scikit-learn-svm)
1314

1415
<h2 id="scikit-learn">Scikit-Learn</h2>
1516
<h3 id="scikit-learn-linear">Linear Models</h3>
@@ -396,3 +397,51 @@
396397
<td>>=0.5</td>
397398
</tr>
398399
</table>
400+
401+
<h3 id="scikit-learn-svm">Support Vector Machine</h3>
402+
403+
📚 <a href="https://scikit-learn.org/stable/modules/svm.html" target="_blank"><b>Models Document</b></a>
404+
405+
406+
<table>
407+
<tr align="center">
408+
<th>ID</th>
409+
<th>Model Name</th>
410+
<th>PyMilo Version</th>
411+
</tr>
412+
<tr align="center">
413+
<td>1</td>
414+
<td><b>Linear SVC</b></td>
415+
<td>>=0.6</td>
416+
</tr>
417+
<tr align="center">
418+
<td>2</td>
419+
<td><b>Linear SVR</b></td>
420+
<td>>=0.6</td>
421+
</tr>
422+
<tr align="center">
423+
<td>3</td>
424+
<td><b>NuSVC</b></td>
425+
<td>>=0.6</td>
426+
</tr>
427+
<tr align="center">
428+
<td>4</td>
429+
<td><b>NuSVR</b></td>
430+
<td>>=0.6</td>
431+
</tr>
432+
<tr align="center">
433+
<td>5</td>
434+
<td><b>One Class SVM</b></td>
435+
<td>>=0.6</td>
436+
</tr>
437+
<tr align="center">
438+
<td>6</td>
439+
<td><b>SVC</b></td>
440+
<td>>=0.6</td>
441+
</tr>
442+
<tr align="center">
443+
<td>7</td>
444+
<td><b>SVR</b></td>
445+
<td>>=0.6</td>
446+
</tr>
447+
</table>

pymilo/chains/svm_chain.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# -*- coding: utf-8 -*-
2+
"""PyMilo chain for svm models."""
3+
from ..transporters.transporter import Command
4+
5+
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter
6+
from ..transporters.randomstate_transporter import RandomStateTransporter
7+
8+
from ..pymilo_param import SKLEARN_SVM_TABLE
9+
from ..exceptions.serialize_exception import PymiloSerializationException, SerilaizatoinErrorTypes
10+
from ..exceptions.deserialize_exception import PymiloDeserializationException, DeSerilaizatoinErrorTypes
11+
from traceback import format_exc
12+
13+
SVM_CHAIN = {
14+
"GeneralDataStructureTransporter": GeneralDataStructureTransporter(),
15+
"RandomStateTransporter": RandomStateTransporter(),
16+
}
17+
18+
19+
def is_svm(model):
20+
"""
21+
Check if the input model is a sklearn's svm model.
22+
23+
:param model: is a string name of a svm or a sklearn object of it
24+
:type model: any object
25+
:return: check result as bool
26+
"""
27+
if isinstance(model, str):
28+
return model in SKLEARN_SVM_TABLE
29+
else:
30+
return type(model) in SKLEARN_SVM_TABLE.values()
31+
32+
33+
def transport_svm(request, command):
34+
"""
35+
Return the transported (Serialized or Deserialized) model.
36+
37+
:param request: given svm to be transported
38+
:type request: any object
39+
:param command: command to specify whether the request should be serialized or deserialized
40+
:type command: transporter.Command
41+
:return: the transported request as a json string or sklearn svm model
42+
"""
43+
_validate_input(request, command)
44+
45+
if command == Command.SERIALIZE:
46+
try:
47+
return serialize_svm(request)
48+
except Exception as e:
49+
raise PymiloSerializationException(
50+
{
51+
'error_type': SerilaizatoinErrorTypes.VALID_MODEL_INVALID_INTERNAL_STRUCTURE,
52+
'error': {
53+
'Exception': repr(e),
54+
'Traceback': format_exc(),
55+
},
56+
'object': request,
57+
})
58+
59+
elif command == Command.DESERIALZIE:
60+
try:
61+
return deserialize_svm(request)
62+
except Exception as e:
63+
raise PymiloDeserializationException(
64+
{
65+
'error_type': SerilaizatoinErrorTypes.VALID_MODEL_INVALID_INTERNAL_STRUCTURE,
66+
'error': {
67+
'Exception': repr(e),
68+
'Traceback': format_exc()},
69+
'object': request})
70+
71+
72+
def serialize_svm(svm_object):
73+
"""
74+
Return the serialized json string of the given svm model.
75+
76+
:param svm_object: given model to be get serialized
77+
:type svm_object: any sklearn svm model
78+
:return: the serialized json string of the given svm
79+
"""
80+
for transporter in SVM_CHAIN:
81+
SVM_CHAIN[transporter].transport(
82+
svm_object, Command.SERIALIZE)
83+
return svm_object.__dict__
84+
85+
86+
def deserialize_svm(svm):
87+
"""
88+
Return the associated sklearn svm model of the given svm.
89+
90+
:param svm: given json string of a svm model to get deserialized to associated sklearn svm model
91+
:type svm: obj
92+
:return: associated sklearn svm model
93+
"""
94+
raw_model = SKLEARN_SVM_TABLE[svm.type]()
95+
data = svm.data
96+
97+
for transporter in SVM_CHAIN:
98+
SVM_CHAIN[transporter].transport(
99+
svm, Command.DESERIALZIE)
100+
for item in data:
101+
setattr(raw_model, item, data[item])
102+
return raw_model
103+
104+
105+
def _validate_input(model, command):
106+
"""
107+
Check if the provided inputs are valid in relation to each other.
108+
109+
:param model: a sklearn svm model or a json string of it, serialized through the pymilo export.
110+
:type model: obj
111+
:param command: command to specify whether the request should be serialized or deserialized
112+
:type command: transporter.Command
113+
:return: None
114+
"""
115+
if command == Command.SERIALIZE:
116+
if is_svm(model):
117+
return
118+
else:
119+
raise PymiloSerializationException(
120+
{
121+
'error_type': SerilaizatoinErrorTypes.INVALID_MODEL,
122+
'object': model
123+
}
124+
)
125+
elif command == Command.DESERIALZIE:
126+
if is_svm(model.type):
127+
return
128+
else:
129+
raise PymiloDeserializationException(
130+
{
131+
'error_type': DeSerilaizatoinErrorTypes.INVALID_MODEL,
132+
'object': model
133+
}
134+
)

pymilo/pymilo_func.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .chains.decision_tree_chain import transport_decision_tree, is_decision_tree
99
from .chains.clustering_chain import transport_clusterer, is_clusterer
1010
from .chains.naive_bayes_chain import transport_naive_bayes, is_naive_bayes
11+
from .chains.svm_chain import transport_svm, is_svm
1112

1213
from .transporters.transporter import Command
1314

@@ -39,6 +40,8 @@ def get_sklearn_data(model):
3940
return transport_clusterer(model, Command.SERIALIZE)
4041
elif is_naive_bayes(model):
4142
return transport_naive_bayes(model, Command.SERIALIZE)
43+
elif is_svm(model):
44+
return transport_svm(model, Command.SERIALIZE)
4245
else:
4346
return None
4447

@@ -61,6 +64,8 @@ def to_sklearn_model(import_obj):
6164
return transport_clusterer(import_obj, Command.DESERIALZIE)
6265
elif is_naive_bayes(import_obj.type):
6366
return transport_naive_bayes(import_obj, Command.DESERIALZIE)
67+
elif is_svm(import_obj.type):
68+
return transport_svm(import_obj, Command.DESERIALZIE)
6469
else:
6570
return None
6671

pymilo/pymilo_obj.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
"""PyMilo modules."""
33
from .pymilo_func import get_sklearn_data, get_sklearn_version, to_sklearn_model
44
from .utils.util import get_sklearn_type
5-
from .pymilo_param import PYMILO_VERSION
5+
from .pymilo_param import PYMILO_VERSION, PYMILO_VERSION_DOES_NOT_EXIST, UNEQUAL_PYMILO_VERSIONS
66
import json
77

8-
from pymilo.exceptions.deserialize_exception import PymiloDeserializationException, DeSerilaizatoinErrorTypes
9-
from pymilo.exceptions.serialize_exception import PymiloSerializationException, SerilaizatoinErrorTypes
8+
from .exceptions.deserialize_exception import PymiloDeserializationException, DeSerilaizatoinErrorTypes
9+
from .exceptions.serialize_exception import PymiloSerializationException, SerilaizatoinErrorTypes
1010
from traceback import format_exc
11+
from warnings import warn
1112

1213

1314
class Export:
@@ -99,6 +100,10 @@ def __init__(self, file_adr, json_dump=None):
99100
else:
100101
with open(file_adr, 'r') as fp:
101102
serialized_model_obj = json.load(fp)
103+
if "pymilo_version" not in serialized_model_obj:
104+
raise Exception(PYMILO_VERSION_DOES_NOT_EXIST)
105+
if not serialized_model_obj["pymilo_version"] == PYMILO_VERSION:
106+
warn(UNEQUAL_PYMILO_VERSIONS, category=Warning)
102107
self.data = serialized_model_obj["data"]
103108
self.version = serialized_model_obj["sklearn_version"]
104109
self.type = serialized_model_obj["model_type"]

pymilo/pymilo_param.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
11
# -*- coding: utf-8 -*-
22
"""Parameters and constants."""
3+
from sklearn.svm import SVR
4+
from sklearn.svm import SVC
5+
from sklearn.svm import OneClassSVM
6+
from sklearn.svm import NuSVR
7+
from sklearn.svm import NuSVC
8+
from sklearn.svm import LinearSVR
9+
from sklearn.svm import LinearSVC
10+
from sklearn.naive_bayes import CategoricalNB
11+
from sklearn.naive_bayes import BernoulliNB
12+
from sklearn.naive_bayes import ComplementNB
13+
from sklearn.naive_bayes import MultinomialNB
14+
from sklearn.naive_bayes import GaussianNB
315
from sklearn.preprocessing import LabelBinarizer
416
from numpy import uint8
17+
from numpy import intc
518
from numpy import inf
619
from numpy import float64
720
from numpy import int32
@@ -77,7 +90,8 @@
7790

7891
PYMILO_VERSION = "0.5"
7992
NOT_SUPPORTED = "NOT_SUPPORTED"
80-
93+
PYMILO_VERSION_DOES_NOT_EXIST = "Corrupted JSON file, `pymilo_version` doesn't exist in this file."
94+
UNEQUAL_PYMILO_VERSIONS = "warning: Installed Pymilo version differes from pymilo version used to create the JSON file."
8195

8296
glm_support = {
8397
'GammaRegressor': False,
@@ -110,12 +124,6 @@
110124
except BaseException:
111125
pass
112126

113-
from sklearn.naive_bayes import GaussianNB
114-
from sklearn.naive_bayes import MultinomialNB
115-
from sklearn.naive_bayes import ComplementNB
116-
from sklearn.naive_bayes import BernoulliNB
117-
from sklearn.naive_bayes import CategoricalNB
118-
119127

120128
SKLEARN_LINEAR_MODEL_TABLE = {
121129
"LinearRegression": LinearRegression,
@@ -195,6 +203,15 @@
195203
"CategoricalNB": CategoricalNB,
196204
}
197205

206+
SKLEARN_SVM_TABLE = {
207+
"LinearSVC": LinearSVC,
208+
"LinearSVR": LinearSVR,
209+
"NuSVC": NuSVC,
210+
"NuSVR": NuSVR,
211+
"OneClassSVM": OneClassSVM,
212+
"SVC": SVC,
213+
"SVR": SVR,
214+
}
198215

199216
KEYS_NEED_PREPROCESSING_BEFORE_DESERIALIZATION = {
200217
"_label_binarizer": LabelBinarizer, # in Ridge Classifier
@@ -207,6 +224,7 @@
207224
}
208225

209226
NUMPY_TYPE_DICT = {
227+
"numpy.intc": intc,
210228
"numpy.int32": int32,
211229
"numpy.int64": int64,
212230
"numpy.float64": float64,
@@ -220,4 +238,5 @@
220238
"DECISION_TREE": "exported_decision_trees",
221239
"CLUSTERING": "exported_clusterings",
222240
"NAIVE_BAYES": "exported_naive_bayes",
241+
"SVM": "exported_svms",
223242
}

pymilo/transporters/bisecting_tree_transporter.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from ..pymilo_param import SKLEARN_CLUSTERING_TABLE, NOT_SUPPORTED
55
from .transporter import AbstractTransporter
66
from .general_data_structure_transporter import GeneralDataStructureTransporter
7-
from ..utils.util import is_iterable
7+
from ..utils.util import check_str_in_iterable
88

99
bisecting_tree_support = SKLEARN_CLUSTERING_TABLE["BisectingKMeans"] != NOT_SUPPORTED
1010
if bisecting_tree_support:
@@ -114,7 +114,6 @@ def is_pymilo_serialized_bisecting_tree(psbt):
114114
:return: boolean
115115
"""
116116
return (
117-
is_iterable(psbt) and
118-
"pymilo_model_type" in psbt and
117+
check_str_in_iterable("pymilo_model_type", psbt) and
119118
psbt["pymilo_model_type"] == "_BisectingTree"
120119
)

0 commit comments

Comments
 (0)