Skip to content

Commit 947f8ff

Browse files
committed
linting
1 parent 2b64250 commit 947f8ff

File tree

1 file changed

+47
-43
lines changed

1 file changed

+47
-43
lines changed

template/quickstart.ipynb

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
" # Pull required modules from this example\n",
7272
" !git clone -b main https://github.com/zenml-io/zenml\n",
7373
" !cp -r zenml/examples/quickstart/* .\n",
74-
" !rm -rf zenml\n"
74+
" !rm -rf zenml"
7575
]
7676
},
7777
{
@@ -84,6 +84,7 @@
8484
"!zenml integration install sklearn -y\n",
8585
"\n",
8686
"import IPython\n",
87+
"\n",
8788
"IPython.Application.instance().kernel.do_shutdown(restart=True)"
8889
]
8990
},
@@ -145,28 +146,22 @@
145146
"outputs": [],
146147
"source": [
147148
"# Do the imports at the top\n",
148-
"from typing_extensions import Annotated\n",
149-
"from sklearn.datasets import load_breast_cancer\n",
150-
"\n",
151149
"import random\n",
152-
"import pandas as pd\n",
153-
"from zenml import step, pipeline, Model, get_step_context\n",
154-
"from zenml.client import Client\n",
155-
"from zenml.logger import get_logger\n",
150+
"from typing import List, Optional\n",
156151
"from uuid import UUID\n",
157152
"\n",
158-
"from typing import Optional, List\n",
159-
"\n",
160-
"from zenml import pipeline\n",
161-
"\n",
153+
"import pandas as pd\n",
154+
"from sklearn.datasets import load_breast_cancer\n",
162155
"from steps import (\n",
163156
" data_loader,\n",
164157
" data_preprocessor,\n",
165158
" data_splitter,\n",
159+
" inference_preprocessor,\n",
166160
" model_evaluator,\n",
167-
" inference_preprocessor\n",
168161
")\n",
169-
"\n",
162+
"from typing_extensions import Annotated\n",
163+
"from zenml import Model, get_step_context, pipeline, step\n",
164+
"from zenml.client import Client\n",
170165
"from zenml.logger import get_logger\n",
171166
"\n",
172167
"logger = get_logger(__name__)\n",
@@ -205,7 +200,7 @@
205200
"@step\n",
206201
"def data_loader_simplified(\n",
207202
" random_state: int, is_inference: bool = False, target: str = \"target\"\n",
208-
") -> Annotated[pd.DataFrame, \"dataset\"]: # We name the dataset \n",
203+
") -> Annotated[pd.DataFrame, \"dataset\"]: # We name the dataset\n",
209204
" \"\"\"Dataset reader step.\"\"\"\n",
210205
" dataset = load_breast_cancer(as_frame=True)\n",
211206
" inference_size = int(len(dataset.target) * 0.05)\n",
@@ -218,7 +213,7 @@
218213
" dataset.drop(inference_subset.index, inplace=True)\n",
219214
" dataset.reset_index(drop=True, inplace=True)\n",
220215
" logger.info(f\"Dataset with {len(dataset)} records loaded!\")\n",
221-
" return dataset\n"
216+
" return dataset"
222217
]
223218
},
224219
{
@@ -291,7 +286,7 @@
291286
" normalize: Optional[bool] = None,\n",
292287
" drop_columns: Optional[List[str]] = None,\n",
293288
" target: Optional[str] = \"target\",\n",
294-
" random_state: int = 17\n",
289+
" random_state: int = 17,\n",
295290
"):\n",
296291
" \"\"\"Feature engineering pipeline.\"\"\"\n",
297292
" # Link all the steps together by calling them and passing the output\n",
@@ -402,7 +397,6 @@
402397
"from zenml.environment import Environment\n",
403398
"from zenml.zen_stores.rest_zen_store import RestZenStore\n",
404399
"\n",
405-
"\n",
406400
"if not isinstance(client.zen_store, RestZenStore):\n",
407401
" # Only spin up a local Dashboard in case you aren't already connected to a remote server\n",
408402
" if Environment.in_google_colab():\n",
@@ -479,7 +473,9 @@
479473
"outputs": [],
480474
"source": [
481475
"# Get artifact version from our run\n",
482-
"dataset_trn_artifact_version_via_run = run.steps[\"data_preprocessor\"].outputs[\"dataset_trn\"] \n",
476+
"dataset_trn_artifact_version_via_run = run.steps[\"data_preprocessor\"].outputs[\n",
477+
" \"dataset_trn\"\n",
478+
"]\n",
483479
"\n",
484480
"# Get latest version from client directly\n",
485481
"dataset_trn_artifact_version = client.get_artifact_version(\"dataset_trn\")\n",
@@ -498,7 +494,9 @@
498494
"source": [
499495
"# Fetch the rest of the artifacts\n",
500496
"dataset_tst_artifact_version = client.get_artifact_version(\"dataset_tst\")\n",
501-
"preprocessing_pipeline_artifact_version = client.get_artifact_version(\"preprocess_pipeline\")"
497+
"preprocessing_pipeline_artifact_version = client.get_artifact_version(\n",
498+
" \"preprocess_pipeline\"\n",
499+
")"
502500
]
503501
},
504502
{
@@ -576,23 +574,25 @@
576574
"def model_trainer(\n",
577575
" dataset_trn: pd.DataFrame,\n",
578576
" model_type: str = \"sgd\",\n",
579-
") -> Annotated[ClassifierMixin, ArtifactConfig(name=\"sklearn_classifier\", is_model_artifact=True)]:\n",
577+
") -> Annotated[\n",
578+
" ClassifierMixin, ArtifactConfig(name=\"sklearn_classifier\", is_model_artifact=True)\n",
579+
"]:\n",
580580
" \"\"\"Configure and train a model on the training dataset.\"\"\"\n",
581581
" target = \"target\"\n",
582582
" if model_type == \"sgd\":\n",
583583
" model = SGDClassifier()\n",
584584
" elif model_type == \"rf\":\n",
585585
" model = RandomForestClassifier()\n",
586586
" else:\n",
587-
" raise ValueError(f\"Unknown model type {model_type}\") \n",
587+
" raise ValueError(f\"Unknown model type {model_type}\")\n",
588588
"\n",
589589
" logger.info(f\"Training model {model}...\")\n",
590590
"\n",
591591
" model.fit(\n",
592592
" dataset_trn.drop(columns=[target]),\n",
593593
" dataset_trn[target],\n",
594594
" )\n",
595-
" return model\n"
595+
" return model"
596596
]
597597
},
598598
{
@@ -630,14 +630,14 @@
630630
" min_train_accuracy: float = 0.0,\n",
631631
" min_test_accuracy: float = 0.0,\n",
632632
"):\n",
633-
" \"\"\"Model training pipeline.\"\"\" \n",
633+
" \"\"\"Model training pipeline.\"\"\"\n",
634634
" if train_dataset_id is None or test_dataset_id is None:\n",
635-
" # If we dont pass the IDs, this will run the feature engineering pipeline \n",
635+
" # If we dont pass the IDs, this will run the feature engineering pipeline\n",
636636
" dataset_trn, dataset_tst = feature_engineering()\n",
637637
" else:\n",
638638
" # Load the datasets from an older pipeline\n",
639639
" dataset_trn = client.get_artifact_version(name_id_or_prefix=train_dataset_id)\n",
640-
" dataset_tst = client.get_artifact_version(name_id_or_prefix=test_dataset_id) \n",
640+
" dataset_tst = client.get_artifact_version(name_id_or_prefix=test_dataset_id)\n",
641641
"\n",
642642
" trained_model = model_trainer(\n",
643643
" dataset_trn=dataset_trn,\n",
@@ -676,7 +676,7 @@
676676
"training(\n",
677677
" model_type=\"rf\",\n",
678678
" train_dataset_id=dataset_trn_artifact_version.id,\n",
679-
" test_dataset_id=dataset_tst_artifact_version.id\n",
679+
" test_dataset_id=dataset_tst_artifact_version.id,\n",
680680
")\n",
681681
"\n",
682682
"rf_run = client.get_pipeline(\"training\").last_run"
@@ -693,7 +693,7 @@
693693
"sgd_run = training(\n",
694694
" model_type=\"sgd\",\n",
695695
" train_dataset_id=dataset_trn_artifact_version.id,\n",
696-
" test_dataset_id=dataset_tst_artifact_version.id\n",
696+
" test_dataset_id=dataset_tst_artifact_version.id,\n",
697697
")\n",
698698
"\n",
699699
"sgd_run = client.get_pipeline(\"training\").last_run"
@@ -717,7 +717,9 @@
717717
"outputs": [],
718718
"source": [
719719
"# The evaluator returns a float value with the accuracy\n",
720-
"rf_run.steps[\"model_evaluator\"].output.load() > sgd_run.steps[\"model_evaluator\"].output.load()"
720+
"rf_run.steps[\"model_evaluator\"].output.load() > sgd_run.steps[\n",
721+
" \"model_evaluator\"\n",
722+
"].output.load()"
721723
]
722724
},
723725
{
@@ -776,7 +778,7 @@
776778
"training_configured(\n",
777779
" model_type=\"sgd\",\n",
778780
" train_dataset_id=dataset_trn_artifact_version.id,\n",
779-
" test_dataset_id=dataset_tst_artifact_version.id\n",
781+
" test_dataset_id=dataset_tst_artifact_version.id,\n",
780782
")"
781783
]
782784
},
@@ -798,7 +800,7 @@
798800
"training_configured(\n",
799801
" model_type=\"rf\",\n",
800802
" train_dataset_id=dataset_trn_artifact_version.id,\n",
801-
" test_dataset_id=dataset_tst_artifact_version.id\n",
803+
" test_dataset_id=dataset_tst_artifact_version.id,\n",
802804
")"
803805
]
804806
},
@@ -848,7 +850,9 @@
848850
"rf_zenml_model_version = client.get_model_version(\"breast_cancer_classifier\", \"rf\")\n",
849851
"\n",
850852
"# We can now load our classifier directly as well\n",
851-
"random_forest_classifier = rf_zenml_model_version.get_artifact(\"sklearn_classifier\").load()\n",
853+
"random_forest_classifier = rf_zenml_model_version.get_artifact(\n",
854+
" \"sklearn_classifier\"\n",
855+
").load()\n",
852856
"\n",
853857
"random_forest_classifier"
854858
]
@@ -956,7 +960,7 @@
956960
"\n",
957961
" predictions = pd.Series(predictions, name=\"predicted\")\n",
958962
"\n",
959-
" return predictions\n"
963+
" return predictions"
960964
]
961965
},
962966
{
@@ -983,18 +987,18 @@
983987
" random_state = 42\n",
984988
" target = \"target\"\n",
985989
"\n",
986-
" df_inference = data_loader(\n",
987-
" random_state=random_state, is_inference=True\n",
988-
" )\n",
990+
" df_inference = data_loader(random_state=random_state, is_inference=True)\n",
989991
" df_inference = inference_preprocessor(\n",
990992
" dataset_inf=df_inference,\n",
991993
" # We use the preprocess pipeline from the feature engineering pipeline\n",
992-
" preprocess_pipeline=client.get_artifact_version(name_id_or_prefix=preprocess_pipeline_id),\n",
994+
" preprocess_pipeline=client.get_artifact_version(\n",
995+
" name_id_or_prefix=preprocess_pipeline_id\n",
996+
" ),\n",
993997
" target=target,\n",
994998
" )\n",
995999
" inference_predict(\n",
9961000
" dataset_inf=df_inference,\n",
997-
" )\n"
1001+
" )"
9981002
]
9991003
},
10001004
{
@@ -1018,7 +1022,7 @@
10181022
"# Lets add some metadata to the model to make it identifiable\n",
10191023
"pipeline_settings[\"model\"] = Model(\n",
10201024
" name=\"breast_cancer_classifier\",\n",
1021-
" version=\"production\", # We can pass in the stage name here!\n",
1025+
" version=\"production\", # We can pass in the stage name here!\n",
10221026
" license=\"Apache 2.0\",\n",
10231027
" description=\"A breast cancer classifier\",\n",
10241028
" tags=[\"breast_cancer\", \"classifier\"],\n",
@@ -1039,9 +1043,7 @@
10391043
"# Let's run it again to make sure we have two versions\n",
10401044
"# We need to pass in the ID of the preprocessing done in the feature engineering pipeline\n",
10411045
"# in order to avoid training-serving skew\n",
1042-
"inference_configured(\n",
1043-
" preprocess_pipeline_id=preprocessing_pipeline_artifact_version.id\n",
1044-
")"
1046+
"inference_configured(preprocess_pipeline_id=preprocessing_pipeline_artifact_version.id)"
10451047
]
10461048
},
10471049
{
@@ -1061,7 +1063,9 @@
10611063
"outputs": [],
10621064
"source": [
10631065
"# Fetch production model\n",
1064-
"production_model_version = client.get_model_version(\"breast_cancer_classifier\", \"production\")\n",
1066+
"production_model_version = client.get_model_version(\n",
1067+
" \"breast_cancer_classifier\", \"production\"\n",
1068+
")\n",
10651069
"\n",
10661070
"# Get the predictions artifact\n",
10671071
"production_model_version.get_artifact(\"predictions\").load()"

0 commit comments

Comments
 (0)