gisaf-backend/src/gisaf/security.py
2024-12-15 20:35:17 +01:00

167 lines
5 KiB
Python

from datetime import datetime, timedelta
import logging
from typing import Annotated
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
from passlib.exc import UnknownHashError
from pydantic import BaseModel
from sqlmodel.ext.asyncio.session import AsyncSession
from jose import JWTError, jwt, ExpiredSignatureError
from sqlmodel import select
from sqlalchemy.orm import selectinload
from gisaf.config import conf
from gisaf.database import db_session
from gisaf.models.authentication import User, UserRead
logger = logging.getLogger(__name__)
pwd_context = CryptContext(schemes=["sha256_crypt", "bcrypt"], deprecated="auto")
class Token(BaseModel):
access_token: str
token_type: str
# class User(BaseModel):
# username: str
# email: str | None = None
# full_name: str | None = None
# disabled: bool | None = None
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/token", auto_error=False)
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
expired_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired",
headers={"WWW-Authenticate": "Bearer"},
)
def get_password_hash(password: str):
return pwd_context.hash(password)
async def delete_user(session: AsyncSession, username: str) -> None:
user_in_db: User | None = await get_user(session, username)
if user_in_db is None:
raise SystemExit(f"User {username} does not exist in the database")
await session.delete(user_in_db)
async def enable_user(session: AsyncSession, username: str, enable=True):
user_in_db = await get_user(session, username)
if user_in_db is None:
raise SystemExit(f"User {username} does not exist in the database")
user_in_db.disabled = not enable # type: ignore
session.add(user_in_db)
await session.commit()
async def create_user(
session: AsyncSession,
username: str,
password: str,
full_name: str,
email: str,
**kwargs,
) -> User:
user_in_db: User | None = await get_user(session, username)
if user_in_db is None:
user = User(
username=username,
password=get_password_hash(password),
full_name=full_name,
email=email,
disabled=False,
)
session.add(user)
await session.commit()
return user
else:
user_in_db.full_name = full_name # type: ignore
user_in_db.email = email # type: ignore
user_in_db.password = get_password_hash(password) # type: ignore
await session.commit()
return user_in_db
async def get_user(session: AsyncSession, username: str) -> User | None:
query = (
select(User).where(User.username == username).options(selectinload(User.roles))
)
data = await session.exec(query)
return data.one_or_none()
def verify_password(user: User, plain_password):
try:
return pwd_context.verify(plain_password, user.password)
except UnknownHashError:
logger.warning(
f"Password not encrypted in DB for {user.username}, assuming it is stored in plain text"
)
return plain_password == user.password
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User | None:
if token is None:
return None
try:
payload = jwt.decode(
token, conf.crypto.secret, algorithms=[conf.crypto.algorithm]
)
username: str = payload.get("sub", "")
if username == "":
raise credentials_exception
except ExpiredSignatureError:
# raise expired_exception
decoded = jwt.get_unverified_claims(token)
logger.debug(f"Session expired for user {decoded.get('sub')}")
return None
except JWTError:
raise credentials_exception
async with db_session() as session:
user = await get_user(session, username=username)
if user is None:
raise credentials_exception
return user
async def authenticate_user(username: str, password: str):
async with db_session() as session:
user = await get_user(session, username)
if not user or user.disabled:
return False
if not verify_password(user, password):
return False
return user
async def get_current_active_user(
current_user: Annotated[UserRead, Depends(get_current_user)]
):
if current_user is not None and current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(
to_encode, conf.crypto.secret, algorithm=conf.crypto.algorithm
)
return encoded_jwt