Fix error handling in resource server
This commit is contained in:
parent
f7ea132b7c
commit
17bf34a8a1
3 changed files with 78 additions and 59 deletions
|
@ -13,6 +13,7 @@ logger = logging.getLogger(__name__)
|
|||
class UserNotInDB(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Database:
|
||||
users: dict[str, User] = {}
|
||||
tokens: dict[str, OAuth2Token] = {}
|
||||
|
|
|
@ -56,7 +56,6 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
app = FastAPI(title="OIDC auth test", lifespan=lifespan)
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
|
|
|
@ -2,74 +2,93 @@ from datetime import datetime
|
|||
import logging
|
||||
|
||||
from httpx import AsyncClient
|
||||
from fastapi import HTTPException, status
|
||||
from jwt import ExpiredSignatureError, InvalidKeyError, decode
|
||||
|
||||
from .models import User
|
||||
from .auth_utils import oidc_providers_settings
|
||||
from .settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_resource(id: str, user: User) -> dict:
|
||||
async def get_resource(resource_id: str, user: User) -> dict:
|
||||
"""
|
||||
Resource processing: build an informative rely as a simple showcase
|
||||
"""
|
||||
pname = getattr(user.oidc_provider, "name", "?")
|
||||
resp = {
|
||||
"hello": f"Hi {user.name} from an OAuth resource provider.",
|
||||
"comment": f"I received a request for '{id}' with an access token signed by {pname}.",
|
||||
"hello": f"Hi {user.name} from an OAuth resource provider",
|
||||
"comment": f"I received a request for '{resource_id}' "
|
||||
+ f"with an access token signed by {pname}",
|
||||
}
|
||||
scope = f"get:{id}"
|
||||
user_scopes = user.userinfo["scope"].split(" ")
|
||||
if scope in user_scopes:
|
||||
if id == "time":
|
||||
resp["time"] = datetime.now().strftime("%c")
|
||||
elif id == "bs":
|
||||
async with AsyncClient() as client:
|
||||
bs = await client.get(
|
||||
"https://corporatebs-generator.sameerkumar.website/"
|
||||
)
|
||||
resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.")
|
||||
# For the demo, resource resource_id matches a scope get:resource_id,
|
||||
# but this has to be refined for production
|
||||
required_scope = f"get:{resource_id}"
|
||||
# Check if the required scope is in the scopes allowed in userinfo
|
||||
if "required_scope" in user.userinfo:
|
||||
user_scopes = user.userinfo["required_scope"].split(" ")
|
||||
if required_scope in user_scopes:
|
||||
await process(user, required_scope, resp)
|
||||
else:
|
||||
resp["sorry"] = f"I don't known how to give '{id}' but i know corporate bs."
|
||||
## For the showcase, giving a explanation.
|
||||
## Alternatively, raise HTTP_401_UNAUTHORIZED
|
||||
resp["sorry"] = (
|
||||
f"No scope {required_scope} in the access token "
|
||||
+ "but it is required for accessing this resource."
|
||||
)
|
||||
else:
|
||||
resp["sorry"] = (
|
||||
f"I don't serve the ressource {id} to you because there is no scope {scope} in the access token,"
|
||||
)
|
||||
resp["sorry"] = "There is no scope in id token"
|
||||
return resp
|
||||
|
||||
# assert user.oidc_provider is not None
|
||||
### Get some info (TODO: refactor)
|
||||
# if (auth_provider_id := user.oidc_provider.name) is None:
|
||||
# raise HTTPException(
|
||||
# status.HTTP_401_UNAUTHORIZED,
|
||||
# "Request headers must have a 'auth_provider' field",
|
||||
# )
|
||||
# if (
|
||||
# auth_provider_settings := oidc_providers_settings.get(auth_provider_id)
|
||||
# ) is None:
|
||||
# raise HTTPException(
|
||||
# status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
|
||||
# )
|
||||
# if (key := auth_provider_settings.get_public_key()) is None:
|
||||
# raise HTTPException(
|
||||
# status.HTTP_401_UNAUTHORIZED,
|
||||
# f"Key for provider '{auth_provider_id}' unknown",
|
||||
# )
|
||||
# logger.warn(f"refresh with scope {scope}")
|
||||
# breakpoint()
|
||||
# refreshed_auth_info = await user.oidc_provider.fetch_access_token(scope=scope)
|
||||
### Decode the new token
|
||||
# try:
|
||||
# payload = decode(
|
||||
# refreshed_auth_info["access_token"],
|
||||
# key=key,
|
||||
# algorithms=["RS256"],
|
||||
# audience="account",
|
||||
# options={"verify_signature": not settings.insecure.skip_verify_signature},
|
||||
# )
|
||||
# except ExpiredSignatureError as err:
|
||||
# logger.info(f"Expired signature: {err}")
|
||||
# raise HTTPException(
|
||||
# status.HTTP_401_UNAUTHORIZED,
|
||||
# "Expired signature (refresh not implemented yet)",
|
||||
# )
|
||||
|
||||
async def process(user, resource_id, resp):
|
||||
"""
|
||||
Too simple to be serious.
|
||||
It's a good fit for a plugin architecture for production
|
||||
"""
|
||||
assert user is not None
|
||||
if resource_id == "time":
|
||||
resp["time"] = datetime.now().strftime("%c")
|
||||
elif resource_id == "bs":
|
||||
async with AsyncClient() as client:
|
||||
bs = await client.get("https://corporatebs-generator.sameerkumar.website/")
|
||||
resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.")
|
||||
else:
|
||||
resp["sorry"] = (
|
||||
f"I don't known how to give '{resource_id}' but i know corporate bs."
|
||||
)
|
||||
|
||||
|
||||
# assert user.oidc_provider is not None
|
||||
### Get some info (TODO: refactor)
|
||||
# if (auth_provider_id := user.oidc_provider.name) is None:
|
||||
# raise HTTPException(
|
||||
# status.HTTP_401_UNAUTHORIZED,
|
||||
# "Request headers must have a 'auth_provider' field",
|
||||
# )
|
||||
# if (
|
||||
# auth_provider_settings := oidc_providers_settings.get(auth_provider_id)
|
||||
# ) is None:
|
||||
# raise HTTPException(
|
||||
# status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
|
||||
# )
|
||||
# if (key := auth_provider_settings.get_public_key()) is None:
|
||||
# raise HTTPException(
|
||||
# status.HTTP_401_UNAUTHORIZED,
|
||||
# f"Key for provider '{auth_provider_id}' unknown",
|
||||
# )
|
||||
# logger.warn(f"refresh with scope {scope}")
|
||||
# breakpoint()
|
||||
# refreshed_auth_info = await user.oidc_provider.fetch_access_token(scope=scope)
|
||||
### Decode the new token
|
||||
# try:
|
||||
# payload = decode(
|
||||
# refreshed_auth_info["access_token"],
|
||||
# key=key,
|
||||
# algorithms=["RS256"],
|
||||
# audience="account",
|
||||
# options={"verify_signature": not settings.insecure.skip_verify_signature},
|
||||
# )
|
||||
# except ExpiredSignatureError as err:
|
||||
# logger.info(f"Expired signature: {err}")
|
||||
# raise HTTPException(
|
||||
# status.HTTP_401_UNAUTHORIZED,
|
||||
# "Expired signature (refresh not implemented yet)",
|
||||
# )
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue