Skip to content

Commit 43cc21c

Browse files
authored
Merge pull request #1909 from aboutcode-org/1884-api-group
Throttle API requests based on user permissions
2 parents 8801c90 + be5edc3 commit 43cc21c

File tree

12 files changed

+363
-106
lines changed

12 files changed

+363
-106
lines changed

vulnerabilities/admin.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99

1010
from django import forms
1111
from django.contrib import admin
12+
from django.contrib.admin.widgets import FilteredSelectMultiple
13+
from django.contrib.auth.admin import GroupAdmin as BasicGroupAdmin
14+
from django.contrib.auth.models import Group
15+
from django.contrib.auth.models import User
1216
from django.core.validators import validate_email
1317

1418
from vulnerabilities.models import ApiUser
@@ -97,3 +101,49 @@ def get_form(self, request, obj=None, **kwargs):
97101
defaults["form"] = self.add_form
98102
defaults.update(kwargs)
99103
return super().get_form(request, obj, **defaults)
104+
105+
106+
class GroupWithUsersForm(forms.ModelForm):
107+
users = forms.ModelMultipleChoiceField(
108+
queryset=User.objects.all(),
109+
required=False,
110+
widget=FilteredSelectMultiple("Users", is_stacked=False),
111+
label="Users",
112+
)
113+
114+
class Meta:
115+
model = Group
116+
fields = "__all__"
117+
118+
def __init__(self, *args, **kwargs):
119+
super().__init__(*args, **kwargs)
120+
self.fields["users"].label_from_instance = lambda user: (
121+
f"{user.username} | {user.email}" if user.email else user.username
122+
)
123+
if self.instance.pk:
124+
self.fields["users"].initial = self.instance.user_set.all()
125+
126+
def save(self, commit=True):
127+
group = super().save(commit=commit)
128+
self.save_m2m()
129+
group.user_set.set(self.cleaned_data["users"])
130+
return group
131+
132+
133+
admin.site.unregister(Group)
134+
135+
136+
@admin.register(Group)
137+
class GroupAdmin(admin.ModelAdmin):
138+
form = GroupWithUsersForm
139+
search_fields = ("name",)
140+
ordering = ("name",)
141+
filter_horizontal = ("permissions",)
142+
143+
def formfield_for_manytomany(self, db_field, request=None, **kwargs):
144+
if db_field.name == "permissions":
145+
qs = kwargs.get("queryset", db_field.remote_field.model.objects)
146+
# Avoid a major performance hit resolving permission names which
147+
# triggers a content_type load:
148+
kwargs["queryset"] = qs.select_related("content_type")
149+
return super().formfield_for_manytomany(db_field, request=request, **kwargs)

vulnerabilities/api.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from rest_framework import viewsets
2323
from rest_framework.decorators import action
2424
from rest_framework.response import Response
25-
from rest_framework.throttling import AnonRateThrottle
2625

2726
from vulnerabilities.models import Alias
2827
from vulnerabilities.models import Exploit
@@ -34,7 +33,7 @@
3433
from vulnerabilities.models import get_purl_query_lookups
3534
from vulnerabilities.severity_systems import EPSS
3635
from vulnerabilities.severity_systems import SCORING_SYSTEMS
37-
from vulnerabilities.throttling import StaffUserRateThrottle
36+
from vulnerabilities.throttling import PermissionBasedUserRateThrottle
3837
from vulnerabilities.utils import get_severity_range
3938

4039

@@ -471,7 +470,7 @@ class PackageViewSet(viewsets.ReadOnlyModelViewSet):
471470
serializer_class = PackageSerializer
472471
filter_backends = (filters.DjangoFilterBackend,)
473472
filterset_class = PackageFilterSet
474-
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
473+
throttle_classes = [PermissionBasedUserRateThrottle]
475474

