Container, bug fixes
This commit is contained in:
parent
23f180e521
commit
57041e9233
14 changed files with 512 additions and 291 deletions
29
src/oidc_test/auth_misc.py
Normal file
29
src/oidc_test/auth_misc.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
from datetime import datetime, timedelta
|
||||
from collections import OrderedDict
|
||||
|
||||
from .models import User
|
||||
|
||||
time_keys = set(("iat", "exp", "auth_time", "updated_at"))
|
||||
|
||||
|
||||
def pretty_details(user: User, now: datetime) -> OrderedDict:
|
||||
details = OrderedDict()
|
||||
# breakpoint()
|
||||
for key in sorted(time_keys):
|
||||
try:
|
||||
dt = datetime.fromtimestamp(user.userinfo[key])
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
else:
|
||||
td = now - dt
|
||||
td = timedelta(days=td.days, seconds=td.seconds)
|
||||
if td.days < 0:
|
||||
ptd = f"in {-td} h:m:s"
|
||||
else:
|
||||
ptd = f"{td} h:m:s ago"
|
||||
details[key] = f"{user.userinfo[key]} - {dt} ({ptd})"
|
||||
for key in sorted(user.userinfo):
|
||||
if key in time_keys:
|
||||
continue
|
||||
details[key] = user.userinfo[key]
|
||||
return details
|
|
@ -1,21 +1,68 @@
|
|||
from typing import Union
|
||||
from functools import wraps
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
|
||||
|
||||
from .models import User
|
||||
# from authlib.oauth1.auth import OAuthToken
|
||||
# from authlib.oauth2.auth import OAuth2Token
|
||||
|
||||
from .models import OAuth2Token, User
|
||||
from .database import db
|
||||
from .settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OIDC_PROVIDERS = set([provider.name for provider in settings.oidc.providers])
|
||||
|
||||
|
||||
def get_current_user(request: Request) -> User:
|
||||
def get_provider(request: Request) -> StarletteOAuth2App:
|
||||
"""Return the oidc_provider from a request object, from the session.
|
||||
It can be used in Depends()"""
|
||||
if (oidc_provider_id := request.session.get("oidc_provider_id")) is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
"Not logged in (no provider in session)",
|
||||
)
|
||||
try:
|
||||
return getattr(authlib_oauth, str(oidc_provider_id))
|
||||
except AttributeError:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||
|
||||
|
||||
async def get_current_user(request: Request) -> User:
|
||||
"""Get the current user from a request object.
|
||||
Also validates the token expiration time.
|
||||
... TODO: complete about refresh token
|
||||
"""
|
||||
if (user_sub := request.session.get("user_sub")) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
return db.get_user(user_sub)
|
||||
if (token := await db.get_token(request.session["token"])) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token unknown")
|
||||
user = await db.get_user(user_sub)
|
||||
## Check if the token is expired
|
||||
if token.expires_at < datetime.timestamp(datetime.now()):
|
||||
oidc_provider = get_provider(request=request)
|
||||
## Ask a new refresh token from the provider
|
||||
logger.info(f"Token expired for user {user.name}")
|
||||
try:
|
||||
userinfo = await oidc_provider.fetch_access_token(
|
||||
refresh_token=token.access_token
|
||||
)
|
||||
except OAuthError as err:
|
||||
logger.exception(err)
|
||||
# raise HTTPException(
|
||||
# status.HTTP_401_UNAUTHORIZED, "Token expired, cannot refresh"
|
||||
# )
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def get_current_user_or_none(request: Request) -> User | None:
|
||||
async def get_current_user_or_none(request: Request) -> User | None:
|
||||
try:
|
||||
return get_current_user(request)
|
||||
return await get_current_user(request)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
@ -35,7 +82,7 @@ def hasrole(required_roles: Union[str, list[str]] = []):
|
|||
500,
|
||||
"Functions decorated with hasrole must have a request:Request argument",
|
||||
)
|
||||
user: User = get_current_user(request)
|
||||
user: User = await get_current_user(request)
|
||||
if not any(required_roles_set.intersection(user.roles_as_set)):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
return await func(request, *args, **kwargs)
|
||||
|
@ -43,3 +90,31 @@ def hasrole(required_roles: Union[str, list[str]] = []):
|
|||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_token_info(token: dict) -> dict:
|
||||
token_info = dict()
|
||||
for key in token:
|
||||
if key != "userinfo":
|
||||
token_info[key] = token[key]
|
||||
return token_info
|
||||
|
||||
|
||||
def fetch_token(name, request):
|
||||
breakpoint()
|
||||
...
|
||||
# if name in OIDC_PROVIDERS:
|
||||
# model = OAuth2Token
|
||||
# else:
|
||||
# model = OAuthToken
|
||||
|
||||
# token = model.find(name=name, user=request.user)
|
||||
# return token.to_token()
|
||||
|
||||
|
||||
def update_token(*args, **kwargs):
|
||||
breakpoint()
|
||||
...
|
||||
|
||||
|
||||
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
|
||||
|
|
|
@ -2,11 +2,12 @@
|
|||
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
|
||||
from .models import User
|
||||
from .models import User, OAuth2Token
|
||||
|
||||
|
||||
class Database:
|
||||
users: dict[str, User] = {}
|
||||
tokens: dict[str, OAuth2Token] = {}
|
||||
|
||||
# Last sessions for the user (key: users's subject id (sub))
|
||||
|
||||
|
@ -17,8 +18,17 @@ class Database:
|
|||
self.users[sub] = user
|
||||
return user
|
||||
|
||||
def get_user(self, sub: str) -> User:
|
||||
async def get_user(self, sub: str) -> User:
|
||||
return self.users[sub]
|
||||
|
||||
async def add_token(self, token_dict: dict, user: User) -> None:
|
||||
# FIXME: The tokens are stored with the user.sub key, meaning that
|
||||
# sessions logged in with different clients simultanously will
|
||||
# interfer with ezach others.
|
||||
self.tokens[user.sub] = OAuth2Token.from_dict(token_dict=token_dict, user=user)
|
||||
|
||||
async def get_token(self, name) -> OAuth2Token | None:
|
||||
return self.tokens.get(name)
|
||||
|
||||
|
||||
db = Database()
|
||||
|
|
|
@ -1,18 +1,35 @@
|
|||
from typing import Annotated
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from httpx import HTTPError
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, Response, status
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.security import OpenIdConnect
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
from authlib.integrations.starlette_client import OAuth, OAuthError
|
||||
|
||||
# authlib startlette integration does not support revocation: using requests
|
||||
# from authlib.integrations.requests_client import OAuth2Session
|
||||
|
||||
from .settings import settings
|
||||
from .models import User
|
||||
from .auth_utils import hasrole, get_current_user_or_none, get_current_user
|
||||
from .auth_utils import (
|
||||
get_provider,
|
||||
hasrole,
|
||||
get_current_user_or_none,
|
||||
get_current_user,
|
||||
authlib_oauth,
|
||||
)
|
||||
from .auth_misc import pretty_details
|
||||
from .database import db
|
||||
|
||||
# logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
templates = Jinja2Templates("src/templates")
|
||||
|
||||
|
||||
|
@ -20,6 +37,7 @@ app = FastAPI(
|
|||
title="OIDC auth test",
|
||||
)
|
||||
|
||||
|
||||
# SessionMiddleware is required by authlib
|
||||
app.add_middleware(
|
||||
SessionMiddleware,
|
||||
|
@ -27,35 +45,58 @@ app.add_middleware(
|
|||
)
|
||||
|
||||
# Add oidc providers to authlib from the settings
|
||||
authlib_oauth = OAuth()
|
||||
|
||||
fastapi_providers = {}
|
||||
_providers = {}
|
||||
|
||||
for provider in settings.oidc.providers:
|
||||
authlib_oauth.register(
|
||||
name=provider.name,
|
||||
server_metadata_url=provider.provider_url,
|
||||
server_metadata_url=provider.openid_configuration,
|
||||
client_kwargs={
|
||||
"scope": "openid email offline_access profile roles",
|
||||
},
|
||||
client_id=provider.client_id,
|
||||
client_secret=provider.client_secret,
|
||||
# fetch_token=fetch_token,
|
||||
# update_token=update_token,
|
||||
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
||||
)
|
||||
fastapi_providers[provider.name] = OpenIdConnect(
|
||||
openIdConnectUrl=provider.openid_configuration
|
||||
)
|
||||
_providers[provider.name] = provider
|
||||
|
||||
|
||||
@app.get("/login")
|
||||
async def login(request: Request, provider: str) -> RedirectResponse:
|
||||
redirect_uri = request.url_for("auth", oidc_provider_id=provider)
|
||||
# Endpoints for the login / authorization process
|
||||
|
||||
|
||||
@app.get("/login/{oidc_provider_id}")
|
||||
async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
||||
"""Login with the provider name,
|
||||
by giving the browser a redirect to its authorize page.
|
||||
After successful authentification, the provider replies with an encrypted
|
||||
auth token that only we can decode and contains userinfo,
|
||||
and a redirect to our own /auth/{oidc_provider_id} url
|
||||
"""
|
||||
redirect_uri = request.url_for("auth", oidc_provider_id=oidc_provider_id)
|
||||
try:
|
||||
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
|
||||
provider_: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
|
||||
except AttributeError:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||
try:
|
||||
return await provider_.authorize_redirect(request, redirect_uri)
|
||||
return await provider_.authorize_redirect(
|
||||
request, redirect_uri, access_type="offline"
|
||||
)
|
||||
except HTTPError:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
|
||||
|
||||
|
||||
@app.get("/auth/{oidc_provider_id}")
|
||||
async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
||||
"""Decrypt the auth token, store it to the session (cookie based)
|
||||
and response to the browser with a redirect to a "welcome user" page.
|
||||
"""
|
||||
try:
|
||||
oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
|
||||
except AttributeError:
|
||||
|
@ -76,35 +117,82 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
|||
# Build and remember the user in the session
|
||||
request.session["user_sub"] = sub
|
||||
# Store the user in the database
|
||||
await db.add_user(sub, user_info=userinfo, oidc_provider=oidc_provider)
|
||||
user = await db.add_user(sub, user_info=userinfo, oidc_provider=oidc_provider)
|
||||
request.session["token"] = userinfo["sub"]
|
||||
await db.add_token(token, user)
|
||||
return RedirectResponse(url="/")
|
||||
else:
|
||||
# Not sure if it's correct to redirect to plain login (which is not implemented anyway)
|
||||
# Not sure if it's correct to redirect to plain login
|
||||
# if no userinfo is provided
|
||||
return RedirectResponse(url="/login")
|
||||
redirect_uri = request.url_for("login", oidc_provider_id=oidc_provider_id)
|
||||
return RedirectResponse(url=redirect_uri)
|
||||
|
||||
|
||||
@app.get("/non-compliant-logout")
|
||||
async def non_compliant_logout(
|
||||
request: Request,
|
||||
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
|
||||
):
|
||||
"""A page for non-compliant OAuth2 servers that we cannot log out."""
|
||||
return templates.TemplateResponse(
|
||||
name="non_compliant_logout.html",
|
||||
request=request,
|
||||
context={"provider": provider},
|
||||
)
|
||||
|
||||
|
||||
@app.get("/logout")
|
||||
async def logout(
|
||||
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
|
||||
request: Request,
|
||||
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
|
||||
) -> RedirectResponse:
|
||||
# TODO: logout from oidc_provider
|
||||
# await user.oidc_provider.logout_redirect()
|
||||
# Clear session
|
||||
request.session.pop("user_sub", None)
|
||||
return RedirectResponse(url="/")
|
||||
# Get provider's endpoint
|
||||
if (
|
||||
provider_logout_uri := provider.server_metadata.get("end_session_endpoint")
|
||||
) is None:
|
||||
logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
|
||||
return RedirectResponse(request.url_for("non_compliant_logout"))
|
||||
post_logout_uri = request.url_for("home")
|
||||
if (id_token := await db.get_token(request.session["token"])) is None:
|
||||
logger.warn("No session in db for the token")
|
||||
return RedirectResponse(request.url_for("home"))
|
||||
logout_url = (
|
||||
provider_logout_uri
|
||||
+ "?"
|
||||
+ urlencode(
|
||||
{
|
||||
"post_logout_redirect_uri": post_logout_uri,
|
||||
"id_token_hint": id_token.raw_id_token,
|
||||
"cliend_id": "oidc_local_test",
|
||||
}
|
||||
)
|
||||
)
|
||||
return RedirectResponse(logout_url)
|
||||
|
||||
|
||||
# Home URL
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def home(
|
||||
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
|
||||
) -> HTMLResponse:
|
||||
now = datetime.now()
|
||||
return templates.TemplateResponse(
|
||||
name="home.html",
|
||||
request=request,
|
||||
context={
|
||||
"settings": settings.model_dump(),
|
||||
"user": user,
|
||||
"now": now,
|
||||
"user_info_details": (
|
||||
pretty_details(user, now)
|
||||
if user and settings.oidc.show_session_details
|
||||
else None
|
||||
),
|
||||
},
|
||||
name="index.html",
|
||||
)
|
||||
|
||||
|
||||
|
@ -113,6 +201,9 @@ async def public() -> HTMLResponse:
|
|||
return HTMLResponse("<h1>Not protected</h1>")
|
||||
|
||||
|
||||
# Some URIs for testing the permissions
|
||||
|
||||
|
||||
@app.get("/protected")
|
||||
async def get_protected(
|
||||
user: Annotated[User, Depends(get_current_user)]
|
||||
|
@ -143,3 +234,40 @@ async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse
|
|||
@hasrole(["foorole", "barrole"])
|
||||
async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
|
||||
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")
|
||||
|
||||
|
||||
# @app.get("/fast_api_depends")
|
||||
# def fast_api_depends(
|
||||
# token: Annotated[str, Depends(fastapi_providers["Keycloak"])]
|
||||
# ) -> HTMLResponse:
|
||||
# return HTMLResponse("You're Authenticated")
|
||||
|
||||
|
||||
def main():
|
||||
from uvicorn import run
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="Address to listen to (default: 0.0.0.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p", "--port", type=int, default=80, help="Port to listen to (default: 80)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--version", action="store_true", help="Print version and exit"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.version:
|
||||
import sys
|
||||
from importlib.metadata import version
|
||||
|
||||
print(version("sms_handler"))
|
||||
sys.exit(0)
|
||||
|
||||
run(app, host=args.host, port=args.port)
|
||||
|
|
|
@ -1,19 +1,19 @@
|
|||
from functools import cached_property
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel, EmailStr, AnyHttpUrl, Field, computed_field
|
||||
from pydantic import computed_field, AnyHttpUrl, EmailStr
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
|
||||
# from app.models import User
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
|
||||
class Role(BaseModel, extra="ignore"):
|
||||
class Role(SQLModel, extra="ignore"):
|
||||
name: str
|
||||
|
||||
|
||||
class UserBase(BaseModel, extra="ignore"):
|
||||
class UserBase(SQLModel, extra="ignore"):
|
||||
|
||||
id: str | None = None
|
||||
sid: str | None = None
|
||||
name: str
|
||||
email: EmailStr | None = None
|
||||
picture: AnyHttpUrl | None = None
|
||||
|
@ -46,3 +46,33 @@ class User(UserBase):
|
|||
@cached_property
|
||||
def roles_as_set(self) -> set[str]:
|
||||
return set([role.name for role in self.roles])
|
||||
|
||||
|
||||
class OAuth2Token(SQLModel):
|
||||
name: str = Field(primary_key=True)
|
||||
token_type: str # = Field(max_length=40)
|
||||
access_token: str # = Field(max_length=2000)
|
||||
raw_id_token: str
|
||||
refresh_token: str # = Field(max_length=200)
|
||||
expires_at: int # = PositiveIntegerField()
|
||||
user: User # = ForeignKey(User)
|
||||
|
||||
def to_token(self):
|
||||
return dict(
|
||||
access_token=self.access_token,
|
||||
token_type=self.token_type,
|
||||
refresh_token=self.refresh_token,
|
||||
expires_at=self.expires_at,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, token_dict: dict, user: User) -> Self:
|
||||
return cls(
|
||||
name=token_dict["access_token"],
|
||||
access_token=token_dict["access_token"],
|
||||
raw_id_token=token_dict["id_token"],
|
||||
token_type=token_dict["token_type"],
|
||||
refresh_token=token_dict["refresh_token"],
|
||||
expires_at=token_dict["expires_at"],
|
||||
user=user,
|
||||
)
|
||||
|
|
|
@ -18,7 +18,7 @@ class OIDCProvider(BaseModel):
|
|||
|
||||
@computed_field
|
||||
@property
|
||||
def provider_url(self) -> str:
|
||||
def openid_configuration(self) -> str:
|
||||
return self.url + "/.well-known/openid-configuration"
|
||||
|
||||
@computed_field
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue