AI Security

AI API Security: Authentication, Authorization, and Rate Limiting

DeviDevs Team
12 min read
#AI API#security#authentication#rate limiting#authorization

AI API Security: Authentication, Authorization, and Rate Limiting

AI APIs present unique security challenges due to their high computational costs, potential for abuse, and sensitivity of processed data. This guide covers comprehensive security measures for AI API deployments.

Authentication Mechanisms

API Key Authentication with Rotation

# api_key_auth.py
import secrets
import hashlib
from datetime import datetime, timedelta
from typing import Dict, Optional, List
from dataclasses import dataclass
 
@dataclass
class APIKey:
    """API key record."""
    key_id: str
    key_hash: str
    name: str
    organization_id: str
    created_at: datetime
    expires_at: Optional[datetime]
    last_used: Optional[datetime]
    scopes: List[str]
    rate_limit_tier: str
    status: str  # active, revoked, expired
 
class APIKeyManager:
    """Manage API keys for AI API access."""
 
    def __init__(self, key_store):
        self.store = key_store
        self.key_prefix = "sk-"
        self.key_length = 48
 
    def generate_key(
        self,
        name: str,
        organization_id: str,
        scopes: List[str],
        rate_limit_tier: str = "standard",
        expires_in_days: Optional[int] = 365
    ) -> Dict:
        """Generate new API key."""
 
        # Generate random key
        random_part = secrets.token_urlsafe(self.key_length)
        full_key = f"{self.key_prefix}{random_part}"
 
        # Hash key for storage
        key_hash = self._hash_key(full_key)
        key_id = f"key_{secrets.token_hex(8)}"
 
        # Calculate expiry
        expires_at = None
        if expires_in_days:
            expires_at = datetime.utcnow() + timedelta(days=expires_in_days)
 
        # Create key record
        api_key = APIKey(
            key_id=key_id,
            key_hash=key_hash,
            name=name,
            organization_id=organization_id,
            created_at=datetime.utcnow(),
            expires_at=expires_at,
            last_used=None,
            scopes=scopes,
            rate_limit_tier=rate_limit_tier,
            status="active"
        )
 
        # Store key (only hash is stored)
        self.store.save_key(api_key)
 
        return {
            "key_id": key_id,
            "api_key": full_key,  # Only returned once!
            "expires_at": expires_at.isoformat() if expires_at else None,
            "scopes": scopes,
            "warning": "Save this key securely. It cannot be retrieved again."
        }
 
    def validate_key(self, api_key: str) -> Dict:
        """Validate API key and return associated metadata."""
 
        if not api_key.startswith(self.key_prefix):
            return {"valid": False, "error": "Invalid key format"}
 
        key_hash = self._hash_key(api_key)
 
        # Look up key by hash
        key_record = self.store.find_by_hash(key_hash)
 
        if not key_record:
            return {"valid": False, "error": "Key not found"}
 
        if key_record.status == "revoked":
            return {"valid": False, "error": "Key has been revoked"}
 
        if key_record.expires_at and datetime.utcnow() > key_record.expires_at:
            return {"valid": False, "error": "Key has expired"}
 
        # Update last used
        key_record.last_used = datetime.utcnow()
        self.store.save_key(key_record)
 
        return {
            "valid": True,
            "key_id": key_record.key_id,
            "organization_id": key_record.organization_id,
            "scopes": key_record.scopes,
            "rate_limit_tier": key_record.rate_limit_tier
        }
 
    def rotate_key(self, key_id: str) -> Dict:
        """Rotate existing API key."""
 
        old_key = self.store.find_by_id(key_id)
        if not old_key:
            raise ValueError("Key not found")
 
        # Generate new key with same properties
        new_key_result = self.generate_key(
            name=old_key.name,
            organization_id=old_key.organization_id,
            scopes=old_key.scopes,
            rate_limit_tier=old_key.rate_limit_tier,
            expires_in_days=None  # Will calculate from original
        )
 
        # Set old key to expire soon (grace period)
        old_key.expires_at = datetime.utcnow() + timedelta(hours=24)
        self.store.save_key(old_key)
 
        return {
            "new_key": new_key_result,
            "old_key_expires": old_key.expires_at.isoformat(),
            "message": "Old key will remain valid for 24 hours"
        }
 
    def revoke_key(self, key_id: str, reason: str) -> bool:
        """Revoke API key immediately."""
 
        key_record = self.store.find_by_id(key_id)
        if not key_record:
            return False
 
        key_record.status = "revoked"
        self.store.save_key(key_record)
 
        # Log revocation
        self._log_revocation(key_id, reason)
 
        return True
 
    def _hash_key(self, key: str) -> str:
        """Hash API key for secure storage."""
        return hashlib.sha256(key.encode()).hexdigest()
 
    def _log_revocation(self, key_id: str, reason: str):
        """Log key revocation for audit."""
        pass
 
 
class JWTAuthenticator:
    """JWT-based authentication for AI APIs."""
 
    def __init__(self, secret_key: str, issuer: str):
        self.secret_key = secret_key
        self.issuer = issuer
        self.algorithm = "RS256"
 
    def generate_token(
        self,
        user_id: str,
        organization_id: str,
        scopes: List[str],
        expires_in: int = 3600
    ) -> str:
        """Generate JWT access token."""
        import jwt
 
        now = datetime.utcnow()
        payload = {
            "sub": user_id,
            "org": organization_id,
            "scopes": scopes,
            "iat": now,
            "exp": now + timedelta(seconds=expires_in),
            "iss": self.issuer,
            "jti": secrets.token_hex(16)
        }
 
        return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
 
    def validate_token(self, token: str) -> Dict:
        """Validate JWT token."""
        import jwt
 
        try:
            payload = jwt.decode(
                token,
                self.secret_key,
                algorithms=[self.algorithm],
                issuer=self.issuer
            )
 
            return {
                "valid": True,
                "user_id": payload["sub"],
                "organization_id": payload["org"],
                "scopes": payload["scopes"],
                "expires_at": datetime.fromtimestamp(payload["exp"])
            }
        except jwt.ExpiredSignatureError:
            return {"valid": False, "error": "Token expired"}
        except jwt.InvalidTokenError as e:
            return {"valid": False, "error": str(e)}

Authorization Framework

Scope-Based Authorization

# authorization.py
from dataclasses import dataclass
from typing import List, Dict, Set
from enum import Enum
from functools import wraps
 
class AIScope(Enum):
    """AI API scopes."""
    # Model access
    INFERENCE_READ = "inference:read"
    INFERENCE_WRITE = "inference:write"
    MODELS_LIST = "models:list"
    MODELS_DEPLOY = "models:deploy"
 
    # Data access
    EMBEDDINGS_CREATE = "embeddings:create"
    COMPLETIONS_CREATE = "completions:create"
    IMAGES_CREATE = "images:create"
    AUDIO_TRANSCRIBE = "audio:transcribe"
 
    # Fine-tuning
    FINETUNE_CREATE = "finetune:create"
    FINETUNE_READ = "finetune:read"
    FINETUNE_DELETE = "finetune:delete"
 
    # Admin
    ORGANIZATION_READ = "organization:read"
    ORGANIZATION_WRITE = "organization:write"
    USAGE_READ = "usage:read"
    KEYS_MANAGE = "keys:manage"
 
@dataclass
class Permission:
    """Permission definition."""
    scope: AIScope
    resource_type: str
    actions: List[str]
    conditions: Dict
 
