Skip to content

Document how to use a GCP service account in Airflow with SkyPilot #6291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 86 additions & 29 deletions examples/airflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

In this guide, we show how a training workflow involving data preprocessing, training and evaluation can be first easily developed with SkyPilot, and then orchestrated in Airflow.

This example uses a remote SkyPilot API Server to manage shared state across invocations, and includes a failure callback to tear down the SkyPilot cluster on task failure.
This example uses a remote SkyPilot API Server to manage shared state across invocations.

<!-- Source: https://docs.google.com/drawings/d/1Di_KIOlxQEUib_RhMKysXBc6u-5WW9FnbougVWRiGF0/edit?usp=sharing -->

Expand All @@ -11,17 +11,17 @@ This example uses a remote SkyPilot API Server to manage shared state across inv
</p>


**💡 Tip:** SkyPilot also supports defining and running pipelines without Airflow. Check out [Jobs Pipelines](https://skypilot.readthedocs.io/en/latest/examples/managed-jobs.html#job-pipelines) for more information.
**💡 Tip:** SkyPilot also supports defining and running pipelines without Airflow. Check out [Jobs Pipelines](https://skypilot.readthedocs.io/en/latest/examples/managed-jobs.html#job-pipelines) for more information.

## Why use SkyPilot with Airflow?
In AI workflows, **the transition from development to production is hard**.
In AI workflows, **the transition from development to production is hard**.

Workflow development happens ad-hoc, with a lot of interaction required
with the code and data. When moving this to an Airflow DAG in production, managing dependencies, environments and the
infra requirements of the workflow gets complex. Porting the code to an airflow requires significant time to test and
Workflow development happens ad-hoc, with a lot of interaction required
with the code and data. When moving this to an Airflow DAG in production, managing dependencies, environments and the
infra requirements of the workflow gets complex. Porting the code to an airflow requires significant time to test and
validate any changes, often requiring re-writing the code as Airflow operators.

**SkyPilot seamlessly bridges the dev -> production gap**.
**SkyPilot seamlessly bridges the dev -> production gap**.

SkyPilot can operate on any of your infra, allowing you to package and run the same code that you ran during development on a
production Airflow cluster. Behind the scenes, SkyPilot handles environment setup, dependency management, and infra orchestration, allowing you to focus on your code.
Expand All @@ -45,13 +45,13 @@ Here's how you can use SkyPilot to take your dev workflows to production in Airf
Once your API server is deployed, you will need to configure Airflow to use it. Set the `SKYPILOT_API_SERVER_ENDPOINT` variable in Airflow - it will be used by the `run_sky_task` function to send requests to the API server:

```bash
airflow variables set SKYPILOT_API_SERVER_ENDPOINT https://<api-server-endpoint>
airflow variables set SKYPILOT_API_SERVER_ENDPOINT http://<skypilot-api-server-endpoint>
```

You can also use the Airflow web UI to set the variable:

<p align="center">
<img src="https://i.imgur.com/vjM0FtH.png" width="600">
<img src="https://i.imgur.com/SQPVxjl.png" width="600">
</p>


Expand Down Expand Up @@ -88,35 +88,42 @@ Once we have developed the tasks, we can seamlessly run them in Airflow.

1. **No changes required to our tasks -** we use the same YAMLs we wrote in the previous step to create an Airflow DAG in `sky_train_dag.py`.
2. **Airflow native logging** - SkyPilot logs are written to container stdout, which is captured as task logs in Airflow and displayed in the UI.
3. **Easy debugging** - If a task fails, you can independently run the task using `sky launch` to debug the issue. SkyPilot will recreate the environment in which the task failed.
3. **Easy debugging** - If a task fails, you can independently run the task using `sky launch` to debug the issue. SkyPilot will recreate the environment in which the task failed.

Here's a snippet of the DAG declaration in [sky_train_dag.py](https://github.com/skypilot-org/skypilot/blob/master/examples/airflow/sky_train_dag.py):
```python
with DAG(dag_id='sky_train_dag',
default_args=default_args,
schedule_interval=None,
with DAG(dag_id='sky_train_dag', default_args=default_args,
catchup=False) as dag:
# Path to SkyPilot YAMLs. Can be a git repo or local directory.
base_path = 'https://github.com/skypilot-org/mock-train-workflow.git'

# Generate bucket UUID as first task
bucket_uuid = generate_bucket_uuid()

# Use the bucket_uuid from previous task
common_envs = {
'DATA_BUCKET_NAME': f"sky-data-demo-{{{{ task_instance.xcom_pull(task_ids='generate_bucket_uuid') }}}}",
'DATA_BUCKET_STORE_TYPE': 's3'
'DATA_BUCKET_NAME': f"sky-data-demo-{bucket_uuid}",
'DATA_BUCKET_STORE_TYPE': 's3',
}

preprocess = run_sky_task.override(task_id="data_preprocess")(
repo_url, 'data_preprocessing.yaml', envs_override=common_envs, git_branch='clientserver_example')

preprocess_task = run_sky_task.override(task_id="data_preprocess")(
base_path,
'data_preprocessing.yaml',
"{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}",
envs_override=common_envs)
train_task = run_sky_task.override(task_id="train")(
repo_url, 'train.yaml', envs_override=common_envs, git_branch='clientserver_example')
base_path,
'train.yaml',
"{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}",
envs_override=common_envs)
eval_task = run_sky_task.override(task_id="eval")(
repo_url, 'eval.yaml', envs_override=common_envs, git_branch='clientserver_example')
base_path,
'eval.yaml',
"{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}",
envs_override=common_envs)

# Define the workflow
bucket_uuid >> preprocess >> train_task >> eval_task
bucket_uuid >> preprocess_task >> train_task >> eval_task
```

Behind the scenes, the `run_sky_task` uses the Airflow native Python operator to invoke the SkyPilot API. All SkyPilot API calls are made to the remote API server, which is configured using the `SKYPILOT_API_SERVER_ENDPOINT` variable.
Expand All @@ -143,25 +150,75 @@ All clusters are set to auto-down after the task is done, so no dangling cluster

1. Copy the DAG file to the Airflow DAGs directory.
```bash
cp sky_train_dag.py /path/to/airflow/dags
cp sky_train_dag.py /path/to/airflow/dags
# If your Airflow is running on Kubernetes, you may use kubectl cp to copy the file to the pod
# kubectl cp sky_train_dag.py <airflow-pod-name>:/opt/airflow/dags
# kubectl cp sky_train_dag.py <airflow-pod-name>:/opt/airflow/dags
```
2. Run `airflow dags list` to confirm that the DAG is loaded.
2. Run `airflow dags list` to confirm that the DAG is loaded.
3. Find the DAG in the Airflow UI (typically http://localhost:8080) and enable it. The UI may take a couple of minutes to reflect the changes. Force unpause the DAG if it is paused with `airflow dags unpause sky_train_dag`
4. Trigger the DAG from the Airflow UI using the `Trigger DAG` button.
5. Navigate to the run in the Airflow UI to see the DAG progress and logs of each task.

If a task fails, `task_failure_callback` will automatically tear down the SkyPilot cluster.

If a task fails, SkyPilot will automatically tear down the SkyPilot cluster.

<p align="center">
<img src="https://i.imgur.com/TXn5eKI.png" width="800">
<img src="https://i.imgur.com/HH4y77d.png" width="800">
</p>
<p align="center">
<img src="https://i.imgur.com/D89N5xt.png" width="800">
<img src="https://i.imgur.com/EOR4lzV.png" width="800">
</p>

## Optional: Configure cloud accounts

By default, the SkyPilot task will run using the same cloud credentials the SkyPilot
API Server has. This may not be ideal if we have
an existing service account for our task.

In this example, we'll use a GCP service account to demonstrate how we can use custom credentials.
Refer to [GCP service account](https://docs.skypilot.co/en/latest/cloud-setup/cloud-permissions/gcp.html#gcp-service-account)
guide on how to set up a service account.

Once you have the JSON key for our service account, create an Airflow
connection to store the credentials.

```bash
airflow connections add \
--conn-type google_cloud_platform \
--conn-extra "{\"keyfile_dict\": \"<YOUR_SERVICE_ACCOUNT_JSON_KEY>\"}" \
skypilot_gcp_task
```

You can also use the Airflow web UI to add the connection:

<p align="center">
<img src="https://i.imgur.com/uFOs3G0.png" width="600">
</p>

Next, we will define `data_preprocessing_gcp_sa.yaml`, which contains small modifications to `data_preprocessing.yaml` that will use our GCP service account. The key changes needed here are to mount the GCP service account JSON key to our SkyPilot cluster, and to activate it using `gcloud auth activate-service-account`.

We will also need a new task to read the GCP service account JSON key from our Airflow connection, and then change the preprocess task in our DAG to refer to this new YAML file.

```python
with DAG(dag_id='sky_train_dag', default_args=default_args,
catchup=False) as dag:
...

# Get GCP credentials
gcp_service_account_json = get_gcp_service_account_json()

...

preprocess_task = run_sky_task.override(task_id="data_preprocess")(
base_path,
'data_preprocessing_gcp_sa.yaml',
"{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}",
gcp_service_account_json=gcp_service_account_json,
envs_override=common_envs)

...
bucket_uuid >> gcp_service_account_json >> preprocess_task >> train_task >> eval_task
```

## Future work: a native Airflow Executor built on SkyPilot

Currently this example relies on a helper `run_sky_task` method to wrap SkyPilot invocation in @task, but in the future SkyPilot can provide a native Airflow Executor.
Expand Down
33 changes: 33 additions & 0 deletions examples/airflow/data_preprocessing_gcp_sa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
resources:
cpus: 1+

envs:
DATA_BUCKET_NAME: sky-demo-data-test
DATA_BUCKET_STORE_TYPE: s3
GCP_SERVICE_ACCOUNT_JSON_PATH: null

file_mounts:
/data:
name: $DATA_BUCKET_NAME
store: $DATA_BUCKET_STORE_TYPE
/tmp/gcp-service-account.json: $GCP_SERVICE_ACCOUNT_JSON_PATH

setup: |
echo "Setting up dependencies for data preprocessing..."

curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-linux-x86_64.tar.gz
tar -xf google-cloud-cli-linux-x86_64.tar.gz

./google-cloud-sdk/install.sh --quiet --path-update true
source ~/.bashrc
gcloud auth activate-service-account --key-file=/tmp/gcp-service-account.json

run: |
echo "Running data preprocessing on behalf of $(gcloud auth list --filter=status:ACTIVE --format="value(account)")..."

# Generate few files with random data to simulate data preprocessing
for i in {0..9}; do
dd if=/dev/urandom of=/data/file_$i bs=1M count=10
done

echo "Data preprocessing completed, wrote to $DATA_BUCKET_NAME"
Loading
Loading