476475
def get_queryset(self):
477476
return super().get_queryset().with_is_vulnerable()
@@ -688,7 +687,7 @@ def get_queryset(self):
688687
serializer_class = VulnerabilitySerializer
689688
filter_backends = (filters.DjangoFilterBackend,)
690689
filterset_class = VulnerabilityFilterSet
691-
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
690+
throttle_classes = [PermissionBasedUserRateThrottle]
692691

693692

694693
class CPEFilterSet(filters.FilterSet):

vulnerabilities/api_extension.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from rest_framework.serializers import ModelSerializer
2424
from rest_framework.serializers import Serializer
2525
from rest_framework.serializers import ValidationError
26-
from rest_framework.throttling import AnonRateThrottle
2726

2827
from vulnerabilities.api import BaseResourceSerializer
2928
from vulnerabilities.models import Exploit
@@ -33,7 +32,7 @@
3332
from vulnerabilities.models import VulnerabilitySeverity
3433
from vulnerabilities.models import Weakness
3534
from vulnerabilities.models import get_purl_query_lookups
36-
from vulnerabilities.throttling import StaffUserRateThrottle
35+
from vulnerabilities.throttling import PermissionBasedUserRateThrottle
3736

3837

3938
class SerializerExcludeFieldsMixin:
@@ -259,7 +258,7 @@ class V2PackageViewSet(viewsets.ReadOnlyModelViewSet):
259258
lookup_field = "purl"
260259
filter_backends = (filters.DjangoFilterBackend,)
261260
filterset_class = V2PackageFilterSet
262-
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
261+
throttle_classes = [PermissionBasedUserRateThrottle]
263262

264263
def get_queryset(self):
265264
return super().get_queryset().with_is_vulnerable().prefetch_related("vulnerabilities")
@@ -345,7 +344,7 @@ class VulnerabilityViewSet(viewsets.ReadOnlyModelViewSet):
345344
lookup_field = "vulnerability_id"
346345
filter_backends = (filters.DjangoFilterBackend,)
347346
filterset_class = V2VulnerabilityFilterSet
348-
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
347+
throttle_classes = [PermissionBasedUserRateThrottle]
349348

350349
def get_queryset(self):
351350
"""
@@ -381,7 +380,7 @@ class CPEViewSet(viewsets.ReadOnlyModelViewSet):
381380
).distinct()
382381
serializer_class = V2VulnerabilitySerializer
383382
filter_backends = (filters.DjangoFilterBackend,)
384-
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
383+
throttle_classes = [PermissionBasedUserRateThrottle]
385384
filterset_class = CPEFilterSet
386385

387386
@action(detail=False, methods=["post"])
@@ -420,4 +419,4 @@ class AliasViewSet(viewsets.ReadOnlyModelViewSet):
420419
serializer_class = V2VulnerabilitySerializer
421420
filter_backends = (filters.DjangoFilterBackend,)
422421
filterset_class = AliasFilterSet
423-
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
422+
throttle_classes = [PermissionBasedUserRateThrottle]

vulnerabilities/api_v2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from rest_framework.permissions import BasePermission
2424
from rest_framework.response import Response
2525
from rest_framework.reverse import reverse
26+
from rest_framework.throttling import AnonRateThrottle
2627

2728
from vulnerabilities.models import AdvisoryReference
2829
from vulnerabilities.models import AdvisorySeverity
@@ -38,6 +39,7 @@
3839
from vulnerabilities.models import VulnerabilityReference
3940
from vulnerabilities.models import VulnerabilitySeverity
4041
from vulnerabilities.models import Weakness
42+
from vulnerabilities.throttling import PermissionBasedUserRateThrottle
4143

4244

4345
class WeaknessV2Serializer(serializers.ModelSerializer):
@@ -199,6 +201,7 @@ class VulnerabilityV2ViewSet(viewsets.ReadOnlyModelViewSet):
199201
queryset = Vulnerability.objects.all()
200202
serializer_class = VulnerabilityV2Serializer
201203
lookup_field = "vulnerability_id"
204+
throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle]
202205

203206
def get_queryset(self):
204207
queryset = super().get_queryset()
@@ -394,6 +397,7 @@ class PackageV2ViewSet(viewsets.ReadOnlyModelViewSet):
394397
serializer_class = PackageV2Serializer
395398
filter_backends = (filters.DjangoFilterBackend,)
396399
filterset_class = PackageV2FilterSet
400+
throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle]
397401

398402
def get_queryset(self):
399403
queryset = super().get_queryset()
@@ -721,6 +725,7 @@ class CodeFixViewSet(viewsets.ReadOnlyModelViewSet):
721725

722726
queryset = CodeFix.objects.all()
723727
serializer_class = CodeFixSerializer
728+
throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle]
724729

725730
def get_queryset(self):
726731
"""
@@ -863,6 +868,7 @@ class PipelineScheduleV2ViewSet(CreateListRetrieveUpdateViewSet):
863868
serializer_class = PipelineScheduleAPISerializer
864869
lookup_field = "pipeline_id"
865870
lookup_value_regex = r"[\w.]+"
871+
throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle]
866872

