Skip to content

Commit 3bd4629

Browse files
committed
Clear throttle cache on ConsumerThrottleLimit.save
1 parent 3b8cbfe commit 3bd4629

File tree

4 files changed

+49
-4
lines changed

4 files changed

+49
-4
lines changed

main/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@
1212

1313

1414
DURATION_MAPPING = {"minute": 60, "hour": 3600, "day": 86400, "week": 604800}
15+
16+
CONSUMER_THROTTLES_KEY = "consumer_throttles"

main/consumer_throttles.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,12 @@
1212
from django.core.cache import cache as default_cache
1313
from django.core.exceptions import ImproperlyConfigured
1414

15-
from main.constants import DURATION_MAPPING
15+
from main.constants import CONSUMER_THROTTLES_KEY, DURATION_MAPPING
1616
from main.models import ConsumerThrottleLimit
1717

1818
log = logging.getLogger(__name__)
1919

2020

21-
CONSUMER_THROTTLES_KEY = "consumer_throttles"
22-
23-
2421
class AsyncBaseThrottle(ABC):
2522
"""
2623
Abstract class for throttling AsyncConsumer requests.

main/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
Classes related to models for main
33
"""
44

5+
from django.core.cache import cache
56
from django.db.models import CharField, DateTimeField, IntegerField, Model
67
from django.db.models.query import QuerySet
78

9+
from main.constants import CONSUMER_THROTTLES_KEY
810
from main.utils import now_in_utc
911

1012

@@ -77,3 +79,8 @@ def __str__(self):
7779
Auth {self.auth_limit}, \
7880
Anon {self.anon_limit} \
7981
per {self.interval}"
82+
83+
def save(self, **kwargs):
84+
"""Override save to reset the throttles cache"""
85+
cache.delete(CONSUMER_THROTTLES_KEY)
86+
return super().save(**kwargs)

main/models_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""main models tests"""
2+
3+
import pytest
4+
from asgiref.sync import sync_to_async
5+
6+
from main.consumer_throttles import UserScopedRateThrottle
7+
from main.factories import ConsumerThrottleLimitFactory
8+
9+
10+
@pytest.mark.django_db
11+
async def test_throttle_limit_save_reset_cache():
12+
"""Saving changes to the throttle limit model should reset the cache"""
13+
auth_limits = (20, 10)
14+
anon_limits = (10, 5)
15+
intervals = ("minute", "hour")
16+
17+
throttle_limit = await sync_to_async(ConsumerThrottleLimitFactory.create)(
18+
auth_limit=auth_limits[0],
19+
anon_limit=anon_limits[0],
20+
interval=intervals[0],
21+
)
22+
scoped_throttle = UserScopedRateThrottle()
23+
scoped_throttle.scope = throttle_limit.throttle_key
24+
assert await scoped_throttle.get_rate() == {
25+
"throttle_key": throttle_limit.throttle_key,
26+
"auth_limit": auth_limits[0],
27+
"anon_limit": anon_limits[0],
28+
"interval": intervals[0],
29+
}
30+
throttle_limit.auth_limit = auth_limits[1]
31+
throttle_limit.anon_limit = anon_limits[1]
32+
throttle_limit.interval = intervals[1]
33+
await throttle_limit.asave()
34+
assert await scoped_throttle.get_rate() == {
35+
"throttle_key": throttle_limit.throttle_key,
36+
"auth_limit": auth_limits[1],
37+
"anon_limit": anon_limits[1],
38+
"interval": intervals[1],
39+
}

0 commit comments

Comments
 (0)