diff --git a/src/oidc_test/auth_provider.py b/src/oidc_test/auth/provider.py similarity index 94% rename from src/oidc_test/auth_provider.py rename to src/oidc_test/auth/provider.py index e50241c..dab4764 100644 --- a/src/oidc_test/auth_provider.py +++ b/src/oidc_test/auth/provider.py @@ -2,14 +2,13 @@ from json import JSONDecodeError from typing import Any from jwt import decode import logging -from collections import OrderedDict from pydantic import ConfigDict from authlib.integrations.starlette_client.apps import StarletteOAuth2App from httpx import AsyncClient -from .settings import AuthProviderSettings, settings -from .models import User +from ..settings import AuthProviderSettings, settings +from ..models import User logger = logging.getLogger("oidc-test") @@ -90,6 +89,3 @@ class Provider(AuthProviderSettings): def get_session_key(self, userinfo): return userinfo[self.session_key] - - -providers: OrderedDict[str, Provider] = OrderedDict() diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth/utils.py similarity index 96% rename from src/oidc_test/auth_utils.py rename to src/oidc_test/auth/utils.py index 8cd5028..0623186 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth/utils.py @@ -10,15 +10,15 @@ from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError # from authlib.oauth1.auth import OAuthToken from authlib.oauth2.auth import OAuth2Token -from .models import User -from .database import db, TokenNotInDb, UserNotInDB -from .settings import settings -from .auth_provider import providers, Provider +from .provider import Provider + +from ..models import User +from ..database import db, TokenNotInDb, UserNotInDB +from ..settings import settings +from ..auth_providers import providers logger = logging.getLogger("oidc-test") -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") - async def fetch_token(name, request): assert name is not None @@ -51,9 +51,6 @@ async def update_token( # await item.save() -authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) - - def init_providers(): """Add oidc providers to authlib from the settings and build the providers dict""" @@ -86,7 +83,8 @@ def init_providers(): providers[provider.id] = provider -init_providers() +authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None: @@ -245,7 +243,7 @@ async def get_user_from_token( except ExpiredSignatureError: raise HTTPException( status.HTTP_401_UNAUTHORIZED, - "Expired signature (refresh not implemented yet)", + "Expired signature (token refresh not implemented yet)", ) except InvalidKeyError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key") @@ -263,12 +261,12 @@ async def get_user_from_token( user.access_token = token except UserNotInDB: logger.info( - f"User {user_id} not found in DB, creating it (real apps can behave differently" + f"User {user_id} not found in DB, creating it (real apps can behave differently)" ) user = await db.add_user( sub=payload["sub"], user_info=payload, - auth_provider=getattr(authlib_oauth, auth_provider_id), + auth_provider=providers[auth_provider_id], access_token=token, ) return user diff --git a/src/oidc_test/auth_providers.py b/src/oidc_test/auth_providers.py new file mode 100644 index 0000000..45f4de6 --- /dev/null +++ b/src/oidc_test/auth_providers.py @@ -0,0 +1,5 @@ +from collections import OrderedDict + +from .auth.provider import Provider + +providers: OrderedDict[str, Provider] = OrderedDict() diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 659fd13..4704f9b 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -5,8 +5,10 @@ import logging from authlib.oauth2.rfc6749 import OAuth2Token from jwt import PyJWTError +from .auth.provider import Provider + from .models import User, Role -from .auth_provider import Provider, providers +from .auth_providers import providers logger = logging.getLogger("oidc-test") diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 4018997..f37339d 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -26,10 +26,8 @@ from authlib.oauth2.rfc6749 import OAuth2Token # from fastapi.security import OpenIdConnect # from pkce import generate_code_verifier, generate_pkce_pair -from .settings import settings -from .auth_provider import NoPublicKey, Provider, providers -from .models import User -from .auth_utils import ( +from .auth.provider import NoPublicKey, Provider +from .auth.utils import ( get_auth_provider, get_auth_provider_or_none, get_current_user_or_none, @@ -38,6 +36,11 @@ from .auth_utils import ( get_token, update_token, ) + +from .auth.utils import init_providers +from .settings import settings +from .auth_providers import providers +from .models import User from .database import TokenNotInDb, db from .resource_server import resource_server @@ -49,6 +52,7 @@ templates = Jinja2Templates(Path(__file__).parent / "templates") @asynccontextmanager async def lifespan(app: FastAPI): assert app is not None + init_providers() for provider in list(providers.values()): try: await provider.get_info() diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index 7c5250b..7b6fd0e 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -56,7 +56,7 @@ class User(UserBase): def decode_access_token(self, verify_signature: bool = True): assert self.access_token is not None, "no access_token" assert self.auth_provider_id is not None, "no auth_provider_id" - from .auth_provider import providers + from .auth_providers import providers return providers[self.auth_provider_id].decode( self.access_token, verify_signature=verify_signature diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index f7f0433..15084bc 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -12,14 +12,16 @@ from fastapi.middleware.cors import CORSMiddleware # from authlib.integrations.starlette_client.apps import StarletteOAuth2App # from authlib.oauth2.rfc6749 import OAuth2Token -from .models import User -from .auth_utils import ( +from .auth.provider import Provider +from .auth.utils import ( get_token_or_none, get_user_from_token, UserWithRole, ) + +from .auth_providers import providers from .settings import settings -from .auth_provider import providers, Provider +from .models import User logger = logging.getLogger("oidc-test") diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 86a2b6b..f3ac8f3 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -13,8 +13,6 @@ from pydantic_settings import ( ) from starlette.requests import Request -from .models import User - class Resource(BaseModel): """A resource with an URL that can be accessed with an OAuth2 access token""" @@ -57,7 +55,7 @@ class AuthProviderSettings(BaseModel): def token_url(self) -> str: return "auth/" + self.id - def get_account_url(self, request: Request, user: User) -> str | None: + def get_account_url(self, request: Request, user: dict) -> str | None: if self.account_url_template: if not (self.url.endswith("/") or self.account_url_template.startswith("/")): sep = "/" diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index 367ea99..e163a68 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -184,11 +184,10 @@ hr { font-family: monospace; } -.resource { +.resourceResult { padding: 0.5em; display: flex; gap: 0.5em; - flex-direction: column; width: fit-content; align-items: center; margin: 5px auto; diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index da513c9..93d0bc6 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -51,7 +51,7 @@ {% endif %} {% if auth_provider.account_url_template %}