867873
def get_serializer_class(self):
868874
if self.action == "create":
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Generated by Django 4.2.22 on 2025-07-01 11:59
2+
3+
from django.db import migrations
4+
5+
6+
class Migration(migrations.Migration):
7+
8+
dependencies = [
9+
("vulnerabilities", "0094_advisoryalias_advisoryreference_advisoryseverity_and_more"),
10+
]
11+
12+
operations = [
13+
migrations.AlterModelOptions(
14+
name="apiuser",
15+
options={
16+
"permissions": [
17+
(
18+
"throttle_3_unrestricted",
19+
"Can make unlimited API requests without any throttling limits",
20+
),
21+
(
22+
"throttle_2_high",
23+
"Can make high number of API requests with minimal throttling",
24+
),
25+
(
26+
"throttle_1_medium",
27+
"Can make medium number of API requests with standard throttling",
28+
),
29+
(
30+
"throttle_0_low",
31+
"Can make low number of API requests with strict throttling",
32+
),
33+
]
34+
},
35+
),
36+
]

vulnerabilities/models.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from cwe2.mappings import xml_database_path
2929
from cwe2.weakness import Weakness as DBWeakness
3030
from django.contrib.auth import get_user_model
31+
from django.contrib.auth.models import Group
3132
from django.contrib.auth.models import UserManager
3233
from django.core import exceptions
3334
from django.core.exceptions import ValidationError
@@ -1489,14 +1490,30 @@ def _validate_username(self, email):
14891490

14901491

14911492
class ApiUser(UserModel):
1492-
"""
1493-
A User proxy model to facilitate simplified admin API user creation.
1494-
"""
1493+
"""A User proxy model to facilitate simplified admin API user creation."""
14951494

14961495
objects = ApiUserManager()
14971496

14981497
class Meta:
14991498
proxy = True
1499+
permissions = [
1500+
(
1501+
"throttle_3_unrestricted",
1502+
"Can make unlimited API requests without any throttling limits",
1503+
),
1504+
(
1505+
"throttle_2_high",
1506+
"Can make high number of API requests with minimal throttling",
1507+
),
1508+
(
1509+
"throttle_1_medium",
1510+
"Can make medium number of API requests with standard throttling",
1511+
),
1512+
(
1513+
"throttle_0_low",
1514+
"Can make low number of API requests with strict throttling",
1515+
),
1516+
]
15001517

15011518

15021519
class ChangeLog(models.Model):

vulnerabilities/tests/test_api.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import os
1212
from urllib.parse import quote
1313

14+
from django.core.cache import cache
1415
from django.test import TestCase
1516
from django.test import TransactionTestCase
1617
from django.test.client import RequestFactory
@@ -452,10 +453,8 @@ def add_aliases(vuln, aliases):
452453

