Continue refactor
This commit is contained in:
parent
496ce016e3
commit
e56be3c378
10 changed files with 38 additions and 34 deletions
|
@ -2,14 +2,13 @@ from json import JSONDecodeError
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from jwt import decode
|
from jwt import decode
|
||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
from .settings import AuthProviderSettings, settings
|
from ..settings import AuthProviderSettings, settings
|
||||||
from .models import User
|
from ..models import User
|
||||||
|
|
||||||
logger = logging.getLogger("oidc-test")
|
logger = logging.getLogger("oidc-test")
|
||||||
|
|
||||||
|
@ -90,6 +89,3 @@ class Provider(AuthProviderSettings):
|
||||||
|
|
||||||
def get_session_key(self, userinfo):
|
def get_session_key(self, userinfo):
|
||||||
return userinfo[self.session_key]
|
return userinfo[self.session_key]
|
||||||
|
|
||||||
|
|
||||||
providers: OrderedDict[str, Provider] = OrderedDict()
|
|
|
@ -10,15 +10,15 @@ from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError
|
||||||
# from authlib.oauth1.auth import OAuthToken
|
# from authlib.oauth1.auth import OAuthToken
|
||||||
from authlib.oauth2.auth import OAuth2Token
|
from authlib.oauth2.auth import OAuth2Token
|
||||||
|
|
||||||
from .models import User
|
from .provider import Provider
|
||||||
from .database import db, TokenNotInDb, UserNotInDB
|
|
||||||
from .settings import settings
|
from ..models import User
|
||||||
from .auth_provider import providers, Provider
|
from ..database import db, TokenNotInDb, UserNotInDB
|
||||||
|
from ..settings import settings
|
||||||
|
from ..auth_providers import providers
|
||||||
|
|
||||||
logger = logging.getLogger("oidc-test")
|
logger = logging.getLogger("oidc-test")
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_token(name, request):
|
async def fetch_token(name, request):
|
||||||
assert name is not None
|
assert name is not None
|
||||||
|
@ -51,9 +51,6 @@ async def update_token(
|
||||||
# await item.save()
|
# await item.save()
|
||||||
|
|
||||||
|
|
||||||
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
|
|
||||||
|
|
||||||
|
|
||||||
def init_providers():
|
def init_providers():
|
||||||
"""Add oidc providers to authlib from the settings
|
"""Add oidc providers to authlib from the settings
|
||||||
and build the providers dict"""
|
and build the providers dict"""
|
||||||
|
@ -86,7 +83,8 @@ def init_providers():
|
||||||
providers[provider.id] = provider
|
providers[provider.id] = provider
|
||||||
|
|
||||||
|
|
||||||
init_providers()
|
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
|
|
||||||
def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None:
|
def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None:
|
||||||
|
@ -245,7 +243,7 @@ async def get_user_from_token(
|
||||||
except ExpiredSignatureError:
|
except ExpiredSignatureError:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_401_UNAUTHORIZED,
|
status.HTTP_401_UNAUTHORIZED,
|
||||||
"Expired signature (refresh not implemented yet)",
|
"Expired signature (token refresh not implemented yet)",
|
||||||
)
|
)
|
||||||
except InvalidKeyError:
|
except InvalidKeyError:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
|
||||||
|
@ -263,12 +261,12 @@ async def get_user_from_token(
|
||||||
user.access_token = token
|
user.access_token = token
|
||||||
except UserNotInDB:
|
except UserNotInDB:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"User {user_id} not found in DB, creating it (real apps can behave differently"
|
f"User {user_id} not found in DB, creating it (real apps can behave differently)"
|
||||||
)
|
)
|
||||||
user = await db.add_user(
|
user = await db.add_user(
|
||||||
sub=payload["sub"],
|
sub=payload["sub"],
|
||||||
user_info=payload,
|
user_info=payload,
|
||||||
auth_provider=getattr(authlib_oauth, auth_provider_id),
|
auth_provider=providers[auth_provider_id],
|
||||||
access_token=token,
|
access_token=token,
|
||||||
)
|
)
|
||||||
return user
|
return user
|
5
src/oidc_test/auth_providers.py
Normal file
5
src/oidc_test/auth_providers.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from .auth.provider import Provider
|
||||||
|
|
||||||
|
providers: OrderedDict[str, Provider] = OrderedDict()
|
|
@ -5,8 +5,10 @@ import logging
|
||||||
from authlib.oauth2.rfc6749 import OAuth2Token
|
from authlib.oauth2.rfc6749 import OAuth2Token
|
||||||
from jwt import PyJWTError
|
from jwt import PyJWTError
|
||||||
|
|
||||||
|
from .auth.provider import Provider
|
||||||
|
|
||||||
from .models import User, Role
|
from .models import User, Role
|
||||||
from .auth_provider import Provider, providers
|
from .auth_providers import providers
|
||||||
|
|
||||||
logger = logging.getLogger("oidc-test")
|
logger = logging.getLogger("oidc-test")
|
||||||
|
|
||||||
|
|
|
@ -26,10 +26,8 @@ from authlib.oauth2.rfc6749 import OAuth2Token
|
||||||
# from fastapi.security import OpenIdConnect
|
# from fastapi.security import OpenIdConnect
|
||||||
# from pkce import generate_code_verifier, generate_pkce_pair
|
# from pkce import generate_code_verifier, generate_pkce_pair
|
||||||
|
|
||||||
from .settings import settings
|
from .auth.provider import NoPublicKey, Provider
|
||||||
from .auth_provider import NoPublicKey, Provider, providers
|
from .auth.utils import (
|
||||||
from .models import User
|
|
||||||
from .auth_utils import (
|
|
||||||
get_auth_provider,
|
get_auth_provider,
|
||||||
get_auth_provider_or_none,
|
get_auth_provider_or_none,
|
||||||
get_current_user_or_none,
|
get_current_user_or_none,
|
||||||
|
@ -38,6 +36,11 @@ from .auth_utils import (
|
||||||
get_token,
|
get_token,
|
||||||
update_token,
|
update_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .auth.utils import init_providers
|
||||||
|
from .settings import settings
|
||||||
|
from .auth_providers import providers
|
||||||
|
from .models import User
|
||||||
from .database import TokenNotInDb, db
|
from .database import TokenNotInDb, db
|
||||||
from .resource_server import resource_server
|
from .resource_server import resource_server
|
||||||
|
|
||||||
|
@ -49,6 +52,7 @@ templates = Jinja2Templates(Path(__file__).parent / "templates")
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
assert app is not None
|
assert app is not None
|
||||||
|
init_providers()
|
||||||
for provider in list(providers.values()):
|
for provider in list(providers.values()):
|
||||||
try:
|
try:
|
||||||
await provider.get_info()
|
await provider.get_info()
|
||||||
|
|
|
@ -56,7 +56,7 @@ class User(UserBase):
|
||||||
def decode_access_token(self, verify_signature: bool = True):
|
def decode_access_token(self, verify_signature: bool = True):
|
||||||
assert self.access_token is not None, "no access_token"
|
assert self.access_token is not None, "no access_token"
|
||||||
assert self.auth_provider_id is not None, "no auth_provider_id"
|
assert self.auth_provider_id is not None, "no auth_provider_id"
|
||||||
from .auth_provider import providers
|
from .auth_providers import providers
|
||||||
|
|
||||||
return providers[self.auth_provider_id].decode(
|
return providers[self.auth_provider_id].decode(
|
||||||
self.access_token, verify_signature=verify_signature
|
self.access_token, verify_signature=verify_signature
|
||||||
|
|
|
@ -12,14 +12,16 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||||
# from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
# from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||||
# from authlib.oauth2.rfc6749 import OAuth2Token
|
# from authlib.oauth2.rfc6749 import OAuth2Token
|
||||||
|
|
||||||
from .models import User
|
from .auth.provider import Provider
|
||||||
from .auth_utils import (
|
from .auth.utils import (
|
||||||
get_token_or_none,
|
get_token_or_none,
|
||||||
get_user_from_token,
|
get_user_from_token,
|
||||||
UserWithRole,
|
UserWithRole,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .auth_providers import providers
|
||||||
from .settings import settings
|
from .settings import settings
|
||||||
from .auth_provider import providers, Provider
|
from .models import User
|
||||||
|
|
||||||
logger = logging.getLogger("oidc-test")
|
logger = logging.getLogger("oidc-test")
|
||||||
|
|
||||||
|
|
|
@ -13,8 +13,6 @@ from pydantic_settings import (
|
||||||
)
|
)
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from .models import User
|
|
||||||
|
|
||||||
|
|
||||||
class Resource(BaseModel):
|
class Resource(BaseModel):
|
||||||
"""A resource with an URL that can be accessed with an OAuth2 access token"""
|
"""A resource with an URL that can be accessed with an OAuth2 access token"""
|
||||||
|
@ -57,7 +55,7 @@ class AuthProviderSettings(BaseModel):
|
||||||
def token_url(self) -> str:
|
def token_url(self) -> str:
|
||||||
return "auth/" + self.id
|
return "auth/" + self.id
|
||||||
|
|
||||||
def get_account_url(self, request: Request, user: User) -> str | None:
|
def get_account_url(self, request: Request, user: dict) -> str | None:
|
||||||
if self.account_url_template:
|
if self.account_url_template:
|
||||||
if not (self.url.endswith("/") or self.account_url_template.startswith("/")):
|
if not (self.url.endswith("/") or self.account_url_template.startswith("/")):
|
||||||
sep = "/"
|
sep = "/"
|
||||||
|
|
|
@ -184,11 +184,10 @@ hr {
|
||||||
font-family: monospace;
|
font-family: monospace;
|
||||||
}
|
}
|
||||||
|
|
||||||
.resource {
|
.resourceResult {
|
||||||
padding: 0.5em;
|
padding: 0.5em;
|
||||||
display: flex;
|
display: flex;
|
||||||
gap: 0.5em;
|
gap: 0.5em;
|
||||||
flex-direction: column;
|
|
||||||
width: fit-content;
|
width: fit-content;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
margin: 5px auto;
|
margin: 5px auto;
|
||||||
|
|
|
@ -51,7 +51,7 @@
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% if auth_provider.account_url_template %}
|
{% if auth_provider.account_url_template %}
|
||||||
<button
|
<button
|
||||||
onclick="location.href='{{ auth_provider.get_account_url(request, user) }}'"
|
onclick="location.href='{{ auth_provider.get_account_url(request, user.model_dump()) }}'"
|
||||||
class="account">
|
class="account">
|
||||||
Account management
|
Account management
|
||||||
</button>
|
</button>
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue