| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282 |
- """
- Authentication and security utilities.
- Provides password hashing, token encryption, and session management.
- """
- from datetime import datetime, timedelta
- from typing import Optional
- from fastapi import Depends, HTTPException, status, Request, Response
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy import select, func
- import bcrypt
- from cryptography.fernet import Fernet
- from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
- import base64
- import os
- from app.models import User
- from app.database import get_db
- from app.config import get_settings
- # Session configuration
- SESSION_COOKIE_NAME = "session"
- SESSION_MAX_AGE = 60 * 60 * 24 * 30 # 30 days
- def get_password_hash(password: str) -> str:
- """Hash a password using bcrypt."""
- # Bcrypt requires bytes
- password_bytes = password.encode('utf-8')
- salt = bcrypt.gensalt()
- hashed = bcrypt.hashpw(password_bytes, salt)
- # Return as string for storage
- return hashed.decode('utf-8')
- def verify_password(plain_password: str, hashed_password: str) -> bool:
- """Verify a password against its hash."""
- password_bytes = plain_password.encode('utf-8')
- hashed_bytes = hashed_password.encode('utf-8')
- return bcrypt.checkpw(password_bytes, hashed_bytes)
- def get_fernet_key() -> bytes:
- """
- Get or generate Fernet encryption key for API tokens.
- Uses SECRET_KEY from settings to derive a consistent encryption key.
- """
- settings = get_settings()
- # Derive a 32-byte key from SECRET_KEY
- key = base64.urlsafe_b64encode(settings.secret_key.encode().ljust(32)[:32])
- return key
- def encrypt_token(token: str) -> str:
- """Encrypt an API token using Fernet."""
- fernet = Fernet(get_fernet_key())
- return fernet.encrypt(token.encode()).decode()
- def decrypt_token(encrypted_token: str) -> str:
- """Decrypt an API token using Fernet."""
- fernet = Fernet(get_fernet_key())
- return fernet.decrypt(encrypted_token.encode()).decode()
- def get_serializer() -> URLSafeTimedSerializer:
- """Get session serializer."""
- settings = get_settings()
- return URLSafeTimedSerializer(settings.secret_key)
- def create_session_token(user_id: int) -> str:
- """Create a signed session token for a user."""
- serializer = get_serializer()
- return serializer.dumps({"user_id": user_id})
- def verify_session_token(token: str, max_age: int = SESSION_MAX_AGE) -> Optional[int]:
- """
- Verify a session token and return the user_id.
- Returns None if token is invalid or expired.
- """
- serializer = get_serializer()
- try:
- data = serializer.loads(token, max_age=max_age)
- return data.get("user_id")
- except (BadSignature, SignatureExpired):
- return None
- def set_session_cookie(response: Response, user_id: int):
- """Set session cookie on response."""
- token = create_session_token(user_id)
- response.set_cookie(
- key=SESSION_COOKIE_NAME,
- value=token,
- max_age=SESSION_MAX_AGE,
- path="/",
- httponly=True,
- samesite="lax",
- # Set secure=True in production with HTTPS
- secure=False
- )
- def clear_session_cookie(response: Response):
- """Clear session cookie."""
- response.delete_cookie(key=SESSION_COOKIE_NAME, path="/")
- async def get_current_user(
- request: Request,
- db: AsyncSession = Depends(get_db)
- ) -> User:
- """
- Get the current authenticated user from session cookie.
- Raises 401 Unauthorized if not authenticated.
- """
- # Get session token from cookie
- token = request.cookies.get(SESSION_COOKIE_NAME)
- if not token:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Not authenticated"
- )
- # Verify token and get user_id
- user_id = verify_session_token(token)
- if user_id is None:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid or expired session"
- )
- # Get user from database
- result = await db.execute(
- select(User).where(User.id == user_id, User.is_active == True)
- )
- user = result.scalar_one_or_none()
- if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="User not found or inactive"
- )
- return user
- async def get_current_user_optional(
- request: Request,
- db: AsyncSession = Depends(get_db)
- ) -> Optional[User]:
- """
- Get the current authenticated user from session cookie.
- Returns None if not authenticated (does not raise exception).
- """
- try:
- return await get_current_user(request, db)
- except HTTPException:
- return None
- async def get_current_admin(
- request: Request,
- db: AsyncSession = Depends(get_db)
- ) -> User:
- """
- Get the current authenticated admin user.
- Raises 403 Forbidden if user is not an admin.
- """
- user = await get_current_user(request, db)
- if not user.is_admin:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Admin access required"
- )
- return user
- async def authenticate_user(
- db: AsyncSession,
- username: str,
- password: str
- ) -> Optional[User]:
- """
- Authenticate a user by username and password.
- Returns User if authentication succeeds, None otherwise.
- """
- # Find user by username (case-insensitive)
- result = await db.execute(
- select(User).where(func.lower(User.username) == func.lower(username))
- )
- user = result.scalar_one_or_none()
- if not user:
- return None
- # Verify password
- if not verify_password(password, user.hashed_password):
- return None
- # Check if user is active
- if not user.is_active:
- return None
- # Update last login
- user.last_login = datetime.now()
- await db.commit()
- return user
- async def create_user(
- db: AsyncSession,
- username: str,
- email: str,
- password: str,
- abs_url: str,
- abs_api_token: str,
- display_name: Optional[str] = None
- ) -> User:
- """
- Create a new user account.
- Raises HTTPException if username or email already exists.
- """
- # Check if username already exists (case-insensitive)
- result = await db.execute(
- select(User).where(func.lower(User.username) == func.lower(username))
- )
- if result.scalar_one_or_none():
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Username already registered"
- )
- # Check if email already exists
- result = await db.execute(
- select(User).where(User.email == email)
- )
- if result.scalar_one_or_none():
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Email already registered"
- )
- # Check if this is the first user (make them admin)
- result = await db.execute(select(func.count(User.id)))
- user_count = result.scalar()
- is_first_user = user_count == 0
- # Create new user
- user = User(
- username=username,
- email=email,
- hashed_password=get_password_hash(password),
- abs_url=abs_url,
- abs_api_token=encrypt_token(abs_api_token),
- display_name=display_name or username,
- created_at=datetime.now(),
- is_active=True,
- is_admin=is_first_user # First user becomes admin
- )
- db.add(user)
- await db.commit()
- await db.refresh(user)
- return user
|