453454
class APIPerformanceTest(TestCase):
454455
def setUp(self):
455-
self.user = ApiUser.objects.create_api_user(username="[email protected]")
456-
self.auth = f"Token {self.user.auth_token.key}"
456+
cache.clear()
457457
self.csrf_client = APIClient(enforce_csrf_checks=True)
458-
self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth)
459458

460459
# This setup creates the following data:
461460
# vulnerabilities: vul1, vul2, vul3
@@ -503,7 +502,7 @@ def setUp(self):
503502
set_as_fixing(package=self.pkg_2_13_2, vulnerability=self.vul1)
504503

505504
def test_api_packages_all_num_queries(self):
506-
with self.assertNumQueries(4):
505+
with self.assertNumQueries(3):
507506
# There are 4 queries:
508507
# 1. SAVEPOINT
509508
# 2. Authenticating user
@@ -519,22 +518,22 @@ def test_api_packages_all_num_queries(self):
519518
]
520519

521520
def test_api_packages_single_num_queries(self):
522-
with self.assertNumQueries(8):
521+
with self.assertNumQueries(7):
523522
self.csrf_client.get(f"/api/packages/{self.pkg_2_14_0_rc1.id}", format="json")
524523

525524
def test_api_packages_single_with_purl_in_query_num_queries(self):
526-
with self.assertNumQueries(9):
525+
with self.assertNumQueries(8):
527526
self.csrf_client.get(f"/api/packages/?purl={self.pkg_2_14_0_rc1.purl}", format="json")
528527

529528
def test_api_packages_single_with_purl_no_version_in_query_num_queries(self):
530-
with self.assertNumQueries(64):
529+
with self.assertNumQueries(63):
531530
self.csrf_client.get(
532531
f"/api/packages/?purl=pkg:maven/com.fasterxml.jackson.core/jackson-databind",
533532
format="json",
534533
)
535534

536535
def test_api_packages_bulk_search(self):
537-
with self.assertNumQueries(45):
536+
with self.assertNumQueries(44):
538537
packages = [self.pkg_2_12_6, self.pkg_2_12_6_1, self.pkg_2_13_1]
539538
purls = [p.purl for p in packages]
540539

@@ -547,7 +546,7 @@ def test_api_packages_bulk_search(self):
547546
).json()
548547

549548
def test_api_packages_with_lookup(self):
550-
with self.assertNumQueries(14):
549+
with self.assertNumQueries(13):
551550
data = {"purl": self.pkg_2_12_6.purl}
552551

553552
resp = self.csrf_client.post(
@@ -557,7 +556,7 @@ def test_api_packages_with_lookup(self):
557556
).json()
558557

559558
def test_api_packages_bulk_lookup(self):
560-
with self.assertNumQueries(45):
559+
with self.assertNumQueries(44):
561560
packages = [self.pkg_2_12_6, self.pkg_2_12_6_1, self.pkg_2_13_1]
562561
purls = [p.purl for p in packages]
563562

@@ -572,10 +571,8 @@ def test_api_packages_bulk_lookup(self):
572571

573572
class APITestCasePackage(TestCase):
574573
def setUp(self):
575-
self.user = ApiUser.objects.create_api_user(username="[email protected]")
576-
self.auth = f"Token {self.user.auth_token.key}"
574+
cache.clear()
577575
self.csrf_client = APIClient(enforce_csrf_checks=True)
578-
self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth)
579576

580577
# This setup creates the following data:
581578
# vulnerabilities: vul1, vul2, vul3
@@ -766,7 +763,7 @@ def test_api_with_wrong_namespace_filter(self):
766763
self.assertEqual(response["count"], 0)
767764

768765
def test_api_with_all_vulnerable_packages(self):
769-
with self.assertNumQueries(4):
766+
with self.assertNumQueries(3):
770767
# There are 4 queries:
771768
# 1. SAVEPOINT
772769
# 2. Authenticating user

0 commit comments

Comments
 (0)