Skip to content

Commit 109707e

Browse files
updates for test speed improvements (#125)
* updates for test speed improvements * updated tests * updated tests * updated tests * updated tests * reverted pytest changes - separate feature * reverted pytest changes - separate feature * reverted pytest changes - separate feature * reverted pytest changes - separate feature * changed partitioning to run more efficiently on github runner * changed partitioning to run more efficiently on github runner * changed partitioning to run more efficiently on github runner * changed partitioning to run more efficiently on github runner * changed partitioning to run more efficiently on github runner * use as query name for spark instance * changes in response to PR review feedback * additional coverage * reverted some review related changes as too much impact on tests * test cleanup * changes in response to PR review feedback post review changes post review changes * corrected rebase issue
1 parent d6b1799 commit 109707e

8 files changed

+228
-259
lines changed

dbldatagen/data_generator.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010
import re
1111

1212
from pyspark.sql.types import LongType, IntegerType, StringType, StructType, StructField, DataType
13-
13+
from .spark_singleton import SparkSingleton
1414
from .column_generation_spec import ColumnGenerationSpec
1515
from .datagen_constants import DEFAULT_RANDOM_SEED, RANDOM_SEED_FIXED, RANDOM_SEED_HASH_FIELD_NAME
16-
from .spark_singleton import SparkSingleton
1716
from .utils import ensure, topologicalSort, DataGenError, deprecated
1817

1918
_OLD_MIN_OPTION = 'min'
@@ -31,7 +30,7 @@ class DataGenerator:
3130
:param rows: = amount of rows to generate
3231
:param startingId: = starting value for generated seed column
3332
:param randomSeed: = seed for random number generator
34-
:param partitions: = number of partitions to generate
33+
:param partitions: = number of partitions to generate, if not provided, uses `spark.sparkContext.defaultParallelism`
3534
:param verbose: = if `True`, generate verbose output
3635
:param batchSize: = UDF batch number of rows to pass via Apache Arrow to Pandas UDFs
3736
:param debug: = if set to True, output debug level of information
@@ -65,7 +64,18 @@ def __init__(self, sparkSession=None, name=None, randomSeedMethod=None,
6564
self._rowCount = rows
6665
self.starting_id = startingId
6766
self.__schema__ = None
68-
self.partitions = partitions if partitions is not None else 10
67+
68+
if sparkSession is None:
69+
sparkSession = SparkSingleton.getLocalInstance()
70+
71+
self.sparkSession = sparkSession
72+
73+
# if the active Spark session is stopped, you may end up with a valid SparkSession object but the underlying
74+
# SparkContext will be invalid
75+
assert sparkSession is not None, "Spark session not initialized"
76+
assert sparkSession.sparkContext is not None, "Expecting spark session to have valid sparkContext"
77+
78+
self.partitions = partitions if partitions is not None else sparkSession.sparkContext.defaultParallelism
6979

7080
# check for old versions of args
7181
if "starting_id" in kwargs:
@@ -121,20 +131,6 @@ def __init__(self, sparkSession=None, name=None, randomSeedMethod=None,
121131
self.withColumn(ColumnGenerationSpec.SEED_COLUMN, LongType(), nullable=False, implicit=True, omit=True)
122132
self._batchSize = batchSize
123133

124-
if sparkSession is None:
125-
sparkSession = SparkSingleton.getInstance()
126-
127-
assert sparkSession is not None, "The spark session attribute must be initialized"
128-
129-
self.sparkSession = sparkSession
130-
if sparkSession is None:
131-
raise DataGenError("""Spark session not initialized
132-
133-
The spark session attribute must be initialized in the DataGenerator initialization
134-
135-
i.e DataGenerator(sparkSession=spark, name="test", ...)
136-
""")
137-
138134
# set up use of pandas udfs
139135
self._setupPandas(batchSize)
140136

@@ -257,7 +253,7 @@ def explain(self, suppressOutput=False):
257253

258254
output = ["", "Data generation plan", "====================",
259255
f"spec=DateGenerator(name={self.name}, rows={self._rowCount}, startingId={self.starting_id}, partitions={self.partitions})"
260-
, ")", "", f"column build order: {self._buildOrder}", "", "build plan:"]
256+
, ")", "", f"column build order: {self._buildOrder}", "", "build plan:"]
261257

262258
for plan_action in self._buildPlan:
263259
output.append(" ==> " + plan_action)
@@ -780,7 +776,8 @@ def _getBaseDataFrame(self, startId=0, streaming=False, options=None):
780776
df1 = df1.withColumnRenamed("id", ColumnGenerationSpec.SEED_COLUMN)
781777

782778
else:
783-
status = (f"Generating streaming data frame with ids from {startId} to {end_id} with {id_partitions} partitions")
779+
status = (
780+
f"Generating streaming data frame with ids from {startId} to {end_id} with {id_partitions} partitions")
784781
self.logger.info(status)
785782
self.executionHistory.append(status)
786783

dbldatagen/spark_singleton.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
"""
1111

1212
import os
13-
import math
1413
import logging
1514
from pyspark.sql import SparkSession
1615

@@ -28,17 +27,27 @@ def getInstance(cls):
2827
return SparkSession.builder.getOrCreate()
2928

3029
@classmethod
31-
def getLocalInstance(cls, appName="new Spark session"):
30+
def getLocalInstance(cls, appName="new Spark session", useAllCores=True):
3231
"""Create a machine local Spark instance for Datalib.
33-
It uses 3/4 of the available cores for the spark session.
32+
By default, it uses `n-1` cores of the available cores for the spark session,
33+
where `n` is total cores available.
3434
35+
:param useAllCores: If `useAllCores` is True, then use all cores rather than `n-1` cores
3536
:returns: A Spark instance
3637
"""
37-
cpu_count = int(math.floor(os.cpu_count() * 0.75))
38-
logging.info("cpu count: %d", cpu_count)
38+
cpu_count = os.cpu_count()
3939

40-
return SparkSession.builder \
41-
.master(f"local[{cpu_count}]") \
40+
if useAllCores:
41+
spark_core_count = cpu_count
42+
else:
43+
spark_core_count = cpu_count - 1
44+
45+
logging.info("Spark core count: %d", spark_core_count)
46+
47+
sparkSession = SparkSession.builder \
48+
.master(f"local[{spark_core_count}]") \
4249
.appName(appName) \
4350
.config("spark.sql.warehouse.dir", "/tmp/spark-warehouse") \
4451
.getOrCreate()
52+
53+
return sparkSession

0 commit comments

Comments
 (0)