Skip to content

Commit c080c19

Browse files
🚨 Fix mypy errors
1 parent 71e2b6d commit c080c19

File tree

11 files changed

+265
-244
lines changed

11 files changed

+265
-244
lines changed

fastapi_gcp_tasks/decorators.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
def task_default_options(**kwargs):
1+
from typing import Any, Callable, TypeVar
2+
3+
F = TypeVar("F", bound=Callable[..., Any])
4+
5+
6+
def task_default_options(**kwargs: Any) -> Callable[[F], F]:
27
"""Wrapper to set default options for a cloud task."""
38

4-
def wrapper(fn):
5-
fn._delay_options = kwargs
9+
def wrapper(fn: F) -> F:
10+
fn._delay_options = kwargs # type: ignore[attr-defined]
611
return fn
712

813
return wrapper

fastapi_gcp_tasks/delayed_route.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Standard Library Imports
2-
from typing import Callable
2+
from typing import Callable, Type
33

44
# Third Party Imports
55
from fastapi.routing import APIRoute
@@ -16,10 +16,10 @@ def DelayedRouteBuilder( # noqa: N802
1616
base_url: str,
1717
queue_path: str,
1818
task_create_timeout: float = 10.0,
19-
pre_create_hook: DelayedTaskHook = None,
20-
client=None,
21-
auto_create_queue=True,
22-
):
19+
pre_create_hook: DelayedTaskHook | None = None,
20+
client: tasks_v2.CloudTasksClient | None = None,
21+
auto_create_queue: bool = True,
22+
) -> Type[APIRoute]:
2323
"""
2424
Returns a Mixin that should be used to override route_class.
2525
@@ -57,11 +57,11 @@ def on_user_create(user_id: str, data: UserData):
5757
class TaskRouteMixin(APIRoute):
5858
def get_route_handler(self) -> Callable:
5959
original_route_handler = super().get_route_handler()
60-
self.endpoint.options = self.delay_options
61-
self.endpoint.delay = self.delay
60+
self.endpoint.options = self.delay_options # type: ignore[attr-defined]
61+
self.endpoint.delay = self.delay # type: ignore[attr-defined]
6262
return original_route_handler
6363

64-
def delay_options(self, **options) -> Delayer:
64+
def delay_options(self, **options: dict) -> Delayer:
6565
delay_opts = {
6666
"base_url": base_url,
6767
"queue_path": queue_path,
@@ -73,12 +73,13 @@ def delay_options(self, **options) -> Delayer:
7373
delay_opts |= self.endpoint._delay_options
7474
delay_opts |= options
7575

76+
# ignoring the type here because the dictionary values are unpacked
7677
return Delayer(
7778
route=self,
78-
**delay_opts,
79+
**delay_opts, # type: ignore[arg-type]
7980
)
8081

81-
def delay(self, **kwargs):
82+
def delay(self, **kwargs: dict) -> tasks_v2.Task:
8283
return self.delay_options().delay(**kwargs)
8384

8485
return TaskRouteMixin

fastapi_gcp_tasks/delayer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Standard Library Imports
22
import datetime
3+
from typing import Any, Iterable
34

45
# Third Party Imports
56
from fastapi.routing import APIRoute
@@ -38,7 +39,7 @@ def __init__(
3839
pre_create_hook: DelayedTaskHook,
3940
task_create_timeout: float = 10.0,
4041
countdown: int = 0,
41-
task_id: str = None,
42+
task_id: str | None = None,
4243
) -> None:
4344
super().__init__(route=route, base_url=base_url)
4445
self.queue_path = queue_path
@@ -50,7 +51,7 @@ def __init__(
5051
self.client = client
5152
self.pre_create_hook = pre_create_hook
5253

53-
def delay(self, **kwargs):
54+
def delay(self, **kwargs: Any) -> tasks_v2.Task:
5455
"""Delay a task on Cloud Tasks."""
5556
# Create http request
5657
request = tasks_v2.HttpRequest()
@@ -76,7 +77,7 @@ def delay(self, **kwargs):
7677

7778
return self.client.create_task(request=request, timeout=self.task_create_timeout)
7879

79-
def _schedule(self):
80+
def _schedule(self) -> timestamp_pb2.Timestamp | None:
8081
if self.countdown is None or self.countdown <= 0:
8182
return None
8283
d = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=self.countdown)
@@ -85,7 +86,7 @@ def _schedule(self):
8586
return timestamp
8687

8788

88-
def _task_method(methods):
89+
def _task_method(methods: Iterable[str]) -> tasks_v2.HttpMethod:
8990
method_map = {
9091
"POST": tasks_v2.HttpMethod.POST,
9192
"GET": tasks_v2.HttpMethod.GET,
@@ -102,4 +103,4 @@ def _task_method(methods):
102103
method = method_map.get(methods[0])
103104
if method is None:
104105
raise BadMethodError(f"Unknown method {methods[0]}")
105-
return method
106+
return tasks_v2.HttpMethod(method)

fastapi_gcp_tasks/dependencies.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
# Standard Library Imports
2-
import typing
32
from datetime import datetime
3+
from typing import Any, Callable
44

55
# Third Party Imports
66
from fastapi import Depends, Header, HTTPException
77

88

9-
def max_retries(count: int = 20):
10-
"""Raises a http exception (with status 200) after max retries are exhausted."""
9+
def max_retries(count: int = 20) -> Callable[[Any], bool]:
10+
"""Raises an http exception (with status 200) after max retries are exhausted."""
1111

1212
def retries_dep(meta: CloudTasksHeaders = Depends()) -> bool:
1313
# count starts from 0 so equality check is required
1414
if meta.retry_count >= count:
1515
raise HTTPException(status_code=200, detail="Max retries exhausted")
16+
return True
1617

1718
return retries_dep
1819

@@ -26,13 +27,13 @@ class CloudTasksHeaders:
2627

2728
def __init__(
2829
self,
29-
x_cloudtasks_taskretrycount: typing.Optional[int] = Header(0),
30-
x_cloudtasks_taskexecutioncount: typing.Optional[int] = Header(0),
31-
x_cloudtasks_queuename: typing.Optional[str] = Header(""),
32-
x_cloudtasks_taskname: typing.Optional[str] = Header(""),
33-
x_cloudtasks_tasketa: typing.Optional[float] = Header(0),
34-
x_cloudtasks_taskpreviousresponse: typing.Optional[int] = Header(0),
35-
x_cloudtasks_taskretryreason: typing.Optional[str] = Header(""),
30+
x_cloudtasks_taskretrycount: int = Header(0),
31+
x_cloudtasks_taskexecutioncount: int = Header(0),
32+
x_cloudtasks_queuename: str = Header(""),
33+
x_cloudtasks_taskname: str = Header(""),
34+
x_cloudtasks_tasketa: float = Header(0),
35+
x_cloudtasks_taskpreviousresponse: int = Header(0),
36+
x_cloudtasks_taskretryreason: str = Header(""),
3637
) -> None:
3738
self.retry_count = x_cloudtasks_taskretrycount
3839
self.execution_count = x_cloudtasks_taskexecutioncount

fastapi_gcp_tasks/hooks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Standard Library Imports
2-
from typing import Callable
2+
from typing import Any, Callable
33

44
# Third Party Imports
55
from google.cloud import scheduler_v1, tasks_v2
@@ -9,15 +9,15 @@
99
ScheduledHook = Callable[[scheduler_v1.CreateJobRequest], scheduler_v1.CreateJobRequest]
1010

1111

12-
def noop_hook(request):
12+
def noop_hook(request: Any) -> Any:
1313
"""Inspired by https://github.com/kelseyhightower/nocode."""
1414
return request
1515

