Files
CosScene/server/app/core/deps.py
T
2026-05-09 16:40:29 +08:00

70 lines
2.1 KiB
Python

from collections.abc import AsyncGenerator
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.security import decode_token
from app.db.session import async_session_factory
from app.models.user import User
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async with async_session_factory() as session:
yield session
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db),
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = decode_token(token)
user_id: str | None = payload.get("sub")
if user_id is None or payload.get("type") != "access":
raise credentials_exception
except ValueError:
raise credentials_exception
result = await db.execute(select(User).where(User.id == int(user_id)))
user = result.scalar_one_or_none()
if user is None:
raise credentials_exception
return user
async def get_current_active_user(
current_user: User = Depends(get_current_user),
) -> User:
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User account is disabled",
)
return current_user
async def get_optional_current_user(
token: str | None = Depends(OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login", auto_error=False)),
db: AsyncSession = Depends(get_db),
) -> User | None:
if not token:
return None
try:
payload = decode_token(token)
user_id = payload.get("sub")
if user_id is None or payload.get("type") != "access":
return None
except ValueError:
return None
result = await db.execute(select(User).where(User.id == int(user_id)))
return result.scalar_one_or_none()