Continue refactor
This commit is contained in:
parent
496ce016e3
commit
e56be3c378
10 changed files with 38 additions and 34 deletions
|
@ -2,14 +2,13 @@ from json import JSONDecodeError
|
|||
from typing import Any
|
||||
from jwt import decode
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
from httpx import AsyncClient
|
||||
|
||||
from .settings import AuthProviderSettings, settings
|
||||
from .models import User
|
||||
from ..settings import AuthProviderSettings, settings
|
||||
from ..models import User
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
|
@ -90,6 +89,3 @@ class Provider(AuthProviderSettings):
|
|||
|
||||
def get_session_key(self, userinfo):
|
||||
return userinfo[self.session_key]
|
||||
|
||||
|
||||
providers: OrderedDict[str, Provider] = OrderedDict()
|
|
@ -10,15 +10,15 @@ from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError
|
|||
# from authlib.oauth1.auth import OAuthToken
|
||||
from authlib.oauth2.auth import OAuth2Token
|
||||
|
||||
from .models import User
|
||||
from .database import db, TokenNotInDb, UserNotInDB
|
||||
from .settings import settings
|
||||
from .auth_provider import providers, Provider
|
||||
from .provider import Provider
|
||||
|
||||
from ..models import User
|
||||
from ..database import db, TokenNotInDb, UserNotInDB
|
||||
from ..settings import settings
|
||||
from ..auth_providers import providers
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
|
||||
async def fetch_token(name, request):
|
||||
assert name is not None
|
||||
|
@ -51,9 +51,6 @@ async def update_token(
|
|||
# await item.save()
|
||||
|
||||
|
||||
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
|
||||
|
||||
|
||||
def init_providers():
|
||||
"""Add oidc providers to authlib from the settings
|
||||
and build the providers dict"""
|
||||
|
@ -86,7 +83,8 @@ def init_providers():
|
|||
providers[provider.id] = provider
|
||||
|
||||
|
||||
init_providers()
|
||||
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
|
||||
def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None:
|
||||
|
@ -245,7 +243,7 @@ async def get_user_from_token(
|
|||
except ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
"Expired signature (refresh not implemented yet)",
|
||||
"Expired signature (token refresh not implemented yet)",
|
||||
)
|
||||
except InvalidKeyError:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
|
||||
|
@ -263,12 +261,12 @@ async def get_user_from_token(
|
|||
user.access_token = token
|
||||
except UserNotInDB:
|
||||
logger.info(
|
||||
f"User {user_id} not found in DB, creating it (real apps can behave differently"
|
||||
f"User {user_id} not found in DB, creating it (real apps can behave differently)"
|
||||
)
|
||||
user = await db.add_user(
|
||||
sub=payload["sub"],
|
||||
user_info=payload,
|
||||
auth_provider=getattr(authlib_oauth, auth_provider_id),
|
||||
auth_provider=providers[auth_provider_id],
|
||||
access_token=token,
|
||||
)
|
||||
return user
|
5
src/oidc_test/auth_providers.py
Normal file
5
src/oidc_test/auth_providers.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
from .auth.provider import Provider
|
||||
|
||||
providers: OrderedDict[str, Provider] = OrderedDict()
|
|
@ -5,8 +5,10 @@ import logging
|
|||
from authlib.oauth2.rfc6749 import OAuth2Token
|
||||
from jwt import PyJWTError
|
||||
|
||||
from .auth.provider import Provider
|
||||
|
||||
from .models import User, Role
|
||||
from .auth_provider import Provider, providers
|
||||
from .auth_providers import providers
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
|
|
|
@ -26,10 +26,8 @@ from authlib.oauth2.rfc6749 import OAuth2Token
|
|||
# from fastapi.security import OpenIdConnect
|
||||
# from pkce import generate_code_verifier, generate_pkce_pair
|
||||
|
||||
from .settings import settings
|
||||
from .auth_provider import NoPublicKey, Provider, providers
|
||||
from .models import User
|
||||
from .auth_utils import (
|
||||
from .auth.provider import NoPublicKey, Provider
|
||||
from .auth.utils import (
|
||||
get_auth_provider,
|
||||
get_auth_provider_or_none,
|
||||
get_current_user_or_none,
|
||||
|
@ -38,6 +36,11 @@ from .auth_utils import (
|
|||
get_token,
|
||||
update_token,
|
||||
)
|
||||
|
||||
from .auth.utils import init_providers
|
||||
from .settings import settings
|
||||
from .auth_providers import providers
|
||||
from .models import User
|
||||
from .database import TokenNotInDb, db
|
||||
from .resource_server import resource_server
|
||||
|
||||
|
@ -49,6 +52,7 @@ templates = Jinja2Templates(Path(__file__).parent / "templates")
|
|||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
assert app is not None
|
||||
init_providers()
|
||||
for provider in list(providers.values()):
|
||||
try:
|
||||
await provider.get_info()
|
||||
|
|
|
@ -56,7 +56,7 @@ class User(UserBase):
|
|||
def decode_access_token(self, verify_signature: bool = True):
|
||||
assert self.access_token is not None, "no access_token"
|
||||
assert self.auth_provider_id is not None, "no auth_provider_id"
|
||||
from .auth_provider import providers
|
||||
from .auth_providers import providers
|
||||
|
||||
return providers[self.auth_provider_id].decode(
|
||||
self.access_token, verify_signature=verify_signature
|
||||
|
|
|
@ -12,14 +12,16 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
# from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
# from authlib.oauth2.rfc6749 import OAuth2Token
|
||||
|
||||
from .models import User
|
||||
from .auth_utils import (
|
||||
from .auth.provider import Provider
|
||||
from .auth.utils import (
|
||||
get_token_or_none,
|
||||
get_user_from_token,
|
||||
UserWithRole,
|
||||
)
|
||||
|
||||
from .auth_providers import providers
|
||||
from .settings import settings
|
||||
from .auth_provider import providers, Provider
|
||||
from .models import User
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
|
|
|
@ -13,8 +13,6 @@ from pydantic_settings import (
|
|||
)
|
||||
from starlette.requests import Request
|
||||
|
||||
from .models import User
|
||||
|
||||
|
||||
class Resource(BaseModel):
|
||||
"""A resource with an URL that can be accessed with an OAuth2 access token"""
|
||||
|
@ -57,7 +55,7 @@ class AuthProviderSettings(BaseModel):
|
|||
def token_url(self) -> str:
|
||||
return "auth/" + self.id
|
||||
|
||||
def get_account_url(self, request: Request, user: User) -> str | None:
|
||||
def get_account_url(self, request: Request, user: dict) -> str | None:
|
||||
if self.account_url_template:
|
||||
if not (self.url.endswith("/") or self.account_url_template.startswith("/")):
|
||||
sep = "/"
|
||||
|
|
|
@ -184,11 +184,10 @@ hr {
|
|||
font-family: monospace;
|
||||
}
|
||||
|
||||
.resource {
|
||||
.resourceResult {
|
||||
padding: 0.5em;
|
||||
display: flex;
|
||||
gap: 0.5em;
|
||||
flex-direction: column;
|
||||
width: fit-content;
|
||||
align-items: center;
|
||||
margin: 5px auto;
|
||||
|
|
|
@ -51,7 +51,7 @@
|
|||
{% endif %}
|
||||
{% if auth_provider.account_url_template %}
|
||||
<button
|
||||
onclick="location.href='{{ auth_provider.get_account_url(request, user) }}'"
|
||||
onclick="location.href='{{ auth_provider.get_account_url(request, user.model_dump()) }}'"
|
||||
class="account">
|
||||
Account management
|
||||
</button>
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue