98 lines
2.8 KiB
Python
98 lines
2.8 KiB
Python
from fastapi import APIRouter, Depends, Query, status
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import func, select, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.deps import get_current_active_user, get_db
|
|
from app.models.notification import Notification
|
|
from app.models.user import User
|
|
from app.schemas.common import PageResponse
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class NotificationOut(BaseModel):
|
|
id: int
|
|
type: str
|
|
title: str
|
|
content: str | None = None
|
|
ref_type: str | None = None
|
|
ref_id: int | None = None
|
|
is_read: bool = False
|
|
created_at: str | None = None
|
|
|
|
model_config = {"from_attributes": True}
|
|
|
|
|
|
class UnreadCount(BaseModel):
|
|
count: int
|
|
|
|
|
|
@router.get("/", response_model=PageResponse[NotificationOut])
|
|
async def list_notifications(
|
|
page: int = Query(default=1, ge=1),
|
|
page_size: int = Query(default=20, ge=1, le=100),
|
|
current_user: User = Depends(get_current_active_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
base = select(func.count(Notification.id)).where(Notification.user_id == current_user.id)
|
|
total = (await db.execute(base)).scalar() or 0
|
|
|
|
offset = (page - 1) * page_size
|
|
result = await db.execute(
|
|
select(Notification)
|
|
.where(Notification.user_id == current_user.id)
|
|
.order_by(Notification.created_at.desc())
|
|
.offset(offset)
|
|
.limit(page_size)
|
|
)
|
|
items = result.scalars().all()
|
|
return PageResponse(total=total, items=items)
|
|
|
|
|
|
@router.get("/unread-count", response_model=UnreadCount)
|
|
async def get_unread_count(
|
|
current_user: User = Depends(get_current_active_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
count = (await db.execute(
|
|
select(func.count(Notification.id)).where(
|
|
Notification.user_id == current_user.id,
|
|
Notification.is_read.is_(False),
|
|
)
|
|
)).scalar() or 0
|
|
return UnreadCount(count=count)
|
|
|
|
|
|
@router.post("/read-all", status_code=status.HTTP_200_OK)
|
|
async def mark_all_read(
|
|
current_user: User = Depends(get_current_active_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
await db.execute(
|
|
update(Notification)
|
|
.where(Notification.user_id == current_user.id, Notification.is_read.is_(False))
|
|
.values(is_read=True)
|
|
)
|
|
await db.commit()
|
|
return {"code": 0, "message": "success"}
|
|
|
|
|
|
@router.post("/{notification_id}/read")
|
|
async def mark_read(
|
|
notification_id: int,
|
|
current_user: User = Depends(get_current_active_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
result = await db.execute(
|
|
select(Notification).where(
|
|
Notification.id == notification_id,
|
|
Notification.user_id == current_user.id,
|
|
)
|
|
)
|
|
n = result.scalar_one_or_none()
|
|
if n:
|
|
n.is_read = True
|
|
await db.commit()
|
|
return {"code": 0, "message": "success"}
|