Remove OAuthToken from db (use authlib dict); basic OAuth2 service provider with Forgejo
Some checks failed
/ build (push) Failing after 13s
/ test (push) Successful in 4s

This commit is contained in:
phil 2025-01-18 06:20:44 +01:00
parent 21ccdad953
commit 2fe7536c53
10 changed files with 106 additions and 50 deletions

3
TODO Normal file
View file

@ -0,0 +1,3 @@
https://docs.authlib.org/en/latest/oauth/2/intro.html#intro-oauth2
https://www.keycloak.org/docs/latest/authorization_services/index.html

View file

@ -11,6 +11,7 @@ dependencies = [
"fastapi[standard]>=0.115.6", "fastapi[standard]>=0.115.6",
"itsdangerous>=2.2.0", "itsdangerous>=2.2.0",
"passlib[bcrypt]>=1.7.4", "passlib[bcrypt]>=1.7.4",
"pkce>=1.0.3",
"pydantic-settings>=2.7.1", "pydantic-settings>=2.7.1",
"python-jose[cryptography]>=3.3.0", "python-jose[cryptography]>=3.3.0",
"requests>=2.32.3", "requests>=2.32.3",

View file

@ -4,12 +4,13 @@ from datetime import datetime
import logging import logging
from fastapi import HTTPException, Request, status from fastapi import HTTPException, Request, status
from authlib.oauth2.rfc6749 import OAuth2Token
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
# from authlib.oauth1.auth import OAuthToken # from authlib.oauth1.auth import OAuthToken
# from authlib.oauth2.auth import OAuth2Token # from authlib.oauth2.auth import OAuth2Token
from .models import OAuth2Token, User from .models import User
from .database import db from .database import db
from .settings import settings from .settings import settings
@ -23,7 +24,7 @@ def get_provider(request: Request) -> StarletteOAuth2App:
It can be used in Depends()""" It can be used in Depends()"""
if (oidc_provider_id := request.session.get("oidc_provider_id")) is None: if (oidc_provider_id := request.session.get("oidc_provider_id")) is None:
raise HTTPException( raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, status.HTTP_503_SERVICE_UNAVAILABLE,
"Not logged in (no provider in session)", "Not logged in (no provider in session)",
) )
try: try:
@ -43,7 +44,7 @@ async def get_current_user(request: Request) -> User:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token unknown") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token unknown")
user = await db.get_user(user_sub) user = await db.get_user(user_sub)
## Check if the token is expired ## Check if the token is expired
if token.expires_at < datetime.timestamp(datetime.now()): if token.is_expired():
oidc_provider = get_provider(request=request) oidc_provider = get_provider(request=request)
## Ask a new refresh token from the provider ## Ask a new refresh token from the provider
logger.info(f"Token expired for user {user.name}") logger.info(f"Token expired for user {user.name}")
@ -117,4 +118,10 @@ def update_token(*args, **kwargs):
... ...
async def get_token(request: Request) -> OAuth2Token:
if (token := await db.get_token(request.session.get("token"))) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
return token
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)

View file

@ -1,9 +1,11 @@
# Implement a fake in-memory database interface for demo purpose """Fake in-memory database interface for demo purpose"""
import logging import logging
from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from .models import User, OAuth2Token, Role from .models import User, Role
from authlib.oauth2.rfc6749 import OAuth2Token
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,11 +39,11 @@ class Database:
async def get_user(self, sub: str) -> User: async def get_user(self, sub: str) -> User:
return self.users[sub] return self.users[sub]
async def add_token(self, token_dict: dict, user: User) -> None: async def add_token(self, token: OAuth2Token, user: User) -> None:
self.tokens[token_dict['id_token']] = OAuth2Token.from_dict(token_dict=token_dict, user=user) self.tokens[token["id_token"]] = token
async def get_token(self, name) -> OAuth2Token | None: async def get_token(self, id_token: str) -> OAuth2Token | None:
return self.tokens.get(name) return self.tokens.get(id_token)
db = Database() db = Database()

View file

