AI Security

Federated Learning Security and Privacy Implementation

DeviDevs Team
12 min read
#federated-learning#ai-security#differential-privacy#secure-aggregation#privacy

Federated learning enables collaborative model training while keeping data decentralized. This guide covers security and privacy implementations for building robust federated learning systems.

Federated Learning Framework

Build a secure federated learning system:

from dataclasses import dataclass
from typing import List, Dict, Optional, Callable
from enum import Enum
import numpy as np
import hashlib
from abc import ABC, abstractmethod
 
class AggregationStrategy(Enum):
    FEDAVG = "federated_averaging"
    FEDPROX = "federated_proximal"
    SCAFFOLD = "scaffold"
 
@dataclass
class ModelUpdate:
    client_id: str
    round_number: int
    weights: Dict[str, np.ndarray]
    num_samples: int
    metrics: Dict[str, float]
    signature: str
 
@dataclass
class GlobalModel:
    version: int
    weights: Dict[str, np.ndarray]
    round_number: int
    participating_clients: List[str]
 
class FederatedLearningServer:
    def __init__(
        self,
        initial_model: Dict[str, np.ndarray],
        aggregation_strategy: AggregationStrategy = AggregationStrategy.FEDAVG,
        min_clients: int = 3,
        client_fraction: float = 0.5
    ):
        self.global_model = GlobalModel(
            version=0,
            weights=initial_model,
            round_number=0,
            participating_clients=[]
        )
        self.aggregation_strategy = aggregation_strategy
        self.min_clients = min_clients
        self.client_fraction = client_fraction
        self.client_updates: List[ModelUpdate] = []
        self.registered_clients: Dict[str, Dict] = {}
 
    def register_client(self, client_id: str, public_key: str, metadata: Dict) -> bool:
        """Register a new client with verification."""
        if client_id in self.registered_clients:
            return False
 
        self.registered_clients[client_id] = {
            'public_key': public_key,
            'metadata': metadata,
            'rounds_participated': 0,
            'last_update': None,
            'reputation_score': 1.0
        }
        return True
 
    def select_clients(self, round_number: int) -> List[str]:
        """Select clients for the current round."""
        eligible_clients = [
            cid for cid, info in self.registered_clients.items()
            if info['reputation_score'] > 0.5
        ]
 
        num_to_select = max(
            self.min_clients,
            int(len(eligible_clients) * self.client_fraction)
        )
 
        # Random selection with reputation weighting
        weights = [
            self.registered_clients[c]['reputation_score']
            for c in eligible_clients
        ]
        weights = np.array(weights) / sum(weights)
 
        selected = np.random.choice(
            eligible_clients,
            size=min(num_to_select, len(eligible_clients)),
            replace=False,
            p=weights
        )
 
        return list(selected)
 
    def receive_update(self, update: ModelUpdate) -> bool:
        """Receive and validate client update."""
        if update.client_id not in self.registered_clients:
            return False
 
        if update.round_number != self.global_model.round_number:
            return False
 
        # Verify signature
        if not self._verify_signature(update):
            return False
 
        # Validate update (check for anomalies)
        if not self._validate_update(update):
            self._penalize_client(update.client_id, 0.1)
            return False
 
        self.client_updates.append(update)
        return True
 
    def _verify_signature(self, update: ModelUpdate) -> bool:
        """Verify update signature."""
        client_info = self.registered_clients.get(update.client_id)
        if not client_info:
            return False
 
        # Compute expected signature
        data = f"{update.client_id}{update.round_number}{update.num_samples}"
        expected = hashlib.sha256(data.encode()).hexdigest()
 
        return update.signature == expected
 
    def _validate_update(self, update: ModelUpdate) -> bool:
        """Validate update for anomalies (model poisoning detection)."""
        if update.num_samples < 1:
            return False
 
        # Check for extreme weight values
        for layer_name, weights in update.weights.items():
            if np.any(np.isnan(weights)) or np.any(np.isinf(weights)):
                return False
 
            # Check for statistical anomalies
            global_weights = self.global_model.weights.get(layer_name)
            if global_weights is not None:
                diff = np.abs(weights - global_weights)
                if np.mean(diff) > 10 * np.std(global_weights):
                    return False
 
        return True
 
    def _penalize_client(self, client_id: str, penalty: float):
        """Reduce client reputation score."""
        if client_id in self.registered_clients:
            current = self.registered_clients[client_id]['reputation_score']
            self.registered_clients[client_id]['reputation_score'] = max(0, current - penalty)
 
    def aggregate(self) -> GlobalModel:
        """Aggregate client updates into new global model."""
        if len(self.client_updates) < self.min_clients:
            raise ValueError(f"Not enough updates: {len(self.client_updates)} < {self.min_clients}")
 
        if self.aggregation_strategy == AggregationStrategy.FEDAVG:
            new_weights = self._federated_averaging()
        elif self.aggregation_strategy == AggregationStrategy.FEDPROX:
            new_weights = self._federated_proximal()
        else:
            new_weights = self._federated_averaging()
 
        participating = [u.client_id for u in self.client_updates]
 
        self.global_model = GlobalModel(
            version=self.global_model.version + 1,
            weights=new_weights,
            round_number=self.global_model.round_number + 1,
            participating_clients=participating
        )
 
        # Update client stats
        for client_id in participating:
            self.registered_clients[client_id]['rounds_participated'] += 1
 
        # Clear updates for next round
        self.client_updates = []
 
        return self.global_model
 
    def _federated_averaging(self) -> Dict[str, np.ndarray]:
        """FedAvg aggregation."""
        total_samples = sum(u.num_samples for u in self.client_updates)
 
        new_weights = {}
        for layer_name in self.global_model.weights.keys():
            weighted_sum = np.zeros_like(self.global_model.weights[layer_name])
 
            for update in self.client_updates:
                weight = update.num_samples / total_samples
                weighted_sum += weight * update.weights[layer_name]
 
            new_weights[layer_name] = weighted_sum
 
        return new_weights
 
    def _federated_proximal(self, mu: float = 0.01) -> Dict[str, np.ndarray]:
        """FedProx aggregation with proximal term."""
        # Similar to FedAvg but clients use proximal term during training
        return self._federated_averaging()

