Container, bug fixes

This commit is contained in:
phil 2025-01-09 23:41:32 +01:00
parent 23f180e521
commit 57041e9233
14 changed files with 512 additions and 291 deletions

View file

@ -0,0 +1,29 @@
from datetime import datetime, timedelta
from collections import OrderedDict
from .models import User
time_keys = set(("iat", "exp", "auth_time", "updated_at"))
def pretty_details(user: User, now: datetime) -> OrderedDict:
details = OrderedDict()
# breakpoint()
for key in sorted(time_keys):
try:
dt = datetime.fromtimestamp(user.userinfo[key])
except (KeyError, TypeError):
pass
else:
td = now - dt
td = timedelta(days=td.days, seconds=td.seconds)
if td.days < 0:
ptd = f"in {-td} h:m:s"
else:
ptd = f"{td} h:m:s ago"
details[key] = f"{user.userinfo[key]} - {dt} ({ptd})"
for key in sorted(user.userinfo):
if key in time_keys:
continue
details[key] = user.userinfo[key]
return details

View file

@ -1,21 +1,68 @@
from typing import Union
from functools import wraps
from datetime import datetime
import logging
from fastapi import HTTPException, Request, status
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
from .models import User
# from authlib.oauth1.auth import OAuthToken
# from authlib.oauth2.auth import OAuth2Token
from .models import OAuth2Token, User
from .database import db
from .settings import settings
logger = logging.getLogger(__name__)
OIDC_PROVIDERS = set([provider.name for provider in settings.oidc.providers])
def get_current_user(request: Request) -> User:
def get_provider(request: Request) -> StarletteOAuth2App:
"""Return the oidc_provider from a request object, from the session.
It can be used in Depends()"""
if (oidc_provider_id := request.session.get("oidc_provider_id")) is None:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
"Not logged in (no provider in session)",
)
try:
return getattr(authlib_oauth, str(oidc_provider_id))
except AttributeError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
async def get_current_user(request: Request) -> User:
"""Get the current user from a request object.
Also validates the token expiration time.
... TODO: complete about refresh token
"""
if (user_sub := request.session.get("user_sub")) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
return db.get_user(user_sub)
if (token := await db.get_token(request.session["token"])) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token unknown")
user = await db.get_user(user_sub)
## Check if the token is expired
if token.expires_at < datetime.timestamp(datetime.now()):
oidc_provider = get_provider(request=request)
## Ask a new refresh token from the provider
logger.info(f"Token expired for user {user.name}")
try:
userinfo = await oidc_provider.fetch_access_token(
refresh_token=token.access_token
)
except OAuthError as err:
logger.exception(err)
# raise HTTPException(
# status.HTTP_401_UNAUTHORIZED, "Token expired, cannot refresh"
# )
return user
def get_current_user_or_none(request: Request) -> User | None:
async def get_current_user_or_none(request: Request) -> User | None:
try:
return get_current_user(request)
return await get_current_user(request)
except HTTPException:
return None
@ -35,7 +82,7 @@ def hasrole(required_roles: Union[str, list[str]] = []):
500,
"Functions decorated with hasrole must have a request:Request argument",
)
user: User = get_current_user(request)
user: User = await get_current_user(request)
if not any(required_roles_set.intersection(user.roles_as_set)):
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
return await func(request, *args, **kwargs)
@ -43,3 +90,31 @@ def hasrole(required_roles: Union[str, list[str]] = []):
return wrapper
return decorator
def get_token_info(token: dict) -> dict:
token_info = dict()
for key in token:
if key != "userinfo":
token_info[key] = token[key]
return token_info
def fetch_token(name, request):
breakpoint()
...
# if name in OIDC_PROVIDERS:
# model = OAuth2Token
# else:
# model = OAuthToken
# token = model.find(name=name, user=request.user)
# return token.to_token()
def update_token(*args, **kwargs):
breakpoint()
...
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)

View file

