Skip to content

Commit 6210c61

Browse files
authored
feat: Add generic OAuth provider (Chainlit#1784)
Signed-off-by: Adam Tao <tcx4c70@gmail.com>
1 parent 784b564 commit 6210c61

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

backend/chainlit/oauth_providers.py

+67
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,72 @@ async def get_user_info(self, token: str):
730730
return (kc_user, user)
731731

732732

733+
class GenericOAuthProvider(OAuthProvider):
734+
env = [
735+
"OAUTH_GENERIC_CLIENT_ID",
736+
"OAUTH_GENERIC_CLIENT_SECRET",
737+
"OAUTH_GENERIC_AUTH_URL",
738+
"OAUTH_GENERIC_TOKEN_URL",
739+
"OAUTH_GENERIC_USER_INFO_URL",
740+
"OAUTH_GENERIC_SCOPES",
741+
]
742+
id = os.environ.get("OAUTH_GENERIC_NAME", "generic")
743+
744+
def __init__(self):
745+
self.client_id = os.environ.get("OAUTH_GENERIC_CLIENT_ID")
746+
self.client_secret = os.environ.get("OAUTH_GENERIC_CLIENT_SECRET")
747+
self.authorize_url = os.environ.get("OAUTH_GENERIC_AUTH_URL")
748+
self.token_url = os.environ.get("OAUTH_GENERIC_TOKEN_URL")
749+
self.user_info_url = os.environ.get("OAUTH_GENERIC_USER_INFO_URL")
750+
self.scopes = os.environ.get("OAUTH_GENERIC_SCOPES")
751+
self.user_identifier = os.environ.get("OAUTH_GENERIC_USER_IDENTIFIER", "email")
752+
753+
self.authorize_params = {
754+
"scope": self.scopes,
755+
"response_type": "code",
756+
}
757+
758+
if prompt := self.get_prompt():
759+
self.authorize_params["prompt"] = prompt
760+
761+
async def get_token(self, code: str, url: str):
762+
payload = {
763+
"client_id": self.client_id,
764+
"client_secret": self.client_secret,
765+
"code": code,
766+
"grant_type": "authorization_code",
767+
"redirect_uri": url,
768+
}
769+
async with httpx.AsyncClient() as client:
770+
response = await client.post(self.token_url, data=payload)
771+
response.raise_for_status()
772+
json = response.json()
773+
token = json.get("access_token")
774+
if not token:
775+
raise httpx.HTTPStatusError(
776+
"Failed to get the access token",
777+
request=response.request,
778+
response=response,
779+
)
780+
return token
781+
782+
async def get_user_info(self, token: str):
783+
async with httpx.AsyncClient() as client:
784+
response = await client.get(
785+
self.user_info_url,
786+
headers={"Authorization": f"Bearer {token}"},
787+
)
788+
response.raise_for_status()
789+
server_user = response.json()
790+
user = User(
791+
identifier=server_user.get(self.user_identifier),
792+
metadata={
793+
"provider": self.id,
794+
},
795+
)
796+
return (server_user, user)
797+
798+
733799
providers = [
734800
GithubOAuthProvider(),
735801
GoogleOAuthProvider(),
@@ -741,6 +807,7 @@ async def get_user_info(self, token: str):
741807
AWSCognitoOAuthProvider(),
742808
GitlabOAuthProvider(),
743809
KeycloakOAuthProvider(),
810+
GenericOAuthProvider(),
744811
]
745812

746813

0 commit comments

Comments
 (0)