Refactor most code, isolate authlib somehow
This commit is contained in:
parent
38b983c2a5
commit
c5bb4f4319
10 changed files with 183 additions and 218 deletions
|
@ -1,29 +0,0 @@
|
|||
from datetime import datetime, timedelta
|
||||
from collections import OrderedDict
|
||||
|
||||
from .models import User
|
||||
|
||||
time_keys = set(("iat", "exp", "auth_time", "updated_at"))
|
||||
|
||||
|
||||
def pretty_details(user: User, now: datetime) -> OrderedDict:
|
||||
details = OrderedDict()
|
||||
# breakpoint()
|
||||
for key in sorted(time_keys):
|
||||
try:
|
||||
dt = datetime.fromtimestamp(user.userinfo[key])
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
else:
|
||||
td = now - dt
|
||||
td = timedelta(days=td.days, seconds=td.seconds)
|
||||
if td.days < 0:
|
||||
ptd = f"in {-td} h:m:s"
|
||||
else:
|
||||
ptd = f"{td} h:m:s ago"
|
||||
details[key] = f"{user.userinfo[key]} - {dt} ({ptd})"
|
||||
for key in sorted(user.userinfo):
|
||||
if key in time_keys:
|
||||
continue
|
||||
details[key] = user.userinfo[key]
|
||||
return details
|
43
src/oidc_test/auth_provider.py
Normal file
43
src/oidc_test/auth_provider.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
from typing import Any
|
||||
from jwt import decode
|
||||
import logging
|
||||
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
|
||||
from .settings import AuthProviderSettings, settings
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
|
||||
class Provider(AuthProviderSettings):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
authlib_client: StarletteOAuth2App = StarletteOAuth2App(None)
|
||||
|
||||
def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]:
|
||||
"""Decode the token with signature check"""
|
||||
if settings.debug_token:
|
||||
decoded = decode(
|
||||
token,
|
||||
self.get_public_key(),
|
||||
algorithms=[self.signature_alg],
|
||||
audience=["account", "oidc-test", "oidc-test-web"],
|
||||
options={
|
||||
"verify_signature": False,
|
||||
"verify_aud": False,
|
||||
}, # not settings.insecure.skip_verify_signature},
|
||||
)
|
||||
logger.debug(str(decoded))
|
||||
return decode(
|
||||
token,
|
||||
self.get_public_key(),
|
||||
algorithms=[self.signature_alg],
|
||||
audience=["account", "oidc-test", "oidc-test-web"],
|
||||
options={
|
||||
"verify_signature": verify_signature,
|
||||
}, # not settings.insecure.skip_verify_signature},
|
||||
)
|
||||
|
||||
|
||||
providers: dict[str, Provider] = {}
|
|
@ -1,4 +1,3 @@
|
|||
import re
|
||||
from typing import Union, Annotated
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
@ -14,7 +13,8 @@ from authlib.oauth2.auth import OAuth2Token
|
|||
|
||||
from .models import User
|
||||
from .database import db, TokenNotInDb, UserNotInDB
|
||||
from .settings import oidc_providers_settings
|
||||
from .settings import settings
|
||||
from .auth_provider import providers, Provider
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
|
@ -35,11 +35,13 @@ async def fetch_token(name, request):
|
|||
# return token.to_token()
|
||||
|
||||
|
||||
async def update_token(name, token, refresh_token=None, access_token=None):
|
||||
async def update_token(
|
||||
provider_id, token, refresh_token: str | None = None, access_token: str | None = None
|
||||
):
|
||||
"""Update the token in the database"""
|
||||
oidc_provider_settings = oidc_providers_settings[name]
|
||||
sid: str = oidc_provider_settings.decode(token["id_token"])["sid"]
|
||||
item = await db.get_token(oidc_provider_settings, sid)
|
||||
provider = providers[provider_id]
|
||||
sid: str = provider.decode(token["id_token"])["sid"]
|
||||
item = await db.get_token(provider, sid)
|
||||
# update old token
|
||||
item["access_token"] = token["access_token"]
|
||||
item["refresh_token"] = token["refresh_token"]
|
||||
|
@ -54,10 +56,12 @@ authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_t
|
|||
|
||||
|
||||
def init_providers():
|
||||
# Add oidc providers to authlib from the settings
|
||||
for id, provider in oidc_providers_settings.items():
|
||||
"""Add oidc providers to authlib from the settings
|
||||
and build the providers dict"""
|
||||
for provider_settings in settings.auth.providers:
|
||||
provider = Provider(**provider_settings.model_dump())
|
||||
authlib_oauth.register(
|
||||
name=id,
|
||||
name=provider.id,
|
||||
server_metadata_url=provider.openid_configuration,
|
||||
client_kwargs={
|
||||
"scope": " ".join(
|
||||
|
@ -74,6 +78,8 @@ def init_providers():
|
|||
update_token=update_token,
|
||||
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
||||
)
|
||||
provider.authlib_client = getattr(authlib_oauth, provider.id)
|
||||
providers[provider.id] = provider
|
||||
|
||||
|
||||
init_providers()
|
||||
|
@ -82,33 +88,41 @@ init_providers()
|
|||
async def get_providers_info():
|
||||
# Get the public key:
|
||||
async with AsyncClient() as client:
|
||||
for provider_settings in oidc_providers_settings.values():
|
||||
if provider_settings.info_url:
|
||||
provider_info = await client.get(provider_settings.url)
|
||||
provider_settings.info = provider_info.json()
|
||||
for provider in providers.values():
|
||||
if provider.info_url:
|
||||
provider_info = await client.get(provider.url)
|
||||
provider.info = provider_info.json()
|
||||
|
||||
|
||||
def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None:
|
||||
def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None:
|
||||
"""Return the oidc_provider from a request object, from the session.
|
||||
It can be used in Depends()"""
|
||||
if (oidc_provider_id := request.session.get("oidc_provider_id")) is None:
|
||||
return
|
||||
try:
|
||||
return getattr(authlib_oauth, str(oidc_provider_id))
|
||||
except AttributeError:
|
||||
if (auth_provider_id := request.session.get("auth_provider_id")) is None:
|
||||
return
|
||||
return getattr(authlib_oauth, str(auth_provider_id), None)
|
||||
|
||||
|
||||
def get_oidc_provider(request: Request) -> StarletteOAuth2App:
|
||||
if (oidc_provider := get_oidc_provider_or_none(request)) is None:
|
||||
if oidc_provider is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No provider")
|
||||
else:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||
def get_auth_provider_client(request: Request) -> StarletteOAuth2App:
|
||||
if (oidc_provider := get_auth_provider_client_or_none(request)) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||
else:
|
||||
return oidc_provider
|
||||
|
||||
|
||||
def get_auth_provider_or_none(request: Request) -> Provider | None:
|
||||
"""Return the oidc_provider settings from a request object, from the session.
|
||||
It can be used in Depends()"""
|
||||
if (auth_provider_id := request.session.get("auth_provider_id")) is None:
|
||||
return
|
||||
return providers.get(auth_provider_id)
|
||||
|
||||
|
||||
def get_auth_provider(request: Request) -> Provider:
|
||||
if (provider := get_auth_provider_or_none(request)) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||
return provider
|
||||
|
||||
|
||||
async def get_current_user(request: Request) -> User:
|
||||
"""Get the current user from a request object.
|
||||
Also validates the token expiration time.
|
||||
|
@ -120,11 +134,11 @@ async def get_current_user(request: Request) -> User:
|
|||
user = await db.get_user(user_sub)
|
||||
## Check if the token is expired
|
||||
if token.is_expired():
|
||||
oidc_provider = get_oidc_provider(request=request)
|
||||
provider = get_auth_provider(request=request)
|
||||
## Ask a new refresh token from the provider
|
||||
logger.info(f"Token expired for user {user.name}")
|
||||
try:
|
||||
userinfo = await oidc_provider.fetch_access_token(
|
||||
userinfo = await provider.authlib_client.fetch_access_token(
|
||||
refresh_token=token.get("refresh_token")
|
||||
)
|
||||
assert userinfo is not None
|
||||
|
@ -150,14 +164,12 @@ async def get_token(request: Request) -> OAuth2Token:
|
|||
"""Return the token from the session.
|
||||
Can be used in Depends()"""
|
||||
try:
|
||||
oidc_provider_settings = oidc_providers_settings[
|
||||
request.session.get("oidc_provider_id", "")
|
||||
]
|
||||
provider = providers[request.session.get("auth_provider_id", "")]
|
||||
except KeyError:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider")
|
||||
try:
|
||||
return await db.get_token(
|
||||
oidc_provider_settings,
|
||||
provider,
|
||||
request.session.get("sid"),
|
||||
)
|
||||
except (TokenNotInDb, InvalidKeyError, DecodeError) as err:
|
||||
|
@ -219,7 +231,7 @@ async def get_user_from_token(
|
|||
"Request headers must have a 'auth_provider' field",
|
||||
)
|
||||
try:
|
||||
auth_provider_settings = oidc_providers_settings[auth_provider_id]
|
||||
provider = providers[auth_provider_id]
|
||||
except KeyError:
|
||||
if auth_provider_id == "":
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth provider")
|
||||
|
@ -230,7 +242,7 @@ async def get_user_from_token(
|
|||
if token == "":
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
|
||||
try:
|
||||
payload = auth_provider_settings.decode(token)
|
||||
payload = provider.decode(token)
|
||||
except ExpiredSignatureError as err:
|
||||
logger.info(f"Expired signature: {err}")
|
||||
raise HTTPException(
|
||||
|
@ -261,7 +273,7 @@ async def get_user_from_token(
|
|||
user = await db.add_user(
|
||||
sub=payload["sub"],
|
||||
user_info=payload,
|
||||
oidc_provider=getattr(authlib_oauth, auth_provider_id),
|
||||
auth_provider=getattr(authlib_oauth, auth_provider_id),
|
||||
access_token=token,
|
||||
)
|
||||
return user
|
||||
|
|
|
@ -2,11 +2,10 @@
|
|||
|
||||
import logging
|
||||
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
from authlib.oauth2.rfc6749 import OAuth2Token
|
||||
|
||||
from .settings import OIDCProvider, oidc_providers_settings
|
||||
from .models import User, Role
|
||||
from .auth_provider import Provider, providers
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
|
@ -29,23 +28,23 @@ class Database:
|
|||
self,
|
||||
sub: str,
|
||||
user_info: dict,
|
||||
oidc_provider: StarletteOAuth2App,
|
||||
auth_provider: Provider,
|
||||
access_token: str,
|
||||
access_token_decoded: dict | None = None,
|
||||
) -> User:
|
||||
if access_token_decoded is None:
|
||||
assert oidc_provider.name is not None
|
||||
oidc_provider_settings = oidc_providers_settings[oidc_provider.name]
|
||||
access_token_decoded = oidc_provider_settings.decode(access_token)
|
||||
assert auth_provider.name is not None
|
||||
provider = providers[auth_provider.id]
|
||||
access_token_decoded = provider.decode(access_token)
|
||||
user_info["auth_provider_id"] = auth_provider.id
|
||||
user = User(**user_info)
|
||||
user.userinfo = user_info
|
||||
user.oidc_provider = oidc_provider
|
||||
user.access_token = access_token
|
||||
user.access_token_decoded = access_token_decoded
|
||||
# user.access_token = access_token
|
||||
# user.access_token_decoded = access_token_decoded
|
||||
# Add roles provided in the access token
|
||||
roles = set()
|
||||
try:
|
||||
r = access_token_decoded["resource_access"][oidc_provider.client_id]["roles"]
|
||||
r = access_token_decoded["resource_access"][auth_provider.client_id]["roles"]
|
||||
roles.update(r)
|
||||
except KeyError:
|
||||
pass
|
||||
|
@ -66,19 +65,19 @@ class Database:
|
|||
raise UserNotInDB
|
||||
return self.users[sub]
|
||||
|
||||
async def add_token(self, oidc_provider_settings: OIDCProvider, token: OAuth2Token) -> None:
|
||||
async def add_token(self, provider: Provider, token: OAuth2Token) -> None:
|
||||
"""Store a token using as key the sid (auth provider's session id)
|
||||
in the id_token"""
|
||||
assert isinstance(oidc_provider_settings, OIDCProvider)
|
||||
assert isinstance(provider, Provider)
|
||||
sid = token["userinfo"]["sid"]
|
||||
self.tokens[sid] = token
|
||||
|
||||
async def get_token(
|
||||
self,
|
||||
oidc_provider_settings: OIDCProvider,
|
||||
provider: Provider,
|
||||
sid: str | None,
|
||||
) -> OAuth2Token:
|
||||
assert isinstance(oidc_provider_settings, OIDCProvider)
|
||||
assert isinstance(provider, Provider)
|
||||
if sid is None:
|
||||
raise TokenNotInDb
|
||||
try:
|
||||
|
|
|
@ -15,7 +15,7 @@ from fastapi.staticfiles import StaticFiles
|
|||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from jwt import InvalidTokenError, PyJWTError
|
||||
from jwt import PyJWTError
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
from authlib.integrations.base_client import OAuthError
|
||||
|
@ -26,11 +26,12 @@ from authlib.oauth2.rfc6749 import OAuth2Token
|
|||
# from fastapi.security import OpenIdConnect
|
||||
# from pkce import generate_code_verifier, generate_pkce_pair
|
||||
|
||||
from .settings import settings, oidc_providers_settings
|
||||
from .settings import settings
|
||||
from .auth_provider import Provider, providers
|
||||
from .models import User
|
||||
from .auth_utils import (
|
||||
get_oidc_provider,
|
||||
get_oidc_provider_or_none,
|
||||
get_auth_provider,
|
||||
get_auth_provider_or_none,
|
||||
get_current_user_or_none,
|
||||
authlib_oauth,
|
||||
get_providers_info,
|
||||
|
@ -38,7 +39,6 @@ from .auth_utils import (
|
|||
get_token,
|
||||
update_token,
|
||||
)
|
||||
from .auth_misc import pretty_details
|
||||
from .database import TokenNotInDb, db
|
||||
from .resource_server import resource_server
|
||||
|
||||
|
@ -78,51 +78,30 @@ app.mount("/resource", resource_server, name="resource_server")
|
|||
async def home(
|
||||
request: Request,
|
||||
user: Annotated[User, Depends(get_current_user_or_none)],
|
||||
oidc_provider: Annotated[StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)],
|
||||
provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)],
|
||||
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)],
|
||||
) -> HTMLResponse:
|
||||
now = datetime.now()
|
||||
if oidc_provider and (
|
||||
(
|
||||
oidc_provider_settings := oidc_providers_settings.get(
|
||||
request.session.get("oidc_provider_id", "")
|
||||
)
|
||||
)
|
||||
is not None
|
||||
):
|
||||
resources = oidc_provider_settings.resources
|
||||
else:
|
||||
resources = []
|
||||
oidc_provider_settings = None
|
||||
|
||||
context = {
|
||||
"settings": settings.model_dump(),
|
||||
"user": user,
|
||||
"now": now,
|
||||
"oidc_provider": oidc_provider,
|
||||
"oidc_provider_settings": oidc_provider_settings,
|
||||
"resources": resources,
|
||||
"now": datetime.now(),
|
||||
"auth_provider": provider,
|
||||
}
|
||||
if token is None:
|
||||
if provider is None or token is None:
|
||||
context["access_token"] = None
|
||||
context["id_token_parsed"] = None
|
||||
context["access_token_parsed"] = None
|
||||
context["refresh_token_parsed"] = None
|
||||
context["resources"] = None
|
||||
else:
|
||||
context["access_token"] = token["access_token"]
|
||||
assert oidc_provider is not None
|
||||
assert oidc_provider.name is not None
|
||||
oidc_provider_settings = oidc_providers_settings[oidc_provider.name]
|
||||
access_token_parsed = oidc_provider_settings.decode(
|
||||
token["access_token"], verify_signature=False
|
||||
)
|
||||
context["resources"] = provider.resources
|
||||
access_token_parsed = provider.decode(token["access_token"], verify_signature=False)
|
||||
context["access_token_scope"] = access_token_parsed["scope"]
|
||||
# context["id_token_parsed"] = pretty_details(user, now)
|
||||
context["id_token_parsed"] = oidc_provider_settings.decode(
|
||||
token["id_token"], verify_signature=False
|
||||
)
|
||||
context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False)
|
||||
context["access_token_parsed"] = access_token_parsed
|
||||
context["refresh_token_parsed"] = oidc_provider_settings.decode(
|
||||
context["refresh_token_parsed"] = provider.decode(
|
||||
token["refresh_token"], verify_signature=False
|
||||
)
|
||||
return templates.TemplateResponse(name="home.html", request=request, context=context)
|
||||
|
@ -131,20 +110,20 @@ async def home(
|
|||
# Endpoints for the login / authorization process
|
||||
|
||||
|
||||
@app.get("/login/{oidc_provider_id}")
|
||||
async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
||||
@app.get("/login/{auth_provider_id}")
|
||||
async def login(request: Request, auth_provider_id: str) -> RedirectResponse:
|
||||
"""Login with the provider id, giving the browser a redirect to its authorize page.
|
||||
The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url
|
||||
The provider is expected to send the browser back to our own /auth/{auth_provider_id} url
|
||||
with the token.
|
||||
"""
|
||||
redirect_uri = request.url_for("auth", oidc_provider_id=oidc_provider_id)
|
||||
redirect_uri = request.url_for("auth", auth_provider_id=auth_provider_id)
|
||||
try:
|
||||
provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
|
||||
provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id)
|
||||
except AttributeError:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||
# if (
|
||||
# code_challenge_method := oidc_providers_settings[
|
||||
# oidc_provider_id
|
||||
# code_challenge_method := providers[
|
||||
# auth_provider_id
|
||||
# ].code_challenge_method
|
||||
# ) is not None:
|
||||
# #client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method)
|
||||
|
@ -164,30 +143,30 @@ async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
|||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
|
||||
|
||||
|
||||
@app.get("/auth/{oidc_provider_id}")
|
||||
async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
||||
@app.get("/auth/{auth_provider_id}")
|
||||
async def auth(request: Request, auth_provider_id: str) -> RedirectResponse:
|
||||
"""Decrypt the auth token, store it to the session (cookie based)
|
||||
and response to the browser with a redirect to a "welcome user" page.
|
||||
"""
|
||||
try:
|
||||
oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
|
||||
authlib_client: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id)
|
||||
except AttributeError:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||
try:
|
||||
token: OAuth2Token = await oidc_provider.authorize_access_token(request)
|
||||
token: OAuth2Token = await authlib_client.authorize_access_token(request)
|
||||
except OAuthError as error:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
|
||||
# Remember the oidc_provider in the session
|
||||
# Remember the authlib_client in the session
|
||||
# logger.info(f"Scope: {token['scope']}")
|
||||
request.session["oidc_provider_id"] = oidc_provider_id
|
||||
request.session["auth_provider_id"] = auth_provider_id
|
||||
#
|
||||
# One could process the full decoded token which contains extra information
|
||||
# eg for updates. Here we are only interested in roles
|
||||
#
|
||||
if userinfo := token.get("userinfo"):
|
||||
# Remember the oidc_provider in the session
|
||||
request.session["oidc_provider_id"] = oidc_provider_id
|
||||
# User id (sub) given by oidc provider
|
||||
# Remember the authlib_client in the session
|
||||
request.session["auth_provider_id"] = auth_provider_id
|
||||
# User id (sub) given by auth provider
|
||||
sub = userinfo["sub"]
|
||||
# Build and remember the user in the session
|
||||
request.session["user_sub"] = sub
|
||||
|
@ -196,7 +175,7 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
|||
user = await db.add_user(
|
||||
sub,
|
||||
user_info=userinfo,
|
||||
oidc_provider=oidc_provider,
|
||||
auth_provider=providers[auth_provider_id],
|
||||
access_token=token["access_token"],
|
||||
)
|
||||
except PyJWTError as err:
|
||||
|
@ -208,48 +187,41 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
|||
# Add the provider session id to the session
|
||||
request.session["sid"] = userinfo["sid"]
|
||||
# Add the token to the db because it is used for logout
|
||||
assert oidc_provider.name is not None
|
||||
oidc_provider_settings = oidc_providers_settings[oidc_provider.name]
|
||||
await db.add_token(oidc_provider_settings, token)
|
||||
provider = providers[auth_provider_id]
|
||||
await db.add_token(provider, token)
|
||||
# Send the user to the home: (s)he is authenticated
|
||||
return RedirectResponse(url=request.url_for("home"))
|
||||
else:
|
||||
# Not sure if it's correct to redirect to plain login
|
||||
# if no userinfo is provided
|
||||
return RedirectResponse(url=request.url_for("login", oidc_provider_id=oidc_provider_id))
|
||||
return RedirectResponse(url=request.url_for("login", auth_provider_id=auth_provider_id))
|
||||
|
||||
|
||||
@app.get("/account")
|
||||
async def account(
|
||||
request: Request,
|
||||
provider: Annotated[Provider, Depends(get_auth_provider)],
|
||||
) -> RedirectResponse:
|
||||
if (
|
||||
oidc_provider_settings := oidc_providers_settings.get(
|
||||
request.session.get("oidc_provider_id", "")
|
||||
)
|
||||
) is None:
|
||||
raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings")
|
||||
return RedirectResponse(f"{oidc_provider_settings.account_url_template}")
|
||||
"""Redirect to the auth provider account management,
|
||||
if account_url_template is in the provider's settings"""
|
||||
return RedirectResponse(f"{provider.account_url_template}")
|
||||
|
||||
|
||||
@app.get("/logout")
|
||||
async def logout(
|
||||
request: Request,
|
||||
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
provider: Annotated[Provider, Depends(get_auth_provider)],
|
||||
) -> RedirectResponse:
|
||||
# Clear session
|
||||
request.session.pop("user_sub", None)
|
||||
# Get provider's endpoint
|
||||
if (provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")) is None:
|
||||
logger.warn(f"Cannot find end_session_endpoint for provider {oidc_provider.name}")
|
||||
if (
|
||||
provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint")
|
||||
) is None:
|
||||
logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
|
||||
return RedirectResponse(request.url_for("non_compliant_logout"))
|
||||
post_logout_uri = request.url_for("home")
|
||||
oidc_provider_settings = oidc_providers_settings.get(
|
||||
request.session.get("oidc_provider_id", "")
|
||||
)
|
||||
assert oidc_provider_settings is not None
|
||||
try:
|
||||
token = await db.get_token(oidc_provider_settings, request.session.pop("sid", None))
|
||||
token = await db.get_token(provider, request.session.pop("sid", None))
|
||||
except TokenNotInDb:
|
||||
logger.warn("No session in db for the token or no token")
|
||||
return RedirectResponse(request.url_for("home"))
|
||||
|
@ -270,30 +242,30 @@ async def logout(
|
|||
@app.get("/non-compliant-logout")
|
||||
async def non_compliant_logout(
|
||||
request: Request,
|
||||
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
provider: Annotated[StarletteOAuth2App, Depends(get_auth_provider)],
|
||||
):
|
||||
"""A page for non-compliant OAuth2 servers that we cannot log out."""
|
||||
# Clear the remain of the session
|
||||
request.session.pop("oidc_provider_id", None)
|
||||
request.session.pop("auth_provider_id", None)
|
||||
return templates.TemplateResponse(
|
||||
name="non_compliant_logout.html",
|
||||
request=request,
|
||||
context={"oidc_provider": oidc_provider, "home_url": request.url_for("home")},
|
||||
context={"oidc_provider": provider, "home_url": request.url_for("home")},
|
||||
)
|
||||
|
||||
|
||||
@app.get("/refresh")
|
||||
async def refresh(
|
||||
request: Request,
|
||||
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
provider: Annotated[Provider, Depends(get_auth_provider)],
|
||||
token: Annotated[OAuth2Token, Depends(get_token)],
|
||||
) -> RedirectResponse:
|
||||
"""Manually refresh token"""
|
||||
new_token = await oidc_provider.fetch_access_token(
|
||||
new_token = await provider.authlib_client.fetch_access_token(
|
||||
refresh_token=token["refresh_token"],
|
||||
grant_type="refresh_token",
|
||||
)
|
||||
await update_token(oidc_provider.name, new_token)
|
||||
await update_token(provider.id, new_token)
|
||||
return RedirectResponse(url=request.url_for("home"))
|
||||
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ from pydantic import (
|
|||
EmailStr,
|
||||
ConfigDict,
|
||||
)
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
@ -37,7 +36,7 @@ class User(UserBase):
|
|||
userinfo: dict = {}
|
||||
access_token: str | None = None
|
||||
access_token_decoded: dict[str, Any] | None = None
|
||||
oidc_provider: StarletteOAuth2App | None = None
|
||||
auth_provider_id: str
|
||||
|
||||
@computed_field
|
||||
@cached_property
|
||||
|
@ -56,11 +55,10 @@ class User(UserBase):
|
|||
|
||||
def decode_access_token(self, verify_signature: bool = True):
|
||||
assert self.access_token is not None
|
||||
assert self.oidc_provider is not None
|
||||
assert self.oidc_provider.name is not None
|
||||
from .settings import oidc_providers_settings
|
||||
assert self.auth_provider_id is not None
|
||||
from .auth_provider import providers
|
||||
|
||||
return oidc_providers_settings[self.oidc_provider.name].decode(
|
||||
return providers[self.auth_provider_id].decode(
|
||||
self.access_token, verify_signature=verify_signature
|
||||
)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ from datetime import datetime
|
|||
from typing import Annotated
|
||||
import logging
|
||||
|
||||
from authlib.jose import Key
|
||||
from httpx import AsyncClient
|
||||
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
|
||||
from fastapi import FastAPI, HTTPException, Depends, status
|
||||
|
@ -15,10 +16,9 @@ from .models import User
|
|||
from .auth_utils import (
|
||||
get_user_from_token,
|
||||
UserWithRole,
|
||||
# get_oidc_provider,
|
||||
# get_token,
|
||||
)
|
||||
from .settings import settings
|
||||
from .auth_provider import providers
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
|
@ -128,7 +128,10 @@ async def get_resource(resource_id: str, user: User) -> dict:
|
|||
"""
|
||||
Resource processing: build an informative rely as a simple showcase
|
||||
"""
|
||||
pname = getattr(user.oidc_provider, "name", "?")
|
||||
try:
|
||||
pname = providers[user.auth_provider_id].name
|
||||
except KeyError:
|
||||
pname = "?"
|
||||
resp = {
|
||||
"hello": f"Hi {user.name} from an OAuth resource provider",
|
||||
"comment": f"I received a request for '{resource_id}' "
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
from os import environ
|
||||
import string
|
||||
import random
|
||||
from typing import Type, Tuple, Any
|
||||
from typing import Type, Tuple
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
from jwt import decode
|
||||
from pydantic import BaseModel, computed_field, AnyUrl
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
|
@ -17,8 +15,6 @@ from starlette.requests import Request
|
|||
|
||||
from .models import User
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
|
||||
class Resource(BaseModel):
|
||||
"""A resource with an URL that can be accessed with an OAuth2 access token"""
|
||||
|
@ -27,8 +23,8 @@ class Resource(BaseModel):
|
|||
name: str
|
||||
|
||||
|
||||
class OIDCProvider(BaseModel):
|
||||
"""OIDC provider, can also be a resource server"""
|
||||
class AuthProviderSettings(BaseModel):
|
||||
"""Auth provider, can also be a resource server"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
@ -79,30 +75,6 @@ class OIDCProvider(BaseModel):
|
|||
-----END PUBLIC KEY-----
|
||||
"""
|
||||
|
||||
def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]:
|
||||
"""Decode the token with signature check"""
|
||||
if settings.debug_token:
|
||||
decoded = decode(
|
||||
token,
|
||||
self.get_public_key(),
|
||||
algorithms=[self.signature_alg],
|
||||
audience=["account", "oidc-test", "oidc-test-web"],
|
||||
options={
|
||||
"verify_signature": False,
|
||||
"verify_aud": False,
|
||||
}, # not settings.insecure.skip_verify_signature},
|
||||
)
|
||||
logger.debug(str(decoded))
|
||||
return decode(
|
||||
token,
|
||||
self.get_public_key(),
|
||||
algorithms=[self.signature_alg],
|
||||
audience=["account", "oidc-test", "oidc-test-web"],
|
||||
options={
|
||||
"verify_signature": verify_signature,
|
||||
}, # not settings.insecure.skip_verify_signature},
|
||||
)
|
||||
|
||||
|
||||
class ResourceProvider(BaseModel):
|
||||
id: str
|
||||
|
@ -111,9 +83,9 @@ class ResourceProvider(BaseModel):
|
|||
resources: list[Resource] = []
|
||||
|
||||
|
||||
class OIDCSettings(BaseModel):
|
||||
class AuthSettings(BaseModel):
|
||||
show_session_details: bool = False
|
||||
providers: list[OIDCProvider] = []
|
||||
providers: list[AuthProviderSettings] = []
|
||||
swagger_provider: str = ""
|
||||
|
||||
|
||||
|
@ -128,7 +100,7 @@ class Settings(BaseSettings):
|
|||
|
||||
model_config = SettingsConfigDict(env_nested_delimiter="__")
|
||||
|
||||
oidc: OIDCSettings = OIDCSettings()
|
||||
auth: AuthSettings = AuthSettings()
|
||||
resource_providers: list[ResourceProvider] = []
|
||||
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
|
||||
log: bool = False
|
||||
|
@ -163,8 +135,3 @@ class Settings(BaseSettings):
|
|||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
oidc_providers_settings: dict[str, OIDCProvider] = dict(
|
||||
[(provider.id, provider) for provider in settings.oidc.providers]
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
<link href="{{ url_for('static', path='/styles.css') }}" rel="stylesheet">
|
||||
<script src="{{ url_for('static', path='/utils.js') }}"></script>
|
||||
</head>
|
||||
<body onload="checkPerms('links-to-check', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">
|
||||
<body onload="checkPerms('links-to-check', '{{ access_token }}', '{{ auth_provider.id }}')">
|
||||
<h1>OIDC-test - FastAPI client</h1>
|
||||
{% block content %}
|
||||
{% endblock %}
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
<div class="login-box">
|
||||
<p class="description">Log in with:</p>
|
||||
<table class="providers">
|
||||
{% for provider in settings.oidc.providers %}
|
||||
{% for provider in settings.auth.providers %}
|
||||
<tr class="provider">
|
||||
<td>
|
||||
<a class="link" href="login/{{ provider.id }}"><div>{{ provider.name }}</div></a>
|
||||
|
@ -32,7 +32,7 @@
|
|||
<div>{{ user.email }}</div>
|
||||
<div>
|
||||
<span>Provider:</span>
|
||||
{{ oidc_provider_settings.name }}
|
||||
{{ auth_provider.name }}
|
||||
</div>
|
||||
{% if user.roles %}
|
||||
<div>
|
||||
|
@ -50,9 +50,9 @@
|
|||
{% endfor %}
|
||||
</div>
|
||||
{% endif %}
|
||||
{% if oidc_provider_settings.account_url_template %}
|
||||
{% if auth_provider.account_url_template %}
|
||||
<button
|
||||
onclick="location.href='{{ oidc_provider_settings.get_account_url(request, user) }}'"
|
||||
onclick="location.href='{{ auth_provider.get_account_url(request, user) }}'"
|
||||
class="account">
|
||||
Account management
|
||||
</button>
|
||||
|
@ -67,21 +67,21 @@
|
|||
Resources validated by scope:
|
||||
</p>
|
||||
<div class="links-to-check">
|
||||
<button resource-id="time" onclick="get_resource('time', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Time</button>
|
||||
<button resource-id="bs" onclick="get_resource('bs', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">BS</button>
|
||||
<button resource-id="time" onclick="get_resource('time', '{{ access_token }}', '{{ auth_provider.id }}')">Time</button>
|
||||
<button resource-id="bs" onclick="get_resource('bs', '{{ access_token }}', '{{ auth_provider.id }}')">BS</button>
|
||||
</div>
|
||||
<p>
|
||||
Resources validated by role:
|
||||
</p>
|
||||
<div class="links-to-check">
|
||||
<button resource-id="public" onclick="get_resource('public', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Public</button>
|
||||
<button resource-id="protected" onclick="get_resource('protected', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth protected content</button>
|
||||
<button resource-id="protected-by-foorole" onclick="get_resource('protected-by-foorole', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth + foorole protected content</button>
|
||||
<button resource-id="protected-by-foorole-or-barrole" onclick="get_resource('protected-by-foorole-or-barrole', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth + foorole or barrole protected content</button>
|
||||
<button resource-id="protected-by-barrole" onclick="get_resource('protected-by-barrole', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth + barrole protected content</button>
|
||||
<button resource-id="protected-by-foorole-and-barrole" onclick="get_resource('protected-by-foorole-and-barrole', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth + foorole and barrole protected content</button>
|
||||
<button resource-id="fast_api_depends" class="hidden" onclick="get_resource('fast_api_depends', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Using FastAPI Depends</button>
|
||||
<!--<button resource-id="introspect" onclick="get_resource('introspect', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Introspect token (401 expected)</button>-->
|
||||
<button resource-id="public" onclick="get_resource('public', '{{ access_token }}', '{{ auth_provider.id }}')">Public</button>
|
||||
<button resource-id="protected" onclick="get_resource('protected', '{{ access_token }}', '{{ auth_provider.id }}')">Auth protected content</button>
|
||||
<button resource-id="protected-by-foorole" onclick="get_resource('protected-by-foorole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole protected content</button>
|
||||
<button resource-id="protected-by-foorole-or-barrole" onclick="get_resource('protected-by-foorole-or-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole or barrole protected content</button>
|
||||
<button resource-id="protected-by-barrole" onclick="get_resource('protected-by-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + barrole protected content</button>
|
||||
<button resource-id="protected-by-foorole-and-barrole" onclick="get_resource('protected-by-foorole-and-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole and barrole protected content</button>
|
||||
<button resource-id="fast_api_depends" class="hidden" onclick="get_resource('fast_api_depends', '{{ access_token }}', '{{ auth_provider.id }}')">Using FastAPI Depends</button>
|
||||
<!--<button resource-id="introspect" onclick="get_resource('introspect', '{{ access_token }}', '{{ auth_provider.id }}')">Introspect token (401 expected)</button>-->
|
||||
</div>
|
||||
<div class="resourceResult">
|
||||
<div id="resource" class="resource"></div>
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue