diff --git a/src/gisaf/geoapi.py b/src/gisaf/geoapi.py index 73ca32f..cade3b2 100644 --- a/src/gisaf/geoapi.py +++ b/src/gisaf/geoapi.py @@ -7,12 +7,14 @@ import logging from typing import Annotated from asyncio import CancelledError -from fastapi import (FastAPI, HTTPException, Response, Header, WebSocket, WebSocketDisconnect, +from fastapi import (Depends, FastAPI, HTTPException, Response, Header, WebSocket, WebSocketDisconnect, status, responses) +from gisaf.models.authentication import User from gisaf.redis_tools import store as redis_store from gisaf.live import live_server from gisaf.registry import registry +from gisaf.security import get_current_active_user, can_view logger = logging.getLogger(__name__) @@ -72,6 +74,7 @@ async def live_layer(store: str, websocket: WebSocket): @api.get('/{store_name}') async def get_geojson(store_name, + user: User = Depends(get_current_active_user), If_None_Match: Annotated[str | None, Header()] = None, simplify: Annotated[float | None, Header()] = 50.0, ): @@ -86,8 +89,10 @@ async def get_geojson(store_name, except KeyError: raise HTTPException(status.HTTP_404_NOT_FOUND) - if hasattr(model, 'viewable_role') and model.viewable_role: - await check_permission(request, model.viewable_role) + if hasattr(model, 'viewable_role'): + if not(user and user.can_view(model)): + logger.info(f'{user.username if user else "Anonymous"} tried to access {model}') + raise HTTPException(status.HTTP_401_UNAUTHORIZED) if await redis_store.has_channel(store_name): ## Live layers diff --git a/src/gisaf/models/authentication.py b/src/gisaf/models/authentication.py index 62d7875..7138dea 100644 --- a/src/gisaf/models/authentication.py +++ b/src/gisaf/models/authentication.py @@ -16,6 +16,7 @@ class UserRoleLink(SQLModel, table=True): class UserBase(SQLModel): username: str email: str + disabled: bool | None = False class User(UserBase, table=True): @@ -25,6 +26,12 @@ class User(UserBase, table=True): link_model=UserRoleLink) password: str | None = None + def can_view(self, model) -> bool: + if hasattr(model, 'viewable_role'): + return model.viewable_role in (role.name for role in self.roles) + else: + return True + class RoleBase(SQLModel): name: str = Field(unique=True) diff --git a/src/gisaf/security.py b/src/gisaf/security.py index eeacfc8..79011c7 100644 --- a/src/gisaf/security.py +++ b/src/gisaf/security.py @@ -1,5 +1,6 @@ from datetime import datetime, timedelta import logging +from typing import Annotated from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer @@ -130,18 +131,18 @@ async def get_current_user( async def authenticate_user(username: str, password: str): async with db_session() as session: user = await get_user(session, username) - if not user: + if not user or user.disabled: return False if not verify_password(user, password): return False return user -# async def get_current_active_user( -# current_user: Annotated[UserRead, Depends(get_current_user)]): -# if current_user.disabled: -# raise HTTPException(status_code=400, detail="Inactive user") -# return current_user +async def get_current_active_user( + current_user: Annotated[UserRead, Depends(get_current_user)]): + if current_user is not None and current_user.disabled: + raise HTTPException(status_code=400, detail="Inactive user") + return current_user def create_access_token(data: dict, expires_delta: timedelta): @@ -151,4 +152,4 @@ def create_access_token(data: dict, expires_delta: timedelta): encoded_jwt = jwt.encode(to_encode, conf.crypto.secret, algorithm=conf.crypto.algorithm) - return encoded_jwt + return encoded_jwt \ No newline at end of file