Use experimental pydantic, sqlmodel 2 and sqlalchemy 2
JWT based user auth pydantic_settings conf
This commit is contained in:
parent
3355b9d716
commit
90091e8a25
14 changed files with 840 additions and 237 deletions
1
src/_version.py
Normal file
1
src/_version.py
Normal file
|
@ -0,0 +1 @@
|
|||
__version__ = '2023.3+d20231113'
|
45
src/api.py
45
src/api.py
|
@ -1,7 +1,12 @@
|
|||
import logging
|
||||
from datetime import timedelta
|
||||
from time import time
|
||||
from uuid import uuid1
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi import Depends, FastAPI, HTTPException, status, Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
@ -17,22 +22,38 @@ from .models.category import (
|
|||
CategoryGroup, CategoryModelType,
|
||||
Category, CategoryRead
|
||||
)
|
||||
from .models.bootstrap import BootstrapData
|
||||
from .database import get_db_session, pandas_query
|
||||
from .security import (
|
||||
User, Token,
|
||||
authenticate_user, get_current_active_user, create_access_token,
|
||||
User as UserAuth,
|
||||
Token,
|
||||
authenticate_user, get_current_user, create_access_token,
|
||||
)
|
||||
from .config import conf
|
||||
|
||||
api = FastAPI()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
api = FastAPI()
|
||||
api.add_middleware(SessionMiddleware, secret_key=conf.crypto.secret)
|
||||
|
||||
db_session = Annotated[AsyncSession, Depends(get_db_session)]
|
||||
|
||||
@api.get("/nothing")
|
||||
async def get_nothing() -> str:
|
||||
return ''
|
||||
|
||||
@api.post("/token", response_model=Token)
|
||||
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
|
||||
@api.get('/bootstrap')
|
||||
async def bootstrap(
|
||||
user: Annotated[UserRead, Depends(get_current_user)]) -> BootstrapData:
|
||||
return BootstrapData(user=user)
|
||||
|
||||
|
||||
@api.post("/token")
|
||||
async def login_for_access_token(
|
||||
db_session: db_session,
|
||||
form_data: OAuth2PasswordRequestForm = Depends()
|
||||
) -> Token:
|
||||
user = await authenticate_user(form_data.username, form_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
|
@ -40,16 +61,14 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(
|
|||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
access_token_expires = timedelta(
|
||||
minutes=conf.security['access_token_expire_minutes'])
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username},
|
||||
expires_delta=access_token_expires)
|
||||
expires_delta=timedelta(seconds=conf.crypto.expire))
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
@api.get("/users")
|
||||
async def get_users(
|
||||
*, db_session: AsyncSession = Depends(get_db_session)
|
||||
db_session: db_session,
|
||||
) -> list[UserRead]:
|
||||
query = select(User).options(selectinload(User.roles))
|
||||
data = await db_session.exec(query)
|
||||
|
@ -57,7 +76,7 @@ async def get_users(
|
|||
|
||||
@api.get("/roles")
|
||||
async def get_roles(
|
||||
*, db_session: AsyncSession = Depends(get_db_session)
|
||||
db_session: db_session,
|
||||
) -> list[RoleRead]:
|
||||
query = select(Role).options(selectinload(Role.users))
|
||||
data = await db_session.exec(query)
|
||||
|
@ -65,7 +84,7 @@ async def get_roles(
|
|||
|
||||
@api.get("/categories")
|
||||
async def get_categories(
|
||||
*, db_session: AsyncSession = Depends(get_db_session)
|
||||
db_session: db_session,
|
||||
) -> list[CategoryRead]:
|
||||
query = select(Category)
|
||||
data = await db_session.exec(query)
|
||||
|
@ -74,7 +93,7 @@ async def get_categories(
|
|||
|
||||
@api.get("/categories_p")
|
||||
async def get_categories_p(
|
||||
*, db_session: AsyncSession = Depends(get_db_session)
|
||||
db_session: db_session,
|
||||
) -> list[CategoryRead]:
|
||||
query = select(Category)
|
||||
df = await db_session.run_sync(pandas_query, query)
|
||||
|
|
|
@ -1,5 +1,14 @@
|
|||
from fastapi import Depends, FastAPI
|
||||
from .api import api
|
||||
from fastapi import FastAPI
|
||||
import logging
|
||||
|
||||
app = FastAPI()
|
||||
from .api import api
|
||||
from .config import conf
|
||||
|
||||
logging.basicConfig(level=conf.gisaf.debugLevel)
|
||||
|
||||
app = FastAPI(
|
||||
debug=True,
|
||||
title=conf.gisaf.title,
|
||||
version=conf.version,
|
||||
)
|
||||
app.mount('/v2', api)
|
324
src/config.py
324
src/config.py
|
@ -1,55 +1,303 @@
|
|||
from os import environ
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
from typing import Any, Type, Tuple
|
||||
|
||||
from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
||||
from sqlalchemy.orm.session import sessionmaker
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource
|
||||
from pydantic.v1.utils import deep_update
|
||||
from yaml import safe_load
|
||||
|
||||
from ._version import __version__
|
||||
#from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
||||
#from sqlalchemy.orm.session import sessionmaker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
ENV = environ.get('env', 'prod')
|
||||
|
||||
config_files = [
|
||||
Path(Path.cwd().root) / 'etc' / 'gisaf' / ENV,
|
||||
Path.home() / '.local' / 'gisaf' / ENV
|
||||
]
|
||||
|
||||
class Config:
|
||||
app: dict
|
||||
postgres: dict
|
||||
storage: dict
|
||||
map: dict
|
||||
security: dict
|
||||
class DashboardHome(BaseSettings):
|
||||
title: str
|
||||
content_file: str
|
||||
footer_file: str
|
||||
|
||||
class GisafConfig(BaseSettings):
|
||||
title: str
|
||||
windowTitle: str
|
||||
debugLevel: str
|
||||
dashboard_home: DashboardHome
|
||||
redirect: str = ''
|
||||
|
||||
class SpatialSysRef(BaseSettings):
|
||||
author: str
|
||||
ellps: str
|
||||
k: int
|
||||
lat_0: float
|
||||
lon_0: float
|
||||
no_defs: bool
|
||||
proj: str
|
||||
towgs84: str
|
||||
units: str
|
||||
x_0: float
|
||||
y_0: float
|
||||
|
||||
class RawSurvey(BaseSettings):
|
||||
spatial_sys_ref: SpatialSysRef
|
||||
srid: int
|
||||
|
||||
class Geo(BaseSettings):
|
||||
raw_survey: RawSurvey
|
||||
simplify_geom_factor: int
|
||||
srid: int
|
||||
srid_for_proj: int
|
||||
|
||||
class Flask(BaseSettings):
|
||||
secret_key: str
|
||||
debug: int
|
||||
|
||||
class MQTT(BaseSettings):
|
||||
broker: str
|
||||
|
||||
class GisafLive(BaseSettings):
|
||||
hostname: str
|
||||
port: int
|
||||
scheme: str
|
||||
redis: str
|
||||
mqtt: MQTT
|
||||
|
||||
class DefaultSurvey(BaseSettings):
|
||||
surveyor_id: int
|
||||
equipment_id: int
|
||||
|
||||
class Survey(BaseSettings):
|
||||
schema_raw: str
|
||||
schema: str
|
||||
default: DefaultSurvey
|
||||
|
||||
class Crypto(BaseSettings):
|
||||
secret: str
|
||||
algorithm: str
|
||||
expire: float
|
||||
|
||||
class DB(BaseSettings):
|
||||
uri: str
|
||||
host: str
|
||||
user: str
|
||||
db: str
|
||||
password: str
|
||||
debug: bool
|
||||
info: bool
|
||||
pool_size: int = 10
|
||||
max_overflow: int = 10
|
||||
|
||||
class Log(BaseSettings):
|
||||
level: str
|
||||
|
||||
class OGCAPILicense(BaseSettings):
|
||||
name: str
|
||||
url: str
|
||||
|
||||
class OGCAPIProvider(BaseSettings):
|
||||
name: str
|
||||
url: str
|
||||
|
||||
class OGCAPIServerContact(BaseSettings):
|
||||
name: str
|
||||
address: str
|
||||
city: str
|
||||
stateorprovince: str
|
||||
postalcode: int
|
||||
country: str
|
||||
email: str
|
||||
|
||||
class OGCAPIIdentification(BaseSettings):
|
||||
title: str
|
||||
description: str
|
||||
keywords: list[str]
|
||||
keywords_type: str
|
||||
terms_of_service: str
|
||||
url: str
|
||||
|
||||
class OGCAPIMetadata(BaseSettings):
|
||||
identification: OGCAPIIdentification
|
||||
license: OGCAPILicense
|
||||
provider: OGCAPIProvider
|
||||
contact: OGCAPIServerContact
|
||||
|
||||
class ServerBind(BaseSettings):
|
||||
host: str
|
||||
port: int
|
||||
|
||||
class OGCAPIServerMap(BaseSettings):
|
||||
url: str
|
||||
attribution: str
|
||||
|
||||
class OGCAPIServer(BaseSettings):
|
||||
bind: ServerBind
|
||||
url: str
|
||||
mimetype: str
|
||||
encoding: str
|
||||
language: str
|
||||
pretty_print: bool
|
||||
limit: int
|
||||
map: OGCAPIServerMap
|
||||
|
||||
class OGCAPI(BaseSettings):
|
||||
base_url: str
|
||||
bbox: list[float]
|
||||
log: Log
|
||||
metadata: OGCAPIMetadata
|
||||
server: OGCAPIServer
|
||||
|
||||
class Map(BaseSettings):
|
||||
tilesBaseDir: str
|
||||
tilesUseRequestUrl: bool
|
||||
tilesSpriteBaseDir: str
|
||||
tilesSpriteUrl: str
|
||||
tilesSpriteBaseUrl: str
|
||||
openMapTilesKey: str
|
||||
zoom: int
|
||||
pitch: int
|
||||
lat: float
|
||||
lng: float
|
||||
bearing: float
|
||||
style: str
|
||||
opacity: float
|
||||
attribution: str
|
||||
status: list[str]
|
||||
defaultStatus: list[str] # FIXME: should be str
|
||||
tagKeys: list[str]
|
||||
|
||||
class Measures(BaseSettings):
|
||||
defaultStore: str
|
||||
|
||||
class BasketDefault(BaseSettings):
|
||||
surveyor: str
|
||||
equipment: str
|
||||
project: str | None
|
||||
status: str
|
||||
store: str | None
|
||||
|
||||
class BasketOldDef(BaseSettings):
|
||||
base_dir: str
|
||||
|
||||
class Basket(BaseSettings):
|
||||
base_dir: str
|
||||
default: BasketDefault
|
||||
|
||||
class Plot(BaseSettings):
|
||||
maxDataSize: int
|
||||
|
||||
class Dashboard(BaseSettings):
|
||||
base_source_url: str
|
||||
base_storage_dir: str
|
||||
base_storage_url: str
|
||||
|
||||
class Widgets(BaseSettings):
|
||||
footer: str
|
||||
|
||||
class Admin(BaseSettings):
|
||||
basket: Basket
|
||||
|
||||
class Attachments(BaseSettings):
|
||||
base_dir: str
|
||||
|
||||
class Job(BaseSettings):
|
||||
id: str
|
||||
func: str
|
||||
trigger: str
|
||||
minutes: int | None = 0
|
||||
seconds: int | None = 0
|
||||
|
||||
class Config(BaseSettings):
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: Type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> Tuple[PydanticBaseSettingsSource, ...]:
|
||||
return env_settings, init_settings, file_secret_settings, config_file_settings
|
||||
|
||||
admin: Admin
|
||||
attachments: Attachments
|
||||
basket: BasketOldDef
|
||||
crypto: Crypto
|
||||
dashboard: Dashboard
|
||||
db: DB
|
||||
flask: Flask
|
||||
geo: Geo
|
||||
gisaf: GisafConfig
|
||||
gisaf_live: GisafLive
|
||||
jobs: list[Job]
|
||||
map: Map
|
||||
measures: Measures
|
||||
ogcapi: OGCAPI
|
||||
plot: Plot
|
||||
plugins: dict[str, dict[str, Any]]
|
||||
survey: Survey
|
||||
version: str
|
||||
engine: AsyncEngine
|
||||
session_maker: sessionmaker
|
||||
|
||||
def __init__(self) -> None:
|
||||
from ._version import __version__
|
||||
self.version = __version__
|
||||
weather_station: dict[str, dict[str, Any]]
|
||||
widgets: Widgets
|
||||
#engine: AsyncEngine
|
||||
#session_maker: sessionmaker
|
||||
|
||||
|
||||
conf = Config()
|
||||
def config_file_settings() -> dict[str, Any]:
|
||||
config: dict[str, Any] = {}
|
||||
for p in config_files:
|
||||
for suffix in {".yaml", ".yml"}:
|
||||
path = p.with_suffix(suffix)
|
||||
if not path.is_file():
|
||||
logger.info(f"No file found at `{path.resolve()}`")
|
||||
continue
|
||||
logger.debug(f"Reading config file `{path.resolve()}`")
|
||||
if path.suffix in {".yaml", ".yml"}:
|
||||
config = deep_update(config, load_yaml(path))
|
||||
else:
|
||||
logger.info(f"Unknown config file extension `{path.suffix}`")
|
||||
return config
|
||||
|
||||
|
||||
def set_app_config(app) -> None:
|
||||
raw_configs = []
|
||||
with open(Path(__file__).parent / 'defaults.yml') as cf:
|
||||
raw_configs.append(cf.read())
|
||||
for cf_path in (
|
||||
Path(Path.cwd().root) / 'etc' / 'gisaf' / ENV,
|
||||
Path.home() / '.local' / 'gisaf' / ENV
|
||||
):
|
||||
try:
|
||||
with open(cf_path.with_suffix('.yml')) as cf:
|
||||
raw_configs.append(cf.read())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
def load_yaml(path: Path) -> dict[str, Any]:
|
||||
with Path(path).open("r") as f:
|
||||
config = safe_load(f)
|
||||
if not isinstance(config, dict):
|
||||
raise TypeError(
|
||||
f"Config file has no top-level mapping: {path}"
|
||||
)
|
||||
return config
|
||||
|
||||
yaml_config = yaml.safe_load('\n'.join(raw_configs))
|
||||
|
||||
conf.app = yaml_config['app']
|
||||
conf.postgres = yaml_config['postgres']
|
||||
conf.storage = yaml_config['storage']
|
||||
conf.map = yaml_config['map']
|
||||
conf.security = yaml_config['security']
|
||||
# create_dirs()
|
||||
conf = Config(version=__version__)
|
||||
|
||||
# def set_app_config(app) -> None:
|
||||
# raw_configs = []
|
||||
# with open(Path(__file__).parent / 'defaults.yml') as cf:
|
||||
# raw_configs.append(cf.read())
|
||||
# for cf_path in (
|
||||
# Path(Path.cwd().root) / 'etc' / 'gisaf' / ENV,
|
||||
# Path.home() / '.local' / 'gisaf' / ENV
|
||||
# ):
|
||||
# try:
|
||||
# with open(cf_path.with_suffix('.yml')) as cf:
|
||||
# raw_configs.append(cf.read())
|
||||
# except FileNotFoundError:
|
||||
# pass
|
||||
|
||||
# yaml_config = safe_load('\n'.join(raw_configs))
|
||||
|
||||
# conf.app = yaml_config['app']
|
||||
# conf.postgres = yaml_config['postgres']
|
||||
# conf.storage = yaml_config['storage']
|
||||
# conf.map = yaml_config['map']
|
||||
# conf.security = yaml_config['security']
|
||||
# # create_dirs()
|
||||
|
||||
|
||||
# def create_dirs():
|
||||
|
@ -65,5 +313,5 @@ def set_app_config(app) -> None:
|
|||
# get_cache_dir().mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_cache_dir() -> Path:
|
||||
return Path(conf.storage['root_cache_path'])
|
||||
# def get_cache_dir() -> Path:
|
||||
# return Path(conf.storage['root_cache_path'])
|
|
@ -1,14 +1,28 @@
|
|||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from .config import conf
|
||||
|
||||
echo = False
|
||||
pg_url = "postgresql+asyncpg://avgis@localhost/avgis"
|
||||
|
||||
engine = create_async_engine(pg_url, echo=echo)
|
||||
engine = create_async_engine(
|
||||
pg_url,
|
||||
echo=echo,
|
||||
pool_size=conf.db.pool_size,
|
||||
max_overflow=conf.db.max_overflow,
|
||||
)
|
||||
|
||||
async def get_db_session():
|
||||
async def get_db_session() -> AsyncSession:
|
||||
async with AsyncSession(engine) as session:
|
||||
yield session
|
||||
|
||||
@asynccontextmanager
|
||||
async def db_session() -> AsyncSession:
|
||||
async with AsyncSession(engine) as session:
|
||||
yield session
|
||||
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from sqlmodel import Field, SQLModel, MetaData, Relationship
|
||||
|
||||
schema = 'gisaf_admin'
|
||||
metadata = MetaData(schema=schema)
|
||||
from .metadata import gisaf_admin
|
||||
|
||||
class UserRoleLink(SQLModel, table=True):
|
||||
metadata = metadata
|
||||
metadata = gisaf_admin
|
||||
__tablename__: str = 'roles_users'
|
||||
user_id: int | None = Field(
|
||||
default=None, foreign_key="user.id", primary_key=True
|
||||
|
@ -20,7 +19,7 @@ class UserBase(SQLModel):
|
|||
|
||||
|
||||
class User(UserBase, table=True):
|
||||
metadata = metadata
|
||||
metadata = gisaf_admin
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
roles: list["Role"] = Relationship(back_populates="users",
|
||||
link_model=UserRoleLink)
|
||||
|
@ -34,7 +33,7 @@ class RoleWithDescription(RoleBase):
|
|||
description: str | None
|
||||
|
||||
class Role(RoleWithDescription, table=True):
|
||||
metadata = metadata
|
||||
metadata = gisaf_admin
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
users: list[User] = Relationship(back_populates="roles",
|
||||
link_model=UserRoleLink)
|
||||
|
|
17
src/models/bootstrap.py
Normal file
17
src/models/bootstrap.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
from sqlmodel import Field, SQLModel, MetaData, JSON, TEXT, Relationship, Column
|
||||
from ..config import conf, Map, Measures, Geo
|
||||
from .authentication import UserRead
|
||||
|
||||
class Proj(SQLModel):
|
||||
srid: str
|
||||
srid_for_proj: str
|
||||
|
||||
class BootstrapData(SQLModel):
|
||||
version: str = conf.version
|
||||
title: str = conf.gisaf.title
|
||||
windowTitle: str = conf.gisaf.windowTitle
|
||||
map: Map = conf.map
|
||||
geo: Geo = conf.geo
|
||||
measures: Measures = conf.measures
|
||||
redirect: str = conf.gisaf.redirect
|
||||
user: UserRead | None = None
|
|
@ -1,8 +1,8 @@
|
|||
from typing import Any
|
||||
from sqlmodel import Field, SQLModel, MetaData, JSON, TEXT, Relationship, Column
|
||||
from pydantic import computed_field
|
||||
|
||||
schema = 'gisaf_survey'
|
||||
metadata = MetaData(schema=schema)
|
||||
from .metadata import gisaf_survey
|
||||
|
||||
mapbox_type_mapping = {
|
||||
'Point': 'symbol',
|
||||
|
@ -11,7 +11,7 @@ mapbox_type_mapping = {
|
|||
}
|
||||
|
||||
class CategoryGroup(SQLModel, table=True):
|
||||
metadata = metadata
|
||||
metadata = gisaf_survey
|
||||
name: str = Field(min_length=4, max_length=4,
|
||||
default=None, primary_key=True)
|
||||
major: str
|
||||
|
@ -23,7 +23,7 @@ class CategoryGroup(SQLModel, table=True):
|
|||
|
||||
|
||||
class CategoryModelType(SQLModel, table=True):
|
||||
metadata = metadata
|
||||
metadata = gisaf_survey
|
||||
name: str = Field(default=None, primary_key=True)
|
||||
|
||||
class Admin:
|
||||
|
@ -32,8 +32,6 @@ class CategoryModelType(SQLModel, table=True):
|
|||
|
||||
|
||||
class CategoryBase(SQLModel):
|
||||
metadata = metadata
|
||||
|
||||
class Admin:
|
||||
menu = 'Other'
|
||||
flask_admin_model_view = 'CategoryModelView'
|
||||
|
@ -49,35 +47,39 @@ class CategoryBase(SQLModel):
|
|||
custom: bool | None
|
||||
auto_import: bool = True
|
||||
model_type: str = Field(max_length=50,
|
||||
foreign_key='CategoryModelType.name', default='Point')
|
||||
foreign_key='CategoryModelType.name',
|
||||
default='Point')
|
||||
long_name: str | None = Field(max_length=50)
|
||||
style: str | None = Field(sa_column=Column(TEXT))
|
||||
symbol: str | None = Field(max_length=1)
|
||||
mapbox_type_custom: str | None = Field(max_length=32)
|
||||
mapbox_paint: dict[str, Any] | None = Field(sa_column=Column(JSON, none_as_null=True))
|
||||
mapbox_layout: dict[str, Any] | None = Field(sa_column=Column(JSON, none_as_null=True))
|
||||
mapbox_paint: dict[str, Any] | None = Field(sa_column=Column(JSON(none_as_null=True)))
|
||||
mapbox_layout: dict[str, Any] | None = Field(sa_column=Column(JSON(none_as_null=True)))
|
||||
viewable_role: str | None
|
||||
extra: dict[str, Any] | None = Field(sa_column=Column(JSON, none_as_null=True))
|
||||
extra: dict[str, Any] | None = Field(sa_column=Column(JSON(none_as_null=True)))
|
||||
|
||||
|
||||
class Category(CategoryBase, table=True):
|
||||
metadata = gisaf_survey
|
||||
name: str = Field(default=None, primary_key=True)
|
||||
|
||||
|
||||
class CategoryRead(CategoryBase):
|
||||
name: str
|
||||
domain = 'V' # Survey
|
||||
domain: str = 'V' # Survey
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def layer_name(self):
|
||||
def layer_name(self) -> str:
|
||||
"""
|
||||
ISO compliant layer name (see ISO 13567)
|
||||
:return: str
|
||||
"""
|
||||
return '{self.domain}-{self.group:4s}-{self.minor_group_1:4s}-{self.minor_group_2:4s}-{self.status:1s}'.format(self=self)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def table_name(self):
|
||||
def table_name(self) -> str:
|
||||
"""
|
||||
Table name
|
||||
:return:
|
||||
|
@ -87,8 +89,9 @@ class CategoryRead(CategoryBase):
|
|||
else:
|
||||
return '{self.domain}_{self.group:4s}_{self.minor_group_1:4s}_{self.minor_group_2:4s}'.format(self=self)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def raw_survey_table_name(self):
|
||||
def raw_survey_table_name(self) -> str:
|
||||
"""
|
||||
Table name
|
||||
:return:
|
||||
|
@ -98,6 +101,7 @@ class CategoryRead(CategoryBase):
|
|||
else:
|
||||
return 'RAW_{self.domain}_{self.group:4s}_{self.minor_group_1:4s}_{self.minor_group_2:4s}'.format(self=self)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def mapbox_type(self):
|
||||
def mapbox_type(self) -> str:
|
||||
return self.mapbox_type_custom or mapbox_type_mapping[self.model_type]
|
||||
|
|
5
src/models/metadata.py
Normal file
5
src/models/metadata.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
from sqlmodel import MetaData
|
||||
|
||||
gisaf = MetaData(schema='gisaf')
|
||||
gisaf_survey = MetaData(schema='gisaf_survey')
|
||||
gisaf_admin= MetaData(schema='gisaf_admin')
|
9
src/models/tags.py
Normal file
9
src/models/tags.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
from typing import Any
|
||||
from sqlmodel import Field, SQLModel, MetaData, JSON, TEXT, Relationship, Column
|
||||
from pydantic import computed_field
|
||||
|
||||
from .metadata import gisaf
|
||||
from .models_base import GeoPointModel
|
||||
|
||||
class Tags(GeoPointModel, table=True):
|
||||
metadata = gisaf
|
164
src/security.py
164
src/security.py
|
@ -1,21 +1,26 @@
|
|||
from datetime import datetime, timedelta
|
||||
from passlib.context import CryptContext
|
||||
import logging
|
||||
from typing import Annotated
|
||||
#from passlib.context import CryptContext
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from passlib.context import CryptContext
|
||||
from passlib.exc import UnknownHashError
|
||||
from pydantic import BaseModel
|
||||
from jose import JWTError, jwt
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from jose import JWTError, jwt, ExpiredSignatureError
|
||||
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from .config import conf
|
||||
from .models.authentication import User as UserInDB
|
||||
from .database import db_session
|
||||
from .models.authentication import User, UserRead
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# openssl rand -hex 32
|
||||
# import secrets
|
||||
# SECRET_KEY = secrets.token_hex(32)
|
||||
ALGORITHM: str = "HS256"
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
|
@ -27,105 +32,118 @@ class TokenData(BaseModel):
|
|||
username: str | None = None
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
username: str
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
disabled: bool | None = None
|
||||
# class User(BaseModel):
|
||||
# username: str
|
||||
# email: str | None = None
|
||||
# full_name: str | None = None
|
||||
# disabled: bool | None = None
|
||||
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
|
||||
|
||||
|
||||
def get_password_hash(password: str):
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
async def delete_user(username):
|
||||
async with conf.session_maker.begin() as session:
|
||||
user_in_db: UserInDB | None = await get_user(username)
|
||||
if user_in_db is None:
|
||||
raise SystemExit(f'User {username} does not exist in the database')
|
||||
await session.delete(user_in_db)
|
||||
async def delete_user(session: AsyncSession, username: str) -> None:
|
||||
user_in_db: User | None = await get_user(session, username)
|
||||
if user_in_db is None:
|
||||
raise SystemExit(f'User {username} does not exist in the database')
|
||||
await session.delete(user_in_db)
|
||||
|
||||
|
||||
async def enable_user(username, enable=True):
|
||||
async with conf.session_maker.begin() as session:
|
||||
user_in_db: UserInDB | None = await get_user(username)
|
||||
if user_in_db is None:
|
||||
raise SystemExit(f'User {username} does not exist in the database')
|
||||
user_in_db.disabled = not enable # type: ignore
|
||||
session.add(user_in_db)
|
||||
await session.commit()
|
||||
async def enable_user(session: AsyncSession, username: str, enable=True):
|
||||
user_in_db: UserRead | None = await get_user(session, username)
|
||||
if user_in_db is None:
|
||||
raise SystemExit(f'User {username} does not exist in the database')
|
||||
user_in_db.disabled = not enable # type: ignore
|
||||
session.add(user_in_db)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def create_user(username: str, password: str, full_name: str,
|
||||
async def create_user(session: AsyncSession, username: str, password: str, full_name: str,
|
||||
email: str, **kwargs):
|
||||
async with conf.session_maker.begin() as session:
|
||||
user_in_db: UserInDB | None = await get_user(username)
|
||||
if user_in_db is None:
|
||||
user = UserInDB(
|
||||
username=username,
|
||||
password=get_password_hash(password),
|
||||
full_name=full_name,
|
||||
email=email,
|
||||
disabled=False
|
||||
)
|
||||
session.add(user)
|
||||
else:
|
||||
user_in_db.full_name = full_name # type: ignore
|
||||
user_in_db.email = email # type: ignore
|
||||
user_in_db.password = get_password_hash(password) # type: ignore
|
||||
await session.commit()
|
||||
user_in_db: User | None = await get_user(session, username)
|
||||
if user_in_db is None:
|
||||
user = User(
|
||||
username=username,
|
||||
password=get_password_hash(password),
|
||||
full_name=full_name,
|
||||
email=email,
|
||||
disabled=False
|
||||
)
|
||||
session.add(user)
|
||||
else:
|
||||
user_in_db.full_name = full_name # type: ignore
|
||||
user_in_db.email = email # type: ignore
|
||||
user_in_db.password = get_password_hash(password) # type: ignore
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def get_user(username: str) -> (UserInDB | None):
|
||||
async with conf.session_maker.begin() as session:
|
||||
req = await session.execute(select(UserInDB).where(UserInDB.username==username))
|
||||
return req.scalar()
|
||||
async def get_user(
|
||||
session: AsyncSession,
|
||||
username: str) -> (User | None):
|
||||
query = select(User).where(User.username==username).options(selectinload(User.roles))
|
||||
data = await session.exec(query)
|
||||
return data.scalar()
|
||||
|
||||
|
||||
def verify_password(plain_password, hashed_password):
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
def verify_password(user: User, plain_password):
|
||||
try:
|
||||
return pwd_context.verify(plain_password, user.password)
|
||||
except UnknownHashError:
|
||||
logger.warning(f'Password not encrypted in DB for {user.username}, assuming it is stored in plain text')
|
||||
return plain_password == user.password
|
||||
|
||||
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme)) -> UserRead | None:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
expired_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
if token is None:
|
||||
return None
|
||||
try:
|
||||
payload = jwt.decode(token, conf.security['secret_key'], algorithms=[ALGORITHM])
|
||||
payload = jwt.decode(token, conf.crypto.secret,
|
||||
algorithms=[conf.crypto.algorithm])
|
||||
username: str = payload.get("sub", '')
|
||||
if username == '':
|
||||
raise credentials_exception
|
||||
token_data = TokenData(username=username)
|
||||
except ExpiredSignatureError:
|
||||
raise expired_exception
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
user = await get_user(username=token_data.username) # type: ignore
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return User(username=user.username, # type: ignore
|
||||
email=user.email, # type: ignore
|
||||
full_name=user.full_name) # type: ignore
|
||||
async with db_session() as session:
|
||||
user = await get_user(session, username=token_data.username)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
|
||||
async def authenticate_user(username: str, password: str):
|
||||
user = await get_user(username)
|
||||
if not user:
|
||||
return False
|
||||
if not verify_password(password, user.password):
|
||||
return False
|
||||
return user
|
||||
async with db_session() as session:
|
||||
user = await get_user(session, username)
|
||||
if not user:
|
||||
return False
|
||||
if not verify_password(user, password):
|
||||
return False
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_active_user(current_user: User = Depends(get_current_user)):
|
||||
if current_user.disabled:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
# async def get_current_active_user(
|
||||
# current_user: Annotated[UserRead, Depends(get_current_user)]):
|
||||
# if current_user.disabled:
|
||||
# raise HTTPException(status_code=400, detail="Inactive user")
|
||||
# return current_user
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: timedelta):
|
||||
|
@ -133,6 +151,6 @@ def create_access_token(data: dict, expires_delta: timedelta):
|
|||
expire = datetime.utcnow() + expires_delta
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode,
|
||||
conf.security['secret_key'],
|
||||
algorithm=ALGORITHM)
|
||||
conf.crypto.secret,
|
||||
algorithm=conf.crypto.algorithm)
|
||||
return encoded_jwt
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue