gisaf-backend/src/gisaf/security.py

155 lines
4.8 KiB
Python
Raw Normal View History

2023-11-06 17:04:17 +05:30
from datetime import datetime, timedelta
import logging
2023-11-06 17:04:17 +05:30
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
from passlib.exc import UnknownHashError
2023-11-06 17:04:17 +05:30
from pydantic import BaseModel
from sqlmodel.ext.asyncio.session import AsyncSession
from jose import JWTError, jwt, ExpiredSignatureError
2023-11-06 17:04:17 +05:30
from sqlalchemy import select
from sqlalchemy.orm import selectinload
2023-11-06 17:04:17 +05:30
from .config import conf
from .database import db_session
from .models.authentication import User, UserRead
2023-11-06 17:04:17 +05:30
logger = logging.getLogger(__name__)
2023-11-06 17:04:17 +05:30
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
2023-11-06 17:04:17 +05:30
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: str | None = None
# class User(BaseModel):
# username: str
# email: str | None = None
# full_name: str | None = None
# disabled: bool | None = None
2023-11-06 17:04:17 +05:30
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
2023-11-06 17:04:17 +05:30
def get_password_hash(password: str):
return pwd_context.hash(password)
async def delete_user(session: AsyncSession, username: str) -> None:
user_in_db: User | None = await get_user(session, username)
if user_in_db is None:
raise SystemExit(f'User {username} does not exist in the database')
await session.delete(user_in_db)
2023-11-06 17:04:17 +05:30
async def enable_user(session: AsyncSession, username: str, enable=True):
user_in_db: UserRead | None = await get_user(session, username)
if user_in_db is None:
raise SystemExit(f'User {username} does not exist in the database')
user_in_db.disabled = not enable # type: ignore
session.add(user_in_db)
await session.commit()
2023-11-06 17:04:17 +05:30
async def create_user(session: AsyncSession, username: str, password: str, full_name: str,
2023-11-06 17:04:17 +05:30
email: str, **kwargs):
user_in_db: User | None = await get_user(session, username)
if user_in_db is None:
user = User(
username=username,
password=get_password_hash(password),
full_name=full_name,
email=email,
disabled=False
)
session.add(user)
else:
user_in_db.full_name = full_name # type: ignore
user_in_db.email = email # type: ignore
user_in_db.password = get_password_hash(password) # type: ignore
await session.commit()
async def get_user(
session: AsyncSession,
username: str) -> (User | None):
query = select(User).where(User.username==username).options(selectinload(User.roles))
data = await session.exec(query)
return data.scalar()
def verify_password(user: User, plain_password):
try:
return pwd_context.verify(plain_password, user.password)
except UnknownHashError:
logger.warning(f'Password not encrypted in DB for {user.username}, assuming it is stored in plain text')
return plain_password == user.password
async def get_current_user(
token: str = Depends(oauth2_scheme)) -> UserRead | None:
2023-11-06 17:04:17 +05:30
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
expired_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired",
headers={"WWW-Authenticate": "Bearer"},
)
if token is None:
return None
2023-11-06 17:04:17 +05:30
try:
payload = jwt.decode(token, conf.crypto.secret,
algorithms=[conf.crypto.algorithm])
2023-11-06 17:04:17 +05:30
username: str = payload.get("sub", '')
if username == '':
raise credentials_exception
token_data = TokenData(username=username)
except ExpiredSignatureError:
raise expired_exception
2023-11-06 17:04:17 +05:30
except JWTError:
raise credentials_exception
async with db_session() as session:
user = await get_user(session, username=token_data.username)
if user is None:
raise credentials_exception
return user
2023-11-06 17:04:17 +05:30
async def authenticate_user(username: str, password: str):
async with db_session() as session:
user = await get_user(session, username)
if not user:
return False
if not verify_password(user, password):
return False
return user
2023-11-06 17:04:17 +05:30
# async def get_current_active_user(
# current_user: Annotated[UserRead, Depends(get_current_user)]):
# if current_user.disabled:
# raise HTTPException(status_code=400, detail="Inactive user")
# return current_user
2023-11-06 17:04:17 +05:30
def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode,
conf.crypto.secret,
algorithm=conf.crypto.algorithm)
2023-11-06 17:04:17 +05:30
return encoded_jwt