This commit is contained in:
phil 2025-01-02 03:30:18 +01:00
parent fa4032c0e6
commit 78ce6fd01a

View file

@ -174,9 +174,11 @@ async def current_user(request: Request, token: str | None = Depends(fastapi_oau
async def login(request: Request, provider: str) -> RedirectResponse: async def login(request: Request, provider: str) -> RedirectResponse:
redirect_uri = request.url_for("auth", provider=provider) redirect_uri = request.url_for("auth", provider=provider)
try: try:
return await getattr(authlib_oauth, provider).authorize_redirect( provider_ = getattr(authlib_oauth, provider)
request, redirect_uri except AttributeError:
) raise HTTPException(500, "")
try:
return await provider_.authorize_redirect(request, redirect_uri)
except HTTPError: except HTTPError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
@ -184,17 +186,18 @@ async def login(request: Request, provider: str) -> RedirectResponse:
@app.get("/auth/{provider}") @app.get("/auth/{provider}")
async def auth(request: Request, provider: str) -> RedirectResponse: async def auth(request: Request, provider: str) -> RedirectResponse:
try: try:
token = await getattr(authlib_oauth, provider).authorize_access_token(request) provider_ = getattr(authlib_oauth, provider)
except AttributeError:
raise HTTPException(500, "")
try:
token = await provider_.authorize_access_token(request)
except OAuthError as error: except OAuthError as error:
return HTMLResponse(f"<h1>{error.error}</h1>") raise HTTPException(status_code=401, detail=error.error)
user = token.get("userinfo") user = token.get("userinfo")
if user: if user:
request.session["user"] = dict(user) request.session["user"] = dict(user)
return RedirectResponse(url=request.session.pop("next", "/")) return RedirectResponse(url="/")
else:
return RedirectResponse(url="/login") return RedirectResponse(url="/login")