gisaf-backend/src/gisaf/security.py
phil 741050db89 Remove relative imports
Fix primary keys (optional)
Add baskets, importers, plugins, reactor
Add fake replacement fro graphql defs (to_migrate)
Add typing marker (py.typed)
2023-12-25 15:50:45 +05:30

154 lines
4.9 KiB
Python

from datetime import datetime, timedelta
import logging
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
from passlib.exc import UnknownHashError
from pydantic import BaseModel
from sqlmodel.ext.asyncio.session import AsyncSession
from jose import JWTError, jwt, ExpiredSignatureError
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from gisaf.config import conf
from gisaf.database import db_session
from gisaf.models.authentication import User, UserRead
logger = logging.getLogger(__name__)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: str | None = None
# class User(BaseModel):
# username: str
# email: str | None = None
# full_name: str | None = None
# disabled: bool | None = None
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
def get_password_hash(password: str):
return pwd_context.hash(password)
async def delete_user(session: AsyncSession, username: str) -> None:
user_in_db: User | None = await get_user(session, username)
if user_in_db is None:
raise SystemExit(f'User {username} does not exist in the database')
await session.delete(user_in_db)
async def enable_user(session: AsyncSession, username: str, enable=True):
user_in_db: UserRead | None = await get_user(session, username)
if user_in_db is None:
raise SystemExit(f'User {username} does not exist in the database')
user_in_db.disabled = not enable # type: ignore
session.add(user_in_db)
await session.commit()
async def create_user(session: AsyncSession, username: str, password: str, full_name: str,
email: str, **kwargs):
user_in_db: User | None = await get_user(session, username)
if user_in_db is None:
user = User(
username=username,
password=get_password_hash(password),
full_name=full_name,
email=email,
disabled=False
)
session.add(user)
else:
user_in_db.full_name = full_name # type: ignore
user_in_db.email = email # type: ignore
user_in_db.password = get_password_hash(password) # type: ignore
await session.commit()
async def get_user(
session: AsyncSession,
username: str) -> (User | None):
query = select(User).where(User.username==username).options(selectinload(User.roles))
data = await session.exec(query)
return data.scalar()
def verify_password(user: User, plain_password):
try:
return pwd_context.verify(plain_password, user.password)
except UnknownHashError:
logger.warning(f'Password not encrypted in DB for {user.username}, assuming it is stored in plain text')
return plain_password == user.password
async def get_current_user(
token: str = Depends(oauth2_scheme)) -> UserRead | None:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
expired_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired",
headers={"WWW-Authenticate": "Bearer"},
)
if token is None:
return None
try:
payload = jwt.decode(token, conf.crypto.secret,
algorithms=[conf.crypto.algorithm])
username: str = payload.get("sub", '')
if username == '':
raise credentials_exception
token_data = TokenData(username=username)
except ExpiredSignatureError:
raise expired_exception
except JWTError:
raise credentials_exception
async with db_session() as session:
user = await get_user(session, username=token_data.username)
if user is None:
raise credentials_exception
return user
async def authenticate_user(username: str, password: str):
async with db_session() as session:
user = await get_user(session, username)
if not user:
return False
if not verify_password(user, password):
return False
return user
# async def get_current_active_user(
# current_user: Annotated[UserRead, Depends(get_current_user)]):
# if current_user.disabled:
# raise HTTPException(status_code=400, detail="Inactive user")
# return current_user
def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode,
conf.crypto.secret,
algorithm=conf.crypto.algorithm)
return encoded_jwt