Skip to content

Commit fe1022b

Browse files
mdesmethashhar
authored andcommitted
Fix parsing authentication header
1 parent 4c57774 commit fe1022b

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

tests/unit/test_client.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -543,13 +543,14 @@ def test_oauth2_authentication_missing_headers(header, error):
543543
'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", Basic realm="Trino"',
544544
'Basic realm="Trino", Bearer realm="Trino", token_type="JWT", Bearer x_redirect_server="{redirect_server}", '
545545
'x_token_server="{token_server}"'
546+
'Bearer x_redirect_server="{redirect_server}",x_token_server="{token_server}",additional_challenge',
546547
])
547548
@httprettified
548549
def test_oauth2_header_parsing(header, sample_post_response_data):
549550
token = str(uuid.uuid4())
550551
challenge_id = str(uuid.uuid4())
551552

552-
redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
553+
redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}?role=test"
553554
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
554555

555556
# noinspection PyUnusedLocal

trino/auth.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -546,17 +546,17 @@ def _construct_cache_key(host: Optional[str], user: Optional[str]) -> Optional[s
546546

547547
@staticmethod
548548
def _parse_authenticate_header(header: str) -> Dict[str, str]:
549-
split_challenge = header.split(" ", 1)
550-
trimmed_challenge = split_challenge[1] if len(split_challenge) > 1 else ""
549+
logger.debug(f"Authentication header: {header}")
550+
components = header.split(",")
551551
auth_info_headers = {}
552552

553-
for item in trimmed_challenge.split(","):
554-
comps = item.split("=")
555-
if len(comps) == 2:
556-
key = comps[0].strip(' "')
557-
value = comps[1].strip(' "')
558-
if key:
559-
auth_info_headers[key.lower()] = value
553+
for component in components:
554+
component = component.strip()
555+
if "=" in component:
556+
key, value = component.split("=", 1)
557+
if value[0] == '"' and value[-1] == '"':
558+
value = value[1:-1]
559+
auth_info_headers[key.lower()] = value
560560
return auth_info_headers
561561

562562

0 commit comments

Comments
 (0)