diff --git a/pyproject.toml b/pyproject.toml index 4509e5b..980bcfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,10 +9,12 @@ dependencies = [ "authlib>=1.4.0", "cachetools>=5.5.0", "fastapi[standard]>=0.115.6", + "httpx>=0.28.1", "itsdangerous>=2.2.0", "passlib[bcrypt]>=1.7.4", "pkce>=1.0.3", "pydantic-settings>=2.7.1", + "pyjwt>=2.10.1", "python-jose[cryptography]>=3.3.0", "requests>=2.32.3", "sqlmodel>=0.0.22", diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 0e5156b..2b3d0fd 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -1,22 +1,63 @@ -from typing import Union +from typing import Union, Annotated, Tuple from functools import wraps import logging -from fastapi import HTTPException, Request, status +from fastapi import HTTPException, Request, Depends, status +from fastapi.security import OAuth2PasswordBearer from authlib.oauth2.rfc6749 import OAuth2Token from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App +from jwt import decode # from authlib.oauth1.auth import OAuthToken # from authlib.oauth2.auth import OAuth2Token from .models import User -from .database import db -from .settings import settings +from .database import db, UserNotInDB +from .settings import settings, OIDCProvider logger = logging.getLogger(__name__) -OIDC_PROVIDERS = set([provider.id for provider in settings.oidc.providers]) +oidc_providers_settings: dict[str, OIDCProvider] = dict([(provider.id, provider) for provider in settings.oidc.providers]) +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + +def fetch_token(name, request): + breakpoint() + ... + # if name in oidc_providers: + # model = OAuth2Token + # else: + # model = OAuthToken + + # token = model.find(name=name, user=request.user) + # return token.to_token() + + +def update_token(*args, **kwargs): + breakpoint() + ... + + +authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) + + +# Add oidc providers to authlib from the settings +for id, provider in oidc_providers_settings.items(): + authlib_oauth.register( + name=id, + server_metadata_url=provider.openid_configuration, + client_kwargs={ + "scope": "openid email offline_access profile", + }, + client_id=provider.client_id, + client_secret=provider.client_secret, + api_base_url=provider.url, + # For PKCE (not implemented yet): + # code_challenge_method="S256", + # fetch_token=fetch_token, + # update_token=update_token, + # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) + ) def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None: """Return the oidc_provider from a request object, from the session. @@ -115,21 +156,34 @@ def get_token_info(token: dict) -> dict: return token_info -def fetch_token(name, request): - breakpoint() - ... - # if name in OIDC_PROVIDERS: - # model = OAuth2Token - # else: - # model = OAuthToken - - # token = model.find(name=name, user=request.user) - # return token.to_token() - - -def update_token(*args, **kwargs): - breakpoint() - ... - - -authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) +async def get_resource_user( + token: Annotated[str, Depends(oauth2_scheme)], + request: Request, +) -> User: + # TODO: decode token (ah!) + # See https://fastapi.tiangolo.com/tutorial/security/oauth2-jwt/#hash-and-verify-the-passwords + if (auth_provider_id := request.headers.get("auth_provider")) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Request headers must have a 'auth_provider' field") + if (auth_provider := oidc_providers_settings.get(auth_provider_id)) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'") + if (key := auth_provider.get_key()) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Key for provider '{auth_provider_id}' unknown") + try: + payload = decode(token, key=key, algorithms=["RS256"], audience="oidc-test") + except Exception as err: + logger.info("Cannot decode token, see below") + logger.exception(err) + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token") + if (user_id := payload.get('sub')) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found") + try: + user = await db.get_user(user_id) + except UserNotInDB: + logger.info(f"User {user_id} not found in DB, creating it (real apps can behave differently") + user = await db.add_user( + sub=payload['sub'], + user_info=payload, + oidc_provider=getattr(authlib_oauth, auth_provider_id), + user_info_from_endpoint={} + ) + return user diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 4b3f529..9d72081 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -10,6 +10,9 @@ from authlib.oauth2.rfc6749 import OAuth2Token logger = logging.getLogger(__name__) +class UserNotInDB(Exception): + pass + class Database: users: dict[str, User] = {} tokens: dict[str, OAuth2Token] = {} @@ -37,6 +40,8 @@ class Database: return user async def get_user(self, sub: str) -> User: + if sub not in self.users: + raise UserNotInDB return self.users[sub] async def add_token(self, token: OAuth2Token, user: User) -> None: diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 36d9d76..351cc2f 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -13,6 +13,7 @@ from fastapi import Depends, FastAPI, HTTPException, Request, status from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from fastapi.templating import Jinja2Templates +from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.base_client import OAuthError @@ -23,7 +24,7 @@ from authlib.oauth2.rfc6749 import OAuth2Token # from fastapi.security import OpenIdConnect # from pkce import generate_code_verifier, generate_pkce_pair -from .settings import settings, OIDCProvider +from .settings import settings from .models import User from .auth_utils import ( get_oidc_provider, @@ -31,21 +32,37 @@ from .auth_utils import ( hasrole, get_current_user_or_none, get_current_user, + get_resource_user, authlib_oauth, get_token, + oidc_providers_settings, ) from .auth_misc import pretty_details from .database import db +from .resource_server import get_resource logger = logging.getLogger("uvicorn.error") templates = Jinja2Templates(Path(__file__).parent / "templates") +origins = [ + "https://tiptop:3002", + "https://philo.ydns.eu/", +] + app = FastAPI( title="OIDC auth test", ) +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + app.mount( "/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static" ) @@ -56,32 +73,6 @@ app.add_middleware( secret_key=settings.secret_key, ) -# Add oidc providers to authlib from the settings - -# fastapi_providers: dict[str, OpenIdConnect] = {} -oidc_providers_settings: dict[str, OIDCProvider] = {} - -for provider in settings.oidc.providers: - authlib_oauth.register( - name=provider.id, - server_metadata_url=provider.openid_configuration, - client_kwargs={ - "scope": "openid email offline_access profile", - }, - client_id=provider.client_id, - client_secret=provider.client_secret, - api_base_url=provider.url, - # For PKCE (not implemented yet): - # code_challenge_method="S256", - # fetch_token=fetch_token, - # update_token=update_token, - # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) - ) - # fastapi_providers[provider.id] = OpenIdConnect( - # openIdConnectUrl=provider.openid_configuration - # ) - oidc_providers_settings[provider.id] = provider - @app.get("/") async def home( @@ -281,43 +272,16 @@ async def non_compliant_logout( @app.get("/resource/{id}") -async def get_resource( +async def get_resource_( id: str, request: Request, - user: Annotated[User, Depends(get_current_user)], - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], - token: Annotated[OAuth2Token, Depends(get_token)], + # user: Annotated[User, Depends(get_current_user)], + # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + # token: Annotated[OAuth2Token, Depends(get_token)], + user: Annotated[User, Depends(get_resource_user)], ) -> JSONResponse: """Generic path for testing a resource provided by a provider""" - assert user is not None # Just to keep QA checks happy - if oidc_provider is None: - raise HTTPException( - status.HTTP_406_NOT_ACCEPTABLE, detail="No such oidc provider" - ) - if ( - provider := oidc_providers_settings.get( - request.session.get("oidc_provider_id", "") - ) - ) is None: - raise HTTPException( - status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider setting" - ) - try: - resource = next(x for x in provider.resources if x.id == id) - except StopIteration: - raise HTTPException( - status.HTTP_406_NOT_ACCEPTABLE, detail="No such resource for this provider" - ) - if ( - response := await oidc_provider.get( - resource.url, - # headers={"Authorization": f"token {token['access_token']}"}, - token=token, - ) - ).is_success: - return JSONResponse(response.json()) - else: - raise HTTPException(status_code=response.status_code, detail=response.text) + return JSONResponse(await get_resource(id, user)) # Routes for test diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py new file mode 100644 index 0000000..a9dfe3a --- /dev/null +++ b/src/oidc_test/resource_server.py @@ -0,0 +1,21 @@ +from datetime import datetime +from httpx import AsyncClient + +from .models import User + +async def get_resource(id: str, user: User) -> dict: + pname = getattr(user.oidc_provider, "name", "?") + resp = { + "hello" : f"Hi {user.name} from an OAuth resource provider.", + "comment": f"I received a request for '{id}' with an access token signed by {pname}." + } + if id == "time": + resp["time"] = datetime.now().strftime("%c") + elif id == "bs": + async with AsyncClient() as client: + bs = await client.get("https://corporatebs-generator.sameerkumar.website/") + resp['bs'] = bs.json().get("phrase", "Sorry, i am out of BS today.") + else: + resp['sorry'] = f"I don't known how to give '{id}' but i know corporate bs." + + return resp diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 81d5099..c511f86 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -36,6 +36,7 @@ class OIDCProvider(BaseModel): hint: str = "No hint" resources: list[Resource] = [] account_url_template: str | None = None + key: str | None = None @computed_field @property @@ -63,6 +64,15 @@ class OIDCProvider(BaseModel): else: return None + def get_key(self) -> str | None: + """Return the public key formatted for """ + if self.key is None: + return None + return f""" + -----BEGIN PUBLIC KEY----- + {self.key} + -----END PUBLIC KEY----- + """ class ResourceProvider(BaseModel): id: str diff --git a/uv.lock b/uv.lock index 6ceb4ca..01b64de 100644 --- a/uv.lock +++ b/uv.lock @@ -488,10 +488,12 @@ dependencies = [ { name = "authlib" }, { name = "cachetools" }, { name = "fastapi", extra = ["standard"] }, + { name = "httpx" }, { name = "itsdangerous" }, { name = "passlib", extra = ["bcrypt"] }, { name = "pkce" }, { name = "pydantic-settings" }, + { name = "pyjwt" }, { name = "python-jose", extra = ["cryptography"] }, { name = "requests" }, { name = "sqlmodel" }, @@ -508,10 +510,12 @@ requires-dist = [ { name = "authlib", specifier = ">=1.4.0" }, { name = "cachetools", specifier = ">=5.5.0" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.115.6" }, + { name = "httpx", specifier = ">=0.28.1" }, { name = "itsdangerous", specifier = ">=2.2.0" }, { name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" }, { name = "pkce", specifier = ">=1.0.3" }, { name = "pydantic-settings", specifier = ">=2.7.1" }, + { name = "pyjwt", specifier = ">=2.10.1" }, { name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" }, { name = "requests", specifier = ">=2.32.3" }, { name = "sqlmodel", specifier = ">=0.0.22" }, @@ -694,6 +698,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f7/3f/01c8b82017c199075f8f788d0d906b9ffbbc5a47dc9918a945e13d5a2bda/pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a", size = 1205513 }, ] +[[package]] +name = "pyjwt" +version = "2.10.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997 }, +] + [[package]] name = "pytest" version = "8.3.4"