Continue refactor
All checks were successful
/ build (push) Successful in 6s
/ test (push) Successful in 5s

This commit is contained in:
phil 2025-02-10 14:14:32 +01:00
parent 496ce016e3
commit e56be3c378
10 changed files with 38 additions and 34 deletions

View file

@ -2,14 +2,13 @@ from json import JSONDecodeError
from typing import Any
from jwt import decode
import logging
from collections import OrderedDict
from pydantic import ConfigDict
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from httpx import AsyncClient
from .settings import AuthProviderSettings, settings
from .models import User
from ..settings import AuthProviderSettings, settings
from ..models import User
logger = logging.getLogger("oidc-test")
@ -90,6 +89,3 @@ class Provider(AuthProviderSettings):
def get_session_key(self, userinfo):
return userinfo[self.session_key]
providers: OrderedDict[str, Provider] = OrderedDict()

View file

@ -10,15 +10,15 @@ from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError
# from authlib.oauth1.auth import OAuthToken
from authlib.oauth2.auth import OAuth2Token
from .models import User
from .database import db, TokenNotInDb, UserNotInDB
from .settings import settings
from .auth_provider import providers, Provider
from .provider import Provider
from ..models import User
from ..database import db, TokenNotInDb, UserNotInDB
from ..settings import settings
from ..auth_providers import providers
logger = logging.getLogger("oidc-test")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
async def fetch_token(name, request):
assert name is not None
@ -51,9 +51,6 @@ async def update_token(
# await item.save()
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
def init_providers():
"""Add oidc providers to authlib from the settings
and build the providers dict"""
@ -86,7 +83,8 @@ def init_providers():
providers[provider.id] = provider
init_providers()
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None:
@ -245,7 +243,7 @@ async def get_user_from_token(
except ExpiredSignatureError:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
"Expired signature (refresh not implemented yet)",
"Expired signature (token refresh not implemented yet)",
)
except InvalidKeyError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
@ -263,12 +261,12 @@ async def get_user_from_token(
user.access_token = token
except UserNotInDB:
logger.info(
f"User {user_id} not found in DB, creating it (real apps can behave differently"
f"User {user_id} not found in DB, creating it (real apps can behave differently)"
)
user = await db.add_user(
sub=payload["sub"],
user_info=payload,
auth_provider=getattr(authlib_oauth, auth_provider_id),
auth_provider=providers[auth_provider_id],
access_token=token,
)
return user

View file

@ -0,0 +1,5 @@
from collections import OrderedDict
from .auth.provider import Provider
providers: OrderedDict[str, Provider] = OrderedDict()

View file

@ -5,8 +5,10 @@ import logging
from authlib.oauth2.rfc6749 import OAuth2Token
from jwt import PyJWTError
from .auth.provider import Provider
from .models import User, Role
from .auth_provider import Provider, providers
from .auth_providers import providers
logger = logging.getLogger("oidc-test")

View file

@ -26,10 +26,8 @@ from authlib.oauth2.rfc6749 import OAuth2Token
# from fastapi.security import OpenIdConnect
# from pkce import generate_code_verifier, generate_pkce_pair
from .settings import settings
from .auth_provider import NoPublicKey, Provider, providers
from .models import User
from .auth_utils import (
from .auth.provider import NoPublicKey, Provider
from .auth.utils import (
get_auth_provider,
get_auth_provider_or_none,
get_current_user_or_none,
@ -38,6 +36,11 @@ from .auth_utils import (
get_token,
update_token,
)
from .auth.utils import init_providers
from .settings import settings
from .auth_providers import providers
from .models import User
from .database import TokenNotInDb, db
from .resource_server import resource_server
@ -49,6 +52,7 @@ templates = Jinja2Templates(Path(__file__).parent / "templates")
@asynccontextmanager
async def lifespan(app: FastAPI):
assert app is not None
init_providers()
for provider in list(providers.values()):
try:
await provider.get_info()

View file

@ -56,7 +56,7 @@ class User(UserBase):
def decode_access_token(self, verify_signature: bool = True):
assert self.access_token is not None, "no access_token"
assert self.auth_provider_id is not None, "no auth_provider_id"
from .auth_provider import providers
from .auth_providers import providers
return providers[self.auth_provider_id].decode(
self.access_token, verify_signature=verify_signature

View file

@ -12,14 +12,16 @@ from fastapi.middleware.cors import CORSMiddleware
# from authlib.integrations.starlette_client.apps import StarletteOAuth2App
# from authlib.oauth2.rfc6749 import OAuth2Token
from .models import User
from .auth_utils import (
from .auth.provider import Provider
from .auth.utils import (
get_token_or_none,
get_user_from_token,
UserWithRole,
)
from .auth_providers import providers
from .settings import settings
from .auth_provider import providers, Provider
from .models import User
logger = logging.getLogger("oidc-test")

View file

@ -13,8 +13,6 @@ from pydantic_settings import (
)
from starlette.requests import Request
from .models import User
class Resource(BaseModel):
"""A resource with an URL that can be accessed with an OAuth2 access token"""
@ -57,7 +55,7 @@ class AuthProviderSettings(BaseModel):
def token_url(self) -> str:
return "auth/" + self.id
def get_account_url(self, request: Request, user: User) -> str | None:
def get_account_url(self, request: Request, user: dict) -> str | None:
if self.account_url_template:
if not (self.url.endswith("/") or self.account_url_template.startswith("/")):
sep = "/"

View file

@ -184,11 +184,10 @@ hr {
font-family: monospace;
}
.resource {
.resourceResult {
padding: 0.5em;
display: flex;
gap: 0.5em;
flex-direction: column;
width: fit-content;
align-items: center;
margin: 5px auto;

View file

@ -51,7 +51,7 @@
{% endif %}
{% if auth_provider.account_url_template %}
<button
onclick="location.href='{{ auth_provider.get_account_url(request, user) }}'"
onclick="location.href='{{ auth_provider.get_account_url(request, user.model_dump()) }}'"
class="account">
Account management
</button>