@ -10,12 +10,15 @@ from urllib.parse import urlencode
from httpx import HTTPError from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from fastapi.security import OpenIdConnect from fastapi.security import OpenIdConnect
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.starlette_client import OAuthError from authlib.integrations.base_client import OAuthError
from authlib.integrations.httpx_client import AsyncOAuth2Client
from authlib.oauth2.rfc6749 import OAuth2Token
from pkce import generate_code_verifier, generate_pkce_pair
from .settings import settings from .settings import settings
from .models import User from .models import User
@ -25,6 +28,7 @@ from .auth_utils import (
get_current_user_or_none, get_current_user_or_none,
get_current_user, get_current_user,
authlib_oauth, authlib_oauth,
get_token,
) )
from .auth_misc import pretty_details from .auth_misc import pretty_details
from .database import db from .database import db
@ -77,8 +81,7 @@ for provider in settings.oidc.providers:
@app.get("/login/{oidc_provider_id}") @app.get("/login/{oidc_provider_id}")
async def login(request: Request, oidc_provider_id: str) -> RedirectResponse: async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
"""Login with the provider id, """Login with the provider id, giving the browser a redirect to its authorize page.
by giving the browser a redirect to its authorize page.
The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url
with the token. with the token.
""" """
@ -87,9 +90,20 @@ async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
provider_: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) provider_: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
except AttributeError: except AttributeError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
if (
code_challenge_method := _providers[oidc_provider_id].code_challenge_method
) is not None:
client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method)
code_verifier = generate_code_verifier()
logger.debug("TODO: PKCE")
else:
code_verifier = None
try: try:
response = await provider_.authorize_redirect( response = await provider_.authorize_redirect(
request, redirect_uri, access_type="offline" request,
redirect_uri,
access_type="offline",
code_verifier=code_verifier,
) )
return response return response
except HTTPError: except HTTPError:
@ -106,7 +120,7 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
except AttributeError: except AttributeError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
try: try:
token = await oidc_provider.authorize_access_token(request) token: OAuth2Token = await oidc_provider.authorize_access_token(request)
except OAuthError as error: except OAuthError as error:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error) raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
# Remember the oidc_provider in the session # Remember the oidc_provider in the session
@ -166,7 +180,7 @@ async def logout(
logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}") logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
return RedirectResponse(request.url_for("non_compliant_logout")) return RedirectResponse(request.url_for("non_compliant_logout"))
post_logout_uri = request.url_for("home") post_logout_uri = request.url_for("home")
if (id_token := await db.get_token(request.session.pop("token", None))) is None: if (token := await db.get_token(request.session.pop("token", None))) is None:
logger.warn("No session in db for the token") logger.warn("No session in db for the token")
return RedirectResponse(request.url_for("home")) return RedirectResponse(request.url_for("home"))
logout_url = ( logout_url = (
@ -175,7 +189,7 @@ async def logout(
+ urlencode( + urlencode(
{ {
"post_logout_redirect_uri": post_logout_uri, "post_logout_redirect_uri": post_logout_uri,
"id_token_hint": id_token.raw_id_token, "id_token_hint": token["id_token"],
"cliend_id": "oidc_local_test", "cliend_id": "oidc_local_test",
} }
) )
@ -260,6 +274,44 @@ async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>") return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")
@app.get("/introspect")
async def get_introspect(
request: Request,
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
token: Annotated[OAuth2Token, Depends(get_token)],
) -> JSONResponse:
if (
response := await provider.get(
provider.server_metadata["introspection_endpoint"],
token=token,
)
).is_success:
return response.json()
else:
raise HTTPException(status_code=response.status_code, detail=response.text)
@app.get("/oauth2-forgejo-test")
async def get_forgejo_user_info(
request: Request,
user: Annotated[User, Depends(get_current_user)],
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
token: Annotated[OAuth2Token, Depends(get_token)],
) -> HTMLResponse:
if (
response := await provider.get(
"/api/v1/user/repos",
# headers={"Authorization": f"token {token['access_token']}"},
token=token,
)
).is_success:
repos = response.json()
names = [repo["name"] for repo in repos]
return HTMLResponse(f"{user.name} has {len(repos)} repos: {', '.join(names)}")
else:
raise HTTPException(status_code=response.status_code, detail=response.text)
# @app.get("/fast_api_depends") # @app.get("/fast_api_depends")
# def fast_api_depends( # def fast_api_depends(
# token: Annotated[str, Depends(fastapi_providers["Keycloak"])] # token: Annotated[str, Depends(fastapi_providers["Keycloak"])]

View file

@ -1,8 +1,16 @@
from functools import cached_property from functools import cached_property
from typing import Self from typing import Self
from pydantic import computed_field, AnyHttpUrl, EmailStr, ConfigDict from pydantic import (
computed_field,
AnyHttpUrl,
EmailStr,
ConfigDict,
GetCoreSchemaHandler,
)
from pydantic_core import CoreSchema, core_schema
from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.oauth2.rfc6749 import OAuth2Token as OAuth2Token_authlib
from sqlmodel import SQLModel, Field from sqlmodel import SQLModel, Field
@ -45,33 +53,3 @@ class User(UserBase):
@cached_property @cached_property
def roles_as_set(self) -> set[str]: def roles_as_set(self) -> set[str]:
return set([role.name for role in self.roles]) return set([role.name for role in self.roles])
class OAuth2Token(SQLModel):
name: str = Field(primary_key=True)
token_type: str # = Field(max_length=40)
access_token: str # = Field(max_length=2000)
raw_id_token: str
refresh_token: str # = Field(max_length=200)
expires_at: int # = PositiveIntegerField()
user: User # = ForeignKey(User)
def to_token(self):
return dict(
access_token=self.access_token,
token_type=self.token_type,
refresh_token=self.refresh_token,
expires_at=self.expires_at,
)
@classmethod
def from_dict(cls, token_dict: dict, user: User) -> Self:
return cls(
name=token_dict["access_token"],
access_token=token_dict["access_token"],
raw_id_token=token_dict["id_token"],
token_type=token_dict["token_type"],
refresh_token=token_dict["refresh_token"],
expires_at=token_dict["expires_at"],
user=user,
)

View file

@ -20,7 +20,7 @@ class OIDCProvider(BaseModel):
client_id: str client_id: str
client_secret: str = "" client_secret: str = ""
# For PKCE (not implemented yet) # For PKCE (not implemented yet)
# code_challenge_method: str | None = None code_challenge_method: str | None = None
hint: str = "No hint" hint: str = "No hint"
@computed_field @computed_field

View file

@ -120,7 +120,7 @@
if (xmlHttp.readyState == 4) { if (xmlHttp.readyState == 4) {
elem.classList.add("hasResponseStatus") elem.classList.add("hasResponseStatus")
elem.classList.add("status-" + xmlHttp.status) elem.classList.add("status-" + xmlHttp.status)
elem.title = "Response code: " + xmlHttp.status elem.title = "Response code: " + xmlHttp.status + " - " + xmlHttp.statusText
} }
} }
xmlHttp.open("GET", elem.href, true) // true for asynchronous xmlHttp.open("GET", elem.href, true) // true for asynchronous