class AuthorizationManager:
    """Manage authorization for AI API requests."""
 
    def __init__(self):
        self.scope_hierarchy = self._build_scope_hierarchy()
        self.resource_policies = {}
 
    def _build_scope_hierarchy(self) -> Dict[str, Set[str]]:
        """Build scope hierarchy for inheritance."""
 
        return {
            "admin": {
                AIScope.ORGANIZATION_WRITE.value,
                AIScope.ORGANIZATION_READ.value,
                AIScope.USAGE_READ.value,
                AIScope.KEYS_MANAGE.value,
                AIScope.MODELS_DEPLOY.value
            },
            "developer": {
                AIScope.INFERENCE_WRITE.value,
                AIScope.INFERENCE_READ.value,
                AIScope.MODELS_LIST.value,
                AIScope.COMPLETIONS_CREATE.value,
                AIScope.EMBEDDINGS_CREATE.value,
                AIScope.IMAGES_CREATE.value,
                AIScope.AUDIO_TRANSCRIBE.value,
                AIScope.FINETUNE_CREATE.value,
                AIScope.FINETUNE_READ.value
            },
            "viewer": {
                AIScope.INFERENCE_READ.value,
                AIScope.MODELS_LIST.value,
                AIScope.USAGE_READ.value,
                AIScope.FINETUNE_READ.value
            }
        }
 
    def check_permission(
        self,
        user_scopes: List[str],
        required_scope: AIScope,
        resource: Optional[str] = None,
        context: Optional[Dict] = None
    ) -> Dict:
        """Check if user has required permission."""
 
        # Expand user scopes
        expanded_scopes = set()
        for scope in user_scopes:
            expanded_scopes.add(scope)
            if scope in self.scope_hierarchy:
                expanded_scopes.update(self.scope_hierarchy[scope])
 
        # Check if required scope is present
        if required_scope.value not in expanded_scopes:
            return {
                "allowed": False,
                "reason": f"Missing required scope: {required_scope.value}",
                "required": required_scope.value,
                "available": list(expanded_scopes)
            }
 
        # Check resource-specific policies
        if resource and resource in self.resource_policies:
            policy_result = self._evaluate_resource_policy(
                resource, user_scopes, context
            )
            if not policy_result["allowed"]:
                return policy_result
 
        return {
            "allowed": True,
            "scope": required_scope.value
        }
 
    def _evaluate_resource_policy(
        self,
        resource: str,
        user_scopes: List[str],
        context: Dict
    ) -> Dict:
        """Evaluate resource-specific policy."""
 
        policy = self.resource_policies.get(resource)
        if not policy:
            return {"allowed": True}
 
        # Check conditions
        for condition in policy.get("conditions", []):
            if not self._evaluate_condition(condition, context):
                return {
                    "allowed": False,
                    "reason": f"Condition not met: {condition['description']}"
                }
 
        return {"allowed": True}
 
    def _evaluate_condition(self, condition: Dict, context: Dict) -> bool:
        """Evaluate a single policy condition."""
 
        condition_type = condition.get("type")
 
        if condition_type == "time_window":
            # Check if request is within allowed time window
            from datetime import datetime
            now = datetime.utcnow().hour
            start = condition.get("start_hour", 0)
            end = condition.get("end_hour", 24)
            return start <= now < end
 
        elif condition_type == "ip_range":
            # Check if request IP is in allowed range
            import ipaddress
            client_ip = context.get("client_ip")
            allowed_ranges = condition.get("ranges", [])
            for range_str in allowed_ranges:
                if ipaddress.ip_address(client_ip) in ipaddress.ip_network(range_str):
                    return True
            return False
 
        elif condition_type == "model_access":
            # Check if user can access specific model
            requested_model = context.get("model")
            allowed_models = condition.get("models", [])
            return requested_model in allowed_models
 
        return True
 
 
def require_scope(scope: AIScope):
    """Decorator to require specific scope for endpoint."""
 
    def decorator(func):
        @wraps(func)
        async def wrapper(request, *args, **kwargs):
            # Get user scopes from request
            user_scopes = request.state.user_scopes
 
            # Check permission
            auth_manager = AuthorizationManager()
            result = auth_manager.check_permission(
                user_scopes=user_scopes,
                required_scope=scope,
                context={
                    "client_ip": request.client.host,
                    "model": request.json.get("model")
                }
            )
 
            if not result["allowed"]:
                from fastapi import HTTPException
                raise HTTPException(
                    status_code=403,
                    detail=result["reason"]
                )
 
            return await func(request, *args, **kwargs)
        return wrapper
    return decorator

Rate Limiting

Multi-Tier Rate Limiter

# rate_limiting.py
from dataclasses import dataclass
from typing import Dict, Optional
from datetime import datetime, timedelta
import asyncio
 
@dataclass
class RateLimitTier:
    """Rate limit tier configuration."""
    name: str
    requests_per_minute: int
    requests_per_day: int
    tokens_per_minute: int
    tokens_per_day: int
    concurrent_requests: int
    burst_multiplier: float
 
class AIRateLimiter:
    """Rate limiter for AI API requests."""
 
    TIERS = {
        "free": RateLimitTier(
            name="free",
            requests_per_minute=3,
            requests_per_day=100,
            tokens_per_minute=10000,
            tokens_per_day=100000,
            concurrent_requests=1,
            burst_multiplier=1.0
        ),
        "standard": RateLimitTier(
            name="standard",
            requests_per_minute=60,
            requests_per_day=10000,
            tokens_per_minute=60000,
            tokens_per_day=1000000,
            concurrent_requests=5,
            burst_multiplier=1.5
        ),
        "professional": RateLimitTier(
            name="professional",
            requests_per_minute=300,
            requests_per_day=50000,
            tokens_per_minute=300000,
            tokens_per_day=5000000,
            concurrent_requests=20,
            burst_multiplier=2.0
        ),
        "enterprise": RateLimitTier(
            name="enterprise",
            requests_per_minute=1000,
            requests_per_day=None,  # Unlimited
            tokens_per_minute=1000000,
            tokens_per_day=None,  # Unlimited
            concurrent_requests=100,
            burst_multiplier=3.0
        )
    }
 
    def __init__(self, redis_client):
        self.redis = redis_client
 
    async def check_rate_limit(
        self,
        organization_id: str,
        tier_name: str,
        estimated_tokens: int = 0
    ) -> Dict:
        """Check if request is within rate limits."""
 
        tier = self.TIERS.get(tier_name, self.TIERS["free"])
        now = datetime.utcnow()
 
        # Keys for different rate limit windows
        minute_key = f"ratelimit:{organization_id}:minute:{now.strftime('%Y%m%d%H%M')}"
        day_key = f"ratelimit:{organization_id}:day:{now.strftime('%Y%m%d')}"
        token_minute_key = f"ratelimit:{organization_id}:tokens:minute:{now.strftime('%Y%m%d%H%M')}"
        token_day_key = f"ratelimit:{organization_id}:tokens:day:{now.strftime('%Y%m%d')}"
        concurrent_key = f"ratelimit:{organization_id}:concurrent"
 
        # Get current counts
        pipe = self.redis.pipeline()
        pipe.get(minute_key)
        pipe.get(day_key)
        pipe.get(token_minute_key)
        pipe.get(token_day_key)
        pipe.get(concurrent_key)
        results = await pipe.execute()
 
        minute_count = int(results[0] or 0)
        day_count = int(results[1] or 0)
        token_minute_count = int(results[2] or 0)
        token_day_count = int(results[3] or 0)
        concurrent_count = int(results[4] or 0)
 
        # Calculate burst allowance
        burst_rpm = int(tier.requests_per_minute * tier.burst_multiplier)
 
        # Check limits
        if minute_count >= burst_rpm:
            return self._rate_limit_response(
                "requests_per_minute",
                tier.requests_per_minute,
                60
            )
 
        if tier.requests_per_day and day_count >= tier.requests_per_day:
            seconds_until_reset = self._seconds_until_day_reset()
            return self._rate_limit_response(
                "requests_per_day",
                tier.requests_per_day,
                seconds_until_reset
            )
 
        if token_minute_count + estimated_tokens > tier.tokens_per_minute:
            return self._rate_limit_response(
                "tokens_per_minute",
                tier.tokens_per_minute,
                60
            )
 
        if tier.tokens_per_day and token_day_count + estimated_tokens > tier.tokens_per_day:
            seconds_until_reset = self._seconds_until_day_reset()
            return self._rate_limit_response(
                "tokens_per_day",
                tier.tokens_per_day,
                seconds_until_reset
            )
 
        if concurrent_count >= tier.concurrent_requests:
            return self._rate_limit_response(
                "concurrent_requests",
                tier.concurrent_requests,
                5  # Retry after 5 seconds
            )
 
        return {
            "allowed": True,
            "tier": tier_name,
            "remaining": {
                "requests_minute": burst_rpm - minute_count - 1,
                "requests_day": (tier.requests_per_day - day_count - 1) if tier.requests_per_day else None,
                "tokens_minute": tier.tokens_per_minute - token_minute_count - estimated_tokens,
                "tokens_day": (tier.tokens_per_day - token_day_count - estimated_tokens) if tier.tokens_per_day else None
            }
        }
 
    async def record_request(
        self,
        organization_id: str,
        tokens_used: int
    ):
        """Record request for rate limiting."""
 
        now = datetime.utcnow()
 
        minute_key = f"ratelimit:{organization_id}:minute:{now.strftime('%Y%m%d%H%M')}"
        day_key = f"ratelimit:{organization_id}:day:{now.strftime('%Y%m%d')}"
        token_minute_key = f"ratelimit:{organization_id}:tokens:minute:{now.strftime('%Y%m%d%H%M')}"
        token_day_key = f"ratelimit:{organization_id}:tokens:day:{now.strftime('%Y%m%d')}"
 
        pipe = self.redis.pipeline()
        pipe.incr(minute_key)
        pipe.expire(minute_key, 120)  # 2 minute TTL
        pipe.incr(day_key)
        pipe.expire(day_key, 90000)  # 25 hour TTL
        pipe.incrby(token_minute_key, tokens_used)
        pipe.expire(token_minute_key, 120)
        pipe.incrby(token_day_key, tokens_used)
        pipe.expire(token_day_key, 90000)
 
        await pipe.execute()
 
    async def acquire_concurrent_slot(self, organization_id: str) -> bool:
        """Acquire concurrent request slot."""
 
        key = f"ratelimit:{organization_id}:concurrent"
        result = await self.redis.incr(key)
        await self.redis.expire(key, 300)  # 5 minute TTL
        return True
 
    async def release_concurrent_slot(self, organization_id: str):
        """Release concurrent request slot."""
 
        key = f"ratelimit:{organization_id}:concurrent"
        await self.redis.decr(key)
 
    def _rate_limit_response(
        self,
        limit_type: str,
        limit: int,
        retry_after: int
    ) -> Dict:
        """Generate rate limit response."""
 
        return {
            "allowed": False,
            "error": "rate_limit_exceeded",
            "limit_type": limit_type,
            "limit": limit,
            "retry_after": retry_after,
            "message": f"Rate limit exceeded for {limit_type}. Retry after {retry_after} seconds."
        }
 
    def _seconds_until_day_reset(self) -> int:
        """Calculate seconds until midnight UTC."""
 
        now = datetime.utcnow()
        tomorrow = now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
        return int((tomorrow - now).total_seconds())

Request Validation

AI-Specific Request Validation

# request_validation.py
from dataclasses import dataclass
from typing import List, Dict, Optional, Any
from enum import Enum
import re
 
class ValidationError(Exception):
    """Validation error with details."""
    def __init__(self, field: str, message: str, code: str):
        self.field = field
        self.message = message
        self.code = code
        super().__init__(message)
 
class AIRequestValidator:
    """Validate AI API requests."""
 
    # Content policy patterns
    BLOCKED_PATTERNS = [
        r'ignore\s+previous\s+instructions',
        r'you\s+are\s+now\s+',
        r'act\s+as\s+(if\s+)?',
        r'pretend\s+to\s+be',
        r'system\s*:\s*',
    ]
 
    MAX_PROMPT_LENGTH = 32000
    MAX_MESSAGES = 100
 
    def __init__(self, config: Optional[Dict] = None):
        self.config = config or {}
        self.blocked_patterns = [
            re.compile(p, re.IGNORECASE)
            for p in self.BLOCKED_PATTERNS
        ]
 
    def validate_completion_request(self, request: Dict) -> Dict:
        """Validate chat/completion request."""
 
        errors = []
 
        # Validate model
        if "model" not in request:
            errors.append(ValidationError("model", "Model is required", "missing_field"))
        else:
            model_error = self._validate_model(request["model"])
            if model_error:
                errors.append(model_error)
 
        # Validate messages
        if "messages" in request:
            message_errors = self._validate_messages(request["messages"])
            errors.extend(message_errors)
        elif "prompt" in request:
            prompt_errors = self._validate_prompt(request["prompt"])
            errors.extend(prompt_errors)
        else:
            errors.append(ValidationError(
                "messages",
                "Either messages or prompt is required",
                "missing_field"
            ))
 
        # Validate parameters
        if "temperature" in request:
            if not 0 <= request["temperature"] <= 2:
                errors.append(ValidationError(
                    "temperature",
                    "Temperature must be between 0 and 2",
                    "out_of_range"
                ))
 
        if "max_tokens" in request:
            if not 1 <= request["max_tokens"] <= 128000:
                errors.append(ValidationError(
                    "max_tokens",
                    "max_tokens must be between 1 and 128000",
                    "out_of_range"
                ))
 
        if errors:
            return {
                "valid": False,
                "errors": [
                    {"field": e.field, "message": e.message, "code": e.code}
                    for e in errors
                ]
            }
 
        return {"valid": True}
 
    def _validate_model(self, model: str) -> Optional[ValidationError]:
        """Validate model name."""
 
        allowed_models = self.config.get("allowed_models", [
            "gpt-4",
            "gpt-4-turbo",
            "gpt-3.5-turbo",
            "claude-3-opus",
            "claude-3-sonnet"
        ])
 
        if model not in allowed_models:
            return ValidationError(
                "model",
                f"Model '{model}' is not available",
                "invalid_model"
            )
 
        return None
 
    def _validate_messages(self, messages: List[Dict]) -> List[ValidationError]:
        """Validate chat messages."""
 
        errors = []
 
        if not isinstance(messages, list):
            return [ValidationError("messages", "Messages must be an array", "invalid_type")]
 
        if len(messages) > self.MAX_MESSAGES:
            errors.append(ValidationError(
                "messages",
                f"Maximum {self.MAX_MESSAGES} messages allowed",
                "too_many_messages"
            ))
 
        for i, message in enumerate(messages):
            # Check required fields
            if "role" not in message:
                errors.append(ValidationError(
                    f"messages[{i}].role",
                    "Role is required",
                    "missing_field"
                ))
 
            if "content" not in message:
                errors.append(ValidationError(
                    f"messages[{i}].content",
                    "Content is required",
                    "missing_field"
                ))
                continue
 
            # Check role validity
            valid_roles = ["system", "user", "assistant", "function"]
            if message.get("role") not in valid_roles:
                errors.append(ValidationError(
                    f"messages[{i}].role",
                    f"Invalid role. Must be one of: {valid_roles}",
                    "invalid_role"
                ))
 
            # Check content length
            content = message.get("content", "")
            if len(content) > self.MAX_PROMPT_LENGTH:
                errors.append(ValidationError(
                    f"messages[{i}].content",
                    f"Content exceeds maximum length of {self.MAX_PROMPT_LENGTH}",
                    "content_too_long"
                ))
 
            # Check for blocked patterns
            blocked = self._check_blocked_patterns(content)
            if blocked:
                errors.append(ValidationError(
                    f"messages[{i}].content",
                    "Content contains blocked patterns",
                    "blocked_content"
                ))
 
        return errors
 
    def _validate_prompt(self, prompt: str) -> List[ValidationError]:
        """Validate prompt string."""
 
        errors = []
 
        if not isinstance(prompt, str):
            return [ValidationError("prompt", "Prompt must be a string", "invalid_type")]
 
        if len(prompt) > self.MAX_PROMPT_LENGTH:
            errors.append(ValidationError(
                "prompt",
                f"Prompt exceeds maximum length of {self.MAX_PROMPT_LENGTH}",
                "content_too_long"
            ))
 
        if self._check_blocked_patterns(prompt):
            errors.append(ValidationError(
                "prompt",
                "Prompt contains blocked patterns",
                "blocked_content"
            ))
 
        return errors
 
    def _check_blocked_patterns(self, content: str) -> bool:
        """Check if content contains blocked patterns."""
 
        for pattern in self.blocked_patterns:
            if pattern.search(content):
                return True
        return False

