Skip to content

Commit 1634706

Browse files
committed
check access token for group membership; fix #43
1 parent e706e86 commit 1634706

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

mlflow_oidc_auth/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(self):
2626
self.OIDC_REDIRECT_URI = os.environ.get("OIDC_REDIRECT_URI", None)
2727
self.OIDC_CLIENT_ID = os.environ.get("OIDC_CLIENT_ID", None)
2828
self.OIDC_CLIENT_SECRET = os.environ.get("OIDC_CLIENT_SECRET", None)
29+
self.OIDC_AUDIENCE = os.environ.get("OIDC_AUDIENCE", None)
2930

3031
# session
3132
self.SESSION_TYPE = os.environ.get("SESSION_TYPE", "cachelib")

mlflow_oidc_auth/views/authentication.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import secrets
22

33
from flask import redirect, session, url_for
4+
import jwt
45

56
import mlflow_oidc_auth.utils as utils
67
from mlflow_oidc_auth.auth import get_oauth_instance
@@ -42,7 +43,13 @@ def callback():
4243

4344
user_groups = importlib.import_module(config.OIDC_GROUP_DETECTION_PLUGIN).get_user_groups(token["access_token"])
4445
else:
45-
user_groups = token["userinfo"][config.OIDC_GROUPS_ATTRIBUTE]
46+
group_attr = config.OIDC_GROUPS_ATTRIBUTE
47+
user_info = token["userinfo"]
48+
decoded_access_token = jwt.decode(token["access_token"], audience=config.OIDC_AUDIENCE, options={"verify_signature": False})
49+
if group_attr in decoded_access_token:
50+
user_groups = decoded_access_token[group_attr]
51+
if group_attr in user_info:
52+
user_groups = user_info[group_attr]
4653

4754
app.logger.debug(f"User groups: {user_groups}")
4855

0 commit comments

Comments
 (0)