Get more user info from provider's userinfo endpoint
This commit is contained in:
parent
724887e133
commit
b5f2e5b57b
3 changed files with 42 additions and 8 deletions
|
@ -1,8 +1,11 @@
|
|||
# Implement a fake in-memory database interface for demo purpose
|
||||
import logging
|
||||
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
|
||||
from .models import User, OAuth2Token
|
||||
from .models import User, OAuth2Token, Role
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Database:
|
||||
|
@ -12,9 +15,22 @@ class Database:
|
|||
# Last sessions for the user (key: users's subject id (sub))
|
||||
|
||||
async def add_user(
|
||||
self, sub: str, user_info: dict, oidc_provider: StarletteOAuth2App
|
||||
self,
|
||||
sub: str,
|
||||
user_info: dict,
|
||||
oidc_provider: StarletteOAuth2App,
|
||||
user_info_from_endpoint: dict,
|
||||
) -> User:
|
||||
user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider)
|
||||
try:
|
||||
raw_roles = user_info_from_endpoint["resource_access"][
|
||||
oidc_provider.client_id
|
||||
]["roles"]
|
||||
except Exception as err:
|
||||
logger.debug(f"Cannot read additional roles: {err}")
|
||||
raw_roles = []
|
||||
for raw_role in raw_roles:
|
||||
user.roles.append(Role(name=raw_role))
|
||||
self.users[sub] = user
|
||||
return user
|
||||
|
||||
|
|
|
@ -32,8 +32,7 @@ from .auth_utils import (
|
|||
from .auth_misc import pretty_details
|
||||
from .database import db
|
||||
|
||||
# logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
templates = Jinja2Templates(Path(__file__).parent / "templates")
|
||||
|
||||
|
@ -59,10 +58,11 @@ for provider in settings.oidc.providers:
|
|||
name=provider.id,
|
||||
server_metadata_url=provider.openid_configuration,
|
||||
client_kwargs={
|
||||
"scope": "openid email offline_access profile roles",
|
||||
"scope": "openid email offline_access profile",
|
||||
},
|
||||
client_id=provider.client_id,
|
||||
client_secret=provider.client_secret,
|
||||
api_base_url=provider.url,
|
||||
# fetch_token=fetch_token,
|
||||
# update_token=update_token,
|
||||
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
||||
|
@ -111,6 +111,7 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
|||
except OAuthError as error:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
|
||||
# Remember the oidc_provider in the session
|
||||
# logger.debug(f"Scope: {token['scope']}")
|
||||
request.session["oidc_provider_id"] = oidc_provider_id
|
||||
#
|
||||
# One could process the full decoded token which contains extra information
|
||||
|
@ -119,10 +120,26 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
|||
if userinfo := token.get("userinfo"):
|
||||
# sub given by oidc provider
|
||||
sub = userinfo["sub"]
|
||||
# Get additional data from userinfo endpoint
|
||||
try:
|
||||
user_info_url = oidc_provider.server_metadata["userinfo_endpoint"]
|
||||
user_info_from_endpoint = (
|
||||
await oidc_provider.get(
|
||||
user_info_url, token=token, follow_redirects=True
|
||||
)
|
||||
).json()
|
||||
except Exception as err:
|
||||
logger.info(f"Cannot get userinfo from endpoint: {err}")
|
||||
user_info_from_endpoint = {}
|
||||
# Build and remember the user in the session
|
||||
request.session["user_sub"] = sub
|
||||
# Store the user in the database
|
||||
user = await db.add_user(sub, user_info=userinfo, oidc_provider=oidc_provider)
|
||||
user = await db.add_user(
|
||||
sub,
|
||||
user_info=userinfo,
|
||||
oidc_provider=oidc_provider,
|
||||
user_info_from_endpoint=user_info_from_endpoint,
|
||||
)
|
||||
request.session["token"] = userinfo["sub"]
|
||||
await db.add_token(token, user)
|
||||
return RedirectResponse(url=request.url_for("home"))
|
||||
|
|
|
@ -19,7 +19,7 @@ class OIDCProvider(BaseModel):
|
|||
url: str
|
||||
client_id: str
|
||||
client_secret: str = ""
|
||||
hint: str = "Use your own credentials"
|
||||
hint: str = "No hint"
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
|
@ -29,7 +29,7 @@ class OIDCProvider(BaseModel):
|
|||
@computed_field
|
||||
@property
|
||||
def token_url(self) -> str:
|
||||
return "auth/" + self.name
|
||||
return "auth/" + self.id
|
||||
|
||||
|
||||
class OIDCSettings(BaseModel):
|
||||
|
@ -43,6 +43,7 @@ class Settings(BaseSettings):
|
|||
|
||||
oidc: OIDCSettings = OIDCSettings()
|
||||
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
|
||||
log: bool = False
|
||||
|
||||
model_config = SettingsConfigDict(env_nested_delimiter="__")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue