Skip to content

Commit 861fec3

Browse files
authored
Use uv project manager and add CI tests (#12)
* migrate to uv * add ci test * disable fail fast * fix python version * fix * import explicitly * fix python version * fix toArrow * lock * add myself to author list and update readme
1 parent 3605eeb commit 861fec3

File tree

7 files changed

+2167
-19
lines changed

7 files changed

+2167
-19
lines changed

.github/workflows/ci.yml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
strategy:
13+
matrix:
14+
python-version: ['3.9', '3.13']
15+
packages: [['pyspark>=4.0.0'], ['pyspark==3.5.6', 'numpy<2.0.0']]
16+
exclude:
17+
- python-version: '3.13'
18+
packages: ['pyspark==3.5.6', 'numpy<2.0.0']
19+
fail-fast: false
20+
21+
steps:
22+
- uses: actions/checkout@v4
23+
24+
- name: Set up Python ${{ matrix.python-version }}
25+
uses: actions/setup-python@v5
26+
with:
27+
python-version: ${{ matrix.python-version }}
28+
29+
- name: Install uv
30+
run: |
31+
curl -LsSf https://astral.sh/uv/install.sh | sh
32+
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
33+
34+
- name: Install dependencies
35+
run: |
36+
echo "${{ matrix.python-version }}" > .python-version
37+
uv add --dev "${{ join(matrix.packages, '" "') }}"
38+
uv sync
39+
40+
- name: Run tests
41+
run: |
42+
uv run pytest

.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.9

README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Save to Hugging Face:
4343
df.write.format("huggingface").save("username/my_dataset")
4444
# Or pass a token manually
4545
df.write.format("huggingface").option("token", "hf_xxx").save("username/my_dataset")
46-
```
46+
```
4747

4848
## Advanced
4949

@@ -91,3 +91,14 @@ huggingface datasource enabled for pyspark 3.x.x (backport from pyspark 4)
9191

9292
The import is only necessary on Spark 3.x to enable the backport.
9393
Spark 4 automatically imports `pyspark_huggingface` as soon as it is installed, and registers the "huggingface" data source.
94+
95+
96+
## Development
97+
98+
[Install uv](https://docs.astral.sh/uv/getting-started/installation/) if not already done.
99+
100+
Then, from the project root directory, sync dependencies and run tests.
101+
```
102+
uv sync
103+
uv run pytest
104+
```

pyproject.toml

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,33 @@
1-
[tool.poetry]
1+
[project]
22
name = "pyspark_huggingface"
33
version = "1.0.0"
44
description = "A DataSource for reading and writing HuggingFace Datasets in Spark"
5-
authors = ["allisonwang-db <[email protected]>", "lhoestq <[email protected]>"]
6-
license = "Apache License 2.0"
5+
authors = [
6+
{name = "allisonwang-db", email = "[email protected]"},
7+
{name = "lhoestq", email = "[email protected]"},
8+
{name = "wengh", email = "[email protected]"},
9+
]
10+
license = {text = "Apache License 2.0"}
711
readme = "README.md"
8-
packages = [
9-
{ include = "pyspark_huggingface" },
12+
requires-python = ">=3.9"
13+
dependencies = [
14+
"datasets>=3.2",
15+
"huggingface-hub>=0.27.1",
1016
]
1117

12-
[tool.poetry.dependencies]
13-
python = "^3.9"
14-
datasets = "^3.2"
15-
huggingface_hub = "^0.27.1"
16-
17-
[tool.poetry.group.dev.dependencies]
18-
pytest = "^8.0.0"
19-
pytest-dotenv = "^0.5.2"
20-
pytest-mock = "^3.14.0"
18+
[dependency-groups]
19+
dev = [
20+
"ipykernel>=6.29.5",
21+
"pyarrow-stubs>=19.4",
22+
"pyspark>=4.0.0",
23+
"pytest>=8.0.0",
24+
"pytest-dotenv>=0.5.2",
25+
"pytest-mock>=3.14.0",
26+
]
2127

2228
[build-system]
23-
requires = ["poetry-core"]
24-
build-backend = "poetry.core.masonry.api"
29+
requires = ["uv_build>=0.7.3,<0.8"]
30+
build-backend = "uv_build"
31+
32+
[tool.uv.build-backend]
33+
module-root = ""

tests/test_huggingface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import pytest
22
from pyspark.sql import SparkSession
33

4+
import pyspark_huggingface # noqa: F401
5+
6+
47
@pytest.fixture
58
def spark():
69
spark = SparkSession.builder.getOrCreate()

tests/test_huggingface_writer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import os
22
import uuid
33

4+
import pyarrow as pa
45
import pytest
56
from pyspark.sql import DataFrame, SparkSession
7+
from pyspark.sql.pandas.types import to_arrow_schema
68
from pyspark.testing import assertDataFrameEqual
79
from pytest_mock import MockerFixture
810

11+
import pyspark_huggingface # noqa: F401
12+
913
# ============== Fixtures & Helpers ==============
1014

1115

@@ -16,6 +20,8 @@ def spark():
1620

1721

1822
def token():
23+
if "HF_TOKEN" not in os.environ:
24+
pytest.skip("HF_TOKEN environment variable is not set")
1925
return os.environ["HF_TOKEN"]
2026

2127

@@ -110,7 +116,7 @@ def test_revision(repo, random_df, api):
110116
)
111117

112118

113-
def test_max_bytes_per_file(spark, mocker: MockerFixture):
119+
def test_max_bytes_per_file(spark: SparkSession, mocker: MockerFixture):
114120
from pyspark_huggingface.huggingface_sink import HuggingFaceDatasetsWriter
115121

116122
repo = "user/test"
@@ -124,5 +130,9 @@ def test_max_bytes_per_file(spark, mocker: MockerFixture):
124130
token="token",
125131
max_bytes_per_file=1,
126132
)
127-
writer.write(iter(df.toArrow().to_batches(max_chunksize=1)))
133+
# Don't use toArrow() because it's not available in pyspark 3.x
134+
arrow_table = pa.Table.from_pylist(
135+
[row.asDict() for row in df.collect()], schema=to_arrow_schema(df.schema)
136+
)
137+
writer.write(iter(arrow_table.to_batches(max_chunksize=1)))
128138
assert api.preupload_lfs_files.call_count == 10

0 commit comments

Comments
 (0)