5
5
# SPDX-License-Identifier: Apache-2.0
6
6
7
7
import copy
8
- from unittest import mock
9
8
10
9
import pytest
10
+ from six .moves import mock
11
11
from sasctl import current_session
12
12
from sasctl .services import model_repository as mr
13
13
@@ -22,56 +22,67 @@ def test_create_model():
22
22
with mock .patch ('sasctl.core.requests.Session.request' ):
23
23
current_session ('example.com' , USER , 'password' )
24
24
25
- TARGET = {'name' : MODEL_NAME ,
26
- 'projectId' : PROJECT_ID ,
27
- 'modeler' : USER ,
28
- 'description' : 'model description' ,
29
- 'function' : 'Classification' ,
30
- 'algorithm' : 'Dummy Algorithm' ,
31
- 'tool' : 'pytest' ,
32
- 'champion' : True ,
33
- 'role' : 'champion' ,
34
- 'immutable' : True ,
35
- 'retrainable' : True ,
36
- 'scoreCodeType' : None ,
37
- 'targetVariable' : None ,
38
- 'trainTable' : None ,
39
- 'classificationEventProbabilityVariableName' : None ,
40
- 'classificationTargetEventValue' : None ,
41
- 'location' : None ,
42
- 'properties' : [{'name' : 'custom1' , 'value' : 123 },
43
- {'name' : 'custom2' , 'value' : 'somevalue' }],
44
- 'inputVariables' : [],
45
- 'outputVariables' : [],
46
- 'version' : '2' }
25
+ TARGET = {
26
+ 'name' : MODEL_NAME ,
27
+ 'projectId' : PROJECT_ID ,
28
+ 'modeler' : USER ,
29
+ 'description' : 'model description' ,
30
+ 'function' : 'Classification' ,
31
+ 'algorithm' : 'Dummy Algorithm' ,
32
+ 'tool' : 'pytest' ,
33
+ 'champion' : True ,
34
+ 'role' : 'champion' ,
35
+ 'immutable' : True ,
36
+ 'retrainable' : True ,
37
+ 'scoreCodeType' : None ,
38
+ 'targetVariable' : None ,
39
+ 'trainTable' : None ,
40
+ 'classificationEventProbabilityVariableName' : None ,
41
+ 'classificationTargetEventValue' : None ,
42
+ 'location' : None ,
43
+ 'properties' : [
44
+ {'name' : 'custom1' , 'value' : 123 },
45
+ {'name' : 'custom2' , 'value' : 'somevalue' },
46
+ ],
47
+ 'inputVariables' : [],
48
+ 'outputVariables' : [],
49
+ 'version' : '2' ,
50
+ }
47
51
48
52
# Passed params should be set correctly
49
53
target = copy .deepcopy (TARGET )
50
- with mock .patch ('sasctl._services.model_repository.ModelRepository.get_project' ) as get_project :
51
- with mock .patch ('sasctl._services.model_repository.ModelRepository' '.get_model' ) as get_model :
52
- with mock .patch ('sasctl._services.model_repository.ModelRepository.post' ) as post :
54
+ with mock .patch (
55
+ 'sasctl._services.model_repository.ModelRepository.get_project'
56
+ ) as get_project :
57
+ with mock .patch (
58
+ 'sasctl._services.model_repository.ModelRepository' '.get_model'
59
+ ) as get_model :
60
+ with mock .patch (
61
+ 'sasctl._services.model_repository.ModelRepository.post'
62
+ ) as post :
53
63
get_project .return_value = {'id' : PROJECT_ID }
54
64
get_model .return_value = None
55
- _ = mr .create_model (MODEL_NAME ,
56
- PROJECT_NAME ,
57
- description = target ['description' ],
58
- function = target ['function' ],
59
- algorithm = target ['algorithm' ],
60
- tool = target ['tool' ],
61
- is_champion = True ,
62
- is_immutable = True ,
63
- is_retrainable = True ,
64
- properties = dict (custom1 = 123 , custom2 = 'somevalue' ))
65
+ _ = mr .create_model (
66
+ MODEL_NAME ,
67
+ PROJECT_NAME ,
68
+ description = target ['description' ],
69
+ function = target ['function' ],
70
+ algorithm = target ['algorithm' ],
71
+ tool = target ['tool' ],
72
+ is_champion = True ,
73
+ is_immutable = True ,
74
+ is_retrainable = True ,
75
+ properties = dict (custom1 = 123 , custom2 = 'somevalue' ),
76
+ )
65
77
assert post .call_count == 1
66
78
url , data = post .call_args
67
79
68
80
# dict isn't guaranteed to preserve order
69
81
# so k/v pairs of properties=dict() may be
70
82
# returned in a different order
71
- assert sorted (target ['properties' ],
72
- key = lambda d : d ['name' ]) \
73
- == sorted (data ['json' ]['properties' ],
74
- key = lambda d : d ['name' ])
83
+ assert sorted (target ['properties' ], key = lambda d : d ['name' ]) == sorted (
84
+ data ['json' ]['properties' ], key = lambda d : d ['name' ]
85
+ )
75
86
76
87
target .pop ('properties' )
77
88
data ['json' ].pop ('properties' )
@@ -80,12 +91,20 @@ def test_create_model():
80
91
# Model dict w/ parameters already specified should be allowed
81
92
# Explicit overrides should be respected.
82
93
target = copy .deepcopy (TARGET )
83
- with mock .patch ('sasctl._services.model_repository.ModelRepository.get_project' ) as get_project :
84
- with mock .patch ('sasctl._services.model_repository.ModelRepository' '.get_model' ) as get_model :
85
- with mock .patch ('sasctl._services.model_repository.ModelRepository.post' ) as post :
94
+ with mock .patch (
95
+ 'sasctl._services.model_repository.ModelRepository.get_project'
96
+ ) as get_project :
97
+ with mock .patch (
98
+ 'sasctl._services.model_repository.ModelRepository' '.get_model'
99
+ ) as get_model :
100
+ with mock .patch (
101
+ 'sasctl._services.model_repository.ModelRepository.post'
102
+ ) as post :
86
103
get_project .return_value = {'id' : PROJECT_ID }
87
104
get_model .return_value = None
88
- _ = mr .create_model (copy .deepcopy (target ), PROJECT_NAME , description = 'Updated Model' )
105
+ _ = mr .create_model (
106
+ copy .deepcopy (target ), PROJECT_NAME , description = 'Updated Model'
107
+ )
89
108
target ['description' ] = 'Updated Model'
90
109
assert post .call_count == 1
91
110
url , data = post .call_args
@@ -104,10 +123,12 @@ def test_copy_analytic_store():
104
123
105
124
MODEL_ID = 12345
106
125
# Intercept calls to lookup the model & call the "copyAnalyticStore" link
107
- with mock .patch ('sasctl._services.model_repository.ModelRepository'
108
- '.get_model' ) as get_model :
109
- with mock .patch ('sasctl._services.model_repository.ModelRepository'
110
- '.request_link' ) as request_link :
126
+ with mock .patch (
127
+ 'sasctl._services.model_repository.ModelRepository' '.get_model'
128
+ ) as get_model :
129
+ with mock .patch (
130
+ 'sasctl._services.model_repository.ModelRepository' '.request_link'
131
+ ) as request_link :
111
132
112
133
# Return a dummy Model with a static id
113
134
get_model .return_value = {'id' : MODEL_ID }
@@ -138,19 +159,17 @@ def test_get_model_by_name():
138
159
139
160
mock_responses = [
140
161
# First response is for list_items/list_models
141
- [
142
- {'id' : 12345 , 'name' : MODEL_NAME },
143
- {'id' : 67890 , 'name' : MODEL_NAME }
144
- ],
145
-
162
+ [{'id' : 12345 , 'name' : MODEL_NAME }, {'id' : 67890 , 'name' : MODEL_NAME }],
146
163
# Second response is mock GET for model details
147
- {'id' : 12345 , 'name' : MODEL_NAME }
164
+ {'id' : 12345 , 'name' : MODEL_NAME },
148
165
]
149
166
150
- with mock .patch ('sasctl._services.model_repository.ModelRepository.request' ) as request :
167
+ with mock .patch (
168
+ 'sasctl._services.model_repository.ModelRepository.request'
169
+ ) as request :
151
170
request .side_effect = mock_responses
152
171
153
172
with pytest .warns (Warning ):
154
173
result = mr .get_model (MODEL_NAME )
155
- assert result ['id' ]== 12345
156
- assert result ['name' ] == MODEL_NAME
174
+ assert result ['id' ] == 12345
175
+ assert result ['name' ] == MODEL_NAME
0 commit comments