from datetime import datetime, timedelta import logging from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from passlib.context import CryptContext from passlib.exc import UnknownHashError from pydantic import BaseModel from sqlmodel.ext.asyncio.session import AsyncSession from jose import JWTError, jwt, ExpiredSignatureError from sqlalchemy import select from sqlalchemy.orm import selectinload from gisaf.config import conf from gisaf.database import db_session from gisaf.models.authentication import User, UserRead logger = logging.getLogger(__name__) pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") class Token(BaseModel): access_token: str token_type: str class TokenData(BaseModel): username: str | None = None # class User(BaseModel): # username: str # email: str | None = None # full_name: str | None = None # disabled: bool | None = None oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) def get_password_hash(password: str): return pwd_context.hash(password) async def delete_user(session: AsyncSession, username: str) -> None: user_in_db: User | None = await get_user(session, username) if user_in_db is None: raise SystemExit(f'User {username} does not exist in the database') await session.delete(user_in_db) async def enable_user(session: AsyncSession, username: str, enable=True): user_in_db: UserRead | None = await get_user(session, username) if user_in_db is None: raise SystemExit(f'User {username} does not exist in the database') user_in_db.disabled = not enable # type: ignore session.add(user_in_db) await session.commit() async def create_user(session: AsyncSession, username: str, password: str, full_name: str, email: str, **kwargs): user_in_db: User | None = await get_user(session, username) if user_in_db is None: user = User( username=username, password=get_password_hash(password), full_name=full_name, email=email, disabled=False ) session.add(user) else: user_in_db.full_name = full_name # type: ignore user_in_db.email = email # type: ignore user_in_db.password = get_password_hash(password) # type: ignore await session.commit() async def get_user( session: AsyncSession, username: str) -> (User | None): query = select(User).where(User.username==username).options(selectinload(User.roles)) data = await session.exec(query) return data.scalar() def verify_password(user: User, plain_password): try: return pwd_context.verify(plain_password, user.password) except UnknownHashError: logger.warning(f'Password not encrypted in DB for {user.username}, assuming it is stored in plain text') return plain_password == user.password async def get_current_user( token: str = Depends(oauth2_scheme)) -> UserRead | None: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) expired_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired", headers={"WWW-Authenticate": "Bearer"}, ) if token is None: return None try: payload = jwt.decode(token, conf.crypto.secret, algorithms=[conf.crypto.algorithm]) username: str = payload.get("sub", '') if username == '': raise credentials_exception token_data = TokenData(username=username) except ExpiredSignatureError: raise expired_exception except JWTError: raise credentials_exception async with db_session() as session: user = await get_user(session, username=token_data.username) if user is None: raise credentials_exception return user async def authenticate_user(username: str, password: str): async with db_session() as session: user = await get_user(session, username) if not user: return False if not verify_password(user, password): return False return user # async def get_current_active_user( # current_user: Annotated[UserRead, Depends(get_current_user)]): # if current_user.disabled: # raise HTTPException(status_code=400, detail="Inactive user") # return current_user def create_access_token(data: dict, expires_delta: timedelta): to_encode = data.copy() expire = datetime.utcnow() + expires_delta to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, conf.crypto.secret, algorithm=conf.crypto.algorithm) return encoded_jwt