- These links should get different response codes depending on the authorization: -
-Session details
--
- {% for key, value in user.userinfo.items() %}
-
- - {{ key }}: {{ value }} - - {% endfor %} -
diff --git a/.containerignore b/.containerignore new file mode 100644 index 0000000..df107bc --- /dev/null +++ b/.containerignore @@ -0,0 +1,2 @@ +.venv +settings.yaml diff --git a/.python-version b/.python-version index 2c07333..24ee5b1 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.11 +3.13 diff --git a/Containerfile b/Containerfile new file mode 100644 index 0000000..dfd3160 --- /dev/null +++ b/Containerfile @@ -0,0 +1,14 @@ +FROM docker.io/library/python:alpine + +COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/ + +COPY . /src + +# Sync the project into a new environment, using the frozen lockfile +WORKDIR /src + +RUN uv sync --frozen --no-cache && uv pip install --system . + +#ENV PATH="/src/.venv/bin:$PATH" + +CMD ["oidc-test", "--port", "80"] diff --git a/pyproject.toml b/pyproject.toml index ca69015..ca7b5eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "fastapi-oidc-test" version = "0.1.0" description = "Add your description here" readme = "README.md" -requires-python = ">=3.11" +requires-python = ">=3.13" dependencies = [ "authlib>=1.4.0", "cachetools>=5.5.0", @@ -12,10 +12,22 @@ dependencies = [ "passlib[bcrypt]>=1.7.4", "pydantic-settings>=2.7.1", "python-jose[cryptography]>=3.3.0", + "requests>=2.32.3", + "sqlmodel>=0.0.22", ] -[tool.uv.sources] -fastapi = { path = "../fastapi", editable = true } +[project.scripts] +oidc-test = "oidc_test.main:main" [dependency-groups] dev = ["ipdb>=0.13.13"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/oidc_test"] + +[tool.uv] +package = true diff --git a/src/oidc_test/auth_misc.py b/src/oidc_test/auth_misc.py new file mode 100644 index 0000000..a4e9ea3 --- /dev/null +++ b/src/oidc_test/auth_misc.py @@ -0,0 +1,29 @@ +from datetime import datetime, timedelta +from collections import OrderedDict + +from .models import User + +time_keys = set(("iat", "exp", "auth_time", "updated_at")) + + +def pretty_details(user: User, now: datetime) -> OrderedDict: + details = OrderedDict() + # breakpoint() + for key in sorted(time_keys): + try: + dt = datetime.fromtimestamp(user.userinfo[key]) + except (KeyError, TypeError): + pass + else: + td = now - dt + td = timedelta(days=td.days, seconds=td.seconds) + if td.days < 0: + ptd = f"in {-td} h:m:s" + else: + ptd = f"{td} h:m:s ago" + details[key] = f"{user.userinfo[key]} - {dt} ({ptd})" + for key in sorted(user.userinfo): + if key in time_keys: + continue + details[key] = user.userinfo[key] + return details diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 668bfd8..60f1e02 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -1,21 +1,68 @@ from typing import Union from functools import wraps +from datetime import datetime +import logging from fastapi import HTTPException, Request, status +from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App -from .models import User +# from authlib.oauth1.auth import OAuthToken +# from authlib.oauth2.auth import OAuth2Token + +from .models import OAuth2Token, User from .database import db +from .settings import settings + +logger = logging.getLogger(__name__) + +OIDC_PROVIDERS = set([provider.name for provider in settings.oidc.providers]) -def get_current_user(request: Request) -> User: +def get_provider(request: Request) -> StarletteOAuth2App: + """Return the oidc_provider from a request object, from the session. + It can be used in Depends()""" + if (oidc_provider_id := request.session.get("oidc_provider_id")) is None: + raise HTTPException( + status.HTTP_500_INTERNAL_SERVER_ERROR, + "Not logged in (no provider in session)", + ) + try: + return getattr(authlib_oauth, str(oidc_provider_id)) + except AttributeError: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") + + +async def get_current_user(request: Request) -> User: + """Get the current user from a request object. + Also validates the token expiration time. + ... TODO: complete about refresh token + """ if (user_sub := request.session.get("user_sub")) is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED) - return db.get_user(user_sub) + if (token := await db.get_token(request.session["token"])) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token unknown") + user = await db.get_user(user_sub) + ## Check if the token is expired + if token.expires_at < datetime.timestamp(datetime.now()): + oidc_provider = get_provider(request=request) + ## Ask a new refresh token from the provider + logger.info(f"Token expired for user {user.name}") + try: + userinfo = await oidc_provider.fetch_access_token( + refresh_token=token.access_token + ) + except OAuthError as err: + logger.exception(err) + # raise HTTPException( + # status.HTTP_401_UNAUTHORIZED, "Token expired, cannot refresh" + # ) + + return user -def get_current_user_or_none(request: Request) -> User | None: +async def get_current_user_or_none(request: Request) -> User | None: try: - return get_current_user(request) + return await get_current_user(request) except HTTPException: return None @@ -35,7 +82,7 @@ def hasrole(required_roles: Union[str, list[str]] = []): 500, "Functions decorated with hasrole must have a request:Request argument", ) - user: User = get_current_user(request) + user: User = await get_current_user(request) if not any(required_roles_set.intersection(user.roles_as_set)): raise HTTPException(status.HTTP_401_UNAUTHORIZED) return await func(request, *args, **kwargs) @@ -43,3 +90,31 @@ def hasrole(required_roles: Union[str, list[str]] = []): return wrapper return decorator + + +def get_token_info(token: dict) -> dict: + token_info = dict() + for key in token: + if key != "userinfo": + token_info[key] = token[key] + return token_info + + +def fetch_token(name, request): + breakpoint() + ... + # if name in OIDC_PROVIDERS: + # model = OAuth2Token + # else: + # model = OAuthToken + + # token = model.find(name=name, user=request.user) + # return token.to_token() + + +def update_token(*args, **kwargs): + breakpoint() + ... + + +authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index fb0b167..1aae7cc 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -2,11 +2,12 @@ from authlib.integrations.starlette_client.apps import StarletteOAuth2App -from .models import User +from .models import User, OAuth2Token class Database: users: dict[str, User] = {} + tokens: dict[str, OAuth2Token] = {} # Last sessions for the user (key: users's subject id (sub)) @@ -17,8 +18,17 @@ class Database: self.users[sub] = user return user - def get_user(self, sub: str) -> User: + async def get_user(self, sub: str) -> User: return self.users[sub] + async def add_token(self, token_dict: dict, user: User) -> None: + # FIXME: The tokens are stored with the user.sub key, meaning that + # sessions logged in with different clients simultanously will + # interfer with ezach others. + self.tokens[user.sub] = OAuth2Token.from_dict(token_dict=token_dict, user=user) + + async def get_token(self, name) -> OAuth2Token | None: + return self.tokens.get(name) + db = Database() diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 832b101..ec6ec01 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -1,18 +1,35 @@ from typing import Annotated +from datetime import datetime +import logging +from urllib.parse import urlencode from httpx import HTTPError -from fastapi import Depends, FastAPI, HTTPException, Request, status +from fastapi import Depends, FastAPI, HTTPException, Request, Response, status from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.templating import Jinja2Templates +from fastapi.security import OpenIdConnect from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client import OAuth, OAuthError +# authlib startlette integration does not support revocation: using requests +# from authlib.integrations.requests_client import OAuth2Session + from .settings import settings from .models import User -from .auth_utils import hasrole, get_current_user_or_none, get_current_user +from .auth_utils import ( + get_provider, + hasrole, + get_current_user_or_none, + get_current_user, + authlib_oauth, +) +from .auth_misc import pretty_details from .database import db +# logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + templates = Jinja2Templates("src/templates") @@ -20,6 +37,7 @@ app = FastAPI( title="OIDC auth test", ) + # SessionMiddleware is required by authlib app.add_middleware( SessionMiddleware, @@ -27,35 +45,58 @@ app.add_middleware( ) # Add oidc providers to authlib from the settings -authlib_oauth = OAuth() + +fastapi_providers = {} +_providers = {} + for provider in settings.oidc.providers: authlib_oauth.register( name=provider.name, - server_metadata_url=provider.provider_url, + server_metadata_url=provider.openid_configuration, client_kwargs={ "scope": "openid email offline_access profile roles", }, client_id=provider.client_id, client_secret=provider.client_secret, + # fetch_token=fetch_token, + # update_token=update_token, # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) ) + fastapi_providers[provider.name] = OpenIdConnect( + openIdConnectUrl=provider.openid_configuration + ) + _providers[provider.name] = provider -@app.get("/login") -async def login(request: Request, provider: str) -> RedirectResponse: - redirect_uri = request.url_for("auth", oidc_provider_id=provider) +# Endpoints for the login / authorization process + + +@app.get("/login/{oidc_provider_id}") +async def login(request: Request, oidc_provider_id: str) -> RedirectResponse: + """Login with the provider name, + by giving the browser a redirect to its authorize page. + After successful authentification, the provider replies with an encrypted + auth token that only we can decode and contains userinfo, + and a redirect to our own /auth/{oidc_provider_id} url + """ + redirect_uri = request.url_for("auth", oidc_provider_id=oidc_provider_id) try: - provider_: StarletteOAuth2App = getattr(authlib_oauth, provider) + provider_: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) except AttributeError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") try: - return await provider_.authorize_redirect(request, redirect_uri) + return await provider_.authorize_redirect( + request, redirect_uri, access_type="offline" + ) except HTTPError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider") @app.get("/auth/{oidc_provider_id}") async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: + """Decrypt the auth token, store it to the session (cookie based) + and response to the browser with a redirect to a "welcome user" page. + """ try: oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) except AttributeError: @@ -76,35 +117,82 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: # Build and remember the user in the session request.session["user_sub"] = sub # Store the user in the database - await db.add_user(sub, user_info=userinfo, oidc_provider=oidc_provider) + user = await db.add_user(sub, user_info=userinfo, oidc_provider=oidc_provider) + request.session["token"] = userinfo["sub"] + await db.add_token(token, user) return RedirectResponse(url="/") else: - # Not sure if it's correct to redirect to plain login (which is not implemented anyway) + # Not sure if it's correct to redirect to plain login # if no userinfo is provided - return RedirectResponse(url="/login") + redirect_uri = request.url_for("login", oidc_provider_id=oidc_provider_id) + return RedirectResponse(url=redirect_uri) + + +@app.get("/non-compliant-logout") +async def non_compliant_logout( + request: Request, + provider: Annotated[StarletteOAuth2App, Depends(get_provider)], +): + """A page for non-compliant OAuth2 servers that we cannot log out.""" + return templates.TemplateResponse( + name="non_compliant_logout.html", + request=request, + context={"provider": provider}, + ) @app.get("/logout") async def logout( - request: Request, user: Annotated[User, Depends(get_current_user_or_none)] + request: Request, + provider: Annotated[StarletteOAuth2App, Depends(get_provider)], ) -> RedirectResponse: - # TODO: logout from oidc_provider - # await user.oidc_provider.logout_redirect() + # Clear session request.session.pop("user_sub", None) - return RedirectResponse(url="/") + # Get provider's endpoint + if ( + provider_logout_uri := provider.server_metadata.get("end_session_endpoint") + ) is None: + logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}") + return RedirectResponse(request.url_for("non_compliant_logout")) + post_logout_uri = request.url_for("home") + if (id_token := await db.get_token(request.session["token"])) is None: + logger.warn("No session in db for the token") + return RedirectResponse(request.url_for("home")) + logout_url = ( + provider_logout_uri + + "?" + + urlencode( + { + "post_logout_redirect_uri": post_logout_uri, + "id_token_hint": id_token.raw_id_token, + "cliend_id": "oidc_local_test", + } + ) + ) + return RedirectResponse(logout_url) + + +# Home URL @app.get("/") async def home( request: Request, user: Annotated[User, Depends(get_current_user_or_none)] ) -> HTMLResponse: + now = datetime.now() return templates.TemplateResponse( + name="home.html", request=request, context={ "settings": settings.model_dump(), "user": user, + "now": now, + "user_info_details": ( + pretty_details(user, now) + if user and settings.oidc.show_session_details + else None + ), }, - name="index.html", ) @@ -113,6 +201,9 @@ async def public() -> HTMLResponse: return HTMLResponse("
Log in with one of these authentication providers:
-- These links should get different response codes depending on the authorization: -
-Session details
-