Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
# under the License.
from __future__ import annotations

import base64
import copy
import datetime
import importlib
import itertools
import json
import logging
import uuid
from collections.abc import Collection, Iterable, Mapping
Expand Down Expand Up @@ -411,8 +413,6 @@ def _validate_jwt(self, id_token, jwks):
return claims

def _get_authentik_token_info(self, id_token):
me = jwt.decode(id_token, options={"verify_signature": False})

verify_signature = self.oauth_remotes["authentik"].client_kwargs.get("verify_signature", True)
if verify_signature:
# Validate the token using authentik certificate
Expand All @@ -426,7 +426,9 @@ def _get_authentik_token_info(self, id_token):
else:
# Return the token info without validating
log.warning("JWT token is not validated!")
return me
_parts = id_token.split(".")
_payload = _parts[1] + "=" * (-len(_parts[1]) % 4)
return json.loads(base64.urlsafe_b64decode(_payload))

raise FabException("OAuth signature verify failed")

Expand Down Expand Up @@ -1453,7 +1455,7 @@ def add_register_user(
register_user.password = hashed_password
else:
register_user.password = self._hash_password(password)
register_user.registration_hash = str(uuid.uuid1())
register_user.registration_hash = str(uuid.uuid4())
try:
self.session.add(register_user)
self.session.commit()
Expand Down Expand Up @@ -2383,7 +2385,9 @@ def _decode_and_validate_azure_jwt(self, id_token: str) -> dict[str, str]:
claims.validate()
return claims

return jwt.decode(id_token, options={"verify_signature": False})
_parts = id_token.split(".")
_payload = _parts[1] + "=" * (-len(_parts[1]) % 4)
return json.loads(base64.urlsafe_b64decode(_payload))

def _ldap_bind_indirect(self, ldap, con) -> None:
"""
Expand Down Expand Up @@ -2418,10 +2422,12 @@ def _search_ldap(self, ldap, con, username):
raise ValueError("AUTH_LDAP_SEARCH must be set")

# build the filter string for the LDAP search
# escape username to prevent LDAP injection attacks
escaped_username = ldap.filter.escape_filter_chars(username)
if self.auth_ldap_search_filter:
filter_str = f"(&{self.auth_ldap_search_filter}({self.auth_ldap_uid_field}={username}))"
filter_str = f"(&{self.auth_ldap_search_filter}({self.auth_ldap_uid_field}={escaped_username}))"
else:
filter_str = f"({self.auth_ldap_uid_field}={username})"
filter_str = f"({self.auth_ldap_uid_field}={escaped_username})"

# build what fields to request in the LDAP search
request_fields = [
Expand Down Expand Up @@ -2491,7 +2497,11 @@ def _ldap_get_nested_groups(self, ldap, con, user_dn) -> list[str]:
"""
log.debug("Nested groups for LDAP enabled.")
# filter for microsoft active directory only
nested_groups_filter_str = f"(&(objectCategory=Group)(member:1.2.840.113556.1.4.1941:={user_dn}))"
# escape user_dn to prevent LDAP injection attacks
escaped_user_dn = ldap.filter.escape_filter_chars(user_dn)
nested_groups_filter_str = (
"(&(objectCategory=Group)(member:1.2.840.113556.1.4.1941:=" + escaped_user_dn + "))"
)
nested_groups_request_fields = ["cn"]

nested_groups_search_result = con.search_s(
Expand Down Expand Up @@ -2596,7 +2606,7 @@ def _get_all_non_dag_permissions(self) -> dict[tuple[str, str], Permission]:
def _cli_safe_flash(text: str, level: str) -> None:
"""Show a flash in a web context or prints a message if not."""
if has_request_context():
flash(Markup(text), level)
flash(escape(text), level)
else:
getattr(log, level)(text.replace("<br>", "\n").replace("<b>", "*").replace("</b>", "*"))

Expand Down
Loading