Resource server: read the required scope in access token
This commit is contained in:
parent
f910834736
commit
b3e19b3e40
4 changed files with 129 additions and 37 deletions
src/oidc_test
|
@ -1,4 +1,4 @@
|
|||
from typing import Union, Annotated, Tuple
|
||||
from typing import Union, Annotated
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
||||
|
@ -18,10 +18,13 @@ from .settings import settings, OIDCProvider
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
oidc_providers_settings: dict[str, OIDCProvider] = dict([(provider.id, provider) for provider in settings.oidc.providers])
|
||||
oidc_providers_settings: dict[str, OIDCProvider] = dict(
|
||||
[(provider.id, provider) for provider in settings.oidc.providers]
|
||||
)
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
|
||||
def fetch_token(name, request):
|
||||
breakpoint()
|
||||
...
|
||||
|
@ -61,8 +64,10 @@ def init_providers():
|
|||
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
||||
)
|
||||
|
||||
|
||||
init_providers()
|
||||
|
||||
|
||||
async def get_providers_info():
|
||||
# Get the public key:
|
||||
async with AsyncClient() as client:
|
||||
|
@ -174,18 +179,35 @@ async def get_user_from_token(
|
|||
request: Request,
|
||||
) -> User:
|
||||
if (auth_provider_id := request.headers.get("auth_provider")) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Request headers must have a 'auth_provider' field")
|
||||
if (auth_provider_settings := oidc_providers_settings.get(auth_provider_id)) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'")
|
||||
oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id)
|
||||
await oidc_provider.load_server_metadata()
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
"Request headers must have a 'auth_provider' field",
|
||||
)
|
||||
if (
|
||||
auth_provider_settings := oidc_providers_settings.get(auth_provider_id)
|
||||
) is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
|
||||
)
|
||||
if (key := auth_provider_settings.get_public_key()) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Key for provider '{auth_provider_id}' unknown")
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
f"Key for provider '{auth_provider_id}' unknown",
|
||||
)
|
||||
try:
|
||||
payload = decode(token, key=key, algorithms=["RS256"], audience="oidc-test")
|
||||
payload = decode(
|
||||
token,
|
||||
key=key,
|
||||
algorithms=["RS256"],
|
||||
audience="oidc-test",
|
||||
options={"verify_signature": not settings.insecure.skip_verify_signature},
|
||||
)
|
||||
except ExpiredSignatureError as err:
|
||||
logger.info(f"Expired signature: {err}")
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Expired signature (refresh not implemented yet)")
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
"Expired signature (refresh not implemented yet)",
|
||||
)
|
||||
except InvalidKeyError as err:
|
||||
logger.info(f"Invalid key: {err}")
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
|
||||
|
@ -193,16 +215,20 @@ async def get_user_from_token(
|
|||
logger.info("Cannot decode token, see below")
|
||||
logger.exception(err)
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token")
|
||||
if (user_id := payload.get('sub')) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found")
|
||||
if (user_id := payload.get("sub")) is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found"
|
||||
)
|
||||
try:
|
||||
user = await db.get_user(user_id)
|
||||
except UserNotInDB:
|
||||
logger.info(f"User {user_id} not found in DB, creating it (real apps can behave differently")
|
||||
logger.info(
|
||||
f"User {user_id} not found in DB, creating it (real apps can behave differently"
|
||||
)
|
||||
user = await db.add_user(
|
||||
sub=payload['sub'],
|
||||
sub=payload["sub"],
|
||||
user_info=payload,
|
||||
oidc_provider=getattr(authlib_oauth, auth_provider_id),
|
||||
user_info_from_endpoint={}
|
||||
user_info_from_endpoint={},
|
||||
)
|
||||
return user
|
||||
|
|
|
@ -53,16 +53,14 @@ origins = [
|
|||
"https://philo.ydns.eu/",
|
||||
]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await get_providers_info()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="OIDC auth test",
|
||||
lifespan=lifespan
|
||||
)
|
||||
app = FastAPI(title="OIDC auth test", lifespan=lifespan)
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
|
@ -284,7 +282,6 @@ async def non_compliant_logout(
|
|||
@app.get("/resource/{id}")
|
||||
async def get_resource_(
|
||||
id: str,
|
||||
request: Request,
|
||||
# user: Annotated[User, Depends(get_current_user)],
|
||||
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
# token: Annotated[OAuth2Token, Depends(get_token)],
|
||||
|
@ -294,7 +291,7 @@ async def get_resource_(
|
|||
return JSONResponse(await get_resource(id, user))
|
||||
|
||||
|
||||
# Routes for test
|
||||
# Routes for RBAC based tests
|
||||
|
||||
|
||||
@app.get("/public")
|
||||
|
|
|
@ -1,21 +1,76 @@
|
|||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
from httpx import AsyncClient
|
||||
from fastapi import HTTPException, status
|
||||
from jwt import ExpiredSignatureError, InvalidKeyError, decode
|
||||
|
||||
from .models import User
|
||||
from .auth_utils import oidc_providers_settings
|
||||
from .settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_resource(id: str, user: User) -> dict:
|
||||
pname = getattr(user.oidc_provider, "name", "?")
|
||||
resp = {
|
||||
"hello": f"Hi {user.name} from an OAuth resource provider.",
|
||||
"comment": f"I received a request for '{id}' with an access token signed by {pname}."
|
||||
"comment": f"I received a request for '{id}' with an access token signed by {pname}.",
|
||||
}
|
||||
scope = f"get:{id}"
|
||||
user_scopes = user.userinfo["scope"].split(" ")
|
||||
if scope in user_scopes:
|
||||
if id == "time":
|
||||
resp["time"] = datetime.now().strftime("%c")
|
||||
elif id == "bs":
|
||||
async with AsyncClient() as client:
|
||||
bs = await client.get("https://corporatebs-generator.sameerkumar.website/")
|
||||
resp['bs'] = bs.json().get("phrase", "Sorry, i am out of BS today.")
|
||||
bs = await client.get(
|
||||
"https://corporatebs-generator.sameerkumar.website/"
|
||||
)
|
||||
resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.")
|
||||
else:
|
||||
resp['sorry'] = f"I don't known how to give '{id}' but i know corporate bs."
|
||||
|
||||
resp["sorry"] = f"I don't known how to give '{id}' but i know corporate bs."
|
||||
else:
|
||||
resp["sorry"] = (
|
||||
f"I don't serve the ressource {id} to you because"
|
||||
"there is no scope {scope} in the access token,"
|
||||
)
|
||||
return resp
|
||||
|
||||
# assert user.oidc_provider is not None
|
||||
### Get some info (TODO: refactor)
|
||||
# if (auth_provider_id := user.oidc_provider.name) is None:
|
||||
# raise HTTPException(
|
||||
# status.HTTP_401_UNAUTHORIZED,
|
||||
# "Request headers must have a 'auth_provider' field",
|
||||
# )
|
||||
# if (
|
||||
# auth_provider_settings := oidc_providers_settings.get(auth_provider_id)
|
||||
# ) is None:
|
||||
# raise HTTPException(
|
||||
# status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
|
||||
# )
|
||||
# if (key := auth_provider_settings.get_public_key()) is None:
|
||||
# raise HTTPException(
|
||||
# status.HTTP_401_UNAUTHORIZED,
|
||||
# f"Key for provider '{auth_provider_id}' unknown",
|
||||
# )
|
||||
# logger.warn(f"refresh with scope {scope}")
|
||||
# breakpoint()
|
||||
# refreshed_auth_info = await user.oidc_provider.fetch_access_token(scope=scope)
|
||||
### Decode the new token
|
||||
# try:
|
||||
# payload = decode(
|
||||
# refreshed_auth_info["access_token"],
|
||||
# key=key,
|
||||
# algorithms=["RS256"],
|
||||
# audience="account",
|
||||
# options={"verify_signature": not settings.insecure.skip_verify_signature},
|
||||
# )
|
||||
# except ExpiredSignatureError as err:
|
||||
# logger.info(f"Expired signature: {err}")
|
||||
# raise HTTPException(
|
||||
# status.HTTP_401_UNAUTHORIZED,
|
||||
# "Expired signature (refresh not implemented yet)",
|
||||
# )
|
||||
|
|
|
@ -36,8 +36,12 @@ class OIDCProvider(BaseModel):
|
|||
hint: str = "No hint"
|
||||
resources: list[Resource] = []
|
||||
account_url_template: str | None = None
|
||||
info_url: str | None = None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key)
|
||||
info: dict[str, str | int ] | None = None # Info fetched from info_url, eg. public key
|
||||
info_url: str | None = (
|
||||
None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key)
|
||||
)
|
||||
info: dict[str, str | int] | None = (
|
||||
None # Info fetched from info_url, eg. public key
|
||||
)
|
||||
public_key: str | None = None
|
||||
|
||||
@computed_field
|
||||
|
@ -68,7 +72,9 @@ class OIDCProvider(BaseModel):
|
|||
|
||||
def get_public_key(self) -> str | None:
|
||||
"""Return the public key formatted for decoding token"""
|
||||
public_key = self.public_key or (self.info is not None and self.info["public_key"])
|
||||
public_key = self.public_key or (
|
||||
self.info is not None and self.info["public_key"]
|
||||
)
|
||||
if public_key is None:
|
||||
return None
|
||||
return f"""
|
||||
|
@ -77,6 +83,7 @@ class OIDCProvider(BaseModel):
|
|||
-----END PUBLIC KEY-----
|
||||
"""
|
||||
|
||||
|
||||
class ResourceProvider(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
@ -90,15 +97,22 @@ class OIDCSettings(BaseModel):
|
|||
swagger_provider: str = ""
|
||||
|
||||
|
||||
class Insecure(BaseModel):
|
||||
"""Warning: changing these defaults are only suitable for debugging"""
|
||||
|
||||
skip_verify_signature: bool = False
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Settings wil be read from an .env file"""
|
||||
|
||||
model_config = SettingsConfigDict(env_nested_delimiter="__")
|
||||
|
||||
oidc: OIDCSettings = OIDCSettings()
|
||||
resource_providers: list[ResourceProvider] = []
|
||||
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
|
||||
log: bool = False
|
||||
|
||||
model_config = SettingsConfigDict(env_nested_delimiter="__")
|
||||
insecure: Insecure = Insecure()
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue