Skip to content

Commit 0640fc1

Browse files
committed
add additional tests for data stream
1 parent b52e7ce commit 0640fc1

File tree

1 file changed

+103
-1
lines changed

1 file changed

+103
-1
lines changed

tests/test_data_stream.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def simple_data():
2525
def long_data():
2626
return pd.DataFrame(
2727
{
28+
"time": [0, 1, 2, 3, 4],
2829
"A": [1, 2, 3, 4, 5],
2930
"B": [5, 4, 3, 2, 1],
3031
}
@@ -33,7 +34,9 @@ def long_data():
3334

3435
@pytest.fixture
3536
def stationary_data():
36-
return pd.DataFrame({"A": [1, 1, 1, 1, 1], "B": [2, 2, 2, 2, 2]})
37+
return pd.DataFrame(
38+
{"time": [0, 1, 2, 3, 4], "A": [1, 1, 1, 1, 1], "B": [2, 2, 2, 2, 2]}
39+
)
3740

3841

3942
@pytest.fixture
@@ -48,6 +51,11 @@ def nan_data():
4851
return pd.DataFrame({"A": [None, None, None]})
4952

5053

54+
@pytest.fixture
55+
def no_valid_data():
56+
return pd.DataFrame({"time": [0, 1], "A": [1, 2]})
57+
58+
5159
# Test DataStream initialization
5260
# =============================================================================
5361

@@ -196,6 +204,14 @@ def test_trim_invalid_method(trim_data):
196204
ds.trim(column_name="A", method="invalid_method")
197205

198206

207+
def test_trim_missing_threshold(long_data):
208+
ds = DataStream(long_data)
209+
with pytest.raises(
210+
ValueError, match="Threshold must be specified for the 'threshold' method."
211+
):
212+
ds.trim(column_name="A", method="threshold")
213+
214+
199215
# Test Compute Statistics
200216
# =============================================================================
201217

@@ -453,3 +469,89 @@ def test_is_not_stationary(long_data):
453469
expected = {"A": False}
454470
print(not_stationary)
455471
assert not_stationary == expected
472+
473+
474+
# Test Head
475+
# =============================================================================
476+
477+
478+
def test_head(long_data):
479+
ds = DataStream(long_data)
480+
result = ds.head(5)
481+
expected = pd.DataFrame(
482+
{
483+
"time": [0, 1, 2, 3, 4],
484+
"A": [1, 2, 3, 4, 5],
485+
"B": [5, 4, 3, 2, 1],
486+
}
487+
)
488+
pd.testing.assert_frame_equal(result, expected)
489+
490+
491+
# Test Process Column
492+
# =============================================================================
493+
494+
495+
def test_process_column_missing_method(simple_data):
496+
ds = DataStream(simple_data)
497+
with pytest.raises(
498+
ValueError, match="Invalid method. Choose 'sliding' or 'non-overlapping'."
499+
):
500+
ds._process_column(column_data="A", estimated_window=1, method="invalid_method")
501+
502+
503+
# Test Find Steady State Std
504+
# =============================================================================
505+
506+
507+
def test_find_steady_state_std(trim_data):
508+
ds = DataStream(trim_data)
509+
result = ds.find_steady_state_std(data=ds.df, column_name="A", window_size=1)
510+
expected = 0
511+
assert result == expected
512+
513+
514+
def test_find_steady_state_std_non_robust(trim_data):
515+
ds = DataStream(trim_data)
516+
result = ds.find_steady_state_std(
517+
data=ds.df, column_name="A", window_size=2, robust=False
518+
)
519+
expected = 3
520+
assert result == expected
521+
522+
523+
def test_find_steady_state_not_valid(no_valid_data):
524+
ds = DataStream(no_valid_data)
525+
result = ds.find_steady_state_std(
526+
data=ds.df, column_name=["time", "A"], window_size=1
527+
)
528+
assert result is None
529+
530+
531+
# Test Find Steady State Rolling Variance
532+
# =============================================================================
533+
534+
535+
def test_find_steady_state_rolling_variance_stationary(stationary_data):
536+
ds = DataStream(stationary_data)
537+
result = ds.find_steady_state_rolling_variance(
538+
data=ds.df, column_name="A", window_size=3
539+
)
540+
print(result)
541+
assert result is None
542+
543+
544+
def test_find_steady_state_none_rolling_variance(long_data):
545+
ds = DataStream(long_data)
546+
result = ds.find_steady_state_rolling_variance(
547+
data=long_data, column_name="A", window_size=3, threshold=0.1
548+
)
549+
assert result is None
550+
551+
552+
def test_find_steady_state_rolling_variance_not_valid(no_valid_data):
553+
ds = DataStream(no_valid_data)
554+
result = ds.find_steady_state_rolling_variance(
555+
data=ds.df, column_name="A", window_size=1
556+
)
557+
assert result is None

0 commit comments

Comments
 (0)