Skip to content

Commit b6c9f83

Browse files
committed
fix toArrow
1 parent 6c13790 commit b6c9f83

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies = [
1717
[dependency-groups]
1818
dev = [
1919
"ipykernel>=6.29.5",
20+
"pyarrow-stubs>=19.4",
2021
"pyspark>=4.0.0",
2122
"pytest>=8.0.0",
2223
"pytest-dotenv>=0.5.2",

tests/test_huggingface_writer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import os
22
import uuid
33

4+
import pyarrow as pa
45
import pytest
56
from pyspark.sql import DataFrame, SparkSession
7+
from pyspark.sql.pandas.types import to_arrow_schema
68
from pyspark.testing import assertDataFrameEqual
79
from pytest_mock import MockerFixture
810

@@ -114,7 +116,7 @@ def test_revision(repo, random_df, api):
114116
)
115117

116118

117-
def test_max_bytes_per_file(spark, mocker: MockerFixture):
119+
def test_max_bytes_per_file(spark: SparkSession, mocker: MockerFixture):
118120
from pyspark_huggingface.huggingface_sink import HuggingFaceDatasetsWriter
119121

120122
repo = "user/test"
@@ -128,5 +130,9 @@ def test_max_bytes_per_file(spark, mocker: MockerFixture):
128130
token="token",
129131
max_bytes_per_file=1,
130132
)
131-
writer.write(iter(df.toArrow().to_batches(max_chunksize=1)))
133+
# Don't use toArrow() because it's not available in pyspark 3.x
134+
arrow_table = pa.Table.from_pylist(
135+
[row.asDict() for row in df.collect()], schema=to_arrow_schema(df.schema)
136+
)
137+
writer.write(iter(arrow_table.to_batches(max_chunksize=1)))
132138
assert api.preupload_lfs_files.call_count == 10

0 commit comments

Comments
 (0)