1616

17-
def chained_hook(*hooks):
17+
def chained_hook(*hooks: Callable[[Any], Any]) -> Callable[[Any], Any]:
1818
"""Call all hooks sequentially with the result from the previous hook."""
1919

20-
def chain(request):
20+
def chain(request: Any) -> Any:
2121
for hook in hooks:
2222
request = hook(request)
2323
return request

fastapi_gcp_tasks/requester.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Standard Library Imports
2-
from typing import Dict, List, Tuple
2+
from typing import Any, Dict, List, Tuple
33
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
44

55
# Third Party Imports
@@ -16,7 +16,7 @@
1616
import ujson as json
1717
except ImportError:
1818
# Standard Library Imports
19-
import json
19+
import json # type: ignore[no-redef]
2020

2121

2222
class Requester:
@@ -39,7 +39,7 @@ def __init__(
3939
self.route = route
4040
self.base_url = base_url.rstrip("/")
4141

42-
def _headers(self, *, values):
42+
def _headers(self, *, values: Dict[str, Any]) -> Dict[str, str]:
4343
headers = _err_val(request_params_to_args(self.route.dependant.header_params, values))
4444
cookies = _err_val(request_params_to_args(self.route.dependant.cookie_params, values))
4545
if len(cookies) > 0:
@@ -49,7 +49,7 @@ def _headers(self, *, values):
4949
# Always send string headers and skip all headers which are supposed to be sent by cloudtasks
5050
return {str(k): str(v) for (k, v) in headers.items() if not str(k).startswith("x_cloudtasks_")}
5151

52-
def _url(self, *, values):
52+
def _url(self, *, values: Dict[str, Any]) -> str:
5353
route = self.route
5454
path_values = _err_val(request_params_to_args(route.dependant.path_params, values))
5555
for name, converter in route.param_convertors.items():
@@ -80,11 +80,11 @@ def _url(self, *, values):
8080
url_parts[4] = urlencode(query)
8181
return urlunparse(url_parts)
8282

83-
def _body(self, *, values):
83+
def _body(self, *, values: Dict[str, Any]) -> bytes | None:
8484
body = None
8585
body_field = self.route.body_field
8686
if body_field and body_field.name:
87-
got_body = values.get(body_field.name, None)
87+
got_body = values.get(body_field.name)
8888
if got_body is None:
8989
if body_field.required:
9090
raise MissingParamError(name=body_field.name)
@@ -95,7 +95,7 @@ def _body(self, *, values):
9595
return body
9696

9797

98-
def _err_val(resp: Tuple[Dict, List[ErrorWrapper]]):
98+
def _err_val(resp: Tuple[Dict, List[ErrorWrapper]]) -> Dict:
9999
values, errors = resp
100100

101101
if len(errors) != 0:

fastapi_gcp_tasks/scheduled_route.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Standard Library Imports
2-
from typing import Callable
2+
from typing import Callable, Type
33

44
# Third Party Imports
55
from fastapi.routing import APIRoute
@@ -15,9 +15,9 @@ def ScheduledRouteBuilder( # noqa: N802
1515
base_url: str,
1616
location_path: str,
1717
job_create_timeout: float = 10.0,
18-
pre_create_hook: ScheduledHook = None,
19-
client=None,
20-
):
18+
pre_create_hook: ScheduledHook | None = None,
19+
client: scheduler_v1.CloudSchedulerClient | None = None,
20+
) -> Type[APIRoute]:
2121
"""
2222
Returns a Mixin that should be used to override route_class.
2323
@@ -47,10 +47,10 @@ def simple_scheduled_task():
4747
class ScheduledRouteMixin(APIRoute):
4848
def get_route_handler(self) -> Callable:
4949
original_route_handler = super().get_route_handler()
50-
self.endpoint.scheduler = self.scheduler_options
50+
self.endpoint.scheduler = self.scheduler_options # type: ignore[attr-defined]
5151
return original_route_handler
5252

53-
def scheduler_options(self, *, name, schedule, **options) -> Scheduler:
53+
def scheduler_options(self, *, name: str, schedule: str, **options: dict) -> Scheduler:
5454
scheduler_opts = {
5555
"base_url": base_url,
5656
"location_path": location_path,
@@ -60,6 +60,8 @@ def scheduler_options(self, *, name, schedule, **options) -> Scheduler:
6060
"name": name,
6161
"schedule": schedule,
6262
} | options
63-
return Scheduler(route=self, **scheduler_opts)
63+
64+
# ignoring the type here because the dictionary values are unpacked
65+
return Scheduler(route=self, **scheduler_opts) # type: ignore[arg-type]
6466

6567
return ScheduledRouteMixin

fastapi_gcp_tasks/scheduler.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Standard Library Imports
2+
from typing import Any, Iterable
23

34
# Third Party Imports
45
from fastapi.routing import APIRoute
@@ -41,7 +42,7 @@ def __init__(
4142
pre_create_hook: ScheduledHook,
4243
name: str = "",
4344
job_create_timeout: float = 10.0,
44-
retry_config: scheduler_v1.RetryConfig = None,
45+
retry_config: scheduler_v1.RetryConfig | None = None,
4546
time_zone: str = "UTC",
4647
force: bool = False,
4748
) -> None:
@@ -73,7 +74,7 @@ def __init__(
7374
self.pre_create_hook = pre_create_hook
7475
self.force = force
7576

76-
def schedule(self, **kwargs):
77+
def schedule(self, **kwargs: Any) -> None:
7778
"""Schedule a job on Cloud Scheduler."""
7879
# Create http request
7980
request = scheduler_v1.HttpTarget()
@@ -103,15 +104,15 @@ def schedule(self, **kwargs):
103104
self.delete()
104105
self.client.create_job(request=request, timeout=self.job_create_timeout)
105106

106-
def _has_changed(self, request: scheduler_v1.CreateJobRequest):
107+
def _has_changed(self, request: scheduler_v1.CreateJobRequest) -> bool:
107108
try:
108109
job = self.client.get_job(name=request.job.name)
109110
# Remove things that are either output only or GCP adds by default
110-
job.user_update_time = None
111-
job.state = None
111+
job.user_update_time = None # type: ignore[assignment]
112+
job.state = None # type: ignore[assignment]
112113
job.status = None
113-
job.last_attempt_time = None
114-
job.schedule_time = None
114+
job.last_attempt_time = None # type: ignore[assignment]
115+
job.schedule_time = None # type: ignore[assignment]
115116
del job.http_target.headers["User-Agent"]
116117
# Proto compare works directly with `__eq__`
117118
return job != request.job
@@ -120,7 +121,7 @@ def _has_changed(self, request: scheduler_v1.CreateJobRequest):
120121
return True
121122
return False
122123

123-
def delete(self):
124+
def delete(self) -> bool | Exception:
124125
"""Delete the job from the scheduler if it exists."""
125126
# We return true or exception because you could have the delete code on multiple instances
126127
try:
@@ -131,7 +132,7 @@ def delete(self):
131132
return ex
132133

133134

134-
def _scheduler_method(methods):
135+
def _scheduler_method(methods: Iterable[str]) -> scheduler_v1.HttpMethod:
135136
method_map = {
136137
"POST": scheduler_v1.HttpMethod.POST,
137138
"GET": scheduler_v1.HttpMethod.GET,
@@ -148,4 +149,4 @@ def _scheduler_method(methods):
148149
method = method_map.get(methods[0])
149150
if method is None:
150151
raise BadMethodError(f"Unknown method {methods[0]}")
151-
return method
152+
return scheduler_v1.HttpMethod(method)

fastapi_gcp_tasks/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
# Third Party Imports
2+
from typing import Any
3+
24
import grpc
35
from google.api_core.exceptions import AlreadyExists
46
from google.cloud import scheduler_v1, tasks_v2
57
from google.cloud.tasks_v2.services.cloud_tasks import transports
68

79

8-
def location_path(*, project: str, location: str, **ignored):
10+
def location_path(*, project: str, location: str) -> str:
911
"""Helper function to construct a location path for Cloud Scheduler."""
1012
return scheduler_v1.CloudSchedulerClient.common_location_path(project=project, location=location)
1113

1214

13-
def queue_path(*, project: str, location: str, queue: str):
15+
def queue_path(*, project: str, location: str, queue: str) -> str:
1416
"""Helper function to construct a queue path for Cloud Tasks."""
1517
return tasks_v2.CloudTasksClient.queue_path(project=project, location=location, queue=queue)
1618

@@ -19,8 +21,8 @@ def ensure_queue(
1921
*,
2022
client: tasks_v2.CloudTasksClient,
2123
path: str,
22-
**kwargs,
23-
):
24+
**kwargs: Any,
25+
) -> None:
2426
"""
2527
Helper function to ensure a Cloud Tasks queue exists.
2628
@@ -39,7 +41,7 @@ def ensure_queue(
3941
pass
4042

4143

42-
def emulator_client(*, host="localhost:8123"):
44+
def emulator_client(*, host: str = "localhost:8123") -> tasks_v2.CloudTasksClient:
4345
"""Helper function to create a CloudTasksClient from an emulator host."""
4446
channel = grpc.insecure_channel(host)
4547
transport = transports.CloudTasksGrpcTransport(channel=channel)

0 commit comments

Comments
 (0)