Many updates, too log to list
This commit is contained in:
parent
5c3d54c3f2
commit
23f180e521
6 changed files with 0 additions and 0 deletions
0
src/oidc_test/__init__.py
Normal file
0
src/oidc_test/__init__.py
Normal file
45
src/oidc_test/auth_utils.py
Normal file
45
src/oidc_test/auth_utils.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
from typing import Union
|
||||
from functools import wraps
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
from .models import User
|
||||
from .database import db
|
||||
|
||||
|
||||
def get_current_user(request: Request) -> User:
|
||||
if (user_sub := request.session.get("user_sub")) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
return db.get_user(user_sub)
|
||||
|
||||
|
||||
def get_current_user_or_none(request: Request) -> User | None:
|
||||
try:
|
||||
return get_current_user(request)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
def hasrole(required_roles: Union[str, list[str]] = []):
|
||||
required_roles_set: set[str]
|
||||
if isinstance(required_roles, str):
|
||||
required_roles_set = set([required_roles])
|
||||
else:
|
||||
required_roles_set = set(required_roles)
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(request=None, *args, **kwargs):
|
||||
if request is None:
|
||||
raise HTTPException(
|
||||
500,
|
||||
"Functions decorated with hasrole must have a request:Request argument",
|
||||
)
|
||||
user: User = get_current_user(request)
|
||||
if not any(required_roles_set.intersection(user.roles_as_set)):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
return await func(request, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
24
src/oidc_test/database.py
Normal file
24
src/oidc_test/database.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
# Implement a fake in-memory database interface for demo purpose
|
||||
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
|
||||
from .models import User
|
||||
|
||||
|
||||
class Database:
|
||||
users: dict[str, User] = {}
|
||||
|
||||
# Last sessions for the user (key: users's subject id (sub))
|
||||
|
||||
async def add_user(
|
||||
self, sub: str, user_info: dict, oidc_provider: StarletteOAuth2App
|
||||
) -> User:
|
||||
user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider)
|
||||
self.users[sub] = user
|
||||
return user
|
||||
|
||||
def get_user(self, sub: str) -> User:
|
||||
return self.users[sub]
|
||||
|
||||
|
||||
db = Database()
|
145
src/oidc_test/main.py
Normal file
145
src/oidc_test/main.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
from typing import Annotated
|
||||
|
||||
from httpx import HTTPError
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
from authlib.integrations.starlette_client import OAuth, OAuthError
|
||||
|
||||
from .settings import settings
|
||||
from .models import User
|
||||
from .auth_utils import hasrole, get_current_user_or_none, get_current_user
|
||||
from .database import db
|
||||
|
||||
templates = Jinja2Templates("src/templates")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="OIDC auth test",
|
||||
)
|
||||
|
||||
# SessionMiddleware is required by authlib
|
||||
app.add_middleware(
|
||||
SessionMiddleware,
|
||||
secret_key=settings.secret_key,
|
||||
)
|
||||
|
||||
# Add oidc providers to authlib from the settings
|
||||
authlib_oauth = OAuth()
|
||||
for provider in settings.oidc.providers:
|
||||
authlib_oauth.register(
|
||||
name=provider.name,
|
||||
server_metadata_url=provider.provider_url,
|
||||
client_kwargs={
|
||||
"scope": "openid email offline_access profile roles",
|
||||
},
|
||||
client_id=provider.client_id,
|
||||
client_secret=provider.client_secret,
|
||||
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
||||
)
|
||||
|
||||
|
||||
@app.get("/login")
|
||||
async def login(request: Request, provider: str) -> RedirectResponse:
|
||||
redirect_uri = request.url_for("auth", oidc_provider_id=provider)
|
||||
try:
|
||||
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
|
||||
except AttributeError:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||
try:
|
||||
return await provider_.authorize_redirect(request, redirect_uri)
|
||||
except HTTPError:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
|
||||
|
||||
|
||||
@app.get("/auth/{oidc_provider_id}")
|
||||
async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
||||
try:
|
||||
oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
|
||||
except AttributeError:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||
try:
|
||||
token = await oidc_provider.authorize_access_token(request)
|
||||
except OAuthError as error:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
|
||||
# Remember the oidc_provider in the session
|
||||
request.session["oidc_provider_id"] = oidc_provider_id
|
||||
#
|
||||
# One could process the full decoded token which contains extra information
|
||||
# eg for updates. Here we are only interested in roles
|
||||
#
|
||||
if userinfo := token.get("userinfo"):
|
||||
# sub given by oidc provider
|
||||
sub = userinfo["sub"]
|
||||
# Build and remember the user in the session
|
||||
request.session["user_sub"] = sub
|
||||
# Store the user in the database
|
||||
await db.add_user(sub, user_info=userinfo, oidc_provider=oidc_provider)
|
||||
return RedirectResponse(url="/")
|
||||
else:
|
||||
# Not sure if it's correct to redirect to plain login (which is not implemented anyway)
|
||||
# if no userinfo is provided
|
||||
return RedirectResponse(url="/login")
|
||||
|
||||
|
||||
@app.get("/logout")
|
||||
async def logout(
|
||||
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
|
||||
) -> RedirectResponse:
|
||||
# TODO: logout from oidc_provider
|
||||
# await user.oidc_provider.logout_redirect()
|
||||
request.session.pop("user_sub", None)
|
||||
return RedirectResponse(url="/")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def home(
|
||||
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
|
||||
) -> HTMLResponse:
|
||||
return templates.TemplateResponse(
|
||||
request=request,
|
||||
context={
|
||||
"settings": settings.model_dump(),
|
||||
"user": user,
|
||||
},
|
||||
name="index.html",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/public")
|
||||
async def public() -> HTMLResponse:
|
||||
return HTMLResponse("<h1>Not protected</h1>")
|
||||
|
||||
|
||||
@app.get("/protected")
|
||||
async def get_protected(
|
||||
user: Annotated[User, Depends(get_current_user)]
|
||||
) -> HTMLResponse:
|
||||
return HTMLResponse("<h1>Only authenticated users can see this</h1>")
|
||||
|
||||
|
||||
@app.get("/protected-by-foorole")
|
||||
@hasrole("foorole")
|
||||
async def get_protected_by_foorole(request: Request) -> HTMLResponse:
|
||||
return HTMLResponse("<h1>Only users with foorole can see this</h1>")
|
||||
|
||||
|
||||
@app.get("/protected-by-barrole")
|
||||
@hasrole("barrole")
|
||||
async def get_protected_by_barrole(request: Request) -> HTMLResponse:
|
||||
return HTMLResponse("<h1>Protected by barrole</h1>")
|
||||
|
||||
|
||||
@app.get("/protected-by-foorole-and-barrole")
|
||||
@hasrole("barrole")
|
||||
@hasrole("foorole")
|
||||
async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse:
|
||||
return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>")
|
||||
|
||||
|
||||
@app.get("/protected-by-foorole-or-barrole")
|
||||
@hasrole(["foorole", "barrole"])
|
||||
async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
|
||||
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")
|
48
src/oidc_test/models.py
Normal file
48
src/oidc_test/models.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
from functools import cached_property
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel, EmailStr, AnyHttpUrl, Field, computed_field
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
|
||||
# from app.models import User
|
||||
|
||||
|
||||
class Role(BaseModel, extra="ignore"):
|
||||
name: str
|
||||
|
||||
|
||||
class UserBase(BaseModel, extra="ignore"):
|
||||
|
||||
id: str | None = None
|
||||
name: str
|
||||
email: EmailStr | None = None
|
||||
picture: AnyHttpUrl | None = None
|
||||
roles: list[Role] = []
|
||||
|
||||
|
||||
class User(UserBase):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
sub: str = Field(
|
||||
description="""subject id of the user given by the oidc provider,
|
||||
also the key for the database 'table'""",
|
||||
)
|
||||
userinfo: dict = {}
|
||||
oidc_provider: StarletteOAuth2App | None = None
|
||||
|
||||
@classmethod
|
||||
def from_auth(cls, userinfo: dict, oidc_provider: StarletteOAuth2App) -> Self:
|
||||
user = cls(**userinfo)
|
||||
user.userinfo = userinfo
|
||||
user.oidc_provider = oidc_provider
|
||||
# Add roles if they are provided in the token
|
||||
if raw_ra := userinfo.get("realm_access"):
|
||||
if raw_roles := raw_ra.get("roles"):
|
||||
user.roles = [Role(name=raw_role) for raw_role in raw_roles]
|
||||
return user
|
||||
|
||||
@computed_field
|
||||
@cached_property
|
||||
def roles_as_set(self) -> set[str]:
|
||||
return set([role.name for role in self.roles])
|
63
src/oidc_test/settings.py
Normal file
63
src/oidc_test/settings.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
import string
|
||||
import random
|
||||
from typing import Type, Tuple
|
||||
|
||||
from pydantic import BaseModel, computed_field
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
YamlConfigSettingsSource,
|
||||
)
|
||||
|
||||
|
||||
class OIDCProvider(BaseModel):
|
||||
name: str = ""
|
||||
url: str = ""
|
||||
client_id: str = ""
|
||||
client_secret: str = ""
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def provider_url(self) -> str:
|
||||
return self.url + "/.well-known/openid-configuration"
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def token_url(self) -> str:
|
||||
return "auth/" + self.name
|
||||
|
||||
|
||||
class OIDCSettings(BaseModel):
|
||||
show_session_details: bool = False
|
||||
providers: list[OIDCProvider] = []
|
||||
swagger_provider: str = ""
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Settings wil be read from an .env file"""
|
||||
|
||||
oidc: OIDCSettings = OIDCSettings()
|
||||
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
|
||||
|
||||
class Config:
|
||||
env_nested_delimiter = "__"
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: Type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> Tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (
|
||||
init_settings,
|
||||
env_settings,
|
||||
file_secret_settings,
|
||||
YamlConfigSettingsSource(settings_cls, "settings.yaml"),
|
||||
dotenv_settings,
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
Loading…
Add table
Add a link
Reference in a new issue