@ -2,11 +2,12 @@
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from .models import User
from .models import User, OAuth2Token
class Database:
users: dict[str, User] = {}
tokens: dict[str, OAuth2Token] = {}
# Last sessions for the user (key: users's subject id (sub))
@ -17,8 +18,17 @@ class Database:
self.users[sub] = user
return user
def get_user(self, sub: str) -> User:
async def get_user(self, sub: str) -> User:
return self.users[sub]
async def add_token(self, token_dict: dict, user: User) -> None:
# FIXME: The tokens are stored with the user.sub key, meaning that
# sessions logged in with different clients simultanously will
# interfer with ezach others.
self.tokens[user.sub] = OAuth2Token.from_dict(token_dict=token_dict, user=user)
async def get_token(self, name) -> OAuth2Token | None:
return self.tokens.get(name)
db = Database()

View file

@ -1,18 +1,35 @@
from typing import Annotated
from datetime import datetime
import logging
from urllib.parse import urlencode
from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi import Depends, FastAPI, HTTPException, Request, Response, status
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from fastapi.security import OpenIdConnect
from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.starlette_client import OAuth, OAuthError
# authlib startlette integration does not support revocation: using requests
# from authlib.integrations.requests_client import OAuth2Session
from .settings import settings
from .models import User
from .auth_utils import hasrole, get_current_user_or_none, get_current_user
from .auth_utils import (
get_provider,
hasrole,
get_current_user_or_none,
get_current_user,
authlib_oauth,
)
from .auth_misc import pretty_details
from .database import db
# logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
templates = Jinja2Templates("src/templates")
@ -20,6 +37,7 @@ app = FastAPI(
title="OIDC auth test",
)
# SessionMiddleware is required by authlib
app.add_middleware(
SessionMiddleware,
@ -27,35 +45,58 @@ app.add_middleware(
)
# Add oidc providers to authlib from the settings
authlib_oauth = OAuth()
fastapi_providers = {}
_providers = {}
for provider in settings.oidc.providers:
authlib_oauth.register(
name=provider.name,
server_metadata_url=provider.provider_url,
server_metadata_url=provider.openid_configuration,
client_kwargs={
"scope": "openid email offline_access profile roles",
},
client_id=provider.client_id,
client_secret=provider.client_secret,
# fetch_token=fetch_token,
# update_token=update_token,
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
)
fastapi_providers[provider.name] = OpenIdConnect(
openIdConnectUrl=provider.openid_configuration
)
_providers[provider.name] = provider
@app.get("/login")
async def login(request: Request, provider: str) -> RedirectResponse:
redirect_uri = request.url_for("auth", oidc_provider_id=provider)
# Endpoints for the login / authorization process
@app.get("/login/{oidc_provider_id}")
async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
"""Login with the provider name,
by giving the browser a redirect to its authorize page.
After successful authentification, the provider replies with an encrypted
auth token that only we can decode and contains userinfo,
and a redirect to our own /auth/{oidc_provider_id} url
"""
redirect_uri = request.url_for("auth", oidc_provider_id=oidc_provider_id)
try:
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
provider_: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
except AttributeError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
try:
return await provider_.authorize_redirect(request, redirect_uri)
return await provider_.authorize_redirect(
request, redirect_uri, access_type="offline"
)
except HTTPError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
@app.get("/auth/{oidc_provider_id}")
async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
"""Decrypt the auth token, store it to the session (cookie based)
and response to the browser with a redirect to a "welcome user" page.
"""
try:
oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
except AttributeError:
@ -76,35 +117,82 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
# Build and remember the user in the session
request.session["user_sub"] = sub
# Store the user in the database
await db.add_user(sub, user_info=userinfo, oidc_provider=oidc_provider)
user = await db.add_user(sub, user_info=userinfo, oidc_provider=oidc_provider)
request.session["token"] = userinfo["sub"]
await db.add_token(token, user)
return RedirectResponse(url="/")
else:
# Not sure if it's correct to redirect to plain login (which is not implemented anyway)
# Not sure if it's correct to redirect to plain login
# if no userinfo is provided
return RedirectResponse(url="/login")
redirect_uri = request.url_for("login", oidc_provider_id=oidc_provider_id)
return RedirectResponse(url=redirect_uri)
@app.get("/non-compliant-logout")
async def non_compliant_logout(
request: Request,
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
):
"""A page for non-compliant OAuth2 servers that we cannot log out."""
return templates.TemplateResponse(
name="non_compliant_logout.html",
request=request,
context={"provider": provider},
)
@app.get("/logout")
async def logout(
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
request: Request,
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
) -> RedirectResponse:
# TODO: logout from oidc_provider
# await user.oidc_provider.logout_redirect()
# Clear session
request.session.pop("user_sub", None)
return RedirectResponse(url="/")
# Get provider's endpoint
if (
provider_logout_uri := provider.server_metadata.get("end_session_endpoint")
) is None:
logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
return RedirectResponse(request.url_for("non_compliant_logout"))
post_logout_uri = request.url_for("home")
if (id_token := await db.get_token(request.session["token"])) is None:
logger.warn("No session in db for the token")
return RedirectResponse(request.url_for("home"))
logout_url = (
provider_logout_uri
+ "?"
+ urlencode(
{
"post_logout_redirect_uri": post_logout_uri,
"id_token_hint": id_token.raw_id_token,
"cliend_id": "oidc_local_test",
}
)
)
return RedirectResponse(logout_url)
# Home URL
@app.get("/")
async def home(
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
) -> HTMLResponse:
now = datetime.now()
return templates.TemplateResponse(
name="home.html",
request=request,
context={
"settings": settings.model_dump(),
"user": user,
"now": now,
"user_info_details": (
pretty_details(user, now)
if user and settings.oidc.show_session_details
else None
),
},
name="index.html",
)
@ -113,6 +201,9 @@ async def public() -> HTMLResponse:
return HTMLResponse("<h1>Not protected</h1>")
# Some URIs for testing the permissions
@app.get("/protected")
async def get_protected(
user: Annotated[User, Depends(get_current_user)]
@ -143,3 +234,40 @@ async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse
@hasrole(["foorole", "barrole"])
async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")
# @app.get("/fast_api_depends")
# def fast_api_depends(
# token: Annotated[str, Depends(fastapi_providers["Keycloak"])]
# ) -> HTMLResponse:
# return HTMLResponse("You're Authenticated")
def main():
from uvicorn import run
from argparse import ArgumentParser
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"-l",
"--host",
type=str,
default="0.0.0.0",
help="Address to listen to (default: 0.0.0.0)",
)
parser.add_argument(
"-p", "--port", type=int, default=80, help="Port to listen to (default: 80)"
)
parser.add_argument(
"-v", "--version", action="store_true", help="Print version and exit"
)
args = parser.parse_args()
if args.version:
import sys
from importlib.metadata import version
print(version("sms_handler"))
sys.exit(0)
run(app, host=args.host, port=args.port)

View file

@ -1,19 +1,19 @@
from functools import cached_property
from typing import Self
from pydantic import BaseModel, EmailStr, AnyHttpUrl, Field, computed_field
from pydantic import computed_field, AnyHttpUrl, EmailStr
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
# from app.models import User
from sqlmodel import SQLModel, Field
class Role(BaseModel, extra="ignore"):
class Role(SQLModel, extra="ignore"):
name: str
class UserBase(BaseModel, extra="ignore"):
class UserBase(SQLModel, extra="ignore"):
id: str | None = None
sid: str | None = None
name: str
email: EmailStr | None = None
picture: AnyHttpUrl | None = None
@ -46,3 +46,33 @@ class User(UserBase):
@cached_property
def roles_as_set(self) -> set[str]:
return set([role.name for role in self.roles])
class OAuth2Token(SQLModel):
name: str = Field(primary_key=True)
token_type: str # = Field(max_length=40)
access_token: str # = Field(max_length=2000)
raw_id_token: str
refresh_token: str # = Field(max_length=200)
expires_at: int # = PositiveIntegerField()
user: User # = ForeignKey(User)
def to_token(self):
return dict(
access_token=self.access_token,
token_type=self.token_type,
refresh_token=self.refresh_token,
expires_at=self.expires_at,
)
@classmethod
def from_dict(cls, token_dict: dict, user: User) -> Self:
return cls(
name=token_dict["access_token"],
access_token=token_dict["access_token"],
raw_id_token=token_dict["id_token"],
token_type=token_dict["token_type"],
refresh_token=token_dict["refresh_token"],
expires_at=token_dict["expires_at"],
user=user,
)

View file

@ -18,7 +18,7 @@ class OIDCProvider(BaseModel):
@computed_field
@property
def provider_url(self) -> str:
def openid_configuration(self) -> str:
return self.url + "/.well-known/openid-configuration"
@computed_field