diff --git a/pandera/backends/pyspark/builtin_checks.py b/pandera/backends/pyspark/builtin_checks.py index 4d30eb6f3..514f475c1 100644 --- a/pandera/backends/pyspark/builtin_checks.py +++ b/pandera/backends/pyspark/builtin_checks.py @@ -1,9 +1,10 @@ """PySpark implementation of built-in checks""" -from typing import Any, Iterable, TypeVar +import re +from typing import Any, Iterable, Optional, TypeVar import pyspark.sql.types as pst -from pyspark.sql.functions import col +from pyspark.sql.functions import col, length, lit import pandera.strategies as st from pandera.api.extensions import register_builtin_check @@ -328,3 +329,28 @@ def str_endswith(data: PysparkDataframeColumnObject, string: str) -> bool: """ cond = col(data.column_name).endswith(string) return data.dataframe.filter(~cond).limit(1).count() == 0 + + +@register_builtin_check( + error="str_length({min_value}, {max_value})", +) +@register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE)) +def str_length( + data: PysparkDataframeColumnObject, + min_value: Optional[int] = None, + max_value: Optional[int] = None, +) -> bool: + """Ensure that the length of strings in a column is within a specified range.""" + if min_value is None and max_value is None: + raise ValueError( + "Must provide at least one of 'min_value' and 'max_value'" + ) + + str_len = length(col(data.column_name)) + cond = lit(True) + if min_value is not None: + cond = cond & (str_len >= min_value) + if max_value is not None: + cond = cond & (str_len <= max_value) + + return data.dataframe.filter(~cond).limit(1).count() == 0