|
1 | 1 | import logging
|
2 |
| -import unittest |
| 2 | +import sys |
3 | 3 | from datetime import timedelta
|
| 4 | +import pytest |
4 | 5 |
|
5 | 6 | from dbldatagen import ensure, mkBoundsList, coalesce_values, deprecated, SparkSingleton, \
|
6 | 7 | parse_time_interval, DataGenError
|
7 | 8 |
|
8 | 9 | spark = SparkSingleton.getLocalInstance("unit tests")
|
9 | 10 |
|
10 | 11 |
|
11 |
| -class TestUtils(unittest.TestCase): |
| 12 | +class TestUtils: |
12 | 13 | x = 1
|
13 | 14 |
|
14 |
| - def setUp(self): |
15 |
| - print("setting up") |
16 |
| - FORMAT = '%(asctime)-15s %(message)s' |
17 |
| - logging.basicConfig(format=FORMAT) |
18 |
| - |
19 |
| - @classmethod |
20 |
| - def setUpClass(cls): |
21 |
| - pass |
| 15 | + @pytest.fixture(autouse=True) |
| 16 | + def setupLogger(self): |
| 17 | + self.logger = logging.getLogger("TestUtils") |
22 | 18 |
|
23 | 19 | @deprecated("testing deprecated")
|
24 | 20 | def testDeprecatedMethod(self):
|
25 | 21 | pass
|
26 | 22 |
|
27 |
| - @unittest.expectedFailure |
28 | 23 | def test_ensure(self):
|
29 |
| - ensure(1 == 2, "Expected error") |
| 24 | + with pytest.raises(Exception): |
| 25 | + ensure(1 == 2, "Expected error") |
30 | 26 |
|
31 |
| - def testMkBoundsList1(self): |
| 27 | + def test_mkBoundsList1(self): |
32 | 28 | """ Test utils mkBoundsList"""
|
33 | 29 | test = mkBoundsList(None, 1)
|
34 | 30 |
|
35 |
| - self.assertEqual(len(test), 2) |
| 31 | + assert len(test) == 2 |
36 | 32 |
|
37 | 33 | test2 = mkBoundsList(None, [1, 1])
|
38 | 34 |
|
39 |
| - self.assertEqual(len(test2), 2) |
| 35 | + assert len(test2) == 2 |
40 | 36 |
|
41 |
| - def testCoalesce(self): |
| 37 | + @pytest.mark.parametrize("test_input,expected", |
| 38 | + [ |
| 39 | + ([None, 1], 1), |
| 40 | + ([2, 1], 2), |
| 41 | + ([3, None, 1], 3), |
| 42 | + ([None, None, None], None), |
| 43 | + ]) |
| 44 | + def test_coalesce(self, test_input, expected): |
42 | 45 | """ Test utils coalesce function"""
|
43 |
| - result = coalesce_values(None, 1) |
44 |
| - |
45 |
| - self.assertEqual(result, 1) |
46 |
| - |
47 |
| - result2 = coalesce_values(3, None, 1) |
48 |
| - |
49 |
| - self.assertEqual(result2, 3) |
50 |
| - |
51 |
| - result3 = coalesce_values(None, None, None) |
52 |
| - |
53 |
| - self.assertIsNone(result3) |
54 |
| - |
55 |
| - def testParseTimeInterval1(self): |
56 |
| - interval = parse_time_interval("1 hours") |
57 |
| - self.assertEqual(timedelta(hours=1), interval) |
| 46 | + result = coalesce_values(*test_input) |
| 47 | + assert result == expected |
| 48 | + |
| 49 | + @pytest.mark.parametrize("test_input,expected", |
| 50 | + [ |
| 51 | + ("1 hours, minutes = 2", timedelta(hours=1, minutes=2)), |
| 52 | + ("4 days, 1 hours, 2 minutes", timedelta(days=4, hours=1, minutes=2)), |
| 53 | + ("days=4, hours=1, minutes=2", timedelta(days=4, hours=1, minutes=2)), |
| 54 | + ("1 hours, 2 seconds", timedelta(hours=1, seconds=2)), |
| 55 | + ("1 hours, 2 minutes", timedelta(hours=1, minutes=2)), |
| 56 | + ("1 hours", timedelta(hours=1)), |
| 57 | + ("1 hour", timedelta(hours=1)), |
| 58 | + ("1 hour, 1 second", timedelta(hours=1, seconds=1)), |
| 59 | + ("1 hour, 10 milliseconds", timedelta(hours=1, milliseconds=10)), |
| 60 | + ("1 hour, 10 microseconds", timedelta(hours=1, microseconds=10)), |
| 61 | + ("1 year, 4 weeks", timedelta(weeks=56)) |
| 62 | + ]) |
| 63 | + def testParseTimeInterval2b(self, test_input, expected): |
| 64 | + interval = parse_time_interval(test_input) |
| 65 | + assert expected == interval |
58 | 66 |
|
59 |
| - def testParseTimeInterval2(self): |
60 |
| - interval = parse_time_interval("1 hours, 2 seconds") |
61 |
| - self.assertEqual(timedelta(hours=1, seconds=2), interval) |
62 |
| - |
63 |
| - def testParseTimeInterval3(self): |
64 |
| - interval = parse_time_interval("1 hours, 2 minutes") |
65 |
| - self.assertEqual(timedelta(hours=1, minutes=2), interval) |
66 |
| - |
67 |
| - def testParseTimeInterval4(self): |
68 |
| - interval = parse_time_interval("4 days, 1 hours, 2 minutes") |
69 |
| - self.assertEqual(timedelta(days=4, hours=1, minutes=2), interval) |
70 |
| - |
71 |
| - def testParseTimeInterval1a(self): |
72 |
| - interval = parse_time_interval("hours=1") |
73 |
| - self.assertEqual(timedelta(hours=1), interval) |
74 |
| - |
75 |
| - def testParseTimeInterval2a(self): |
76 |
| - interval = parse_time_interval("hours=1, seconds = 2") |
77 |
| - self.assertEqual(timedelta(hours=1, seconds=2), interval) |
| 67 | + def testDatagenExceptionObject(self): |
| 68 | + testException = DataGenError("testing") |
78 | 69 |
|
79 |
| - def testParseTimeInterval3a(self): |
80 |
| - interval = parse_time_interval("1 hours, minutes = 2") |
81 |
| - self.assertEqual(timedelta(hours=1, minutes=2), interval) |
| 70 | + assert testException is not None |
82 | 71 |
|
83 |
| - def testParseTimeInterval4a(self): |
84 |
| - interval = parse_time_interval("days=4, hours=1, minutes=2") |
85 |
| - self.assertEqual(timedelta(days=4, hours=1, minutes=2), interval) |
| 72 | + assert type(repr(testException)) is str |
| 73 | + self.logger.info(repr(testException)) |
86 | 74 |
|
87 |
| - def testDatagenExceptionObject(self): |
88 |
| - testException = DataGenError("testing") |
| 75 | + assert type(str(testException)) is str |
| 76 | + self.logger.info(str(testException)) |
89 | 77 |
|
90 |
| - self.assertIsNotNone(testException) |
91 | 78 |
|
92 |
| - print("error has repr", repr(testException)) |
93 |
| - print("error has str", str(testException)) |
94 | 79 |
|
95 | 80 | # run the tests
|
96 | 81 | # if __name__ == '__main__':
|
|
0 commit comments