Skip to content

Commit c354c1b

Browse files
authored
Merge pull request #91 from w4terlaw/fix/90-regression-queue-function
fix: #90
2 parents 26cb9a1 + 39ffc9b commit c354c1b

File tree

5 files changed

+102
-72
lines changed

5 files changed

+102
-72
lines changed

taskiq_faststream/formatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def dumps( # type: ignore[override]
2626
:param message: message to send.
2727
:return: Dumped message.
2828
"""
29-
labels = message.labels
29+
labels = message.labels.copy()
3030
labels.pop("schedule", None)
3131
labels.pop("schedule_id", None)
3232

taskiq_faststream/kicker.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1+
from typing import Any
2+
13
from taskiq.kicker import AsyncKicker, _FuncParams, _ReturnType
4+
from taskiq.message import TaskiqMessage
25

36

47
class LabelRespectKicker(AsyncKicker[_FuncParams, _ReturnType]):
58
"""Patched kicker doesn't cast labels to str."""
9+
10+
def _prepare_message(self, *args: Any, **kwargs: Any) -> TaskiqMessage:
11+
msg = super()._prepare_message(*args, **kwargs)
12+
msg.labels = self.labels
13+
return msg

tests/messages.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from collections.abc import AsyncIterator, Iterator
2+
3+
message = "Hi!"
4+
5+
6+
def sync_callable_msg() -> str:
7+
return message
8+
9+
10+
async def async_callable_msg() -> str:
11+
return message
12+
13+
14+
async def async_generator_msg() -> AsyncIterator[str]:
15+
yield message
16+
17+
18+
def sync_generator_msg() -> Iterator[str]:
19+
yield message
20+
21+
22+
class _C:
23+
def __call__(self) -> str:
24+
return message
25+
26+
27+
class _AC:
28+
async def __call__(self) -> str:
29+
return message
30+
31+
32+
sync_callable_class_message = _C()
33+
async_callable_class_message = _AC()

tests/test_resolve_message.py

Lines changed: 29 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,34 @@
1-
from collections.abc import AsyncIterator, Iterator
1+
import typing
22

33
import pytest
4+
from faststream.types import SendableMessage
45

56
from taskiq_faststream.utils import resolve_msg
6-
7-
8-
@pytest.mark.anyio
9-
async def test_regular() -> None:
10-
async for m in resolve_msg("msg"):
11-
assert m == "msg"
12-
13-
14-
@pytest.mark.anyio
15-
async def test_sync_callable() -> None:
16-
async for m in resolve_msg(lambda: "msg"):
17-
assert m == "msg"
18-
19-
7+
from tests import messages
8+
9+
10+
@pytest.mark.parametrize(
11+
"msg",
12+
[
13+
messages.message, # regular msg
14+
messages.sync_callable_msg, # sync callable
15+
messages.async_callable_msg, # async callable
16+
messages.sync_generator_msg, # sync generator
17+
messages.async_generator_msg, # async generator
18+
messages.sync_callable_class_message, # sync callable class
19+
messages.async_callable_class_message, # async callable class
20+
],
21+
)
2022
@pytest.mark.anyio
21-
async def test_async_callable() -> None:
22-
async def gen_msg() -> str:
23-
return "msg"
24-
25-
async for m in resolve_msg(gen_msg):
26-
assert m == "msg"
27-
28-
29-
@pytest.mark.anyio
30-
async def test_sync_callable_class() -> None:
31-
class C:
32-
def __init__(self) -> None:
33-
pass
34-
35-
def __call__(self) -> str:
36-
return "msg"
37-
38-
async for m in resolve_msg(C()):
39-
assert m == "msg"
40-
41-
42-
@pytest.mark.anyio
43-
async def test_async_callable_class() -> None:
44-
class C:
45-
def __init__(self) -> None:
46-
pass
47-
48-
async def __call__(self) -> str:
49-
return "msg"
50-
51-
async for m in resolve_msg(C()):
52-
assert m == "msg"
53-
54-
55-
@pytest.mark.anyio
56-
async def test_async_generator() -> None:
57-
async def get_msg() -> AsyncIterator[str]:
58-
yield "msg"
59-
60-
async for m in resolve_msg(get_msg):
61-
assert m == "msg"
62-
63-
64-
@pytest.mark.anyio
65-
async def test_sync_generator() -> None:
66-
def get_msg() -> Iterator[str]:
67-
yield "msg"
68-
69-
async for m in resolve_msg(get_msg):
70-
assert m == "msg"
23+
async def test_resolve_msg(
24+
msg: typing.Union[
25+
None,
26+
SendableMessage,
27+
typing.Callable[[], SendableMessage],
28+
typing.Callable[[], typing.Awaitable[SendableMessage]],
29+
typing.Callable[[], typing.Generator[SendableMessage, None, None]],
30+
typing.Callable[[], typing.AsyncGenerator[SendableMessage, None]],
31+
],
32+
) -> None:
33+
async for m in resolve_msg(msg):
34+
assert m == messages.message

tests/testcase.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import asyncio
2+
import typing
23
from datetime import datetime, timedelta, timezone
34
from typing import Any
45
from unittest.mock import MagicMock
56

67
import pytest
8+
from faststream.types import SendableMessage
79
from faststream.utils.functions import timeout_scope
810
from freezegun import freeze_time
9-
from taskiq import AsyncBroker, TaskiqScheduler
11+
from taskiq import AsyncBroker
1012
from taskiq.cli.scheduler.args import SchedulerArgs
1113
from taskiq.cli.scheduler.run import run_scheduler
1214
from taskiq.schedule_sources import LabelScheduleSource
1315

1416
from taskiq_faststream import BrokerWrapper, StreamScheduler
17+
from tests import messages
1518

1619

1720
@pytest.mark.anyio
@@ -54,7 +57,7 @@ async def handler(msg: str) -> None:
5457
task = asyncio.create_task(
5558
run_scheduler(
5659
SchedulerArgs(
57-
scheduler=TaskiqScheduler(
60+
scheduler=StreamScheduler(
5861
broker=taskiq_broker,
5962
sources=[LabelScheduleSource(taskiq_broker)],
6063
),
@@ -69,24 +72,44 @@ async def handler(msg: str) -> None:
6972
mock.assert_called_once_with("Hi!")
7073
task.cancel()
7174

75+
@pytest.mark.parametrize(
76+
"msg",
77+
[
78+
messages.message, # regular msg
79+
messages.sync_callable_msg, # sync callable
80+
messages.async_callable_msg, # async callable
81+
messages.sync_generator_msg, # sync generator
82+
messages.async_generator_msg, # async generator
83+
messages.sync_callable_class_message, # sync callable class
84+
messages.async_callable_class_message, # async callable class
85+
],
86+
)
7287
async def test_task_multiple_schedules_by_cron(
7388
self,
7489
subject: str,
7590
broker: Any,
7691
event: asyncio.Event,
92+
msg: typing.Union[
93+
None,
94+
SendableMessage,
95+
typing.Callable[[], SendableMessage],
96+
typing.Callable[[], typing.Awaitable[SendableMessage]],
97+
typing.Callable[[], typing.Generator[SendableMessage, None, None]],
98+
typing.Callable[[], typing.AsyncGenerator[SendableMessage, None]],
99+
],
77100
) -> None:
78101
"""Test cron runs twice via StreamScheduler."""
79102
received_message = []
80103

81104
@broker.subscriber(subject)
82-
async def handler(msg: str) -> None:
83-
received_message.append(msg)
105+
async def handler(message: str) -> None:
106+
received_message.append(message)
84107
event.set()
85108

86109
taskiq_broker = self.build_taskiq_broker(broker)
87110

88111
taskiq_broker.task(
89-
"Hi!",
112+
msg,
90113
**{self.subj_name: subject},
91114
schedule=[
92115
{
@@ -116,4 +139,6 @@ async def handler(msg: str) -> None:
116139

117140
task.cancel()
118141

119-
assert received_message == ["Hi!", "Hi!"], received_message
142+
assert received_message == [messages.message, messages.message], (
143+
received_message
144+
)

0 commit comments

Comments
 (0)