|
71 | 71 | " # Pull required modules from this example\n",
|
72 | 72 | " !git clone -b main https://github.com/zenml-io/zenml\n",
|
73 | 73 | " !cp -r zenml/examples/quickstart/* .\n",
|
74 |
| - " !rm -rf zenml\n" |
| 74 | + " !rm -rf zenml" |
75 | 75 | ]
|
76 | 76 | },
|
77 | 77 | {
|
|
84 | 84 | "!zenml integration install sklearn -y\n",
|
85 | 85 | "\n",
|
86 | 86 | "import IPython\n",
|
| 87 | + "\n", |
87 | 88 | "IPython.Application.instance().kernel.do_shutdown(restart=True)"
|
88 | 89 | ]
|
89 | 90 | },
|
|
145 | 146 | "outputs": [],
|
146 | 147 | "source": [
|
147 | 148 | "# Do the imports at the top\n",
|
148 |
| - "from typing_extensions import Annotated\n", |
149 |
| - "from sklearn.datasets import load_breast_cancer\n", |
150 |
| - "\n", |
151 | 149 | "import random\n",
|
152 |
| - "import pandas as pd\n", |
153 |
| - "from zenml import step, pipeline, Model, get_step_context\n", |
154 |
| - "from zenml.client import Client\n", |
155 |
| - "from zenml.logger import get_logger\n", |
| 150 | + "from typing import List, Optional\n", |
156 | 151 | "from uuid import UUID\n",
|
157 | 152 | "\n",
|
158 |
| - "from typing import Optional, List\n", |
159 |
| - "\n", |
160 |
| - "from zenml import pipeline\n", |
161 |
| - "\n", |
| 153 | + "import pandas as pd\n", |
| 154 | + "from sklearn.datasets import load_breast_cancer\n", |
162 | 155 | "from steps import (\n",
|
163 | 156 | " data_loader,\n",
|
164 | 157 | " data_preprocessor,\n",
|
165 | 158 | " data_splitter,\n",
|
| 159 | + " inference_preprocessor,\n", |
166 | 160 | " model_evaluator,\n",
|
167 |
| - " inference_preprocessor\n", |
168 | 161 | ")\n",
|
169 |
| - "\n", |
| 162 | + "from typing_extensions import Annotated\n", |
| 163 | + "from zenml import Model, get_step_context, pipeline, step\n", |
| 164 | + "from zenml.client import Client\n", |
170 | 165 | "from zenml.logger import get_logger\n",
|
171 | 166 | "\n",
|
172 | 167 | "logger = get_logger(__name__)\n",
|
|
205 | 200 | "@step\n",
|
206 | 201 | "def data_loader_simplified(\n",
|
207 | 202 | " random_state: int, is_inference: bool = False, target: str = \"target\"\n",
|
208 |
| - ") -> Annotated[pd.DataFrame, \"dataset\"]: # We name the dataset \n", |
| 203 | + ") -> Annotated[pd.DataFrame, \"dataset\"]: # We name the dataset\n", |
209 | 204 | " \"\"\"Dataset reader step.\"\"\"\n",
|
210 | 205 | " dataset = load_breast_cancer(as_frame=True)\n",
|
211 | 206 | " inference_size = int(len(dataset.target) * 0.05)\n",
|
|
218 | 213 | " dataset.drop(inference_subset.index, inplace=True)\n",
|
219 | 214 | " dataset.reset_index(drop=True, inplace=True)\n",
|
220 | 215 | " logger.info(f\"Dataset with {len(dataset)} records loaded!\")\n",
|
221 |
| - " return dataset\n" |
| 216 | + " return dataset" |
222 | 217 | ]
|
223 | 218 | },
|
224 | 219 | {
|
|
291 | 286 | " normalize: Optional[bool] = None,\n",
|
292 | 287 | " drop_columns: Optional[List[str]] = None,\n",
|
293 | 288 | " target: Optional[str] = \"target\",\n",
|
294 |
| - " random_state: int = 17\n", |
| 289 | + " random_state: int = 17,\n", |
295 | 290 | "):\n",
|
296 | 291 | " \"\"\"Feature engineering pipeline.\"\"\"\n",
|
297 | 292 | " # Link all the steps together by calling them and passing the output\n",
|
|
402 | 397 | "from zenml.environment import Environment\n",
|
403 | 398 | "from zenml.zen_stores.rest_zen_store import RestZenStore\n",
|
404 | 399 | "\n",
|
405 |
| - "\n", |
406 | 400 | "if not isinstance(client.zen_store, RestZenStore):\n",
|
407 | 401 | " # Only spin up a local Dashboard in case you aren't already connected to a remote server\n",
|
408 | 402 | " if Environment.in_google_colab():\n",
|
|
479 | 473 | "outputs": [],
|
480 | 474 | "source": [
|
481 | 475 | "# Get artifact version from our run\n",
|
482 |
| - "dataset_trn_artifact_version_via_run = run.steps[\"data_preprocessor\"].outputs[\"dataset_trn\"] \n", |
| 476 | + "dataset_trn_artifact_version_via_run = run.steps[\"data_preprocessor\"].outputs[\n", |
| 477 | + " \"dataset_trn\"\n", |
| 478 | + "]\n", |
483 | 479 | "\n",
|
484 | 480 | "# Get latest version from client directly\n",
|
485 | 481 | "dataset_trn_artifact_version = client.get_artifact_version(\"dataset_trn\")\n",
|
|
498 | 494 | "source": [
|
499 | 495 | "# Fetch the rest of the artifacts\n",
|
500 | 496 | "dataset_tst_artifact_version = client.get_artifact_version(\"dataset_tst\")\n",
|
501 |
| - "preprocessing_pipeline_artifact_version = client.get_artifact_version(\"preprocess_pipeline\")" |
| 497 | + "preprocessing_pipeline_artifact_version = client.get_artifact_version(\n", |
| 498 | + " \"preprocess_pipeline\"\n", |
| 499 | + ")" |
502 | 500 | ]
|
503 | 501 | },
|
504 | 502 | {
|
|
576 | 574 | "def model_trainer(\n",
|
577 | 575 | " dataset_trn: pd.DataFrame,\n",
|
578 | 576 | " model_type: str = \"sgd\",\n",
|
579 |
| - ") -> Annotated[ClassifierMixin, ArtifactConfig(name=\"sklearn_classifier\", is_model_artifact=True)]:\n", |
| 577 | + ") -> Annotated[\n", |
| 578 | + " ClassifierMixin, ArtifactConfig(name=\"sklearn_classifier\", is_model_artifact=True)\n", |
| 579 | + "]:\n", |
580 | 580 | " \"\"\"Configure and train a model on the training dataset.\"\"\"\n",
|
581 | 581 | " target = \"target\"\n",
|
582 | 582 | " if model_type == \"sgd\":\n",
|
583 | 583 | " model = SGDClassifier()\n",
|
584 | 584 | " elif model_type == \"rf\":\n",
|
585 | 585 | " model = RandomForestClassifier()\n",
|
586 | 586 | " else:\n",
|
587 |
| - " raise ValueError(f\"Unknown model type {model_type}\") \n", |
| 587 | + " raise ValueError(f\"Unknown model type {model_type}\")\n", |
588 | 588 | "\n",
|
589 | 589 | " logger.info(f\"Training model {model}...\")\n",
|
590 | 590 | "\n",
|
591 | 591 | " model.fit(\n",
|
592 | 592 | " dataset_trn.drop(columns=[target]),\n",
|
593 | 593 | " dataset_trn[target],\n",
|
594 | 594 | " )\n",
|
595 |
| - " return model\n" |
| 595 | + " return model" |
596 | 596 | ]
|
597 | 597 | },
|
598 | 598 | {
|
|
630 | 630 | " min_train_accuracy: float = 0.0,\n",
|
631 | 631 | " min_test_accuracy: float = 0.0,\n",
|
632 | 632 | "):\n",
|
633 |
| - " \"\"\"Model training pipeline.\"\"\" \n", |
| 633 | + " \"\"\"Model training pipeline.\"\"\"\n", |
634 | 634 | " if train_dataset_id is None or test_dataset_id is None:\n",
|
635 |
| - " # If we dont pass the IDs, this will run the feature engineering pipeline \n", |
| 635 | + " # If we dont pass the IDs, this will run the feature engineering pipeline\n", |
636 | 636 | " dataset_trn, dataset_tst = feature_engineering()\n",
|
637 | 637 | " else:\n",
|
638 | 638 | " # Load the datasets from an older pipeline\n",
|
639 | 639 | " dataset_trn = client.get_artifact_version(name_id_or_prefix=train_dataset_id)\n",
|
640 |
| - " dataset_tst = client.get_artifact_version(name_id_or_prefix=test_dataset_id) \n", |
| 640 | + " dataset_tst = client.get_artifact_version(name_id_or_prefix=test_dataset_id)\n", |
641 | 641 | "\n",
|
642 | 642 | " trained_model = model_trainer(\n",
|
643 | 643 | " dataset_trn=dataset_trn,\n",
|
|
676 | 676 | "training(\n",
|
677 | 677 | " model_type=\"rf\",\n",
|
678 | 678 | " train_dataset_id=dataset_trn_artifact_version.id,\n",
|
679 |
| - " test_dataset_id=dataset_tst_artifact_version.id\n", |
| 679 | + " test_dataset_id=dataset_tst_artifact_version.id,\n", |
680 | 680 | ")\n",
|
681 | 681 | "\n",
|
682 | 682 | "rf_run = client.get_pipeline(\"training\").last_run"
|
|
693 | 693 | "sgd_run = training(\n",
|
694 | 694 | " model_type=\"sgd\",\n",
|
695 | 695 | " train_dataset_id=dataset_trn_artifact_version.id,\n",
|
696 |
| - " test_dataset_id=dataset_tst_artifact_version.id\n", |
| 696 | + " test_dataset_id=dataset_tst_artifact_version.id,\n", |
697 | 697 | ")\n",
|
698 | 698 | "\n",
|
699 | 699 | "sgd_run = client.get_pipeline(\"training\").last_run"
|
|
717 | 717 | "outputs": [],
|
718 | 718 | "source": [
|
719 | 719 | "# The evaluator returns a float value with the accuracy\n",
|
720 |
| - "rf_run.steps[\"model_evaluator\"].output.load() > sgd_run.steps[\"model_evaluator\"].output.load()" |
| 720 | + "rf_run.steps[\"model_evaluator\"].output.load() > sgd_run.steps[\n", |
| 721 | + " \"model_evaluator\"\n", |
| 722 | + "].output.load()" |
721 | 723 | ]
|
722 | 724 | },
|
723 | 725 | {
|
|
776 | 778 | "training_configured(\n",
|
777 | 779 | " model_type=\"sgd\",\n",
|
778 | 780 | " train_dataset_id=dataset_trn_artifact_version.id,\n",
|
779 |
| - " test_dataset_id=dataset_tst_artifact_version.id\n", |
| 781 | + " test_dataset_id=dataset_tst_artifact_version.id,\n", |
780 | 782 | ")"
|
781 | 783 | ]
|
782 | 784 | },
|
|
798 | 800 | "training_configured(\n",
|
799 | 801 | " model_type=\"rf\",\n",
|
800 | 802 | " train_dataset_id=dataset_trn_artifact_version.id,\n",
|
801 |
| - " test_dataset_id=dataset_tst_artifact_version.id\n", |
| 803 | + " test_dataset_id=dataset_tst_artifact_version.id,\n", |
802 | 804 | ")"
|
803 | 805 | ]
|
804 | 806 | },
|
|
848 | 850 | "rf_zenml_model_version = client.get_model_version(\"breast_cancer_classifier\", \"rf\")\n",
|
849 | 851 | "\n",
|
850 | 852 | "# We can now load our classifier directly as well\n",
|
851 |
| - "random_forest_classifier = rf_zenml_model_version.get_artifact(\"sklearn_classifier\").load()\n", |
| 853 | + "random_forest_classifier = rf_zenml_model_version.get_artifact(\n", |
| 854 | + " \"sklearn_classifier\"\n", |
| 855 | + ").load()\n", |
852 | 856 | "\n",
|
853 | 857 | "random_forest_classifier"
|
854 | 858 | ]
|
|
956 | 960 | "\n",
|
957 | 961 | " predictions = pd.Series(predictions, name=\"predicted\")\n",
|
958 | 962 | "\n",
|
959 |
| - " return predictions\n" |
| 963 | + " return predictions" |
960 | 964 | ]
|
961 | 965 | },
|
962 | 966 | {
|
|
983 | 987 | " random_state = 42\n",
|
984 | 988 | " target = \"target\"\n",
|
985 | 989 | "\n",
|
986 |
| - " df_inference = data_loader(\n", |
987 |
| - " random_state=random_state, is_inference=True\n", |
988 |
| - " )\n", |
| 990 | + " df_inference = data_loader(random_state=random_state, is_inference=True)\n", |
989 | 991 | " df_inference = inference_preprocessor(\n",
|
990 | 992 | " dataset_inf=df_inference,\n",
|
991 | 993 | " # We use the preprocess pipeline from the feature engineering pipeline\n",
|
992 |
| - " preprocess_pipeline=client.get_artifact_version(name_id_or_prefix=preprocess_pipeline_id),\n", |
| 994 | + " preprocess_pipeline=client.get_artifact_version(\n", |
| 995 | + " name_id_or_prefix=preprocess_pipeline_id\n", |
| 996 | + " ),\n", |
993 | 997 | " target=target,\n",
|
994 | 998 | " )\n",
|
995 | 999 | " inference_predict(\n",
|
996 | 1000 | " dataset_inf=df_inference,\n",
|
997 |
| - " )\n" |
| 1001 | + " )" |
998 | 1002 | ]
|
999 | 1003 | },
|
1000 | 1004 | {
|
|
1018 | 1022 | "# Lets add some metadata to the model to make it identifiable\n",
|
1019 | 1023 | "pipeline_settings[\"model\"] = Model(\n",
|
1020 | 1024 | " name=\"breast_cancer_classifier\",\n",
|
1021 |
| - " version=\"production\", # We can pass in the stage name here!\n", |
| 1025 | + " version=\"production\", # We can pass in the stage name here!\n", |
1022 | 1026 | " license=\"Apache 2.0\",\n",
|
1023 | 1027 | " description=\"A breast cancer classifier\",\n",
|
1024 | 1028 | " tags=[\"breast_cancer\", \"classifier\"],\n",
|
|
1039 | 1043 | "# Let's run it again to make sure we have two versions\n",
|
1040 | 1044 | "# We need to pass in the ID of the preprocessing done in the feature engineering pipeline\n",
|
1041 | 1045 | "# in order to avoid training-serving skew\n",
|
1042 |
| - "inference_configured(\n", |
1043 |
| - " preprocess_pipeline_id=preprocessing_pipeline_artifact_version.id\n", |
1044 |
| - ")" |
| 1046 | + "inference_configured(preprocess_pipeline_id=preprocessing_pipeline_artifact_version.id)" |
1045 | 1047 | ]
|
1046 | 1048 | },
|
1047 | 1049 | {
|
|
1061 | 1063 | "outputs": [],
|
1062 | 1064 | "source": [
|
1063 | 1065 | "# Fetch production model\n",
|
1064 |
| - "production_model_version = client.get_model_version(\"breast_cancer_classifier\", \"production\")\n", |
| 1066 | + "production_model_version = client.get_model_version(\n", |
| 1067 | + " \"breast_cancer_classifier\", \"production\"\n", |
| 1068 | + ")\n", |
1065 | 1069 | "\n",
|
1066 | 1070 | "# Get the predictions artifact\n",
|
1067 | 1071 | "production_model_version.get_artifact(\"predictions\").load()"
|
|
0 commit comments