diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py
index 9bcd361f00092..c482b177714a3 100644
--- a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py
+++ b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py
@@ -17,10 +17,12 @@
# under the License.
from __future__ import annotations
+import base64
import copy
import datetime
import importlib
import itertools
+import json
import logging
import uuid
from collections.abc import Collection, Iterable, Mapping
@@ -411,8 +413,6 @@ def _validate_jwt(self, id_token, jwks):
return claims
def _get_authentik_token_info(self, id_token):
- me = jwt.decode(id_token, options={"verify_signature": False})
-
verify_signature = self.oauth_remotes["authentik"].client_kwargs.get("verify_signature", True)
if verify_signature:
# Validate the token using authentik certificate
@@ -426,7 +426,9 @@ def _get_authentik_token_info(self, id_token):
else:
# Return the token info without validating
log.warning("JWT token is not validated!")
- return me
+ _parts = id_token.split(".")
+ _payload = _parts[1] + "=" * (-len(_parts[1]) % 4)
+ return json.loads(base64.urlsafe_b64decode(_payload))
raise FabException("OAuth signature verify failed")
@@ -1453,7 +1455,7 @@ def add_register_user(
register_user.password = hashed_password
else:
register_user.password = self._hash_password(password)
- register_user.registration_hash = str(uuid.uuid1())
+ register_user.registration_hash = str(uuid.uuid4())
try:
self.session.add(register_user)
self.session.commit()
@@ -2383,7 +2385,9 @@ def _decode_and_validate_azure_jwt(self, id_token: str) -> dict[str, str]:
claims.validate()
return claims
- return jwt.decode(id_token, options={"verify_signature": False})
+ _parts = id_token.split(".")
+ _payload = _parts[1] + "=" * (-len(_parts[1]) % 4)
+ return json.loads(base64.urlsafe_b64decode(_payload))
def _ldap_bind_indirect(self, ldap, con) -> None:
"""
@@ -2418,10 +2422,12 @@ def _search_ldap(self, ldap, con, username):
raise ValueError("AUTH_LDAP_SEARCH must be set")
# build the filter string for the LDAP search
+ # escape username to prevent LDAP injection attacks
+ escaped_username = ldap.filter.escape_filter_chars(username)
if self.auth_ldap_search_filter:
- filter_str = f"(&{self.auth_ldap_search_filter}({self.auth_ldap_uid_field}={username}))"
+ filter_str = f"(&{self.auth_ldap_search_filter}({self.auth_ldap_uid_field}={escaped_username}))"
else:
- filter_str = f"({self.auth_ldap_uid_field}={username})"
+ filter_str = f"({self.auth_ldap_uid_field}={escaped_username})"
# build what fields to request in the LDAP search
request_fields = [
@@ -2491,7 +2497,11 @@ def _ldap_get_nested_groups(self, ldap, con, user_dn) -> list[str]:
"""
log.debug("Nested groups for LDAP enabled.")
# filter for microsoft active directory only
- nested_groups_filter_str = f"(&(objectCategory=Group)(member:1.2.840.113556.1.4.1941:={user_dn}))"
+ # escape user_dn to prevent LDAP injection attacks
+ escaped_user_dn = ldap.filter.escape_filter_chars(user_dn)
+ nested_groups_filter_str = (
+ "(&(objectCategory=Group)(member:1.2.840.113556.1.4.1941:=" + escaped_user_dn + "))"
+ )
nested_groups_request_fields = ["cn"]
nested_groups_search_result = con.search_s(
@@ -2596,7 +2606,7 @@ def _get_all_non_dag_permissions(self) -> dict[tuple[str, str], Permission]:
def _cli_safe_flash(text: str, level: str) -> None:
"""Show a flash in a web context or prints a message if not."""
if has_request_context():
- flash(Markup(text), level)
+ flash(escape(text), level)
else:
getattr(log, level)(text.replace("
", "\n").replace("", "*").replace("", "*"))