Skip to content

Commit d72ef19

Browse files
committed
Fix missing h2o/viya scenario on model import
1 parent ff83a94 commit d72ef19

File tree

1 file changed

+76
-63
lines changed

1 file changed

+76
-63
lines changed

src/sasctl/pzmm/writeScoreCode.py

Lines changed: 76 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from ..core import platform_version
88
from .._services.model_repository import ModelRepository as modelRepo
99

10+
1011
# %%
1112
class ScoreCode():
12-
13+
1314
@classmethod
1415
def writeScoreCode(cls, inputDF, targetDF, modelPrefix,
1516
predictMethod, modelFileName,
@@ -18,26 +19,26 @@ def writeScoreCode(cls, inputDF, targetDF, modelPrefix,
1819
otherVariable=False, model=None, isH2OModel=False, missingValues=False,
1920
scoreCAS=True):
2021
'''
21-
Writes a Python score code file based on training data used to generate the model
22-
pickle file. The Python file is included in the ZIP file that is imported or registered
23-
into the common model repository. The model can then be used by SAS applications,
22+
Writes a Python score code file based on training data used to generate the model
23+
pickle file. The Python file is included in the ZIP file that is imported or registered
24+
into the common model repository. The model can then be used by SAS applications,
2425
such as SAS Open Model Manager.
25-
26+
2627
The score code that is generated is designed to be a working template for any
27-
Python model, but is not guaranteed to work out of the box for scoring, publishing,
28+
Python model, but is not guaranteed to work out of the box for scoring, publishing,
2829
or validating the model.
29-
30-
Note that for categorical variables, the variable is split into the possible
31-
categorical values of the variable. Also, by default it does NOT include a catch-all
32-
[catVar]_Other variable to store any missing values or any values not found in the
33-
training data set. If you have missing values or values not included in your training
30+
31+
Note that for categorical variables, the variable is split into the possible
32+
categorical values of the variable. Also, by default it does NOT include a catch-all
33+
[catVar]_Other variable to store any missing values or any values not found in the
34+
training data set. If you have missing values or values not included in your training
3435
data set, you must set the OtherVariable option to True.
35-
36+
3637
Both the inputDF and targetDF dataframes have the following stipulations:
3738
* Column names must be a valid Python variable name.
3839
* For categorical columns, the values must be a valid Python variable name.
3940
If either of these conditions is broken, an exception is raised.
40-
41+
4142
Parameters
4243
----------
4344
inputDF : DataFrame
@@ -48,11 +49,11 @@ def writeScoreCode(cls, inputDF, targetDF, modelPrefix,
4849
The `DataFrame` object contains the training data for the target variable.
4950
modelPrefix : string
5051
The variable for the model name that is used when naming model files.
51-
(For example: hmeqClassTree + [Score.py || .pickle]).
52+
(For example: hmeqClassTree + [Score.py || .pickle]).
5253
predictMethod : string
5354
User-defined prediction method for score testing. This should be
54-
in a form such that the model and data input can be added using
55-
the format() command.
55+
in a form such that the model and data input can be added using
56+
the format() command.
5657
For example: '{}.predict_proba({})'.
5758
modelFileName : string
5859
Name of the model file that contains the model.
@@ -92,10 +93,11 @@ def writeScoreCode(cls, inputDF, targetDF, modelPrefix,
9293
Python score code wrapped in DS2 and prepared for CAS scoring or publishing.
9394
'dmcas_packagescorecode.sas' (for SAS Viya 3.5 models)
9495
Python score code wrapped in DS2 and prepared for SAS Microanalyic Service scoring or publishing.
95-
'''
96+
'''
97+
9698
# Call REST API to check SAS Viya version
9799
isViya35 = (platform_version() == '3.5')
98-
100+
99101
# Initialize modelID to remove unbound variable warnings
100102
modelID = None
101103

@@ -118,76 +120,86 @@ def upload_and_copy_score_resources(model, files):
118120
else:
119121
model = modelRepo.get_model(model)
120122
modelID = model['id']
121-
123+
122124
# From the input dataframe columns, create a list of input variables, then check for viability
123125
inputVarList = list(inputDF.columns)
124126
for name in inputVarList:
125127
if not str(name).isidentifier():
126128
raise SyntaxError('Invalid column name in inputDF. Columns must be ' +
127129
'valid as Python variables.')
128130
newVarList = list(inputVarList)
129-
inputDtypesList = list(inputDF.dtypes)
130-
131+
inputDtypesList = list(inputDF.dtypes)
132+
131133
# Set the location for the Python score file to be written, then open the file
132134
zPath = Path(pyPath)
133135
pyPath = Path(pyPath) / (modelPrefix + 'Score.py')
134136
with open(pyPath, 'w') as cls.pyFile:
135-
137+
136138
# For H2O models, include the necessary packages
137139
if isH2OModel:
138140
cls.pyFile.write('''\
139141
import h2o
140142
import gzip, shutil, os''')
141-
# Import math for imputation; pickle for serialized models; pandas for data management; numpy for computation
143+
# Import math for imputation; pickle for serialized models; pandas for data management; numpy for
144+
# computation
142145
cls.pyFile.write('''\n
143146
import math
144147
import pickle
145148
import pandas as pd
146149
import numpy as np''')
147-
# In SAS Viya 4.0 and SAS Open Model Manager, a settings.py file is generated that points to the resource location
150+
# In SAS Viya 4.0 and SAS Open Model Manager, a settings.py file is generated that points to the resource
151+
# location
148152
if not isViya35:
149153
cls.pyFile.write('''\n
150154
import settings''')
151-
155+
152156
# Use a global variable for the model in order to load from memory only once
153157
cls.pyFile.write('''\n\n
154158
global _thisModelFit''')
155-
159+
156160
# For H2O models, include the server initialization, or h2o.connect() call to use an H2O server
157161
if isH2OModel:
158162
cls.pyFile.write('''\n
159163
h2o.init()''')
160164

161165
# For each case of SAS Viya version and H2O model or not, load the model file as variable _thisModelFit
162-
if isViya35 and not isH2OModel:
166+
if isViya35 and isH2OModel:
163167
cls.pyFile.write('''\n
164-
with gzip.open('/models/resources/viya/{modelID}/{modelFileName}', 'r') as fileIn, open('/models/resources/viya/{modelID}/{modelZipFileName}', 'wb') as fileOut:
168+
with gzip.open('/models/resources/viya/{modelID}/{modelFileName}', 'r') as fileIn, open('/models/resources/viya/{
169+
modelID}/{modelZipFileName}', 'wb') as fileOut:
165170
shutil.copyfileobj(fileIn, fileOut)
166171
os.chmod('/models/resources/viya/{modelID}/{modelZipFileName}', 0o777)
167172
_thisModelFit = h2o.import_mojo('/models/resources/viya/{modelID}/{modelZipFileName}')'''.format(
168173
modelID=modelID,
169174
modelFileName=modelFileName,
170175
modelZipFileName=modelFileName[:-4] + 'zip'
171176
))
177+
elif isViya35 and not isH2OModel:
178+
cls.pyFile.write('''\n
179+
with open('/models/resources/viya/{modelID}/{modelFileName}', 'rb') as _pFile:
180+
_thisModelFit = pickle.load(_pfile)'''.format(modelID=modelID, modelFileName=modelFileName))
172181
elif not isViya35 and not isH2OModel:
173182
cls.pyFile.write('''\n
174183
with open(settings.pickle_path + '{modelFileName}', 'rb') as _pFile:
175184
_thisModelFit = pickle.load(_pFile)'''.format(modelFileName=modelFileName))
176185
elif not isViya35 and isH2OModel:
177186
cls.pyFile.write('''\n
178-
with gzip.open(settings.pickle_path + '{modelFileName}', 'r') as fileIn, open(settings.pickle_path + '{modelZipFileName}', 'wb') as fileOut:
187+
with gzip.open(settings.pickle_path + '{modelFileName}', 'r') as fileIn, open(settings.pickle_path + '{
188+
modelZipFileName}', 'wb') as fileOut:
179189
shutil.copyfileobj(fileIn, fileOut)
180190
os.chmod(settings.pickle_path + '{modelZipFileName}', 0o777)
181191
_thisModelFit = h2o.import_mojo(settings.pickle_path + '{modelZipFileName}')'''.format(modelFileName=modelFileName,
182-
modelZipFileName=modelFileName[:-4] + 'zip'
183-
))
184-
# Create the score function with variables from the input dataframe provided and create the output variable line for SAS Model Manager
192+
modelZipFileName=modelFileName[
193+
:-4] + 'zip'))
194+
# Create the score function with variables from the input dataframe provided and create the output
195+
# variable line for SAS Model Manager
185196
cls.pyFile.write('''\n
186197
def score{modelPrefix}({inputVarList}):
187198
"Output: {metrics}"'''.format(modelPrefix=modelPrefix,
188-
inputVarList=', '.join(inputVarList),
189-
metrics=', '.join(metrics)))
190-
# As a check for missing model variables, run a try/except block that reattempts to load the model in as a variable
199+
inputVarList=', '.join(inputVarList),
200+
metrics=', '.join(metrics)))
201+
# As a check for missing model variables, run a try/except block that reattempts to load the model in as
202+
# a variable
191203
cls.pyFile.write('''\n
192204
try:
193205
_thisModelFit
@@ -209,7 +221,7 @@ def score{modelPrefix}({inputVarList}):
209221
elif not isViya35 and isH2OModel:
210222
cls.pyFile.write('''
211223
_thisModelFit = h2o.import_mojo(settings.pickle_path + '{}')'''.format(modelFileName[:-4] + 'zip'))
212-
224+
213225
if missingValues:
214226
# For each input variable, impute for missing values based on variable dtype
215227
for i, dTypes in enumerate(inputDtypesList):
@@ -222,7 +234,7 @@ def score{modelPrefix}({inputVarList}):
222234
{inputVar} = {inputVarMode}
223235
except TypeError:
224236
{inputVar} = {inputVarMode}'''.format(inputVar=inputVarList[i],
225-
inputVarMode=float(list(inputDF[inputVarList[i]].mode())[0])))
237+
inputVarMode=float(list(inputDF[inputVarList[i]].mode())[0])))
226238
else:
227239
cls.pyFile.write('''\n
228240
try:
@@ -239,10 +251,10 @@ def score{modelPrefix}({inputVarList}):
239251
categoryStr = 'Other'\n'''.format(inputVar=inputVarList[i]))
240252

241253
tempVar = cls.splitStringColumn(inputDF[inputVarList[i]],
242-
otherVariable)
254+
otherVariable)
243255
newVarList.remove(inputVarList[i])
244256
newVarList.extend(tempVar)
245-
257+
246258
# For non-H2O models, insert the model into the provided predictMethod call
247259
if not isH2OModel:
248260
predictMethod = predictMethod.format('_thisModelFit', 'inputArray')
@@ -301,10 +313,10 @@ def score{modelPrefix}({inputVarList}):
301313
cls.pyFile.write('''\n
302314
{} = float(prediction[1][2])
303315
{} = prediction[1][0]'''.format(metrics[0], metrics[1]))
304-
316+
305317
cls.pyFile.write('''\n
306318
return({}, {})'''.format(metrics[0], metrics[1]))
307-
319+
308320
# For SAS Viya 3.5, the model is first registered to SAS Model Manager, then the model UUID can be
309321
# added to the score code and reuploaded to the model file contents
310322
if isViya35:
@@ -335,34 +347,34 @@ def score{modelPrefix}({inputVarList}):
335347
model = modelRepo.get_model(modelID)
336348
model['scoreCodeType'] = 'ds2MultiType'
337349
modelRepo.update_model(model)
338-
350+
339351
def splitStringColumn(cls, inputSeries, otherVariable):
340352
'''
341353
Splits a column of string values into a number of new variables equal
342354
to the number of unique values in the original column (excluding None
343355
values). It then writes to a file the statements that tokenize the newly
344356
defined variables.
345-
357+
346358
Here is an example: Given a series named strCol with values ['A', 'B', 'C',
347359
None, 'A', 'B', 'A', 'D'], designates the following new variables:
348360
strCol_A, strCol_B, strCol_D. It then writes the following to the file:
349361
strCol_A = np.where(val == 'A', 1.0, 0.0)
350362
strCol_B = np.where(val == 'B', 1.0, 0.0)
351363
strCol_D = np.where(val == 'D', 1.0, 0.0)
352-
364+
353365
Parameters
354366
---------------
355367
inputSeries : string series
356368
Series with the string dtype.
357369
cls.pyFile : file (class variable)
358370
Open python file to write into.
359-
371+
360372
Returns
361373
---------------
362374
newVarList : string list
363375
List of all new variable names split from unique values.
364376
'''
365-
377+
366378
uniqueValues = inputSeries.unique()
367379
uniqueValues = list(filter(None, uniqueValues))
368380
uniqueValues = [x for x in uniqueValues if str(x) != 'nan']
@@ -375,50 +387,50 @@ def splitStringColumn(cls, inputSeries, otherVariable):
375387
newVarList.append('{}_{}'.format(inputSeries.name, uniq))
376388
cls.pyFile.write('''
377389
{0} = np.where(categoryStr == '{1}', 1.0, 0.0)'''.format(newVarList[i], uniq))
378-
390+
379391
if ('Other' not in uniqueValues) and otherVariable:
380392
newVarList.append('{}_Other'.format(inputSeries.name))
381393
cls.pyFile.write('''
382394
{}_Other = np.where(categoryStr == 'Other', 1.0, 0.0)'''.format(inputSeries.name))
383-
395+
384396
return newVarList
385-
397+
386398
def checkIfBinary(inputSeries):
387399
'''
388400
Checks a pandas series to determine whether the values are binary or nominal.
389-
401+
390402
Parameters
391403
---------------
392404
inputSeries : float or int series
393405
A series with numeric values.
394-
406+
395407
Returns
396408
---------------
397409
isBinary : boolean
398410
The returned value is True if the series values are binary, and False if the series values
399411
are nominal.
400412
'''
401-
413+
402414
isBinary = False
403415
binaryFloat = [float(1), float(0)]
404-
416+
405417
if inputSeries.value_counts().size == 2:
406-
if (binaryFloat[0] in inputSeries.astype('float') and
407-
binaryFloat[1] in inputSeries.astype('float')):
418+
if (binaryFloat[0] in inputSeries.astype('float') and
419+
binaryFloat[1] in inputSeries.astype('float')):
408420
isBinary = False
409421
else:
410422
isBinary = True
411-
423+
412424
return isBinary
413-
425+
414426
def convertMAStoCAS(MASCode, modelId):
415-
'''Using the generated score.sas code from the Python wrapper API,
427+
'''Using the generated score.sas code from the Python wrapper API,
416428
convert the SAS Microanalytic Service based code to CAS compatible.
417429
418430
Parameters
419431
----------
420432
MASCode : str
421-
String representation of the packagescore.sas DS2 wrapper
433+
String representation of the packagescore.sas DS2 wrapper
422434
modelId : str or dict
423435
The name or id of the model, or a dictionary representation of
424436
the model
@@ -436,16 +448,17 @@ def convertMAStoCAS(MASCode, modelId):
436448
outputString = outputString + 'varchar(100) '
437449
else:
438450
outputString = outputString + 'double '
439-
outputString = outputString + outVar['name'] + ';\n'
451+
outputString = outputString + outVar['name'] + ';\n'
440452
start = MASCode.find('score(')
441453
finish = MASCode[start:].find(');')
442-
scoreVars = MASCode[start+6:start+finish]
443-
inputString = ' '.join([x for x in scoreVars.split(' ') if (x != 'double' and x != 'in_out' and x != 'varchar(100)')])
454+
scoreVars = MASCode[start + 6:start + finish]
455+
inputString = ' '.join(
456+
[x for x in scoreVars.split(' ') if (x != 'double' and x != 'in_out' and x != 'varchar(100)')])
444457
endBlock = 'method run();\n set SASEP.IN;\n score({});\nend;\nenddata;'.format(inputString)
445458
replaceStrings = {'package pythonScore / overwrite=yes;': 'data sasep.out;',
446459
'dcl int resultCode revision;': 'dcl double resultCode revision;\n' + outputString,
447460
'endpackage;': endBlock}
448461
replaceStrings = dict((re.escape(k), v) for k, v in replaceStrings.items())
449462
pattern = re.compile('|'.join(replaceStrings.keys()))
450463
casCode = pattern.sub(lambda m: replaceStrings[re.escape(m.group(0))], MASCode)
451-
return casCode
464+
return casCode

0 commit comments

Comments
 (0)