Resource server: read the required scope in access token
This commit is contained in:
parent
f910834736
commit
b3e19b3e40
4 changed files with 129 additions and 37 deletions
|
@ -1,4 +1,4 @@
|
||||||
from typing import Union, Annotated, Tuple
|
from typing import Union, Annotated
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -18,10 +18,13 @@ from .settings import settings, OIDCProvider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
oidc_providers_settings: dict[str, OIDCProvider] = dict([(provider.id, provider) for provider in settings.oidc.providers])
|
oidc_providers_settings: dict[str, OIDCProvider] = dict(
|
||||||
|
[(provider.id, provider) for provider in settings.oidc.providers]
|
||||||
|
)
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
|
|
||||||
def fetch_token(name, request):
|
def fetch_token(name, request):
|
||||||
breakpoint()
|
breakpoint()
|
||||||
...
|
...
|
||||||
|
@ -61,8 +64,10 @@ def init_providers():
|
||||||
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
init_providers()
|
init_providers()
|
||||||
|
|
||||||
|
|
||||||
async def get_providers_info():
|
async def get_providers_info():
|
||||||
# Get the public key:
|
# Get the public key:
|
||||||
async with AsyncClient() as client:
|
async with AsyncClient() as client:
|
||||||
|
@ -174,18 +179,35 @@ async def get_user_from_token(
|
||||||
request: Request,
|
request: Request,
|
||||||
) -> User:
|
) -> User:
|
||||||
if (auth_provider_id := request.headers.get("auth_provider")) is None:
|
if (auth_provider_id := request.headers.get("auth_provider")) is None:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Request headers must have a 'auth_provider' field")
|
raise HTTPException(
|
||||||
if (auth_provider_settings := oidc_providers_settings.get(auth_provider_id)) is None:
|
status.HTTP_401_UNAUTHORIZED,
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'")
|
"Request headers must have a 'auth_provider' field",
|
||||||
oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id)
|
)
|
||||||
await oidc_provider.load_server_metadata()
|
if (
|
||||||
|
auth_provider_settings := oidc_providers_settings.get(auth_provider_id)
|
||||||
|
) is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
|
||||||
|
)
|
||||||
if (key := auth_provider_settings.get_public_key()) is None:
|
if (key := auth_provider_settings.get_public_key()) is None:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Key for provider '{auth_provider_id}' unknown")
|
raise HTTPException(
|
||||||
|
status.HTTP_401_UNAUTHORIZED,
|
||||||
|
f"Key for provider '{auth_provider_id}' unknown",
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
payload = decode(token, key=key, algorithms=["RS256"], audience="oidc-test")
|
payload = decode(
|
||||||
|
token,
|
||||||
|
key=key,
|
||||||
|
algorithms=["RS256"],
|
||||||
|
audience="oidc-test",
|
||||||
|
options={"verify_signature": not settings.insecure.skip_verify_signature},
|
||||||
|
)
|
||||||
except ExpiredSignatureError as err:
|
except ExpiredSignatureError as err:
|
||||||
logger.info(f"Expired signature: {err}")
|
logger.info(f"Expired signature: {err}")
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Expired signature (refresh not implemented yet)")
|
raise HTTPException(
|
||||||
|
status.HTTP_401_UNAUTHORIZED,
|
||||||
|
"Expired signature (refresh not implemented yet)",
|
||||||
|
)
|
||||||
except InvalidKeyError as err:
|
except InvalidKeyError as err:
|
||||||
logger.info(f"Invalid key: {err}")
|
logger.info(f"Invalid key: {err}")
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
|
||||||
|
@ -193,16 +215,20 @@ async def get_user_from_token(
|
||||||
logger.info("Cannot decode token, see below")
|
logger.info("Cannot decode token, see below")
|
||||||
logger.exception(err)
|
logger.exception(err)
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token")
|
||||||
if (user_id := payload.get('sub')) is None:
|
if (user_id := payload.get("sub")) is None:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found")
|
raise HTTPException(
|
||||||
|
status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
user = await db.get_user(user_id)
|
user = await db.get_user(user_id)
|
||||||
except UserNotInDB:
|
except UserNotInDB:
|
||||||
logger.info(f"User {user_id} not found in DB, creating it (real apps can behave differently")
|
logger.info(
|
||||||
|
f"User {user_id} not found in DB, creating it (real apps can behave differently"
|
||||||
|
)
|
||||||
user = await db.add_user(
|
user = await db.add_user(
|
||||||
sub=payload['sub'],
|
sub=payload["sub"],
|
||||||
user_info=payload,
|
user_info=payload,
|
||||||
oidc_provider=getattr(authlib_oauth, auth_provider_id),
|
oidc_provider=getattr(authlib_oauth, auth_provider_id),
|
||||||
user_info_from_endpoint={}
|
user_info_from_endpoint={},
|
||||||
)
|
)
|
||||||
return user
|
return user
|
||||||
|
|
|
@ -53,16 +53,14 @@ origins = [
|
||||||
"https://philo.ydns.eu/",
|
"https://philo.ydns.eu/",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
await get_providers_info()
|
await get_providers_info()
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(title="OIDC auth test", lifespan=lifespan)
|
||||||
title="OIDC auth test",
|
|
||||||
lifespan=lifespan
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
|
@ -284,7 +282,6 @@ async def non_compliant_logout(
|
||||||
@app.get("/resource/{id}")
|
@app.get("/resource/{id}")
|
||||||
async def get_resource_(
|
async def get_resource_(
|
||||||
id: str,
|
id: str,
|
||||||
request: Request,
|
|
||||||
# user: Annotated[User, Depends(get_current_user)],
|
# user: Annotated[User, Depends(get_current_user)],
|
||||||
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||||
# token: Annotated[OAuth2Token, Depends(get_token)],
|
# token: Annotated[OAuth2Token, Depends(get_token)],
|
||||||
|
@ -294,7 +291,7 @@ async def get_resource_(
|
||||||
return JSONResponse(await get_resource(id, user))
|
return JSONResponse(await get_resource(id, user))
|
||||||
|
|
||||||
|
|
||||||
# Routes for test
|
# Routes for RBAC based tests
|
||||||
|
|
||||||
|
|
||||||
@app.get("/public")
|
@app.get("/public")
|
||||||
|
|
|
@ -1,21 +1,76 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import logging
|
||||||
|
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from jwt import ExpiredSignatureError, InvalidKeyError, decode
|
||||||
|
|
||||||
from .models import User
|
from .models import User
|
||||||
|
from .auth_utils import oidc_providers_settings
|
||||||
|
from .settings import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def get_resource(id: str, user: User) -> dict:
|
async def get_resource(id: str, user: User) -> dict:
|
||||||
pname = getattr(user.oidc_provider, "name", "?")
|
pname = getattr(user.oidc_provider, "name", "?")
|
||||||
resp = {
|
resp = {
|
||||||
"hello": f"Hi {user.name} from an OAuth resource provider.",
|
"hello": f"Hi {user.name} from an OAuth resource provider.",
|
||||||
"comment": f"I received a request for '{id}' with an access token signed by {pname}."
|
"comment": f"I received a request for '{id}' with an access token signed by {pname}.",
|
||||||
}
|
}
|
||||||
|
scope = f"get:{id}"
|
||||||
|
user_scopes = user.userinfo["scope"].split(" ")
|
||||||
|
if scope in user_scopes:
|
||||||
if id == "time":
|
if id == "time":
|
||||||
resp["time"] = datetime.now().strftime("%c")
|
resp["time"] = datetime.now().strftime("%c")
|
||||||
elif id == "bs":
|
elif id == "bs":
|
||||||
async with AsyncClient() as client:
|
async with AsyncClient() as client:
|
||||||
bs = await client.get("https://corporatebs-generator.sameerkumar.website/")
|
bs = await client.get(
|
||||||
resp['bs'] = bs.json().get("phrase", "Sorry, i am out of BS today.")
|
"https://corporatebs-generator.sameerkumar.website/"
|
||||||
|
)
|
||||||
|
resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.")
|
||||||
else:
|
else:
|
||||||
resp['sorry'] = f"I don't known how to give '{id}' but i know corporate bs."
|
resp["sorry"] = f"I don't known how to give '{id}' but i know corporate bs."
|
||||||
|
else:
|
||||||
|
resp["sorry"] = (
|
||||||
|
f"I don't serve the ressource {id} to you because"
|
||||||
|
"there is no scope {scope} in the access token,"
|
||||||
|
)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
# assert user.oidc_provider is not None
|
||||||
|
### Get some info (TODO: refactor)
|
||||||
|
# if (auth_provider_id := user.oidc_provider.name) is None:
|
||||||
|
# raise HTTPException(
|
||||||
|
# status.HTTP_401_UNAUTHORIZED,
|
||||||
|
# "Request headers must have a 'auth_provider' field",
|
||||||
|
# )
|
||||||
|
# if (
|
||||||
|
# auth_provider_settings := oidc_providers_settings.get(auth_provider_id)
|
||||||
|
# ) is None:
|
||||||
|
# raise HTTPException(
|
||||||
|
# status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
|
||||||
|
# )
|
||||||
|
# if (key := auth_provider_settings.get_public_key()) is None:
|
||||||
|
# raise HTTPException(
|
||||||
|
# status.HTTP_401_UNAUTHORIZED,
|
||||||
|
# f"Key for provider '{auth_provider_id}' unknown",
|
||||||
|
# )
|
||||||
|
# logger.warn(f"refresh with scope {scope}")
|
||||||
|
# breakpoint()
|
||||||
|
# refreshed_auth_info = await user.oidc_provider.fetch_access_token(scope=scope)
|
||||||
|
### Decode the new token
|
||||||
|
# try:
|
||||||
|
# payload = decode(
|
||||||
|
# refreshed_auth_info["access_token"],
|
||||||
|
# key=key,
|
||||||
|
# algorithms=["RS256"],
|
||||||
|
# audience="account",
|
||||||
|
# options={"verify_signature": not settings.insecure.skip_verify_signature},
|
||||||
|
# )
|
||||||
|
# except ExpiredSignatureError as err:
|
||||||
|
# logger.info(f"Expired signature: {err}")
|
||||||
|
# raise HTTPException(
|
||||||
|
# status.HTTP_401_UNAUTHORIZED,
|
||||||
|
# "Expired signature (refresh not implemented yet)",
|
||||||
|
# )
|
||||||
|
|
|
@ -36,8 +36,12 @@ class OIDCProvider(BaseModel):
|
||||||
hint: str = "No hint"
|
hint: str = "No hint"
|
||||||
resources: list[Resource] = []
|
resources: list[Resource] = []
|
||||||
account_url_template: str | None = None
|
account_url_template: str | None = None
|
||||||
info_url: str | None = None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key)
|
info_url: str | None = (
|
||||||
info: dict[str, str | int ] | None = None # Info fetched from info_url, eg. public key
|
None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key)
|
||||||
|
)
|
||||||
|
info: dict[str, str | int] | None = (
|
||||||
|
None # Info fetched from info_url, eg. public key
|
||||||
|
)
|
||||||
public_key: str | None = None
|
public_key: str | None = None
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
|
@ -68,7 +72,9 @@ class OIDCProvider(BaseModel):
|
||||||
|
|
||||||
def get_public_key(self) -> str | None:
|
def get_public_key(self) -> str | None:
|
||||||
"""Return the public key formatted for decoding token"""
|
"""Return the public key formatted for decoding token"""
|
||||||
public_key = self.public_key or (self.info is not None and self.info["public_key"])
|
public_key = self.public_key or (
|
||||||
|
self.info is not None and self.info["public_key"]
|
||||||
|
)
|
||||||
if public_key is None:
|
if public_key is None:
|
||||||
return None
|
return None
|
||||||
return f"""
|
return f"""
|
||||||
|
@ -77,6 +83,7 @@ class OIDCProvider(BaseModel):
|
||||||
-----END PUBLIC KEY-----
|
-----END PUBLIC KEY-----
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ResourceProvider(BaseModel):
|
class ResourceProvider(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
|
@ -90,15 +97,22 @@ class OIDCSettings(BaseModel):
|
||||||
swagger_provider: str = ""
|
swagger_provider: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class Insecure(BaseModel):
|
||||||
|
"""Warning: changing these defaults are only suitable for debugging"""
|
||||||
|
|
||||||
|
skip_verify_signature: bool = False
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
"""Settings wil be read from an .env file"""
|
"""Settings wil be read from an .env file"""
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_nested_delimiter="__")
|
||||||
|
|
||||||
oidc: OIDCSettings = OIDCSettings()
|
oidc: OIDCSettings = OIDCSettings()
|
||||||
resource_providers: list[ResourceProvider] = []
|
resource_providers: list[ResourceProvider] = []
|
||||||
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
|
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
|
||||||
log: bool = False
|
log: bool = False
|
||||||
|
insecure: Insecure = Insecure()
|
||||||
model_config = SettingsConfigDict(env_nested_delimiter="__")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def settings_customise_sources(
|
def settings_customise_sources(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue