Remove OAuthToken from db (use authlib dict); basic OAuth2 service provider with Forgejo
This commit is contained in:
parent
21ccdad953
commit
2fe7536c53
10 changed files with 106 additions and 50 deletions
3
TODO
Normal file
3
TODO
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
https://docs.authlib.org/en/latest/oauth/2/intro.html#intro-oauth2
|
||||||
|
|
||||||
|
https://www.keycloak.org/docs/latest/authorization_services/index.html
|
|
@ -11,6 +11,7 @@ dependencies = [
|
||||||
"fastapi[standard]>=0.115.6",
|
"fastapi[standard]>=0.115.6",
|
||||||
"itsdangerous>=2.2.0",
|
"itsdangerous>=2.2.0",
|
||||||
"passlib[bcrypt]>=1.7.4",
|
"passlib[bcrypt]>=1.7.4",
|
||||||
|
"pkce>=1.0.3",
|
||||||
"pydantic-settings>=2.7.1",
|
"pydantic-settings>=2.7.1",
|
||||||
"python-jose[cryptography]>=3.3.0",
|
"python-jose[cryptography]>=3.3.0",
|
||||||
"requests>=2.32.3",
|
"requests>=2.32.3",
|
||||||
|
|
|
@ -4,12 +4,13 @@ from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fastapi import HTTPException, Request, status
|
from fastapi import HTTPException, Request, status
|
||||||
|
from authlib.oauth2.rfc6749 import OAuth2Token
|
||||||
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
|
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
|
||||||
|
|
||||||
# from authlib.oauth1.auth import OAuthToken
|
# from authlib.oauth1.auth import OAuthToken
|
||||||
# from authlib.oauth2.auth import OAuth2Token
|
# from authlib.oauth2.auth import OAuth2Token
|
||||||
|
|
||||||
from .models import OAuth2Token, User
|
from .models import User
|
||||||
from .database import db
|
from .database import db
|
||||||
from .settings import settings
|
from .settings import settings
|
||||||
|
|
||||||
|
@ -23,7 +24,7 @@ def get_provider(request: Request) -> StarletteOAuth2App:
|
||||||
It can be used in Depends()"""
|
It can be used in Depends()"""
|
||||||
if (oidc_provider_id := request.session.get("oidc_provider_id")) is None:
|
if (oidc_provider_id := request.session.get("oidc_provider_id")) is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
"Not logged in (no provider in session)",
|
"Not logged in (no provider in session)",
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
@ -43,7 +44,7 @@ async def get_current_user(request: Request) -> User:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token unknown")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token unknown")
|
||||||
user = await db.get_user(user_sub)
|
user = await db.get_user(user_sub)
|
||||||
## Check if the token is expired
|
## Check if the token is expired
|
||||||
if token.expires_at < datetime.timestamp(datetime.now()):
|
if token.is_expired():
|
||||||
oidc_provider = get_provider(request=request)
|
oidc_provider = get_provider(request=request)
|
||||||
## Ask a new refresh token from the provider
|
## Ask a new refresh token from the provider
|
||||||
logger.info(f"Token expired for user {user.name}")
|
logger.info(f"Token expired for user {user.name}")
|
||||||
|
@ -117,4 +118,10 @@ def update_token(*args, **kwargs):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
async def get_token(request: Request) -> OAuth2Token:
|
||||||
|
if (token := await db.get_token(request.session.get("token"))) is None:
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
|
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
# Implement a fake in-memory database interface for demo purpose
|
"""Fake in-memory database interface for demo purpose"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||||
|
|
||||||
from .models import User, OAuth2Token, Role
|
from .models import User, Role
|
||||||
|
from authlib.oauth2.rfc6749 import OAuth2Token
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -37,11 +39,11 @@ class Database:
|
||||||
async def get_user(self, sub: str) -> User:
|
async def get_user(self, sub: str) -> User:
|
||||||
return self.users[sub]
|
return self.users[sub]
|
||||||
|
|
||||||
async def add_token(self, token_dict: dict, user: User) -> None:
|
async def add_token(self, token: OAuth2Token, user: User) -> None:
|
||||||
self.tokens[token_dict['id_token']] = OAuth2Token.from_dict(token_dict=token_dict, user=user)
|
self.tokens[token["id_token"]] = token
|
||||||
|
|
||||||
async def get_token(self, name) -> OAuth2Token | None:
|
async def get_token(self, id_token: str) -> OAuth2Token | None:
|
||||||
return self.tokens.get(name)
|
return self.tokens.get(id_token)
|
||||||
|
|
||||||
|
|
||||||
db = Database()
|
db = Database()
|
||||||
|
|
|
@ -10,12 +10,15 @@ from urllib.parse import urlencode
|
||||||
|
|
||||||
from httpx import HTTPError
|
from httpx import HTTPError
|
||||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
from fastapi.security import OpenIdConnect
|
from fastapi.security import OpenIdConnect
|
||||||
from starlette.middleware.sessions import SessionMiddleware
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||||
from authlib.integrations.starlette_client import OAuthError
|
from authlib.integrations.base_client import OAuthError
|
||||||
|
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
||||||
|
from authlib.oauth2.rfc6749 import OAuth2Token
|
||||||
|
from pkce import generate_code_verifier, generate_pkce_pair
|
||||||
|
|
||||||
from .settings import settings
|
from .settings import settings
|
||||||
from .models import User
|
from .models import User
|
||||||
|
@ -25,6 +28,7 @@ from .auth_utils import (
|
||||||
get_current_user_or_none,
|
get_current_user_or_none,
|
||||||
get_current_user,
|
get_current_user,
|
||||||
authlib_oauth,
|
authlib_oauth,
|
||||||
|
get_token,
|
||||||
)
|
)
|
||||||
from .auth_misc import pretty_details
|
from .auth_misc import pretty_details
|
||||||
from .database import db
|
from .database import db
|
||||||
|
@ -77,8 +81,7 @@ for provider in settings.oidc.providers:
|
||||||
|
|
||||||
@app.get("/login/{oidc_provider_id}")
|
@app.get("/login/{oidc_provider_id}")
|
||||||
async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
||||||
"""Login with the provider id,
|
"""Login with the provider id, giving the browser a redirect to its authorize page.
|
||||||
by giving the browser a redirect to its authorize page.
|
|
||||||
The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url
|
The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url
|
||||||
with the token.
|
with the token.
|
||||||
"""
|
"""
|
||||||
|
@ -87,9 +90,20 @@ async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
||||||
provider_: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
|
provider_: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||||
|
if (
|
||||||
|
code_challenge_method := _providers[oidc_provider_id].code_challenge_method
|
||||||
|
) is not None:
|
||||||
|
client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method)
|
||||||
|
code_verifier = generate_code_verifier()
|
||||||
|
logger.debug("TODO: PKCE")
|
||||||
|
else:
|
||||||
|
code_verifier = None
|
||||||
try:
|
try:
|
||||||
response = await provider_.authorize_redirect(
|
response = await provider_.authorize_redirect(
|
||||||
request, redirect_uri, access_type="offline"
|
request,
|
||||||
|
redirect_uri,
|
||||||
|
access_type="offline",
|
||||||
|
code_verifier=code_verifier,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
except HTTPError:
|
except HTTPError:
|
||||||
|
@ -106,7 +120,7 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||||
try:
|
try:
|
||||||
token = await oidc_provider.authorize_access_token(request)
|
token: OAuth2Token = await oidc_provider.authorize_access_token(request)
|
||||||
except OAuthError as error:
|
except OAuthError as error:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
|
||||||
# Remember the oidc_provider in the session
|
# Remember the oidc_provider in the session
|
||||||
|
@ -166,7 +180,7 @@ async def logout(
|
||||||
logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
|
logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
|
||||||
return RedirectResponse(request.url_for("non_compliant_logout"))
|
return RedirectResponse(request.url_for("non_compliant_logout"))
|
||||||
post_logout_uri = request.url_for("home")
|
post_logout_uri = request.url_for("home")
|
||||||
if (id_token := await db.get_token(request.session.pop("token", None))) is None:
|
if (token := await db.get_token(request.session.pop("token", None))) is None:
|
||||||
logger.warn("No session in db for the token")
|
logger.warn("No session in db for the token")
|
||||||
return RedirectResponse(request.url_for("home"))
|
return RedirectResponse(request.url_for("home"))
|
||||||
logout_url = (
|
logout_url = (
|
||||||
|
@ -175,7 +189,7 @@ async def logout(
|
||||||
+ urlencode(
|
+ urlencode(
|
||||||
{
|
{
|
||||||
"post_logout_redirect_uri": post_logout_uri,
|
"post_logout_redirect_uri": post_logout_uri,
|
||||||
"id_token_hint": id_token.raw_id_token,
|
"id_token_hint": token["id_token"],
|
||||||
"cliend_id": "oidc_local_test",
|
"cliend_id": "oidc_local_test",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -260,6 +274,44 @@ async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
|
||||||
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")
|
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/introspect")
|
||||||
|
async def get_introspect(
|
||||||
|
request: Request,
|
||||||
|
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
|
||||||
|
token: Annotated[OAuth2Token, Depends(get_token)],
|
||||||
|
) -> JSONResponse:
|
||||||
|
if (
|
||||||
|
response := await provider.get(
|
||||||
|
provider.server_metadata["introspection_endpoint"],
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
).is_success:
|
||||||
|
return response.json()
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/oauth2-forgejo-test")
|
||||||
|
async def get_forgejo_user_info(
|
||||||
|
request: Request,
|
||||||
|
user: Annotated[User, Depends(get_current_user)],
|
||||||
|
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
|
||||||
|
token: Annotated[OAuth2Token, Depends(get_token)],
|
||||||
|
) -> HTMLResponse:
|
||||||
|
if (
|
||||||
|
response := await provider.get(
|
||||||
|
"/api/v1/user/repos",
|
||||||
|
# headers={"Authorization": f"token {token['access_token']}"},
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
).is_success:
|
||||||
|
repos = response.json()
|
||||||
|
names = [repo["name"] for repo in repos]
|
||||||
|
return HTMLResponse(f"{user.name} has {len(repos)} repos: {', '.join(names)}")
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||||
|
|
||||||
|
|
||||||
# @app.get("/fast_api_depends")
|
# @app.get("/fast_api_depends")
|
||||||
# def fast_api_depends(
|
# def fast_api_depends(
|
||||||
# token: Annotated[str, Depends(fastapi_providers["Keycloak"])]
|
# token: Annotated[str, Depends(fastapi_providers["Keycloak"])]
|
||||||
|
|
|
@ -1,8 +1,16 @@
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
from pydantic import computed_field, AnyHttpUrl, EmailStr, ConfigDict
|
from pydantic import (
|
||||||
|
computed_field,
|
||||||
|
AnyHttpUrl,
|
||||||
|
EmailStr,
|
||||||
|
ConfigDict,
|
||||||
|
GetCoreSchemaHandler,
|
||||||
|
)
|
||||||
|
from pydantic_core import CoreSchema, core_schema
|
||||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||||
|
from authlib.oauth2.rfc6749 import OAuth2Token as OAuth2Token_authlib
|
||||||
from sqlmodel import SQLModel, Field
|
from sqlmodel import SQLModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,33 +53,3 @@ class User(UserBase):
|
||||||
@cached_property
|
@cached_property
|
||||||
def roles_as_set(self) -> set[str]:
|
def roles_as_set(self) -> set[str]:
|
||||||
return set([role.name for role in self.roles])
|
return set([role.name for role in self.roles])
|
||||||
|
|
||||||
|
|
||||||
class OAuth2Token(SQLModel):
|
|
||||||
name: str = Field(primary_key=True)
|
|
||||||
token_type: str # = Field(max_length=40)
|
|
||||||
access_token: str # = Field(max_length=2000)
|
|
||||||
raw_id_token: str
|
|
||||||
refresh_token: str # = Field(max_length=200)
|
|
||||||
expires_at: int # = PositiveIntegerField()
|
|
||||||
user: User # = ForeignKey(User)
|
|
||||||
|
|
||||||
def to_token(self):
|
|
||||||
return dict(
|
|
||||||
access_token=self.access_token,
|
|
||||||
token_type=self.token_type,
|
|
||||||
refresh_token=self.refresh_token,
|
|
||||||
expires_at=self.expires_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, token_dict: dict, user: User) -> Self:
|
|
||||||
return cls(
|
|
||||||
name=token_dict["access_token"],
|
|
||||||
access_token=token_dict["access_token"],
|
|
||||||
raw_id_token=token_dict["id_token"],
|
|
||||||
token_type=token_dict["token_type"],
|
|
||||||
refresh_token=token_dict["refresh_token"],
|
|
||||||
expires_at=token_dict["expires_at"],
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ class OIDCProvider(BaseModel):
|
||||||
client_id: str
|
client_id: str
|
||||||
client_secret: str = ""
|
client_secret: str = ""
|
||||||
# For PKCE (not implemented yet)
|
# For PKCE (not implemented yet)
|
||||||
# code_challenge_method: str | None = None
|
code_challenge_method: str | None = None
|
||||||
hint: str = "No hint"
|
hint: str = "No hint"
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
|
|
|
@ -120,7 +120,7 @@
|
||||||
if (xmlHttp.readyState == 4) {
|
if (xmlHttp.readyState == 4) {
|
||||||
elem.classList.add("hasResponseStatus")
|
elem.classList.add("hasResponseStatus")
|
||||||
elem.classList.add("status-" + xmlHttp.status)
|
elem.classList.add("status-" + xmlHttp.status)
|
||||||
elem.title = "Response code: " + xmlHttp.status
|
elem.title = "Response code: " + xmlHttp.status + " - " + xmlHttp.statusText
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
xmlHttp.open("GET", elem.href, true) // true for asynchronous
|
xmlHttp.open("GET", elem.href, true) // true for asynchronous
|
||||||
|
|
|
@ -109,6 +109,8 @@
|
||||||
<a href="protected-by-foorole-and-barrole">Auth + foorole and barrole protected content</a>
|
<a href="protected-by-foorole-and-barrole">Auth + foorole and barrole protected content</a>
|
||||||
<a href="fast_api_depends" class="hidden">Using FastAPI Depends</a>
|
<a href="fast_api_depends" class="hidden">Using FastAPI Depends</a>
|
||||||
<a href="other">Other</a>
|
<a href="other">Other</a>
|
||||||
|
<a href="oauth2-forgejo-test">OAuth2 test (forgejo user info)</a>
|
||||||
|
<a href="introspect">Introspect token</a>
|
||||||
</div>
|
</div>
|
||||||
{% if user_info_details %}
|
{% if user_info_details %}
|
||||||
<div class="debug-auth">
|
<div class="debug-auth">
|
||||||
|
|
11
uv.lock
generated
11
uv.lock
generated
|
@ -490,6 +490,7 @@ dependencies = [
|
||||||
{ name = "fastapi", extra = ["standard"] },
|
{ name = "fastapi", extra = ["standard"] },
|
||||||
{ name = "itsdangerous" },
|
{ name = "itsdangerous" },
|
||||||
{ name = "passlib", extra = ["bcrypt"] },
|
{ name = "passlib", extra = ["bcrypt"] },
|
||||||
|
{ name = "pkce" },
|
||||||
{ name = "pydantic-settings" },
|
{ name = "pydantic-settings" },
|
||||||
{ name = "python-jose", extra = ["cryptography"] },
|
{ name = "python-jose", extra = ["cryptography"] },
|
||||||
{ name = "requests" },
|
{ name = "requests" },
|
||||||
|
@ -509,6 +510,7 @@ requires-dist = [
|
||||||
{ name = "fastapi", extras = ["standard"], specifier = ">=0.115.6" },
|
{ name = "fastapi", extras = ["standard"], specifier = ">=0.115.6" },
|
||||||
{ name = "itsdangerous", specifier = ">=2.2.0" },
|
{ name = "itsdangerous", specifier = ">=2.2.0" },
|
||||||
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" },
|
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" },
|
||||||
|
{ name = "pkce", specifier = ">=1.0.3" },
|
||||||
{ name = "pydantic-settings", specifier = ">=2.7.1" },
|
{ name = "pydantic-settings", specifier = ">=2.7.1" },
|
||||||
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
|
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
|
||||||
{ name = "requests", specifier = ">=2.32.3" },
|
{ name = "requests", specifier = ">=2.32.3" },
|
||||||
|
@ -565,6 +567,15 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772 },
|
{ url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pkce"
|
||||||
|
version = "1.0.3"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/29/ea/ddd845c2ec21bf1e8555c782b32dc39b82f0b12764feb9f73ccbb2470f13/pkce-1.0.3.tar.gz", hash = "sha256:9775fd76d8a743d39b87df38af1cd04a58c9b5a5242d5a6350ef343d06814ab6", size = 2757 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/15/51/52c22ec0812d25f5bf297a01153604bfa7bfa59ed66f6cd8345beb3c2b2a/pkce-1.0.3-py3-none-any.whl", hash = "sha256:55927e24c7d403b2491ebe182b95d9dcb1807643243d47e3879fbda5aad4471d", size = 3200 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pluggy"
|
name = "pluggy"
|
||||||
version = "1.5.0"
|
version = "1.5.0"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue