Continue refactor
All checks were successful
/ build (push) Successful in 6s
/ test (push) Successful in 5s

This commit is contained in:
phil 2025-02-10 14:14:32 +01:00
parent 496ce016e3
commit e56be3c378
10 changed files with 38 additions and 34 deletions

View file

@ -2,14 +2,13 @@ from json import JSONDecodeError
from typing import Any from typing import Any
from jwt import decode from jwt import decode
import logging import logging
from collections import OrderedDict
from pydantic import ConfigDict from pydantic import ConfigDict
from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from httpx import AsyncClient from httpx import AsyncClient
from .settings import AuthProviderSettings, settings from ..settings import AuthProviderSettings, settings
from .models import User from ..models import User
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")
@ -90,6 +89,3 @@ class Provider(AuthProviderSettings):
def get_session_key(self, userinfo): def get_session_key(self, userinfo):
return userinfo[self.session_key] return userinfo[self.session_key]
providers: OrderedDict[str, Provider] = OrderedDict()

View file

@ -10,15 +10,15 @@ from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError
# from authlib.oauth1.auth import OAuthToken # from authlib.oauth1.auth import OAuthToken
from authlib.oauth2.auth import OAuth2Token from authlib.oauth2.auth import OAuth2Token
from .models import User from .provider import Provider
from .database import db, TokenNotInDb, UserNotInDB
from .settings import settings from ..models import User
from .auth_provider import providers, Provider from ..database import db, TokenNotInDb, UserNotInDB
from ..settings import settings
from ..auth_providers import providers
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
async def fetch_token(name, request): async def fetch_token(name, request):
assert name is not None assert name is not None
@ -51,9 +51,6 @@ async def update_token(
# await item.save() # await item.save()
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
def init_providers(): def init_providers():
"""Add oidc providers to authlib from the settings """Add oidc providers to authlib from the settings
and build the providers dict""" and build the providers dict"""
@ -86,7 +83,8 @@ def init_providers():
providers[provider.id] = provider providers[provider.id] = provider
init_providers() authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None: def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None:
@ -245,7 +243,7 @@ async def get_user_from_token(
except ExpiredSignatureError: except ExpiredSignatureError:
raise HTTPException( raise HTTPException(
status.HTTP_401_UNAUTHORIZED, status.HTTP_401_UNAUTHORIZED,
"Expired signature (refresh not implemented yet)", "Expired signature (token refresh not implemented yet)",
) )
except InvalidKeyError: except InvalidKeyError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
@ -263,12 +261,12 @@ async def get_user_from_token(
user.access_token = token user.access_token = token
except UserNotInDB: except UserNotInDB:
logger.info( logger.info(
f"User {user_id} not found in DB, creating it (real apps can behave differently" f"User {user_id} not found in DB, creating it (real apps can behave differently)"
) )
user = await db.add_user( user = await db.add_user(
sub=payload["sub"], sub=payload["sub"],
user_info=payload, user_info=payload,
auth_provider=getattr(authlib_oauth, auth_provider_id), auth_provider=providers[auth_provider_id],
access_token=token, access_token=token,
) )
return user return user

View file

@ -0,0 +1,5 @@
from collections import OrderedDict
from .auth.provider import Provider
providers: OrderedDict[str, Provider] = OrderedDict()

View file

@ -5,8 +5,10 @@ import logging
from authlib.oauth2.rfc6749 import OAuth2Token from authlib.oauth2.rfc6749 import OAuth2Token
from jwt import PyJWTError from jwt import PyJWTError
from .auth.provider import Provider
from .models import User, Role from .models import User, Role
from .auth_provider import Provider, providers from .auth_providers import providers
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")

View file

@ -26,10 +26,8 @@ from authlib.oauth2.rfc6749 import OAuth2Token
# from fastapi.security import OpenIdConnect # from fastapi.security import OpenIdConnect
# from pkce import generate_code_verifier, generate_pkce_pair # from pkce import generate_code_verifier, generate_pkce_pair
from .settings import settings from .auth.provider import NoPublicKey, Provider
from .auth_provider import NoPublicKey, Provider, providers from .auth.utils import (
from .models import User
from .auth_utils import (
get_auth_provider, get_auth_provider,
get_auth_provider_or_none, get_auth_provider_or_none,
get_current_user_or_none, get_current_user_or_none,
@ -38,6 +36,11 @@ from .auth_utils import (
get_token, get_token,
update_token, update_token,
) )
from .auth.utils import init_providers
from .settings import settings
from .auth_providers import providers
from .models import User
from .database import TokenNotInDb, db from .database import TokenNotInDb, db
from .resource_server import resource_server from .resource_server import resource_server
@ -49,6 +52,7 @@ templates = Jinja2Templates(Path(__file__).parent / "templates")
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
assert app is not None assert app is not None
init_providers()
for provider in list(providers.values()): for provider in list(providers.values()):
try: try:
await provider.get_info() await provider.get_info()

View file

@ -56,7 +56,7 @@ class User(UserBase):
def decode_access_token(self, verify_signature: bool = True): def decode_access_token(self, verify_signature: bool = True):
assert self.access_token is not None, "no access_token" assert self.access_token is not None, "no access_token"
assert self.auth_provider_id is not None, "no auth_provider_id" assert self.auth_provider_id is not None, "no auth_provider_id"
from .auth_provider import providers from .auth_providers import providers
return providers[self.auth_provider_id].decode( return providers[self.auth_provider_id].decode(
self.access_token, verify_signature=verify_signature self.access_token, verify_signature=verify_signature

View file

@ -12,14 +12,16 @@ from fastapi.middleware.cors import CORSMiddleware
# from authlib.integrations.starlette_client.apps import StarletteOAuth2App # from authlib.integrations.starlette_client.apps import StarletteOAuth2App
# from authlib.oauth2.rfc6749 import OAuth2Token # from authlib.oauth2.rfc6749 import OAuth2Token
from .models import User from .auth.provider import Provider
from .auth_utils import ( from .auth.utils import (
get_token_or_none, get_token_or_none,
get_user_from_token, get_user_from_token,
UserWithRole, UserWithRole,
) )
from .auth_providers import providers
from .settings import settings from .settings import settings
from .auth_provider import providers, Provider from .models import User
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")

View file

@ -13,8 +13,6 @@ from pydantic_settings import (
) )
from starlette.requests import Request from starlette.requests import Request
from .models import User
class Resource(BaseModel): class Resource(BaseModel):
"""A resource with an URL that can be accessed with an OAuth2 access token""" """A resource with an URL that can be accessed with an OAuth2 access token"""
@ -57,7 +55,7 @@ class AuthProviderSettings(BaseModel):
def token_url(self) -> str: def token_url(self) -> str:
return "auth/" + self.id return "auth/" + self.id
def get_account_url(self, request: Request, user: User) -> str | None: def get_account_url(self, request: Request, user: dict) -> str | None:
if self.account_url_template: if self.account_url_template:
if not (self.url.endswith("/") or self.account_url_template.startswith("/")): if not (self.url.endswith("/") or self.account_url_template.startswith("/")):
sep = "/" sep = "/"

View file

@ -184,11 +184,10 @@ hr {
font-family: monospace; font-family: monospace;
} }
.resource { .resourceResult {
padding: 0.5em; padding: 0.5em;
display: flex; display: flex;
gap: 0.5em; gap: 0.5em;
flex-direction: column;
width: fit-content; width: fit-content;
align-items: center; align-items: center;
margin: 5px auto; margin: 5px auto;

View file

@ -51,7 +51,7 @@
{% endif %} {% endif %}
{% if auth_provider.account_url_template %} {% if auth_provider.account_url_template %}
<button <button
onclick="location.href='{{ auth_provider.get_account_url(request, user) }}'" onclick="location.href='{{ auth_provider.get_account_url(request, user.model_dump()) }}'"
class="account"> class="account">
Account management Account management
</button> </button>