Add fake db, properly deal with roles, improve types, etc
This commit is contained in:
parent
522b3465df
commit
5c3d54c3f2
6 changed files with 126 additions and 52 deletions
|
@ -1,16 +1,16 @@
|
|||
from typing import Union
|
||||
from functools import wraps
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
from .models import User
|
||||
from .database import db
|
||||
|
||||
|
||||
def get_current_user(request: Request) -> User:
|
||||
auth_data = request.session.get("user")
|
||||
if auth_data is None:
|
||||
raise HTTPException(401, "Not authorized")
|
||||
return User(**auth_data)
|
||||
if (user_sub := request.session.get("user_sub")) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
return db.get_user(user_sub)
|
||||
|
||||
|
||||
def get_current_user_or_none(request: Request) -> User | None:
|
||||
|
@ -20,11 +20,7 @@ def get_current_user_or_none(request: Request) -> User | None:
|
|||
return None
|
||||
|
||||
|
||||
def hasrole(
|
||||
required_roles: Union[str, list[str]] = [],
|
||||
roles_key: str = "roles",
|
||||
realm: str | None = "realm_access", # Keycloak standard for realm defined roles
|
||||
):
|
||||
def hasrole(required_roles: Union[str, list[str]] = []):
|
||||
required_roles_set: set[str]
|
||||
if isinstance(required_roles, str):
|
||||
required_roles_set = set([required_roles])
|
||||
|
@ -39,18 +35,9 @@ def hasrole(
|
|||
500,
|
||||
"Functions decorated with hasrole must have a request:Request argument",
|
||||
)
|
||||
if "user" not in request.session:
|
||||
raise HTTPException(401, "Not authorized")
|
||||
user = request.session["user"]
|
||||
try:
|
||||
if realm in user:
|
||||
roles = user[realm][roles_key]
|
||||
else:
|
||||
roles = user[roles_key]
|
||||
except KeyError:
|
||||
raise HTTPException(401, "Not authorized")
|
||||
if not any(required_roles_set.intersection(roles)):
|
||||
raise HTTPException(401, "Not authorized")
|
||||
user: User = get_current_user(request)
|
||||
if not any(required_roles_set.intersection(user.roles_as_set)):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
return await func(request, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
|
24
src/oidc-test/database.py
Normal file
24
src/oidc-test/database.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
# Implement a fake in-memory database interface for demo purpose
|
||||
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
|
||||
from .models import User
|
||||
|
||||
|
||||
class Database:
|
||||
users: dict[str, User] = {}
|
||||
|
||||
# Last sessions for the user (key: users's subject id (sub))
|
||||
|
||||
async def add_user(
|
||||
self, sub: str, user_info: dict, oidc_provider: StarletteOAuth2App
|
||||
) -> User:
|
||||
user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider)
|
||||
self.users[sub] = user
|
||||
return user
|
||||
|
||||
def get_user(self, sub: str) -> User:
|
||||
return self.users[sub]
|
||||
|
||||
|
||||
db = Database()
|
|
@ -11,6 +11,7 @@ from authlib.integrations.starlette_client import OAuth, OAuthError
|
|||
from .settings import settings
|
||||
from .models import User
|
||||
from .auth_utils import hasrole, get_current_user_or_none, get_current_user
|
||||
from .database import db
|
||||
|
||||
templates = Jinja2Templates("src/templates")
|
||||
|
||||
|
@ -25,13 +26,15 @@ app.add_middleware(
|
|||
secret_key=settings.secret_key,
|
||||
)
|
||||
|
||||
# Add oidc providers from the settings
|
||||
# Add oidc providers to authlib from the settings
|
||||
authlib_oauth = OAuth()
|
||||
for provider in settings.oidc.providers:
|
||||
authlib_oauth.register(
|
||||
name=provider.name,
|
||||
server_metadata_url=provider.provider_url,
|
||||
client_kwargs={"scope": "openid email offline_access profile roles"},
|
||||
client_kwargs={
|
||||
"scope": "openid email offline_access profile roles",
|
||||
},
|
||||
client_id=provider.client_id,
|
||||
client_secret=provider.client_secret,
|
||||
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
||||
|
@ -40,38 +43,54 @@ for provider in settings.oidc.providers:
|
|||
|
||||
@app.get("/login")
|
||||
async def login(request: Request, provider: str) -> RedirectResponse:
|
||||
redirect_uri = request.url_for("auth", provider=provider)
|
||||
redirect_uri = request.url_for("auth", oidc_provider_id=provider)
|
||||
try:
|
||||
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
|
||||
except AttributeError:
|
||||
raise HTTPException(500, "No such provider")
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||
try:
|
||||
return await provider_.authorize_redirect(request, redirect_uri)
|
||||
except HTTPError:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
|
||||
|
||||
|
||||
@app.get("/auth/{provider}")
|
||||
async def auth(request: Request, provider: str) -> RedirectResponse:
|
||||
@app.get("/auth/{oidc_provider_id}")
|
||||
async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
||||
try:
|
||||
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
|
||||
oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
|
||||
except AttributeError:
|
||||
raise HTTPException(500, "No such provider")
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||
try:
|
||||
token = await provider_.authorize_access_token(request)
|
||||
token = await oidc_provider.authorize_access_token(request)
|
||||
except OAuthError as error:
|
||||
raise HTTPException(status_code=401, detail=error.error)
|
||||
user = token.get("userinfo")
|
||||
if user:
|
||||
request.session["user"] = dict(user)
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
|
||||
# Remember the oidc_provider in the session
|
||||
request.session["oidc_provider_id"] = oidc_provider_id
|
||||
#
|
||||
# One could process the full decoded token which contains extra information
|
||||
# eg for updates. Here we are only interested in roles
|
||||
#
|
||||
if userinfo := token.get("userinfo"):
|
||||
# sub given by oidc provider
|
||||
sub = userinfo["sub"]
|
||||
# Build and remember the user in the session
|
||||
request.session["user_sub"] = sub
|
||||
# Store the user in the database
|
||||
await db.add_user(sub, user_info=userinfo, oidc_provider=oidc_provider)
|
||||
return RedirectResponse(url="/")
|
||||
else:
|
||||
# Not sure if it's correct to redirect to plain login (which is not implemented anyway)
|
||||
# if no userinfo is provided
|
||||
return RedirectResponse(url="/login")
|
||||
|
||||
|
||||
@app.get("/logout")
|
||||
async def logout(request: Request) -> RedirectResponse:
|
||||
request.session.pop("user", None)
|
||||
async def logout(
|
||||
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
|
||||
) -> RedirectResponse:
|
||||
# TODO: logout from oidc_provider
|
||||
# await user.oidc_provider.logout_redirect()
|
||||
request.session.pop("user_sub", None)
|
||||
return RedirectResponse(url="/")
|
||||
|
||||
|
||||
|
@ -84,7 +103,6 @@ async def home(
|
|||
context={
|
||||
"settings": settings.model_dump(),
|
||||
"user": user,
|
||||
"auth_data": request.session.get("user"),
|
||||
},
|
||||
name="index.html",
|
||||
)
|
||||
|
|
|
@ -1,11 +1,48 @@
|
|||
from pydantic import BaseModel, EmailStr, AnyHttpUrl
|
||||
from functools import cached_property
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel, EmailStr, AnyHttpUrl, Field, computed_field
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
|
||||
# from app.models import User
|
||||
|
||||
|
||||
class User(BaseModel, extra="ignore"):
|
||||
class Role(BaseModel, extra="ignore"):
|
||||
name: str
|
||||
|
||||
|
||||
class UserBase(BaseModel, extra="ignore"):
|
||||
|
||||
id: str | None = None
|
||||
name: str
|
||||
email: EmailStr | None = None
|
||||
picture: AnyHttpUrl | None = None
|
||||
roles: list[Role] = []
|
||||
|
||||
|
||||
class User(UserBase):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
sub: str = Field(
|
||||
description="""subject id of the user given by the oidc provider,
|
||||
also the key for the database 'table'""",
|
||||
)
|
||||
userinfo: dict = {}
|
||||
oidc_provider: StarletteOAuth2App | None = None
|
||||
|
||||
@classmethod
|
||||
def from_auth(cls, userinfo: dict, oidc_provider: StarletteOAuth2App) -> Self:
|
||||
user = cls(**userinfo)
|
||||
user.userinfo = userinfo
|
||||
user.oidc_provider = oidc_provider
|
||||
# Add roles if they are provided in the token
|
||||
if raw_ra := userinfo.get("realm_access"):
|
||||
if raw_roles := raw_ra.get("roles"):
|
||||
user.roles = [Role(name=raw_role) for raw_role in raw_roles]
|
||||
return user
|
||||
|
||||
@computed_field
|
||||
@cached_property
|
||||
def roles_as_set(self) -> set[str]:
|
||||
return set([role.name for role in self.roles])
|
||||
|
|
|
@ -15,7 +15,6 @@ class OIDCProvider(BaseModel):
|
|||
url: str = ""
|
||||
client_id: str = ""
|
||||
client_secret: str = ""
|
||||
is_swagger: bool = False
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
|
@ -33,13 +32,6 @@ class OIDCSettings(BaseModel):
|
|||
providers: list[OIDCProvider] = []
|
||||
swagger_provider: str = ""
|
||||
|
||||
def get_swagger_provider(self) -> OIDCProvider:
|
||||
for provider in self.providers:
|
||||
if provider.is_swagger:
|
||||
return provider
|
||||
else:
|
||||
raise UserWarning("Please define a provider for Swagger with id_swagger")
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Settings wil be read from an .env file"""
|
||||
|
|
|
@ -119,6 +119,10 @@
|
|||
.hasResponseStatus.status-401 {
|
||||
background-color: #ff000040;
|
||||
}
|
||||
.role {
|
||||
padding: 3px 6px;
|
||||
background-color: #44228840;
|
||||
}
|
||||
</style>
|
||||
<script>
|
||||
function checkHref(elem) {
|
||||
|
@ -143,10 +147,10 @@
|
|||
<h1>FastAPI test app for OIDC</h1>
|
||||
{% if not user %}
|
||||
<div class="login-box">
|
||||
<p>Not logged in</p>
|
||||
<p>Log in with one of these authentication providers:</p>
|
||||
<div class="login-toolbox">
|
||||
{% for provider in settings.oidc.providers %}
|
||||
<a href="login?provider={{ provider.name }}">Login with: {{ provider.name }}</a>
|
||||
<a href="login?provider={{ provider.name }}">{{ provider.name }}</a>
|
||||
{% else %}
|
||||
<span class="error">Cannot login: no oidc prodiver in settings.yaml</span>
|
||||
{% endfor %}
|
||||
|
@ -159,7 +163,19 @@
|
|||
{% if user.picture %}
|
||||
<img src="{{ user.picture }}" class="picture"></img>
|
||||
{% endif %}
|
||||
<p>{{ user.email }}</p>
|
||||
<div>{{ user.email }}</div>
|
||||
{% if user.roles %}
|
||||
<div>
|
||||
<span>Roles:</span>
|
||||
{% for role in user.roles %}
|
||||
<span class="role">{{ role.name }}</span>
|
||||
{% endfor %}
|
||||
</div>
|
||||
{% endif %}
|
||||
<div>
|
||||
<span>Provider:</span>
|
||||
{{ user.oidc_provider.name }}
|
||||
</div>
|
||||
<a href="logout" class="logout">Logout</a>
|
||||
</div>
|
||||
{% endif %}
|
||||
|
@ -180,7 +196,7 @@
|
|||
<div class="debug-auth">
|
||||
<p>Session details</p>
|
||||
<ul>
|
||||
{% for key, value in auth_data.items() %}
|
||||
{% for key, value in user.userinfo.items() %}
|
||||
<li>
|
||||
<span class="key">{{ key }}</span>: {{ value }}
|
||||
</li>
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue