Initial project commit
This commit is contained in:
@@ -0,0 +1,31 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
DATABASE_URL: str
|
||||
DATABASE_URL_SYNC: str
|
||||
REDIS_URL: str = "redis://localhost:6379/0"
|
||||
SECRET_KEY: str
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 43200
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 60
|
||||
|
||||
STORAGE_BACKEND: str = "local"
|
||||
LOCAL_STORAGE_PATH: str = "./uploads"
|
||||
S3_ENDPOINT: str = ""
|
||||
S3_ACCESS_KEY: str = ""
|
||||
S3_SECRET_KEY: str = ""
|
||||
S3_BUCKET: str = "ciyuan-viewfinder"
|
||||
S3_REGION: str = ""
|
||||
S3_PUBLIC_URL: str = ""
|
||||
|
||||
TENCENT_MAP_KEY: str = ""
|
||||
|
||||
SENTRY_DSN: str = ""
|
||||
LOG_JSON: bool = False
|
||||
LOG_LEVEL: str = "INFO"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -0,0 +1,69 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.security import decode_token
|
||||
from app.db.session import async_session_factory
|
||||
from app.models.user import User
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with async_session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
user_id: str | None = payload.get("sub")
|
||||
if user_id is None or payload.get("type") != "access":
|
||||
raise credentials_exception
|
||||
except ValueError:
|
||||
raise credentials_exception
|
||||
|
||||
result = await db.execute(select(User).where(User.id == int(user_id)))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User account is disabled",
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_optional_current_user(
|
||||
token: str | None = Depends(OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login", auto_error=False)),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User | None:
|
||||
if not token:
|
||||
return None
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None or payload.get("type") != "access":
|
||||
return None
|
||||
except ValueError:
|
||||
return None
|
||||
result = await db.execute(select(User).where(User.id == int(user_id)))
|
||||
return result.scalar_one_or_none()
|
||||
@@ -0,0 +1,40 @@
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
class JSONFormatter(logging.Formatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
log_entry = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"message": record.getMessage(),
|
||||
"module": record.module,
|
||||
"function": record.funcName,
|
||||
"line": record.lineno,
|
||||
}
|
||||
if record.exc_info and record.exc_info[1]:
|
||||
log_entry["exception"] = self.formatException(record.exc_info)
|
||||
return json.dumps(log_entry, ensure_ascii=False)
|
||||
|
||||
|
||||
def setup_logging(json_format: bool = False, level: int = logging.INFO):
|
||||
root = logging.getLogger()
|
||||
root.setLevel(level)
|
||||
|
||||
for handler in root.handlers[:]:
|
||||
root.removeHandler(handler)
|
||||
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
if json_format:
|
||||
handler.setFormatter(JSONFormatter())
|
||||
else:
|
||||
handler.setFormatter(logging.Formatter(
|
||||
"%(asctime)s %(levelname)s [%(name)s] %(message)s"
|
||||
))
|
||||
root.addHandler(handler)
|
||||
|
||||
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||
@@ -0,0 +1,49 @@
|
||||
import logging
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Sliding window rate limiter backed by Redis.
|
||||
|
||||
Usage as FastAPI dependency:
|
||||
@router.post("/action", dependencies=[Depends(RateLimiter(times=5, seconds=60))])
|
||||
"""
|
||||
|
||||
def __init__(self, times: int = 10, seconds: int = 60):
|
||||
self.times = times
|
||||
self.seconds = seconds
|
||||
|
||||
async def __call__(self, request: Request) -> None:
|
||||
redis = getattr(request.app.state, "redis", None)
|
||||
if redis is None:
|
||||
return
|
||||
|
||||
identifier = self._get_identifier(request)
|
||||
key = f"rl:{request.url.path}:{identifier}"
|
||||
|
||||
try:
|
||||
pipe = redis.pipeline()
|
||||
pipe.incr(key)
|
||||
pipe.expire(key, self.seconds)
|
||||
results = await pipe.execute()
|
||||
current = results[0]
|
||||
except RedisError:
|
||||
logger.warning("Rate limiter skipped because Redis is unavailable", exc_info=True)
|
||||
return
|
||||
|
||||
if current > self.times:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Too many requests, please try again later",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_identifier(request: Request) -> str:
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
return request.client.host if request.client else "unknown"
|
||||
@@ -0,0 +1,46 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: str | int, expires_delta: timedelta | None = None
|
||||
) -> str:
|
||||
now = datetime.now(timezone.utc)
|
||||
expire = now + (
|
||||
expires_delta
|
||||
if expires_delta
|
||||
else timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
)
|
||||
to_encode = {"sub": str(subject), "exp": expire, "type": "access"}
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(subject: str | int) -> str:
|
||||
now = datetime.now(timezone.utc)
|
||||
expire = now + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
to_encode = {"sub": str(subject), "exp": expire, "type": "refresh"}
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict:
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
|
||||
except JWTError as exc:
|
||||
raise ValueError("Invalid token") from exc
|
||||
return payload
|
||||
@@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config as BotoConfig
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class StorageBackend(ABC):
|
||||
@abstractmethod
|
||||
def upload(self, file_data: bytes, filename: str, content_type: str = "") -> str:
|
||||
"""Upload file and return its public URL."""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
"""Delete a file by key/path."""
|
||||
|
||||
@abstractmethod
|
||||
def generate_presigned_url(self, key: str, content_type: str = "", expires: int = 3600) -> str:
|
||||
"""Generate a presigned upload URL (S3-like backends only)."""
|
||||
|
||||
|
||||
class LocalStorageBackend(StorageBackend):
|
||||
def __init__(self, storage_path: str, serve_url_prefix: str = "/uploads"):
|
||||
self.storage_path = Path(storage_path)
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
self.serve_url_prefix = serve_url_prefix.rstrip("/")
|
||||
|
||||
def _make_key(self, filename: str) -> str:
|
||||
ext = Path(filename).suffix
|
||||
return f"{uuid.uuid4().hex}{ext}"
|
||||
|
||||
def upload(self, file_data: bytes, filename: str, content_type: str = "") -> str:
|
||||
key = self._make_key(filename)
|
||||
dest = self.storage_path / key
|
||||
dest.write_bytes(file_data)
|
||||
return f"{self.serve_url_prefix}/{key}"
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
path = key.lstrip("/")
|
||||
if path.startswith("uploads/"):
|
||||
path = path[len("uploads/"):]
|
||||
target = self.storage_path / path
|
||||
if target.exists():
|
||||
target.unlink()
|
||||
|
||||
def generate_presigned_url(self, key: str, content_type: str = "", expires: int = 3600) -> str:
|
||||
raise NotImplementedError("Local storage does not support presigned URLs")
|
||||
|
||||
|
||||
class S3StorageBackend(StorageBackend):
|
||||
"""Compatible with MinIO, Aliyun OSS, Tencent COS, and AWS S3."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
access_key: str,
|
||||
secret_key: str,
|
||||
bucket: str,
|
||||
region: str = "",
|
||||
public_url: str = "",
|
||||
):
|
||||
self.bucket = bucket
|
||||
self.public_url = public_url.rstrip("/") if public_url else ""
|
||||
kwargs: dict = {
|
||||
"endpoint_url": endpoint,
|
||||
"aws_access_key_id": access_key,
|
||||
"aws_secret_access_key": secret_key,
|
||||
"config": BotoConfig(signature_version="s3v4"),
|
||||
}
|
||||
if region:
|
||||
kwargs["region_name"] = region
|
||||
self.client = boto3.client("s3", **kwargs)
|
||||
|
||||
try:
|
||||
self.client.head_bucket(Bucket=bucket)
|
||||
except Exception:
|
||||
try:
|
||||
self.client.create_bucket(Bucket=bucket)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _make_key(self, filename: str) -> str:
|
||||
ext = Path(filename).suffix
|
||||
return f"images/{uuid.uuid4().hex}{ext}"
|
||||
|
||||
def _get_url(self, key: str) -> str:
|
||||
if self.public_url:
|
||||
return f"{self.public_url}/{key}"
|
||||
return f"{self.client.meta.endpoint_url}/{self.bucket}/{key}"
|
||||
|
||||
def upload(self, file_data: bytes, filename: str, content_type: str = "") -> str:
|
||||
key = self._make_key(filename)
|
||||
extra = {}
|
||||
if content_type:
|
||||
extra["ContentType"] = content_type
|
||||
self.client.put_object(Bucket=self.bucket, Key=key, Body=file_data, **extra)
|
||||
return self._get_url(key)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self.client.delete_object(Bucket=self.bucket, Key=key)
|
||||
|
||||
def generate_presigned_url(self, key: str, content_type: str = "", expires: int = 3600) -> str:
|
||||
params: dict = {"Bucket": self.bucket, "Key": key}
|
||||
if content_type:
|
||||
params["ContentType"] = content_type
|
||||
return self.client.generate_presigned_url(
|
||||
"put_object", Params=params, ExpiresIn=expires
|
||||
)
|
||||
|
||||
|
||||
def get_storage_backend() -> StorageBackend:
|
||||
backend = settings.STORAGE_BACKEND.lower()
|
||||
if backend == "s3":
|
||||
return S3StorageBackend(
|
||||
endpoint=settings.S3_ENDPOINT,
|
||||
access_key=settings.S3_ACCESS_KEY,
|
||||
secret_key=settings.S3_SECRET_KEY,
|
||||
bucket=settings.S3_BUCKET,
|
||||
region=settings.S3_REGION,
|
||||
public_url=settings.S3_PUBLIC_URL,
|
||||
)
|
||||
return LocalStorageBackend(
|
||||
storage_path=settings.LOCAL_STORAGE_PATH,
|
||||
)
|
||||
|
||||
|
||||
storage: StorageBackend | None = None
|
||||
|
||||
|
||||
def init_storage() -> StorageBackend:
|
||||
global storage
|
||||
storage = get_storage_backend()
|
||||
return storage
|
||||
Reference in New Issue
Block a user