Skip to content

Commit 96d5fc1

Browse files
Document how to use a GCP service account in Airflow with SkyPilot (#6291)
* document how to use a GCP service account in Airflow with SkyPilot * revert DATA_BUCKET_STORE_TYPE from gcs to s3 * remove unused import * add missing newline * assign var.value.SKYPILOT_API_SERVER_ENDPOINT to a variable * add a note on airflow task virtual env * fix link * add code snippet for run_sky_task * simplify code snippet * simplify passing in SKYPILOT_API_SERVER_ENDPOINT * fix readme * simplify * move import sky inside _run_sky_task
1 parent 1348997 commit 96d5fc1

File tree

3 files changed

+236
-114
lines changed

3 files changed

+236
-114
lines changed

examples/airflow/README.md

Lines changed: 100 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
In this guide, we show how a training workflow involving data preprocessing, training and evaluation can be first easily developed with SkyPilot, and then orchestrated in Airflow.
44

5-
This example uses a remote SkyPilot API Server to manage shared state across invocations, and includes a failure callback to tear down the SkyPilot cluster on task failure.
5+
This example uses a remote SkyPilot API Server to manage shared state across invocations.
66

77
<!-- Source: https://docs.google.com/drawings/d/1Di_KIOlxQEUib_RhMKysXBc6u-5WW9FnbougVWRiGF0/edit?usp=sharing -->
88

@@ -11,17 +11,17 @@ This example uses a remote SkyPilot API Server to manage shared state across inv
1111
</p>
1212

1313

14-
**💡 Tip:** SkyPilot also supports defining and running pipelines without Airflow. Check out [Jobs Pipelines](https://skypilot.readthedocs.io/en/latest/examples/managed-jobs.html#job-pipelines) for more information.
14+
**💡 Tip:** SkyPilot also supports defining and running pipelines without Airflow. Check out [Jobs Pipelines](https://skypilot.readthedocs.io/en/latest/examples/managed-jobs.html#job-pipelines) for more information.
1515

1616
## Why use SkyPilot with Airflow?
17-
In AI workflows, **the transition from development to production is hard**.
17+
In AI workflows, **the transition from development to production is hard**.
1818

19-
Workflow development happens ad-hoc, with a lot of interaction required
20-
with the code and data. When moving this to an Airflow DAG in production, managing dependencies, environments and the
21-
infra requirements of the workflow gets complex. Porting the code to an airflow requires significant time to test and
19+
Workflow development happens ad-hoc, with a lot of interaction required
20+
with the code and data. When moving this to an Airflow DAG in production, managing dependencies, environments and the
21+
infra requirements of the workflow gets complex. Porting the code to an airflow requires significant time to test and
2222
validate any changes, often requiring re-writing the code as Airflow operators.
2323

24-
**SkyPilot seamlessly bridges the dev -> production gap**.
24+
**SkyPilot seamlessly bridges the dev -> production gap**.
2525

2626
SkyPilot can operate on any of your infra, allowing you to package and run the same code that you ran during development on a
2727
production Airflow cluster. Behind the scenes, SkyPilot handles environment setup, dependency management, and infra orchestration, allowing you to focus on your code.
@@ -45,13 +45,13 @@ Here's how you can use SkyPilot to take your dev workflows to production in Airf
4545
Once your API server is deployed, you will need to configure Airflow to use it. Set the `SKYPILOT_API_SERVER_ENDPOINT` variable in Airflow - it will be used by the `run_sky_task` function to send requests to the API server:
4646

4747
```bash
48-
airflow variables set SKYPILOT_API_SERVER_ENDPOINT https://<api-server-endpoint>
48+
airflow variables set SKYPILOT_API_SERVER_ENDPOINT http://<skypilot-api-server-endpoint>
4949
```
5050

5151
You can also use the Airflow web UI to set the variable:
5252

5353
<p align="center">
54-
<img src="https://i.imgur.com/vjM0FtH.png" width="600">
54+
<img src="https://i.imgur.com/SQPVxjl.png" width="600">
5555
</p>
5656

5757

@@ -88,38 +88,58 @@ Once we have developed the tasks, we can seamlessly run them in Airflow.
8888

8989
1. **No changes required to our tasks -** we use the same YAMLs we wrote in the previous step to create an Airflow DAG in `sky_train_dag.py`.
9090
2. **Airflow native logging** - SkyPilot logs are written to container stdout, which is captured as task logs in Airflow and displayed in the UI.
91-
3. **Easy debugging** - If a task fails, you can independently run the task using `sky launch` to debug the issue. SkyPilot will recreate the environment in which the task failed.
91+
3. **Easy debugging** - If a task fails, you can independently run the task using `sky launch` to debug the issue. SkyPilot will recreate the environment in which the task failed.
9292

9393
Here's a snippet of the DAG declaration in [sky_train_dag.py](https://github.com/skypilot-org/skypilot/blob/master/examples/airflow/sky_train_dag.py):
9494
```python
95-
with DAG(dag_id='sky_train_dag',
96-
default_args=default_args,
97-
schedule_interval=None,
95+
with DAG(dag_id='sky_train_dag', default_args=default_args,
9896
catchup=False) as dag:
9997
# Path to SkyPilot YAMLs. Can be a git repo or local directory.
10098
base_path = 'https://github.com/skypilot-org/mock-train-workflow.git'
10199

102100
# Generate bucket UUID as first task
103101
bucket_uuid = generate_bucket_uuid()
104-
102+
105103
# Use the bucket_uuid from previous task
106104
common_envs = {
107-
'DATA_BUCKET_NAME': f"sky-data-demo-{{{{ task_instance.xcom_pull(task_ids='generate_bucket_uuid') }}}}",
108-
'DATA_BUCKET_STORE_TYPE': 's3'
105+
'DATA_BUCKET_NAME': f"sky-data-demo-{bucket_uuid}",
106+
'DATA_BUCKET_STORE_TYPE': 's3',
109107
}
110-
111-
preprocess = run_sky_task.override(task_id="data_preprocess")(
112-
repo_url, 'data_preprocessing.yaml', envs_override=common_envs, git_branch='clientserver_example')
108+
109+
skypilot_api_server_endpoint = "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}"
110+
preprocess_task = run_sky_task.override(task_id="data_preprocess")(
111+
base_path,
112+
'data_preprocessing.yaml',
113+
skypilot_api_server_endpoint,
114+
envs_override=common_envs)
113115
train_task = run_sky_task.override(task_id="train")(
114-
repo_url, 'train.yaml', envs_override=common_envs, git_branch='clientserver_example')
116+
base_path,
117+
'train.yaml',
118+
skypilot_api_server_endpoint,
119+
envs_override=common_envs)
115120
eval_task = run_sky_task.override(task_id="eval")(
116-
repo_url, 'eval.yaml', envs_override=common_envs, git_branch='clientserver_example')
121+
base_path,
122+
'eval.yaml',
123+
skypilot_api_server_endpoint,
124+
envs_override=common_envs)
117125

118126
# Define the workflow
119-
bucket_uuid >> preprocess >> train_task >> eval_task
127+
bucket_uuid >> preprocess_task >> train_task >> eval_task
120128
```
121129

122-
Behind the scenes, the `run_sky_task` uses the Airflow native Python operator to invoke the SkyPilot API. All SkyPilot API calls are made to the remote API server, which is configured using the `SKYPILOT_API_SERVER_ENDPOINT` variable.
130+
Behind the scenes, the `run_sky_task` uses the Airflow native [PythonVirtualenvOperator](https://airflow.apache.org/docs/apache-airflow-providers-standard/stable/operators/python.html#pythonvirtualenvoperator) (@task.virtualenv), which creates a Python virtual environment with `skypilot` installed. We need to run the task in a virtual environment as there's a dependency conflict between the latest `skypilot` and `airflow` Python package.
131+
132+
```python
133+
@task.virtualenv(
134+
python_version='3.11',
135+
requirements=['skypilot-nightly[gcp,aws,kubernetes]'],
136+
system_site_packages=False,
137+
templates_dict={
138+
'SKYPILOT_API_SERVER_ENDPOINT': "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}",
139+
})
140+
def run_sky_task(...):
141+
...
142+
```
123143

124144
The task YAML files can be sourced in two ways:
125145

@@ -143,25 +163,75 @@ All clusters are set to auto-down after the task is done, so no dangling cluster
143163

144164
1. Copy the DAG file to the Airflow DAGs directory.
145165
```bash
146-
cp sky_train_dag.py /path/to/airflow/dags
166+
cp sky_train_dag.py /path/to/airflow/dags
147167
# If your Airflow is running on Kubernetes, you may use kubectl cp to copy the file to the pod
148-
# kubectl cp sky_train_dag.py <airflow-pod-name>:/opt/airflow/dags
168+
# kubectl cp sky_train_dag.py <airflow-pod-name>:/opt/airflow/dags
149169
```
150-
2. Run `airflow dags list` to confirm that the DAG is loaded.
170+
2. Run `airflow dags list` to confirm that the DAG is loaded.
151171
3. Find the DAG in the Airflow UI (typically http://localhost:8080) and enable it. The UI may take a couple of minutes to reflect the changes. Force unpause the DAG if it is paused with `airflow dags unpause sky_train_dag`
152172
4. Trigger the DAG from the Airflow UI using the `Trigger DAG` button.
153173
5. Navigate to the run in the Airflow UI to see the DAG progress and logs of each task.
154174

155-
If a task fails, `task_failure_callback` will automatically tear down the SkyPilot cluster.
156-
175+
If a task fails, SkyPilot will automatically tear down the SkyPilot cluster.
157176

158177
<p align="center">
159-
<img src="https://i.imgur.com/TXn5eKI.png" width="800">
178+
<img src="https://i.imgur.com/HH4y77d.png" width="800">
179+
</p>
180+
<p align="center">
181+
<img src="https://i.imgur.com/EOR4lzV.png" width="800">
160182
</p>
183+
184+
## Optional: Configure cloud accounts
185+
186+
By default, the SkyPilot task will run using the same cloud credentials the SkyPilot
187+
API Server has. This may not be ideal if we have
188+
an existing service account for our task.
189+
190+
In this example, we'll use a GCP service account to demonstrate how we can use custom credentials.
191+
Refer to [GCP service account](https://docs.skypilot.co/en/latest/cloud-setup/cloud-permissions/gcp.html#gcp-service-account)
192+
guide on how to set up a service account.
193+
194+
Once you have the JSON key for our service account, create an Airflow
195+
connection to store the credentials.
196+
197+
```bash
198+
airflow connections add \
199+
--conn-type google_cloud_platform \
200+
--conn-extra "{\"keyfile_dict\": \"<YOUR_SERVICE_ACCOUNT_JSON_KEY>\"}" \
201+
skypilot_gcp_task
202+
```
203+
204+
You can also use the Airflow web UI to add the connection:
205+
161206
<p align="center">
162-
<img src="https://i.imgur.com/D89N5xt.png" width="800">
207+
<img src="https://i.imgur.com/uFOs3G0.png" width="600">
163208
</p>
164209

210+
Next, we will define `data_preprocessing_gcp_sa.yaml`, which contains small modifications to `data_preprocessing.yaml` that will use our GCP service account. The key changes needed here are to mount the GCP service account JSON key to our SkyPilot cluster, and to activate it using `gcloud auth activate-service-account`.
211+
212+
We will also need a new task to read the GCP service account JSON key from our Airflow connection, and then change the preprocess task in our DAG to refer to this new YAML file.
213+
214+
```python
215+
with DAG(dag_id='sky_train_dag', default_args=default_args,
216+
catchup=False) as dag:
217+
...
218+
219+
# Get GCP credentials
220+
gcp_service_account_json = get_gcp_service_account_json()
221+
222+
...
223+
224+
preprocess_task = run_sky_task.override(task_id="data_preprocess")(
225+
base_path,
226+
'data_preprocessing_gcp_sa.yaml',
227+
gcp_service_account_json=gcp_service_account_json,
228+
envs_override=common_envs)
229+
230+
...
231+
232+
bucket_uuid >> gcp_service_account_json >> preprocess_task >> train_task >> eval_task
233+
```
234+
165235
## Future work: a native Airflow Executor built on SkyPilot
166236

167237
Currently this example relies on a helper `run_sky_task` method to wrap SkyPilot invocation in @task, but in the future SkyPilot can provide a native Airflow Executor.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
resources:
2+
cpus: 1+
3+
4+
envs:
5+
DATA_BUCKET_NAME: sky-demo-data-test
6+
DATA_BUCKET_STORE_TYPE: s3
7+
GCP_SERVICE_ACCOUNT_JSON_PATH: null
8+
9+
file_mounts:
10+
/data:
11+
name: $DATA_BUCKET_NAME
12+
store: $DATA_BUCKET_STORE_TYPE
13+
/tmp/gcp-service-account.json: $GCP_SERVICE_ACCOUNT_JSON_PATH
14+
15+
setup: |
16+
echo "Setting up dependencies for data preprocessing..."
17+
18+
curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-linux-x86_64.tar.gz
19+
tar -xf google-cloud-cli-linux-x86_64.tar.gz
20+
21+
./google-cloud-sdk/install.sh --quiet --path-update true
22+
source ~/.bashrc
23+
gcloud auth activate-service-account --key-file=/tmp/gcp-service-account.json
24+
25+
run: |
26+
echo "Running data preprocessing on behalf of $(gcloud auth list --filter=status:ACTIVE --format="value(account)")..."
27+
28+
# Generate few files with random data to simulate data preprocessing
29+
for i in {0..9}; do
30+
dd if=/dev/urandom of=/data/file_$i bs=1M count=10
31+
done
32+
33+
echo "Data preprocessing completed, wrote to $DATA_BUCKET_NAME"

0 commit comments

Comments
 (0)