Security Middleware

Complete API Security Pipeline

# security_middleware.py
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.base import BaseHTTPMiddleware
import time
 
class AISecurityMiddleware(BaseHTTPMiddleware):
    """Complete security middleware for AI API."""
 
    def __init__(self, app: FastAPI, config: Dict):
        super().__init__(app)
        self.api_key_manager = APIKeyManager(config["key_store"])
        self.auth_manager = AuthorizationManager()
        self.rate_limiter = AIRateLimiter(config["redis"])
        self.request_validator = AIRequestValidator(config)
 
    async def dispatch(self, request: Request, call_next):
        start_time = time.time()
 
        # Skip health check endpoints
        if request.url.path in ["/health", "/ready"]:
            return await call_next(request)
 
        try:
            # 1. Authenticate
            auth_result = await self._authenticate(request)
            if not auth_result["valid"]:
                raise HTTPException(status_code=401, detail=auth_result["error"])
 
            # Store auth info in request state
            request.state.organization_id = auth_result["organization_id"]
            request.state.user_scopes = auth_result["scopes"]
            request.state.rate_limit_tier = auth_result["rate_limit_tier"]
 
            # 2. Check rate limits
            rate_result = await self.rate_limiter.check_rate_limit(
                organization_id=auth_result["organization_id"],
                tier_name=auth_result["rate_limit_tier"],
                estimated_tokens=self._estimate_tokens(request)
            )
 
            if not rate_result["allowed"]:
                raise HTTPException(
                    status_code=429,
                    detail=rate_result["message"],
                    headers={
                        "Retry-After": str(rate_result["retry_after"]),
                        "X-RateLimit-Limit": str(rate_result.get("limit", "")),
                        "X-RateLimit-Reset": str(rate_result.get("retry_after", ""))
                    }
                )
 
            # 3. Acquire concurrent slot
            await self.rate_limiter.acquire_concurrent_slot(
                auth_result["organization_id"]
            )
 
            try:
                # 4. Validate request
                if request.method == "POST":
                    body = await request.json()
                    validation = self.request_validator.validate_completion_request(body)
                    if not validation["valid"]:
                        raise HTTPException(
                            status_code=400,
                            detail={"errors": validation["errors"]}
                        )
 
                # 5. Process request
                response = await call_next(request)
 
                # 6. Record usage
                duration = time.time() - start_time
                await self._record_usage(request, response, duration)
 
                # Add rate limit headers
                response.headers["X-RateLimit-Remaining"] = str(
                    rate_result.get("remaining", {}).get("requests_minute", "")
                )
 
                return response
 
            finally:
                # Release concurrent slot
                await self.rate_limiter.release_concurrent_slot(
                    auth_result["organization_id"]
                )
 
        except HTTPException:
            raise
        except Exception as e:
            # Log error
            raise HTTPException(status_code=500, detail="Internal server error")
 
    async def _authenticate(self, request: Request) -> Dict:
        """Authenticate request."""
 
        auth_header = request.headers.get("Authorization")
 
        if not auth_header:
            return {"valid": False, "error": "Missing authorization header"}
 
        if auth_header.startswith("Bearer "):
            # API key authentication
            api_key = auth_header[7:]
            return self.api_key_manager.validate_key(api_key)
 
        return {"valid": False, "error": "Invalid authorization format"}
 
    def _estimate_tokens(self, request: Request) -> int:
        """Estimate tokens for rate limiting."""
        # Rough estimation based on request size
        return 1000  # Default estimate
 
    async def _record_usage(self, request: Request, response, duration: float):
        """Record API usage."""
        # Implementation for usage tracking
        pass

Conclusion

Securing AI APIs requires comprehensive measures:

  1. Strong Authentication - API keys with rotation and JWT tokens
  2. Fine-Grained Authorization - Scope-based access control
  3. Intelligent Rate Limiting - Multi-tier limits with burst handling
  4. Request Validation - Content policy enforcement and input validation
  5. Complete Pipeline - Middleware integrating all security layers

By implementing these security measures, you can protect your AI APIs from abuse while providing a reliable service to legitimate users.

Weekly AI Security & Automation Digest

Get the latest on AI Security, workflow automation, secure integrations, and custom platform development delivered weekly.

No spam. Unsubscribe anytime.