Differential Privacy Implementation

Add differential privacy guarantees:

from dataclasses import dataclass
from typing import Tuple
import numpy as np
 
@dataclass
class PrivacyBudget:
    epsilon: float
    delta: float
    consumed_epsilon: float = 0.0
    consumed_delta: float = 0.0
 
class DifferentialPrivacy:
    def __init__(
        self,
        epsilon: float,
        delta: float,
        max_grad_norm: float = 1.0,
        noise_multiplier: float = 1.0
    ):
        self.budget = PrivacyBudget(epsilon=epsilon, delta=delta)
        self.max_grad_norm = max_grad_norm
        self.noise_multiplier = noise_multiplier
 
    def clip_gradients(self, gradients: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        """Clip gradients to bounded sensitivity."""
        total_norm = 0.0
        for grad in gradients.values():
            total_norm += np.sum(grad ** 2)
        total_norm = np.sqrt(total_norm)
 
        clip_factor = min(1.0, self.max_grad_norm / (total_norm + 1e-6))
 
        clipped = {}
        for name, grad in gradients.items():
            clipped[name] = grad * clip_factor
 
        return clipped
 
    def add_noise(
        self,
        gradients: Dict[str, np.ndarray],
        num_samples: int
    ) -> Dict[str, np.ndarray]:
        """Add calibrated Gaussian noise for DP."""
        sigma = self.noise_multiplier * self.max_grad_norm / num_samples
 
        noisy_gradients = {}
        for name, grad in gradients.items():
            noise = np.random.normal(0, sigma, grad.shape)
            noisy_gradients[name] = grad + noise
 
        return noisy_gradients
 
    def apply_dp(
        self,
        gradients: Dict[str, np.ndarray],
        num_samples: int
    ) -> Tuple[Dict[str, np.ndarray], float]:
        """Apply differential privacy to gradients."""
        # Clip gradients
        clipped = self.clip_gradients(gradients)
 
        # Add noise
        noisy = self.add_noise(clipped, num_samples)
 
        # Compute privacy cost for this step
        epsilon_step = self._compute_privacy_cost(num_samples)
        self.budget.consumed_epsilon += epsilon_step
 
        return noisy, epsilon_step
 
    def _compute_privacy_cost(self, num_samples: int) -> float:
        """Compute epsilon cost using moments accountant."""
        # Simplified privacy accounting
        q = 1.0 / num_samples  # Sampling probability
        sigma = self.noise_multiplier
 
        # RDP to (epsilon, delta)-DP conversion (simplified)
        alpha = 2
        rdp = alpha * q ** 2 / (2 * sigma ** 2)
 
        epsilon = rdp + np.log(1 / self.budget.delta) / (alpha - 1)
        return epsilon
 
    def check_budget(self) -> bool:
        """Check if privacy budget is exhausted."""
        return self.budget.consumed_epsilon < self.budget.epsilon
 
    def get_privacy_spent(self) -> Dict:
        """Get current privacy expenditure."""
        return {
            'epsilon_budget': self.budget.epsilon,
            'epsilon_spent': self.budget.consumed_epsilon,
            'epsilon_remaining': self.budget.epsilon - self.budget.consumed_epsilon,
            'delta': self.budget.delta
        }
 
class LocalDifferentialPrivacy:
    """Client-side local differential privacy."""
 
    def __init__(self, epsilon: float):
        self.epsilon = epsilon
 
    def randomized_response(self, bit: bool) -> bool:
        """Randomized response mechanism for single bit."""
        p = np.exp(self.epsilon) / (1 + np.exp(self.epsilon))
 
        if np.random.random() < p:
            return bit
        else:
            return not bit
 
    def laplace_mechanism(self, value: float, sensitivity: float) -> float:
        """Laplace mechanism for numeric values."""
        scale = sensitivity / self.epsilon
        noise = np.random.laplace(0, scale)
        return value + noise
 
    def privatize_vector(
        self,
        vector: np.ndarray,
        sensitivity: float
    ) -> np.ndarray:
        """Apply LDP to a vector."""
        scale = sensitivity / self.epsilon
        noise = np.random.laplace(0, scale, vector.shape)
        return vector + noise

Secure Aggregation Protocol

Implement secure aggregation for privacy:

from dataclasses import dataclass
from typing import List, Dict, Tuple
import numpy as np
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
import secrets
 
@dataclass
class SecretShare:
    client_id: str
    share_index: int
    share_value: np.ndarray
    commitment: bytes
 
class SecureAggregation:
    def __init__(self, threshold: int, num_clients: int):
        self.threshold = threshold  # Minimum clients needed
        self.num_clients = num_clients
        self.client_keys: Dict[str, bytes] = {}
        self.shares: Dict[str, List[SecretShare]] = {}
 
    def generate_client_keys(self, client_id: str) -> Tuple[bytes, bytes]:
        """Generate key pair for a client."""
        private_key = secrets.token_bytes(32)
        public_key = self._derive_public_key(private_key)
 
        self.client_keys[client_id] = public_key
        return private_key, public_key
 
    def _derive_public_key(self, private_key: bytes) -> bytes:
        """Derive public key from private key."""
        hkdf = HKDF(
            algorithm=hashes.SHA256(),
            length=32,
            salt=None,
            info=b'public_key'
        )
        return hkdf.derive(private_key)
 
    def create_pairwise_masks(
        self,
        client_id: str,
        private_key: bytes,
        other_clients: List[str],
        vector_shape: Tuple
    ) -> Dict[str, np.ndarray]:
        """Create pairwise masks for secure aggregation."""
        masks = {}
 
        for other_id in other_clients:
            if other_id == client_id:
                continue
 
            other_public_key = self.client_keys.get(other_id)
            if not other_public_key:
                continue
 
            # Derive shared secret
            shared_secret = self._derive_shared_secret(private_key, other_public_key)
 
            # Generate deterministic mask from shared secret
            mask = self._generate_mask(shared_secret, vector_shape, client_id, other_id)
            masks[other_id] = mask
 
        return masks
 
    def _derive_shared_secret(self, private_key: bytes, public_key: bytes) -> bytes:
        """Derive shared secret using ECDH-like mechanism."""
        combined = private_key + public_key
        hkdf = HKDF(
            algorithm=hashes.SHA256(),
            length=32,
            salt=None,
            info=b'shared_secret'
        )
        return hkdf.derive(combined)
 
    def _generate_mask(
        self,
        seed: bytes,
        shape: Tuple,
        client_a: str,
        client_b: str
    ) -> np.ndarray:
        """Generate PRG-based mask."""
        # Deterministic ordering
        if client_a < client_b:
            sign = 1
        else:
            sign = -1
 
        # Use seed to initialize RNG
        seed_int = int.from_bytes(seed[:4], 'big')
        rng = np.random.RandomState(seed_int)
 
        mask = rng.standard_normal(shape) * sign
        return mask
 
    def mask_update(
        self,
        update: np.ndarray,
        masks: Dict[str, np.ndarray]
    ) -> np.ndarray:
        """Mask update with pairwise masks."""
        masked = update.copy()
 
        for mask in masks.values():
            masked += mask
 
        return masked
 
    def aggregate_masked_updates(
        self,
        masked_updates: Dict[str, np.ndarray],
        dropout_clients: List[str] = None
    ) -> np.ndarray:
        """Aggregate masked updates - masks cancel out."""
        if dropout_clients is None:
            dropout_clients = []
 
        # If no dropouts, masks cancel perfectly
        aggregated = np.zeros_like(list(masked_updates.values())[0])
 
        for client_id, update in masked_updates.items():
            if client_id not in dropout_clients:
                aggregated += update
 
        return aggregated / len(masked_updates)
 
    def shamir_secret_share(
        self,
        secret: np.ndarray,
        num_shares: int,
        threshold: int
    ) -> List[Tuple[int, np.ndarray]]:
        """Create Shamir secret shares."""
        # Generate random coefficients
        coefficients = [secret]
        for _ in range(threshold - 1):
            coefficients.append(np.random.randn(*secret.shape))
 
        shares = []
        for i in range(1, num_shares + 1):
            share = np.zeros_like(secret)
            for j, coef in enumerate(coefficients):
                share += coef * (i ** j)
            shares.append((i, share))
 
        return shares
 
    def reconstruct_secret(
        self,
        shares: List[Tuple[int, np.ndarray]],
        threshold: int
    ) -> np.ndarray:
        """Reconstruct secret from Shamir shares."""
        if len(shares) < threshold:
            raise ValueError(f"Need at least {threshold} shares")
 
        shares = shares[:threshold]
        secret = np.zeros_like(shares[0][1])
 
        for i, (xi, yi) in enumerate(shares):
            # Lagrange basis polynomial
            numerator = 1.0
            denominator = 1.0
 
            for j, (xj, _) in enumerate(shares):
                if i != j:
                    numerator *= (0 - xj)
                    denominator *= (xi - xj)
 
            secret += yi * (numerator / denominator)
 
        return secret

Model Poisoning Defense

Implement defenses against model poisoning attacks:

from dataclasses import dataclass
from typing import List, Dict, Tuple
import numpy as np
from scipy import stats
 
@dataclass
class ClientAnalysis:
    client_id: str
    is_malicious: bool
    anomaly_score: float
    detection_method: str
 
class ModelPoisoningDefense:
    def __init__(self, detection_threshold: float = 2.0):
        self.detection_threshold = detection_threshold
        self.client_history: Dict[str, List[np.ndarray]] = {}
 
    def analyze_updates(
        self,
        updates: List[ModelUpdate],
        global_weights: Dict[str, np.ndarray]
    ) -> List[ClientAnalysis]:
        """Analyze updates for potential poisoning."""
        analyses = []
 
        for update in updates:
            # Multiple detection methods
            krum_score = self._krum_score(update, updates)
            cosine_score = self._cosine_similarity_score(update, global_weights)
            magnitude_score = self._magnitude_score(update, global_weights)
            history_score = self._history_consistency_score(update)
 
            # Combined anomaly score
            anomaly_score = (
                0.3 * krum_score +
                0.3 * cosine_score +
                0.2 * magnitude_score +
                0.2 * history_score
            )
 
            is_malicious = anomaly_score > self.detection_threshold
            detection_method = self._get_primary_detector(
                krum_score, cosine_score, magnitude_score, history_score
            )
 
            analyses.append(ClientAnalysis(
                client_id=update.client_id,
                is_malicious=is_malicious,
                anomaly_score=anomaly_score,
                detection_method=detection_method
            ))
 
            # Update history
            update_vector = self._flatten_weights(update.weights)
            if update.client_id not in self.client_history:
                self.client_history[update.client_id] = []
            self.client_history[update.client_id].append(update_vector)
 
        return analyses
 
    def _krum_score(self, target: ModelUpdate, all_updates: List[ModelUpdate]) -> float:
        """Multi-Krum anomaly score."""
        target_vec = self._flatten_weights(target.weights)
 
        distances = []
        for update in all_updates:
            if update.client_id == target.client_id:
                continue
 
            vec = self._flatten_weights(update.weights)
            dist = np.linalg.norm(target_vec - vec)
            distances.append(dist)
 
        if not distances:
            return 0.0
 
        # Score based on sum of nearest neighbors
        distances.sort()
        n_neighbors = min(len(distances), len(all_updates) - 2)
        krum_score = sum(distances[:n_neighbors])
 
        # Normalize
        median_score = np.median([
            sum(sorted([
                np.linalg.norm(self._flatten_weights(u1.weights) - self._flatten_weights(u2.weights))
                for u2 in all_updates if u1.client_id != u2.client_id
            ])[:n_neighbors])
            for u1 in all_updates
        ])
 
        return krum_score / (median_score + 1e-6)
 
    def _cosine_similarity_score(
        self,
        update: ModelUpdate,
        global_weights: Dict[str, np.ndarray]
    ) -> float:
        """Score based on cosine similarity with global model direction."""
        update_vec = self._flatten_weights(update.weights)
        global_vec = self._flatten_weights(global_weights)
 
        # Direction from global to update
        direction = update_vec - global_vec
 
        # Expected direction (simplified - use historical average)
        norm = np.linalg.norm(direction)
        if norm < 1e-6:
            return 0.0
 
        # Low similarity with expected direction = anomaly
        # For simplicity, check if update moves model significantly
        cos_sim = np.dot(direction, global_vec) / (norm * np.linalg.norm(global_vec) + 1e-6)
 
        # Very negative similarity is suspicious
        return max(0, -cos_sim)
 
    def _magnitude_score(
        self,
        update: ModelUpdate,
        global_weights: Dict[str, np.ndarray]
    ) -> float:
        """Score based on update magnitude."""
        update_vec = self._flatten_weights(update.weights)
        global_vec = self._flatten_weights(global_weights)
 
        diff = update_vec - global_vec
        magnitude = np.linalg.norm(diff)
 
        # Z-score based on expected magnitude
        expected_magnitude = np.linalg.norm(global_vec) * 0.1  # Heuristic
        std_magnitude = expected_magnitude * 0.5
 
        z_score = abs(magnitude - expected_magnitude) / (std_magnitude + 1e-6)
        return z_score
 
    def _history_consistency_score(self, update: ModelUpdate) -> float:
        """Score based on consistency with client's history."""
        history = self.client_history.get(update.client_id, [])
 
        if len(history) < 2:
            return 0.0
 
        current = self._flatten_weights(update.weights)
 
        # Check if current update is consistent with history
        historical_changes = []
        for i in range(1, len(history)):
            change = np.linalg.norm(history[i] - history[i-1])
            historical_changes.append(change)
 
        if not historical_changes:
            return 0.0
 
        current_change = np.linalg.norm(current - history[-1])
 
        mean_change = np.mean(historical_changes)
        std_change = np.std(historical_changes) + 1e-6
 
        z_score = abs(current_change - mean_change) / std_change
        return z_score
 
    def _flatten_weights(self, weights: Dict[str, np.ndarray]) -> np.ndarray:
        """Flatten weight dict to single vector."""
        return np.concatenate([w.flatten() for w in weights.values()])
 
    def _get_primary_detector(self, *scores) -> str:
        """Get which detector contributed most to anomaly."""
        methods = ['krum', 'cosine', 'magnitude', 'history']
        max_idx = np.argmax(scores)
        return methods[max_idx]
 
    def robust_aggregation(
        self,
        updates: List[ModelUpdate],
        global_weights: Dict[str, np.ndarray],
        method: str = 'trimmed_mean'
    ) -> Dict[str, np.ndarray]:
        """Perform robust aggregation resistant to poisoning."""
        if method == 'trimmed_mean':
            return self._trimmed_mean_aggregation(updates, trim_ratio=0.1)
        elif method == 'median':
            return self._median_aggregation(updates)
        elif method == 'krum':
            return self._krum_aggregation(updates)
        else:
            raise ValueError(f"Unknown method: {method}")
 
    def _trimmed_mean_aggregation(
        self,
        updates: List[ModelUpdate],
        trim_ratio: float
    ) -> Dict[str, np.ndarray]:
        """Trimmed mean aggregation."""
        n_trim = int(len(updates) * trim_ratio)
 
        aggregated = {}
        for layer_name in updates[0].weights.keys():
            layer_updates = np.array([u.weights[layer_name] for u in updates])
 
            # Sort and trim along client axis
            sorted_updates = np.sort(layer_updates, axis=0)
            if n_trim > 0:
                trimmed = sorted_updates[n_trim:-n_trim]
            else:
                trimmed = sorted_updates
 
            aggregated[layer_name] = np.mean(trimmed, axis=0)
 
        return aggregated
 
    def _median_aggregation(self, updates: List[ModelUpdate]) -> Dict[str, np.ndarray]:
        """Coordinate-wise median aggregation."""
        aggregated = {}
        for layer_name in updates[0].weights.keys():
            layer_updates = np.array([u.weights[layer_name] for u in updates])
            aggregated[layer_name] = np.median(layer_updates, axis=0)
 
        return aggregated
 
    def _krum_aggregation(self, updates: List[ModelUpdate]) -> Dict[str, np.ndarray]:
        """Select update with minimum Krum score."""
        min_score = float('inf')
        best_update = None
 
        for update in updates:
            score = self._krum_score(update, updates)
            if score < min_score:
                min_score = score
                best_update = update
 
        return best_update.weights if best_update else updates[0].weights

Conclusion

Secure federated learning requires multiple layers of protection including differential privacy, secure aggregation, and model poisoning defenses. Implement client verification and reputation systems to track trustworthiness. Use differential privacy to provide formal privacy guarantees. Deploy secure aggregation protocols to prevent the server from seeing individual updates. Add robust aggregation methods to defend against poisoning attacks. Remember that privacy and security in federated learning involve tradeoffs with model utility - carefully calibrate parameters based on your specific requirements.

Weekly AI Security & Automation Digest

Get the latest on AI Security, workflow automation, secure integrations, and custom platform development delivered weekly.

No spam. Unsubscribe anytime.