Skip to content

Commit 3cd3e9a

Browse files
wip
1 parent f99882d commit 3cd3e9a

File tree

7 files changed

+92
-8
lines changed

7 files changed

+92
-8
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ details of use and many examples.
6565

6666
Release notes and details of the latest changes for this specific release
6767
can be found in the GitHub repository
68-
[here](https://github.com/databrickslabs/dbldatagen/blob/release/v0.4.0/CHANGELOG.md)
68+
[here](https://github.com/databrickslabs/dbldatagen/blob/release/v0.4.0post1/CHANGELOG.md)
6969

7070
# Installation
7171

dbldatagen/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def get_version(version):
3434
return version_info
3535

3636

37-
__version__ = "0.4.0" # DO NOT EDIT THIS DIRECTLY! It is managed by bumpversion
37+
__version__ = "0.4.0post1" # DO NOT EDIT THIS DIRECTLY! It is managed by bumpversion
3838
__version_info__ = get_version(__version__)
3939

4040

dbldatagen/data_generator.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,12 +226,21 @@ def _setupPandas(self, pandasBatchSize):
226226
self.logger.info("Spark version: %s", self.sparkSession.version)
227227
if str(self.sparkSession.version).startswith("3"):
228228
self.logger.info("Using spark 3.x")
229-
self.sparkSession.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
229+
try:
230+
self.sparkSession.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
231+
except Exception: # pylint: disable=broad-exception-caught
232+
pass
230233
else:
231-
self.sparkSession.conf.set("spark.sql.execution.arrow.enabled", "true")
234+
try:
235+
self.sparkSession.conf.set("spark.sql.execution.arrow.enabled", "true")
236+
except Exception: # pylint: disable=broad-exception-caught
237+
pass
232238

233239
if self._batchSize is not None:
234-
self.sparkSession.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", self._batchSize)
240+
try:
241+
self.sparkSession.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", self._batchSize)
242+
except Exception: # pylint: disable=broad-exception-caught
243+
pass
235244

236245
def _setupLogger(self):
237246
"""Set up logging

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
author = 'Databricks Inc'
3333

3434
# The full version, including alpha/beta/rc tags
35-
release = "0.4.0" # DO NOT EDIT THIS DIRECTLY! It is managed by bumpversion
35+
release = "0.4.0post1" # DO NOT EDIT THIS DIRECTLY! It is managed by bumpversion
3636

3737
# -- General configuration ---------------------------------------------------
3838

python/.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.4.0
2+
current_version = 0.4.0post1
33
commit = False
44
tag = False
55
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+){0,1}(?P<release>\D*)(?P<build>\d*)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
setuptools.setup(
3333
name="dbldatagen",
34-
version="0.4.0",
34+
version="0.4.0post1",
3535
author="Ronan Stokes, Databricks",
3636
description="Databricks Labs - PySpark Synthetic Data Generator",
3737
long_description=long_description,

tests/test_serverless.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import pytest
2+
3+
import dbldatagen as dg
4+
5+
6+
class TestSimulatedServerless:
7+
"""Serverless operation and other forms of shared spark cloud operation often have restrictions on what
8+
features may be used.
9+
10+
In this set of tests, we'll simulate some of the common restrictions found in Databricks serverless and shared
11+
environments to ensure that common operations still work.
12+
13+
Serverless operations have some of the following restrictions:
14+
15+
- Spark config settings cannot be written
16+
17+
"""
18+
19+
@pytest.fixture(scope="class")
20+
def serverlessSpark(self):
21+
from unittest.mock import MagicMock
22+
23+
sparkSession = dg.SparkSingleton.getLocalInstance("unit tests")
24+
25+
oldSetMethod = sparkSession.conf.set
26+
oldGetMethod = sparkSession.conf.get
27+
sparkSession.conf.set = MagicMock(
28+
side_effect=ValueError("Setting value prohibited in simulated serverless env."))
29+
sparkSession.conf.get = MagicMock(
30+
side_effect=ValueError("Getting value prohibited in simulated serverless env."))
31+
32+
yield sparkSession
33+
34+
sparkSession.conf.set = oldSetMethod
35+
sparkSession.conf.get = oldGetMethod
36+
37+
def test_basic_data(self, serverlessSpark):
38+
from pyspark.sql.types import FloatType, IntegerType, StringType
39+
40+
row_count = 1000 * 100
41+
column_count = 10
42+
testDataSpec = (
43+
dg.DataGenerator(serverlessSpark, name="test_data_set1", rows=row_count, partitions=4)
44+
.withIdOutput()
45+
.withColumn(
46+
"r",
47+
FloatType(),
48+
expr="floor(rand() * 350) * (86400 + 3600)",
49+
numColumns=column_count,
50+
)
51+
.withColumn("code1", IntegerType(), minValue=100, maxValue=200)
52+
.withColumn("code2", "integer", minValue=0, maxValue=10, random=True)
53+
.withColumn("code3", StringType(), values=["online", "offline", "unknown"])
54+
.withColumn(
55+
"code4", StringType(), values=["a", "b", "c"], random=True, percentNulls=0.05
56+
)
57+
.withColumn(
58+
"code5", "string", values=["a", "b", "c"], random=True, weights=[9, 1, 1]
59+
)
60+
)
61+
62+
dfTestData = testDataSpec.build()
63+
64+
@pytest.mark.parametrize("providerName, providerOptions", [
65+
("basic/user", {"rows": 50, "partitions": 4, "random": False, "dummyValues": 0}),
66+
("basic/user", {"rows": 100, "partitions": -1, "random": True, "dummyValues": 0})
67+
])
68+
def test_basic_user_table_retrieval(self, providerName, providerOptions, serverlessSpark):
69+
ds = dg.Datasets(serverlessSpark, providerName).get(**providerOptions)
70+
assert ds is not None, f"""expected to get dataset specification for provider `{providerName}`
71+
with options: {providerOptions}
72+
"""
73+
df = ds.build()
74+
75+
assert df.count() >= 0

0 commit comments

Comments
 (0)