Skip to content

Commit a366176

Browse files
committed
Refactor functions into separate module
Needed because otherwise during testing we end up in a circular import
1 parent 09e691b commit a366176

File tree

1 file changed

+5
-59
lines changed

1 file changed

+5
-59
lines changed

mlflow_oidc_auth/views/authentication.py

Lines changed: 5 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import fnmatch
21
import secrets
32

43
from flask import redirect, render_template, session, url_for
@@ -8,6 +7,7 @@
87
from mlflow_oidc_auth.auth import get_oauth_instance
98
from mlflow_oidc_auth.config import config
109
from mlflow_oidc_auth.user import create_user, populate_groups, update_user
10+
from mlflow_oidc_auth.token_utils import token_get_user_is_admin, token_get_user_groups
1111

1212

1313
def login():
@@ -31,62 +31,6 @@ def logout():
3131
return redirect("/")
3232

3333

34-
def get_user_groups(token: dict) -> list[str]:
35-
"""Retrieve the list of groups this user (based on the provided token) is a member of
36-
37-
Args:
38-
token: dictionary holding the oidc token information
39-
40-
Returns:
41-
list of all the groups this user is a member of
42-
"""
43-
user_groups = []
44-
45-
if config.OIDC_GROUP_DETECTION_PLUGIN:
46-
import importlib
47-
48-
user_groups = importlib.import_module(config.OIDC_GROUP_DETECTION_PLUGIN).get_user_groups(token["access_token"])
49-
else:
50-
user_groups = token["userinfo"][config.OIDC_GROUPS_ATTRIBUTE]
51-
52-
app.logger.debug(f"All user groups: {user_groups}")
53-
54-
# Now filter the user groups to keep only those matching the pattern or the ADMIN group
55-
user_groups = sorted(
56-
set(
57-
[
58-
x
59-
for p in config.OIDC_GROUP_FILTER_PATTERNS
60-
for x in [g for g in user_groups if (fnmatch.fnmatch(g, p) or (g == config.OIDC_ADMIN_GROUP_NAME))]
61-
]
62-
)
63-
)
64-
65-
app.logger.debug(f"Filtered user groups: {user_groups}")
66-
67-
return user_groups
68-
69-
70-
def get_is_admin(user_groups: list[str]):
71-
"""Check if the admin group is included in the user_groups. In that case
72-
it means that the user is an admin user
73-
74-
Args:
75-
user_groups (list[str]): list of the groups the current user belongs to
76-
77-
Returns:
78-
True if the admin group is in the list of the groups of the current user, False otherwise
79-
80-
"""
81-
is_admin = False
82-
83-
if config.OIDC_ADMIN_GROUP_NAME in user_groups:
84-
app.logger.debug(f"User is in admin group {config.OIDC_ADMIN_GROUP_NAME}")
85-
is_admin = True
86-
87-
return is_admin
88-
89-
9034
def callback():
9135
"""Validate the state to protect against CSRF"""
9236

@@ -107,8 +51,10 @@ def callback():
10751
display_name = token["userinfo"]["name"]
10852

10953
# Get groups and admin status
110-
user_groups = get_user_groups(token)
111-
is_admin = get_is_admin(user_groups)
54+
user_groups = token_get_user_groups(token)
55+
app.logger.debug(f"Filtered user groups the user belongs to: {user_groups}")
56+
is_admin = token_get_user_is_admin(user_groups)
57+
app.logger.debug(f"User is an admin user: {is_admin}")
11258

11359
# If there are no user_groups (including the admin group) that allow login to server, give 401
11460
if not len(user_groups):

0 commit comments

Comments
 (0)