View file

@ -109,6 +109,8 @@
<a href="protected-by-foorole-and-barrole">Auth + foorole and barrole protected content</a> <a href="protected-by-foorole-and-barrole">Auth + foorole and barrole protected content</a>
<a href="fast_api_depends" class="hidden">Using FastAPI Depends</a> <a href="fast_api_depends" class="hidden">Using FastAPI Depends</a>
<a href="other">Other</a> <a href="other">Other</a>
<a href="oauth2-forgejo-test">OAuth2 test (forgejo user info)</a>
<a href="introspect">Introspect token</a>
</div> </div>
{% if user_info_details %} {% if user_info_details %}
<div class="debug-auth"> <div class="debug-auth">

11
uv.lock generated
View file

@ -490,6 +490,7 @@ dependencies = [
{ name = "fastapi", extra = ["standard"] }, { name = "fastapi", extra = ["standard"] },
{ name = "itsdangerous" }, { name = "itsdangerous" },
{ name = "passlib", extra = ["bcrypt"] }, { name = "passlib", extra = ["bcrypt"] },
{ name = "pkce" },
{ name = "pydantic-settings" }, { name = "pydantic-settings" },
{ name = "python-jose", extra = ["cryptography"] }, { name = "python-jose", extra = ["cryptography"] },
{ name = "requests" }, { name = "requests" },
@ -509,6 +510,7 @@ requires-dist = [
{ name = "fastapi", extras = ["standard"], specifier = ">=0.115.6" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.115.6" },
{ name = "itsdangerous", specifier = ">=2.2.0" }, { name = "itsdangerous", specifier = ">=2.2.0" },
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" }, { name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" },
{ name = "pkce", specifier = ">=1.0.3" },
{ name = "pydantic-settings", specifier = ">=2.7.1" }, { name = "pydantic-settings", specifier = ">=2.7.1" },
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" }, { name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
{ name = "requests", specifier = ">=2.32.3" }, { name = "requests", specifier = ">=2.32.3" },
@ -565,6 +567,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772 }, { url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772 },
] ]
[[package]]
name = "pkce"
version = "1.0.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/29/ea/ddd845c2ec21bf1e8555c782b32dc39b82f0b12764feb9f73ccbb2470f13/pkce-1.0.3.tar.gz", hash = "sha256:9775fd76d8a743d39b87df38af1cd04a58c9b5a5242d5a6350ef343d06814ab6", size = 2757 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/15/51/52c22ec0812d25f5bf297a01153604bfa7bfa59ed66f6cd8345beb3c2b2a/pkce-1.0.3-py3-none-any.whl", hash = "sha256:55927e24c7d403b2491ebe182b95d9dcb1807643243d47e3879fbda5aad4471d", size = 3200 },
]
[[package]] [[package]]
name = "pluggy" name = "pluggy"
version = "1.5.0" version = "1.5.0"