From e3050d744bf1cf979500ed4bd57b44da68ba35d6 Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja Date: Thu, 17 Jul 2025 16:33:03 -0700 Subject: [PATCH 01/13] document how to use a GCP service account in Airflow with SkyPilot --- examples/airflow/README.md | 115 ++++++++--- .../airflow/data_preprocessing_gcp_sa.yaml | 33 +++ examples/airflow/sky_train_dag.py | 190 ++++++++++-------- 3 files changed, 228 insertions(+), 110 deletions(-) create mode 100644 examples/airflow/data_preprocessing_gcp_sa.yaml diff --git a/examples/airflow/README.md b/examples/airflow/README.md index b2bc1948409..a59d92e934c 100644 --- a/examples/airflow/README.md +++ b/examples/airflow/README.md @@ -2,7 +2,7 @@ In this guide, we show how a training workflow involving data preprocessing, training and evaluation can be first easily developed with SkyPilot, and then orchestrated in Airflow. -This example uses a remote SkyPilot API Server to manage shared state across invocations, and includes a failure callback to tear down the SkyPilot cluster on task failure. +This example uses a remote SkyPilot API Server to manage shared state across invocations. @@ -11,17 +11,17 @@ This example uses a remote SkyPilot API Server to manage shared state across inv

-**💡 Tip:** SkyPilot also supports defining and running pipelines without Airflow. Check out [Jobs Pipelines](https://skypilot.readthedocs.io/en/latest/examples/managed-jobs.html#job-pipelines) for more information. +**💡 Tip:** SkyPilot also supports defining and running pipelines without Airflow. Check out [Jobs Pipelines](https://skypilot.readthedocs.io/en/latest/examples/managed-jobs.html#job-pipelines) for more information. ## Why use SkyPilot with Airflow? -In AI workflows, **the transition from development to production is hard**. +In AI workflows, **the transition from development to production is hard**. -Workflow development happens ad-hoc, with a lot of interaction required -with the code and data. When moving this to an Airflow DAG in production, managing dependencies, environments and the -infra requirements of the workflow gets complex. Porting the code to an airflow requires significant time to test and +Workflow development happens ad-hoc, with a lot of interaction required +with the code and data. When moving this to an Airflow DAG in production, managing dependencies, environments and the +infra requirements of the workflow gets complex. Porting the code to an airflow requires significant time to test and validate any changes, often requiring re-writing the code as Airflow operators. -**SkyPilot seamlessly bridges the dev -> production gap**. +**SkyPilot seamlessly bridges the dev -> production gap**. SkyPilot can operate on any of your infra, allowing you to package and run the same code that you ran during development on a production Airflow cluster. Behind the scenes, SkyPilot handles environment setup, dependency management, and infra orchestration, allowing you to focus on your code. @@ -45,13 +45,13 @@ Here's how you can use SkyPilot to take your dev workflows to production in Airf Once your API server is deployed, you will need to configure Airflow to use it. Set the `SKYPILOT_API_SERVER_ENDPOINT` variable in Airflow - it will be used by the `run_sky_task` function to send requests to the API server: ```bash -airflow variables set SKYPILOT_API_SERVER_ENDPOINT https:// +airflow variables set SKYPILOT_API_SERVER_ENDPOINT http:// ``` You can also use the Airflow web UI to set the variable:

- +

@@ -88,35 +88,42 @@ Once we have developed the tasks, we can seamlessly run them in Airflow. 1. **No changes required to our tasks -** we use the same YAMLs we wrote in the previous step to create an Airflow DAG in `sky_train_dag.py`. 2. **Airflow native logging** - SkyPilot logs are written to container stdout, which is captured as task logs in Airflow and displayed in the UI. -3. **Easy debugging** - If a task fails, you can independently run the task using `sky launch` to debug the issue. SkyPilot will recreate the environment in which the task failed. +3. **Easy debugging** - If a task fails, you can independently run the task using `sky launch` to debug the issue. SkyPilot will recreate the environment in which the task failed. Here's a snippet of the DAG declaration in [sky_train_dag.py](https://github.com/skypilot-org/skypilot/blob/master/examples/airflow/sky_train_dag.py): ```python -with DAG(dag_id='sky_train_dag', - default_args=default_args, - schedule_interval=None, +with DAG(dag_id='sky_train_dag', default_args=default_args, catchup=False) as dag: # Path to SkyPilot YAMLs. Can be a git repo or local directory. base_path = 'https://github.com/skypilot-org/mock-train-workflow.git' # Generate bucket UUID as first task bucket_uuid = generate_bucket_uuid() - + # Use the bucket_uuid from previous task common_envs = { - 'DATA_BUCKET_NAME': f"sky-data-demo-{{{{ task_instance.xcom_pull(task_ids='generate_bucket_uuid') }}}}", - 'DATA_BUCKET_STORE_TYPE': 's3' + 'DATA_BUCKET_NAME': f"sky-data-demo-{bucket_uuid}", + 'DATA_BUCKET_STORE_TYPE': 'gcs', } - - preprocess = run_sky_task.override(task_id="data_preprocess")( - repo_url, 'data_preprocessing.yaml', envs_override=common_envs, git_branch='clientserver_example') + + preprocess_task = run_sky_task.override(task_id="data_preprocess")( + base_path, + 'data_preprocessing.yaml', + "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}", + envs_override=common_envs) train_task = run_sky_task.override(task_id="train")( - repo_url, 'train.yaml', envs_override=common_envs, git_branch='clientserver_example') + base_path, + 'train.yaml', + "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}", + envs_override=common_envs) eval_task = run_sky_task.override(task_id="eval")( - repo_url, 'eval.yaml', envs_override=common_envs, git_branch='clientserver_example') + base_path, + 'eval.yaml', + "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}", + envs_override=common_envs) # Define the workflow - bucket_uuid >> preprocess >> train_task >> eval_task + bucket_uuid >> preprocess_task >> train_task >> eval_task ``` Behind the scenes, the `run_sky_task` uses the Airflow native Python operator to invoke the SkyPilot API. All SkyPilot API calls are made to the remote API server, which is configured using the `SKYPILOT_API_SERVER_ENDPOINT` variable. @@ -143,25 +150,75 @@ All clusters are set to auto-down after the task is done, so no dangling cluster 1. Copy the DAG file to the Airflow DAGs directory. ```bash - cp sky_train_dag.py /path/to/airflow/dags + cp sky_train_dag.py /path/to/airflow/dags # If your Airflow is running on Kubernetes, you may use kubectl cp to copy the file to the pod - # kubectl cp sky_train_dag.py :/opt/airflow/dags + # kubectl cp sky_train_dag.py :/opt/airflow/dags ``` -2. Run `airflow dags list` to confirm that the DAG is loaded. +2. Run `airflow dags list` to confirm that the DAG is loaded. 3. Find the DAG in the Airflow UI (typically http://localhost:8080) and enable it. The UI may take a couple of minutes to reflect the changes. Force unpause the DAG if it is paused with `airflow dags unpause sky_train_dag` 4. Trigger the DAG from the Airflow UI using the `Trigger DAG` button. 5. Navigate to the run in the Airflow UI to see the DAG progress and logs of each task. -If a task fails, `task_failure_callback` will automatically tear down the SkyPilot cluster. - +If a task fails, SkyPilot will automatically tear down the SkyPilot cluster.

- +

- + +

+ +## Optional: Configure cloud accounts + +By default, the SkyPilot task will run using the same cloud credentials the SkyPilot +API Server has. This may not be ideal if we have +an existing service account for our task. + +In this example, we'll use a GCP service account to demonstrate how we can use custom credentials. +Refer to [GCP service account](https://docs.skypilot.co/en/latest/cloud-setup/cloud-permissions/gcp.html#gcp-service-account) +guide on how to set up a service account. + +Once you have the JSON key for our service account, create an Airflow +connection to store the credentials. + +```bash +airflow connections add \ + --conn-type google_cloud_platform \ + --conn-extra "{\"keyfile_dict\": \"\"}" \ + skypilot_gcp_task +``` + +You can also use the Airflow web UI to add the connection: + +

+

+Next, we will define `data_preprocessing_gcp_sa.yaml`, which contains small modifications to `data_preprocessing.yaml` that will use our GCP service account. The key changes needed here are to mount the GCP service account JSON key to our SkyPilot cluster, and to activate it using `gcloud auth activate-service-account`. + +We will also need a new task to read the GCP service account JSON key from our Airflow connection, and then change the preprocess task in our DAG to refer to this new YAML file. + +```python +with DAG(dag_id='sky_train_dag', default_args=default_args, + catchup=False) as dag: + ... + + # Get GCP credentials + gcp_service_account_json = get_gcp_service_account_json() + + ... + + preprocess_task = run_sky_task.override(task_id="data_preprocess")( + base_path, + 'data_preprocessing_gcp_sa.yaml', + "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}", + gcp_service_account_json=gcp_service_account_json, + envs_override=common_envs) + + ... + bucket_uuid >> gcp_service_account_json >> preprocess_task >> train_task >> eval_task +``` + ## Future work: a native Airflow Executor built on SkyPilot Currently this example relies on a helper `run_sky_task` method to wrap SkyPilot invocation in @task, but in the future SkyPilot can provide a native Airflow Executor. diff --git a/examples/airflow/data_preprocessing_gcp_sa.yaml b/examples/airflow/data_preprocessing_gcp_sa.yaml new file mode 100644 index 00000000000..e178e709443 --- /dev/null +++ b/examples/airflow/data_preprocessing_gcp_sa.yaml @@ -0,0 +1,33 @@ +resources: + cpus: 1+ + +envs: + DATA_BUCKET_NAME: sky-demo-data-test + DATA_BUCKET_STORE_TYPE: s3 + GCP_SERVICE_ACCOUNT_JSON_PATH: null + +file_mounts: + /data: + name: $DATA_BUCKET_NAME + store: $DATA_BUCKET_STORE_TYPE + /tmp/gcp-service-account.json: $GCP_SERVICE_ACCOUNT_JSON_PATH + +setup: | + echo "Setting up dependencies for data preprocessing..." + + curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-linux-x86_64.tar.gz + tar -xf google-cloud-cli-linux-x86_64.tar.gz + + ./google-cloud-sdk/install.sh --quiet --path-update true + source ~/.bashrc + gcloud auth activate-service-account --key-file=/tmp/gcp-service-account.json + +run: | + echo "Running data preprocessing on behalf of $(gcloud auth list --filter=status:ACTIVE --format="value(account)")..." + + # Generate few files with random data to simulate data preprocessing + for i in {0..9}; do + dd if=/dev/urandom of=/data/file_$i bs=1M count=10 + done + + echo "Data preprocessing completed, wrote to $DATA_BUCKET_NAME" diff --git a/examples/airflow/sky_train_dag.py b/examples/airflow/sky_train_dag.py index 9814e8ecae0..c83e3f635d2 100644 --- a/examples/airflow/sky_train_dag.py +++ b/examples/airflow/sky_train_dag.py @@ -1,40 +1,31 @@ +import json import os +from typing import Optional import uuid from airflow import DAG from airflow.decorators import task -from airflow.models import Variable -from airflow.utils.dates import days_ago -import yaml +from airflow.exceptions import AirflowNotFoundException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +import pendulum default_args = { 'owner': 'airflow', - 'start_date': days_ago(1), + 'start_date': pendulum.today('UTC').add(days=-1,) } -# Unique bucket name for this DAG run -DATA_BUCKET_NAME = str(uuid.uuid4())[:4] - -def task_failure_callback(context): - """Callback to shut down SkyPilot cluster on task failure.""" - cluster_name = context['task_instance'].xcom_pull( - task_ids=context['task_instance'].task_id, key='cluster_name') - if cluster_name: - print( - f"Task failed or was cancelled. Shutting down SkyPilot cluster: {cluster_name}" - ) - import sky - down_request = sky.down(cluster_name) - sky.stream_and_get(down_request) - - -@task(on_failure_callback=task_failure_callback) +@task.virtualenv( + python_version='3.11', + requirements=['skypilot-nightly[gcp,aws,kubernetes]'], + system_site_packages=False, +) def run_sky_task(base_path: str, yaml_path: str, + skypilot_api_server_endpoint: str, + gcp_service_account_json: Optional[str] = None, envs_override: dict = None, - git_branch: str = None, - **kwargs): + git_branch: str = None): """Generic function to run a SkyPilot task. This is a blocking call that runs the SkyPilot task and streams the logs. @@ -44,19 +35,71 @@ def run_sky_task(base_path: str, Args: base_path: Base path (local directory or git repo URL) yaml_path: Path to the YAML file (relative to base_path) + skypilot_api_server_endpoint: SkyPilot API server endpoint + gcp_service_account_json: GCP service account JSON-encoded string envs_override: Dictionary of environment variables to override in the task config git_branch: Optional branch name to checkout (only used if base_path is a git repo) """ + import json + import os import subprocess import tempfile + import uuid + + import yaml + + import sky + + def _run_sky_task(yaml_path: str, envs_override: dict): + """Internal helper to run the sky task after directory setup.""" + with open(os.path.expanduser(yaml_path), 'r', encoding='utf-8') as f: + task_config = yaml.safe_load(f) + + # Initialize envs if not present + if 'envs' not in task_config: + task_config['envs'] = {} + + # Update the envs with the override values + # task.update_envs() is not used here, see https://github.com/skypilot-org/skypilot/issues/4363 + task_config['envs'].update(envs_override) + + task = sky.Task.from_yaml_config(task_config) + cluster_uuid = str(uuid.uuid4())[:4] + task_name = os.path.splitext(os.path.basename(yaml_path))[0] + cluster_name = f'{task_name}-{cluster_uuid}' + + print(f"Starting SkyPilot task with cluster: {cluster_name}") + + launch_request_id = sky.launch(task, + cluster_name=cluster_name, + down=True) + job_id, _ = sky.stream_and_get(launch_request_id) + # TODO(romilb): In the future, we can use deferrable tasks to avoid blocking + # the worker while waiting for cluster to start. + + # Stream the logs for airflow logging + sky.tail_logs(cluster_name=cluster_name, job_id=job_id, follow=True) + + # Terminate the cluster after the task is done + down_id = sky.down(cluster_name) + sky.stream_and_get(down_id) + + return cluster_name - # Set the SkyPilot API server endpoint from Airflow Variables - endpoint = Variable.get('SKYPILOT_API_SERVER_ENDPOINT', None) - if not endpoint: - raise ValueError('SKYPILOT_API_SERVER_ENDPOINT is not set in airflow.') - os.environ['SKYPILOT_API_SERVER_ENDPOINT'] = endpoint + # Set the SkyPilot API server endpoint + os.environ['SKYPILOT_API_SERVER_ENDPOINT'] = skypilot_api_server_endpoint + sky.api_login(skypilot_api_server_endpoint) original_cwd = os.getcwd() + + # Write GCP service account JSON to a temporary file, + # which will be mounted to the SkyPilot cluster. + if gcp_service_account_json: + with tempfile.NamedTemporaryFile(delete=False, + suffix='.json') as temp_file: + temp_file.write(gcp_service_account_json.encode('utf-8')) + envs_override['GCP_SERVICE_ACCOUNT_JSON_PATH'] = temp_file.name + try: # Handle git repos vs local paths if base_path.startswith(('http://', 'https://', 'git://')): @@ -76,91 +119,76 @@ def run_sky_task(base_path: str, os.chdir(temp_dir) # Run the sky task - return _run_sky_task(full_yaml_path, envs_override or {}, - kwargs) + return _run_sky_task(full_yaml_path, envs_override or {}) else: full_yaml_path = os.path.join(base_path, yaml_path) os.chdir(base_path) # Run the sky task - return _run_sky_task(full_yaml_path, envs_override or {}, kwargs) + return _run_sky_task(full_yaml_path, envs_override or {}) finally: os.chdir(original_cwd) -def _run_sky_task(yaml_path: str, envs_override: dict, kwargs: dict): - """Internal helper to run the sky task after directory setup.""" - import sky - - with open(os.path.expanduser(yaml_path), 'r', encoding='utf-8') as f: - task_config = yaml.safe_load(f) - - # Initialize envs if not present - if 'envs' not in task_config: - task_config['envs'] = {} - - # Update the envs with the override values - # task.update_envs() is not used here, see https://github.com/skypilot-org/skypilot/issues/4363 - task_config['envs'].update(envs_override) - - task = sky.Task.from_yaml_config(task_config) - cluster_uuid = str(uuid.uuid4())[:4] - task_name = os.path.splitext(os.path.basename(yaml_path))[0] - cluster_name = f'{task_name}-{cluster_uuid}' - kwargs['ti'].xcom_push(key='cluster_name', - value=cluster_name) # For failure callback - - launch_request_id = sky.launch(task, cluster_name=cluster_name, down=True) - job_id, _ = sky.stream_and_get(launch_request_id) - # TODO(romilb): In the future, we can use deferrable tasks to avoid blocking - # the worker while waiting for cluster to start. - - # Stream the logs for airflow logging - sky.tail_logs(cluster_name=cluster_name, job_id=job_id, follow=True) - - # Terminate the cluster after the task is done - down_id = sky.down(cluster_name) - sky.stream_and_get(down_id) - - @task -def generate_bucket_uuid(**context): +def generate_bucket_uuid(): + """Generate a unique bucket UUID for this DAG run.""" + import uuid bucket_uuid = str(uuid.uuid4())[:4] return bucket_uuid -with DAG(dag_id='sky_train_dag', - default_args=default_args, - schedule_interval=None, +@task +def get_gcp_service_account_json() -> Optional[str]: + """Fetch GCP credentials from Airflow connection.""" + try: + hook = GoogleBaseHook(gcp_conn_id='skypilot_gcp_task') + status, message = hook.test_connection() + print(f"GCP connection status: {status}, message: {message}") + except AirflowNotFoundException: + print("GCP connection not found, skipping") + return None + conn = hook.get_connection(hook.gcp_conn_id) + service_account_json = conn.extra_dejson.get('keyfile_dict') + return service_account_json + + +with DAG(dag_id='sky_train_dag', default_args=default_args, catchup=False) as dag: # Path to SkyPilot YAMLs. Can be a git repo or local directory. base_path = 'https://github.com/skypilot-org/mock-train-workflow.git' # Generate bucket UUID as first task - # See https://stackoverflow.com/questions/55748050/generating-uuid-and-use-it-across-airflow-dag bucket_uuid = generate_bucket_uuid() + # Get GCP credentials (if available) + gcp_service_account_json = get_gcp_service_account_json() + # Use the bucket_uuid from previous task common_envs = { - 'DATA_BUCKET_NAME': f"sky-data-demo-{{{{ task_instance.xcom_pull(task_ids='generate_bucket_uuid') }}}}", - 'DATA_BUCKET_STORE_TYPE': 's3' + 'DATA_BUCKET_NAME': f"sky-data-demo-{bucket_uuid}", + 'DATA_BUCKET_STORE_TYPE': 'gcs', } - preprocess = run_sky_task.override(task_id="data_preprocess")( + preprocess_task = run_sky_task.override(task_id="data_preprocess")( base_path, + # Or data_preprocessing_gcp_sa.yaml if you want to use a custom GCP service account 'data_preprocessing.yaml', - envs_override=common_envs, - git_branch='clientserver_example') + "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}", + gcp_service_account_json=gcp_service_account_json, + envs_override=common_envs) train_task = run_sky_task.override(task_id="train")( base_path, 'train.yaml', - envs_override=common_envs, - git_branch='clientserver_example') + "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}", + gcp_service_account_json=None, + envs_override=common_envs) eval_task = run_sky_task.override(task_id="eval")( base_path, 'eval.yaml', - envs_override=common_envs, - git_branch='clientserver_example') + "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}", + gcp_service_account_json=None, + envs_override=common_envs) # Define the workflow - bucket_uuid >> preprocess >> train_task >> eval_task + bucket_uuid >> gcp_service_account_json >> preprocess_task >> train_task >> eval_task From 79c7776a80219fc66431ff31438f955ebd27238c Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja Date: Thu, 17 Jul 2025 16:40:51 -0700 Subject: [PATCH 02/13] revert DATA_BUCKET_STORE_TYPE from gcs to s3 --- examples/airflow/README.md | 2 +- examples/airflow/sky_train_dag.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/airflow/README.md b/examples/airflow/README.md index a59d92e934c..0aeace8dcb0 100644 --- a/examples/airflow/README.md +++ b/examples/airflow/README.md @@ -103,7 +103,7 @@ with DAG(dag_id='sky_train_dag', default_args=default_args, # Use the bucket_uuid from previous task common_envs = { 'DATA_BUCKET_NAME': f"sky-data-demo-{bucket_uuid}", - 'DATA_BUCKET_STORE_TYPE': 'gcs', + 'DATA_BUCKET_STORE_TYPE': 's3', } preprocess_task = run_sky_task.override(task_id="data_preprocess")( diff --git a/examples/airflow/sky_train_dag.py b/examples/airflow/sky_train_dag.py index c83e3f635d2..56d481d0e74 100644 --- a/examples/airflow/sky_train_dag.py +++ b/examples/airflow/sky_train_dag.py @@ -167,7 +167,7 @@ def get_gcp_service_account_json() -> Optional[str]: # Use the bucket_uuid from previous task common_envs = { 'DATA_BUCKET_NAME': f"sky-data-demo-{bucket_uuid}", - 'DATA_BUCKET_STORE_TYPE': 'gcs', + 'DATA_BUCKET_STORE_TYPE': 's3', } preprocess_task = run_sky_task.override(task_id="data_preprocess")( From ed207b90c06ddc85901bca3dbb5b1bf1ec561534 Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja Date: Thu, 17 Jul 2025 16:44:51 -0700 Subject: [PATCH 03/13] remove unused import --- examples/airflow/sky_train_dag.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/airflow/sky_train_dag.py b/examples/airflow/sky_train_dag.py index 56d481d0e74..ebd060e093b 100644 --- a/examples/airflow/sky_train_dag.py +++ b/examples/airflow/sky_train_dag.py @@ -133,7 +133,6 @@ def _run_sky_task(yaml_path: str, envs_override: dict): @task def generate_bucket_uuid(): """Generate a unique bucket UUID for this DAG run.""" - import uuid bucket_uuid = str(uuid.uuid4())[:4] return bucket_uuid From fe83e97c280130005f112a2c5d3a5980030b0881 Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja Date: Mon, 21 Jul 2025 10:32:04 -0700 Subject: [PATCH 04/13] add missing newline --- examples/airflow/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/airflow/README.md b/examples/airflow/README.md index 0aeace8dcb0..58311e9f27c 100644 --- a/examples/airflow/README.md +++ b/examples/airflow/README.md @@ -216,6 +216,7 @@ with DAG(dag_id='sky_train_dag', default_args=default_args, envs_override=common_envs) ... + bucket_uuid >> gcp_service_account_json >> preprocess_task >> train_task >> eval_task ``` From d40cc2b5d25957eae3cc3885cc27c020a8ade6e4 Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja Date: Mon, 21 Jul 2025 14:31:31 -0700 Subject: [PATCH 05/13] assign var.value.SKYPILOT_API_SERVER_ENDPOINT to a variable --- examples/airflow/README.md | 9 +++++---- examples/airflow/sky_train_dag.py | 7 ++++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/airflow/README.md b/examples/airflow/README.md index 58311e9f27c..066fa179496 100644 --- a/examples/airflow/README.md +++ b/examples/airflow/README.md @@ -106,20 +106,21 @@ with DAG(dag_id='sky_train_dag', default_args=default_args, 'DATA_BUCKET_STORE_TYPE': 's3', } + skypilot_api_server_endpoint = "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}" preprocess_task = run_sky_task.override(task_id="data_preprocess")( base_path, 'data_preprocessing.yaml', - "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}", + skypilot_api_server_endpoint, envs_override=common_envs) train_task = run_sky_task.override(task_id="train")( base_path, 'train.yaml', - "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}", + skypilot_api_server_endpoint, envs_override=common_envs) eval_task = run_sky_task.override(task_id="eval")( base_path, 'eval.yaml', - "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}", + skypilot_api_server_endpoint, envs_override=common_envs) # Define the workflow @@ -211,7 +212,7 @@ with DAG(dag_id='sky_train_dag', default_args=default_args, preprocess_task = run_sky_task.override(task_id="data_preprocess")( base_path, 'data_preprocessing_gcp_sa.yaml', - "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}", + skypilot_api_server_endpoint, gcp_service_account_json=gcp_service_account_json, envs_override=common_envs) diff --git a/examples/airflow/sky_train_dag.py b/examples/airflow/sky_train_dag.py index ebd060e093b..3779be89915 100644 --- a/examples/airflow/sky_train_dag.py +++ b/examples/airflow/sky_train_dag.py @@ -169,23 +169,24 @@ def get_gcp_service_account_json() -> Optional[str]: 'DATA_BUCKET_STORE_TYPE': 's3', } + skypilot_api_server_endpoint = "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}" preprocess_task = run_sky_task.override(task_id="data_preprocess")( base_path, # Or data_preprocessing_gcp_sa.yaml if you want to use a custom GCP service account 'data_preprocessing.yaml', - "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}", + skypilot_api_server_endpoint, gcp_service_account_json=gcp_service_account_json, envs_override=common_envs) train_task = run_sky_task.override(task_id="train")( base_path, 'train.yaml', - "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}", + skypilot_api_server_endpoint, gcp_service_account_json=None, envs_override=common_envs) eval_task = run_sky_task.override(task_id="eval")( base_path, 'eval.yaml', - "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}", + skypilot_api_server_endpoint, gcp_service_account_json=None, envs_override=common_envs) From dbc14cc3bdbd35be5653fc9ff23f5d66dd10b22b Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja Date: Tue, 22 Jul 2025 12:51:50 -0700 Subject: [PATCH 06/13] add a note on airflow task virtual env --- examples/airflow/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/airflow/README.md b/examples/airflow/README.md index 066fa179496..6115a19a94f 100644 --- a/examples/airflow/README.md +++ b/examples/airflow/README.md @@ -127,7 +127,9 @@ with DAG(dag_id='sky_train_dag', default_args=default_args, bucket_uuid >> preprocess_task >> train_task >> eval_task ``` -Behind the scenes, the `run_sky_task` uses the Airflow native Python operator to invoke the SkyPilot API. All SkyPilot API calls are made to the remote API server, which is configured using the `SKYPILOT_API_SERVER_ENDPOINT` variable. +Behind the scenes, the `run_sky_task` uses the Airflow native [PythonVirtualenvOperator](https://airflow.apache.org/docs/apache-airflow-providers-standard/stable/operators/python.html) (@task.virtualenv), which creates a Python virtual environment with `skypilot` installed. All SkyPilot API calls are made to the remote API server, which is configured using the `SKYPILOT_API_SERVER_ENDPOINT` variable. + +Note: We need to run the task in a virtual environment as there's a dependency conflict between the latest `skypilot` and `airflow` Python package. The task YAML files can be sourced in two ways: From a84c14d35aa9401dc9e4397b7d2d6897c3a99810 Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja Date: Tue, 22 Jul 2025 12:54:09 -0700 Subject: [PATCH 07/13] fix link --- examples/airflow/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/airflow/README.md b/examples/airflow/README.md index 6115a19a94f..4fca565ac0f 100644 --- a/examples/airflow/README.md +++ b/examples/airflow/README.md @@ -127,7 +127,7 @@ with DAG(dag_id='sky_train_dag', default_args=default_args, bucket_uuid >> preprocess_task >> train_task >> eval_task ``` -Behind the scenes, the `run_sky_task` uses the Airflow native [PythonVirtualenvOperator](https://airflow.apache.org/docs/apache-airflow-providers-standard/stable/operators/python.html) (@task.virtualenv), which creates a Python virtual environment with `skypilot` installed. All SkyPilot API calls are made to the remote API server, which is configured using the `SKYPILOT_API_SERVER_ENDPOINT` variable. +Behind the scenes, the `run_sky_task` uses the Airflow native [PythonVirtualenvOperator](https://airflow.apache.org/docs/apache-airflow-providers-standard/stable/operators/python.html#pythonvirtualenvoperator) (@task.virtualenv), which creates a Python virtual environment with `skypilot` installed. All SkyPilot API calls are made to the remote API server, which is configured using the `SKYPILOT_API_SERVER_ENDPOINT` variable. Note: We need to run the task in a virtual environment as there's a dependency conflict between the latest `skypilot` and `airflow` Python package. From 59bbb6dfc80060faa43338fb2a541c655e6b752e Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja Date: Tue, 22 Jul 2025 13:03:14 -0700 Subject: [PATCH 08/13] add code snippet for run_sky_task --- examples/airflow/README.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/examples/airflow/README.md b/examples/airflow/README.md index 4fca565ac0f..1be27274fe0 100644 --- a/examples/airflow/README.md +++ b/examples/airflow/README.md @@ -131,6 +131,21 @@ Behind the scenes, the `run_sky_task` uses the Airflow native [PythonVirtualenvO Note: We need to run the task in a virtual environment as there's a dependency conflict between the latest `skypilot` and `airflow` Python package. +```python +@task.virtualenv( + python_version='3.11', + requirements=['skypilot-nightly[gcp,aws,kubernetes]'], + system_site_packages=False, +) +def run_sky_task(base_path: str, + yaml_path: str, + skypilot_api_server_endpoint: str, + gcp_service_account_json: Optional[str] = None, + envs_override: dict = None, + git_branch: str = None): + ... +``` + The task YAML files can be sourced in two ways: 1. **From a Git repository** (as shown above): From 6bff16a331fe3688090efc3d7d8395cfedb81454 Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja Date: Tue, 22 Jul 2025 13:06:53 -0700 Subject: [PATCH 09/13] simplify code snippet --- examples/airflow/README.md | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/examples/airflow/README.md b/examples/airflow/README.md index 1be27274fe0..0368fa6a656 100644 --- a/examples/airflow/README.md +++ b/examples/airflow/README.md @@ -137,13 +137,8 @@ Note: We need to run the task in a virtual environment as there's a dependency c requirements=['skypilot-nightly[gcp,aws,kubernetes]'], system_site_packages=False, ) -def run_sky_task(base_path: str, - yaml_path: str, - skypilot_api_server_endpoint: str, - gcp_service_account_json: Optional[str] = None, - envs_override: dict = None, - git_branch: str = None): - ... +def run_sky_task(...): + ... ``` The task YAML files can be sourced in two ways: From 61c3da96c4dbdce4329af334782d1fdd5e609009 Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja Date: Tue, 22 Jul 2025 13:26:40 -0700 Subject: [PATCH 10/13] simplify passing in SKYPILOT_API_SERVER_ENDPOINT --- examples/airflow/README.md | 8 ++++---- examples/airflow/sky_train_dag.py | 19 ++++++++----------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/examples/airflow/README.md b/examples/airflow/README.md index 0368fa6a656..e11c2c9a3da 100644 --- a/examples/airflow/README.md +++ b/examples/airflow/README.md @@ -127,16 +127,16 @@ with DAG(dag_id='sky_train_dag', default_args=default_args, bucket_uuid >> preprocess_task >> train_task >> eval_task ``` -Behind the scenes, the `run_sky_task` uses the Airflow native [PythonVirtualenvOperator](https://airflow.apache.org/docs/apache-airflow-providers-standard/stable/operators/python.html#pythonvirtualenvoperator) (@task.virtualenv), which creates a Python virtual environment with `skypilot` installed. All SkyPilot API calls are made to the remote API server, which is configured using the `SKYPILOT_API_SERVER_ENDPOINT` variable. - -Note: We need to run the task in a virtual environment as there's a dependency conflict between the latest `skypilot` and `airflow` Python package. +Behind the scenes, the `run_sky_task` uses the Airflow native [PythonVirtualenvOperator](https://airflow.apache.org/docs/apache-airflow-providers-standard/stable/operators/python.html#pythonvirtualenvoperator) (@task.virtualenv), which creates a Python virtual environment with `skypilot` installed. We need to run the task in a virtual environment as there's a dependency conflict between the latest `skypilot` and `airflow` Python package. ```python @task.virtualenv( python_version='3.11', requirements=['skypilot-nightly[gcp,aws,kubernetes]'], system_site_packages=False, -) + templates_dict={ + 'SKYPILOT_API_SERVER_ENDPOINT': "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}", + }) def run_sky_task(...): ... ``` diff --git a/examples/airflow/sky_train_dag.py b/examples/airflow/sky_train_dag.py index 3779be89915..3294174be60 100644 --- a/examples/airflow/sky_train_dag.py +++ b/examples/airflow/sky_train_dag.py @@ -19,13 +19,15 @@ python_version='3.11', requirements=['skypilot-nightly[gcp,aws,kubernetes]'], system_site_packages=False, -) + templates_dict={ + 'SKYPILOT_API_SERVER_ENDPOINT': "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}", + }) def run_sky_task(base_path: str, yaml_path: str, - skypilot_api_server_endpoint: str, gcp_service_account_json: Optional[str] = None, envs_override: dict = None, - git_branch: str = None): + git_branch: str = None, + **kwargs): """Generic function to run a SkyPilot task. This is a blocking call that runs the SkyPilot task and streams the logs. @@ -35,12 +37,10 @@ def run_sky_task(base_path: str, Args: base_path: Base path (local directory or git repo URL) yaml_path: Path to the YAML file (relative to base_path) - skypilot_api_server_endpoint: SkyPilot API server endpoint gcp_service_account_json: GCP service account JSON-encoded string envs_override: Dictionary of environment variables to override in the task config git_branch: Optional branch name to checkout (only used if base_path is a git repo) """ - import json import os import subprocess import tempfile @@ -87,8 +87,9 @@ def _run_sky_task(yaml_path: str, envs_override: dict): return cluster_name # Set the SkyPilot API server endpoint - os.environ['SKYPILOT_API_SERVER_ENDPOINT'] = skypilot_api_server_endpoint - sky.api_login(skypilot_api_server_endpoint) + if kwargs['templates_dict']: + os.environ['SKYPILOT_API_SERVER_ENDPOINT'] = kwargs['templates_dict'][ + 'SKYPILOT_API_SERVER_ENDPOINT'] original_cwd = os.getcwd() @@ -169,24 +170,20 @@ def get_gcp_service_account_json() -> Optional[str]: 'DATA_BUCKET_STORE_TYPE': 's3', } - skypilot_api_server_endpoint = "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}" preprocess_task = run_sky_task.override(task_id="data_preprocess")( base_path, # Or data_preprocessing_gcp_sa.yaml if you want to use a custom GCP service account 'data_preprocessing.yaml', - skypilot_api_server_endpoint, gcp_service_account_json=gcp_service_account_json, envs_override=common_envs) train_task = run_sky_task.override(task_id="train")( base_path, 'train.yaml', - skypilot_api_server_endpoint, gcp_service_account_json=None, envs_override=common_envs) eval_task = run_sky_task.override(task_id="eval")( base_path, 'eval.yaml', - skypilot_api_server_endpoint, gcp_service_account_json=None, envs_override=common_envs) From c4f18c7e0d6b5568418d40909b05514ca53d446b Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja Date: Tue, 22 Jul 2025 13:32:02 -0700 Subject: [PATCH 11/13] fix readme --- examples/airflow/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/airflow/README.md b/examples/airflow/README.md index e11c2c9a3da..69faf5728fc 100644 --- a/examples/airflow/README.md +++ b/examples/airflow/README.md @@ -224,7 +224,6 @@ with DAG(dag_id='sky_train_dag', default_args=default_args, preprocess_task = run_sky_task.override(task_id="data_preprocess")( base_path, 'data_preprocessing_gcp_sa.yaml', - skypilot_api_server_endpoint, gcp_service_account_json=gcp_service_account_json, envs_override=common_envs) From 2012ed4ac1cbb125fe1aa6120f29ef82c168ce24 Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja Date: Tue, 22 Jul 2025 13:38:20 -0700 Subject: [PATCH 12/13] simplify --- examples/airflow/sky_train_dag.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/examples/airflow/sky_train_dag.py b/examples/airflow/sky_train_dag.py index 3294174be60..6da5325af80 100644 --- a/examples/airflow/sky_train_dag.py +++ b/examples/airflow/sky_train_dag.py @@ -177,15 +177,10 @@ def get_gcp_service_account_json() -> Optional[str]: gcp_service_account_json=gcp_service_account_json, envs_override=common_envs) train_task = run_sky_task.override(task_id="train")( - base_path, - 'train.yaml', - gcp_service_account_json=None, - envs_override=common_envs) - eval_task = run_sky_task.override(task_id="eval")( - base_path, - 'eval.yaml', - gcp_service_account_json=None, - envs_override=common_envs) + base_path, 'train.yaml', envs_override=common_envs) + eval_task = run_sky_task.override(task_id="eval")(base_path, + 'eval.yaml', + envs_override=common_envs) # Define the workflow bucket_uuid >> gcp_service_account_json >> preprocess_task >> train_task >> eval_task From 836a93f6e5bf7501c17f4d33d7491c1df40384d4 Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja Date: Tue, 22 Jul 2025 13:50:29 -0700 Subject: [PATCH 13/13] move import sky inside _run_sky_task --- examples/airflow/sky_train_dag.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/airflow/sky_train_dag.py b/examples/airflow/sky_train_dag.py index 6da5325af80..b3ae9593782 100644 --- a/examples/airflow/sky_train_dag.py +++ b/examples/airflow/sky_train_dag.py @@ -48,9 +48,8 @@ def run_sky_task(base_path: str, import yaml - import sky - def _run_sky_task(yaml_path: str, envs_override: dict): + import sky """Internal helper to run the sky task after directory setup.""" with open(os.path.expanduser(yaml_path), 'r', encoding='utf-8') as f: task_config = yaml.safe_load(f)