AI Security

Securitatea si confidentialitatea in Federated Learning - Implementare

Nicu Constantin
--12 min lectura
#federated-learning#ai-security#differential-privacy#secure-aggregation#privacy

Federated learning permite antrenamentul colaborativ al modelelor mentinand datele descentralizate. Acest ghid acopera implementarile de securitate si confidentialitate pentru construirea de sisteme de federated learning robuste.

Framework de Federated Learning

Construieste un sistem de federated learning securizat:

from dataclasses import dataclass
from typing import List, Dict, Optional, Callable
from enum import Enum
import numpy as np
import hashlib
from abc import ABC, abstractmethod
 
class AggregationStrategy(Enum):
    FEDAVG = "federated_averaging"
    FEDPROX = "federated_proximal"
    SCAFFOLD = "scaffold"
 
@dataclass
class ModelUpdate:
    client_id: str
    round_number: int
    weights: Dict[str, np.ndarray]
    num_samples: int
    metrics: Dict[str, float]
    signature: str
 
@dataclass
class GlobalModel:
    version: int
    weights: Dict[str, np.ndarray]
    round_number: int
    participating_clients: List[str]
 
class FederatedLearningServer:
    def __init__(
        self,
        initial_model: Dict[str, np.ndarray],
        aggregation_strategy: AggregationStrategy = AggregationStrategy.FEDAVG,
        min_clients: int = 3,
        client_fraction: float = 0.5
    ):
        self.global_model = GlobalModel(
            version=0,
            weights=initial_model,
            round_number=0,
            participating_clients=[]
        )
        self.aggregation_strategy = aggregation_strategy
        self.min_clients = min_clients
        self.client_fraction = client_fraction
        self.client_updates: List[ModelUpdate] = []
        self.registered_clients: Dict[str, Dict] = {}
 
    def register_client(self, client_id: str, public_key: str, metadata: Dict) -> bool:
        """Inregistreaza un client nou cu verificare."""
        if client_id in self.registered_clients:
            return False
 
        self.registered_clients[client_id] = {
            'public_key': public_key,
            'metadata': metadata,
            'rounds_participated': 0,
            'last_update': None,
            'reputation_score': 1.0
        }
        return True
 
    def select_clients(self, round_number: int) -> List[str]:
        """Selecteaza clientii pentru runda curenta."""
        eligible_clients = [
            cid for cid, info in self.registered_clients.items()
            if info['reputation_score'] > 0.5
        ]
 
        num_to_select = max(
            self.min_clients,
            int(len(eligible_clients) * self.client_fraction)
        )
 
        # Selectie aleatorie cu ponderare bazata pe reputatie
        weights = [
            self.registered_clients[c]['reputation_score']
            for c in eligible_clients
        ]
        weights = np.array(weights) / sum(weights)
 
        selected = np.random.choice(
            eligible_clients,
            size=min(num_to_select, len(eligible_clients)),
            replace=False,
            p=weights
        )
 
        return list(selected)
 
    def receive_update(self, update: ModelUpdate) -> bool:
        """Primeste si valideaza actualizarea clientului."""
        if update.client_id not in self.registered_clients:
            return False
 
        if update.round_number != self.global_model.round_number:
            return False
 
        # Verifica semnatura
        if not self._verify_signature(update):
            return False
 
        # Valideaza actualizarea (verifica anomalii)
        if not self._validate_update(update):
            self._penalize_client(update.client_id, 0.1)
            return False
 
        self.client_updates.append(update)
        return True
 
    def _verify_signature(self, update: ModelUpdate) -> bool:
        """Verifica semnatura actualizarii."""
        client_info = self.registered_clients.get(update.client_id)
        if not client_info:
            return False
 
        # Calculeaza semnatura asteptata
        data = f"{update.client_id}{update.round_number}{update.num_samples}"
        expected = hashlib.sha256(data.encode()).hexdigest()
 
        return update.signature == expected
 
    def _validate_update(self, update: ModelUpdate) -> bool:
        """Valideaza actualizarea pentru anomalii (detectia model poisoning)."""
        if update.num_samples < 1:
            return False
 
        # Verifica valori extreme ale ponderilor
        for layer_name, weights in update.weights.items():
            if np.any(np.isnan(weights)) or np.any(np.isinf(weights)):
                return False
 
            # Verifica anomalii statistice
            global_weights = self.global_model.weights.get(layer_name)
            if global_weights is not None:
                diff = np.abs(weights - global_weights)
                if np.mean(diff) > 10 * np.std(global_weights):
                    return False
 
        return True
 
    def _penalize_client(self, client_id: str, penalty: float):
        """Reduce scorul de reputatie al clientului."""
        if client_id in self.registered_clients:
            current = self.registered_clients[client_id]['reputation_score']
            self.registered_clients[client_id]['reputation_score'] = max(0, current - penalty)
 
    def aggregate(self) -> GlobalModel:
        """Agrega actualizarile clientilor intr-un nou model global."""
        if len(self.client_updates) < self.min_clients:
            raise ValueError(f"Not enough updates: {len(self.client_updates)} < {self.min_clients}")
 
        if self.aggregation_strategy == AggregationStrategy.FEDAVG:
            new_weights = self._federated_averaging()
        elif self.aggregation_strategy == AggregationStrategy.FEDPROX:
            new_weights = self._federated_proximal()
        else:
            new_weights = self._federated_averaging()
 
        participating = [u.client_id for u in self.client_updates]
 
        self.global_model = GlobalModel(
            version=self.global_model.version + 1,
            weights=new_weights,
            round_number=self.global_model.round_number + 1,
            participating_clients=participating
        )
 
        # Actualizeaza statisticile clientilor
        for client_id in participating:
            self.registered_clients[client_id]['rounds_participated'] += 1
 
        # Curata actualizarile pentru runda urmatoare
        self.client_updates = []
 
        return self.global_model
 
    def _federated_averaging(self) -> Dict[str, np.ndarray]:
        """Agregare FedAvg."""
        total_samples = sum(u.num_samples for u in self.client_updates)
 
        new_weights = {}
        for layer_name in self.global_model.weights.keys():
            weighted_sum = np.zeros_like(self.global_model.weights[layer_name])
 
            for update in self.client_updates:
                weight = update.num_samples / total_samples
                weighted_sum += weight * update.weights[layer_name]
 
            new_weights[layer_name] = weighted_sum
 
        return new_weights
 
    def _federated_proximal(self, mu: float = 0.01) -> Dict[str, np.ndarray]:
        """Agregare FedProx cu termen proximal."""
        # Similar cu FedAvg dar clientii folosesc termen proximal in timpul antrenamentului
        return self._federated_averaging()

Implementarea Differential Privacy

Adauga garantii de differential privacy:

from dataclasses import dataclass
from typing import Tuple
import numpy as np
 
@dataclass
class PrivacyBudget:
    epsilon: float
    delta: float
    consumed_epsilon: float = 0.0
    consumed_delta: float = 0.0
 
class DifferentialPrivacy:
    def __init__(
        self,
        epsilon: float,
        delta: float,
        max_grad_norm: float = 1.0,
        noise_multiplier: float = 1.0
    ):
        self.budget = PrivacyBudget(epsilon=epsilon, delta=delta)
        self.max_grad_norm = max_grad_norm
        self.noise_multiplier = noise_multiplier
 
    def clip_gradients(self, gradients: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        """Clipeaza gradientii la sensibilitate limitata."""
        total_norm = 0.0
        for grad in gradients.values():
            total_norm += np.sum(grad ** 2)
        total_norm = np.sqrt(total_norm)
 
        clip_factor = min(1.0, self.max_grad_norm / (total_norm + 1e-6))
 
        clipped = {}
        for name, grad in gradients.items():
            clipped[name] = grad * clip_factor
 
        return clipped
 
    def add_noise(
        self,
        gradients: Dict[str, np.ndarray],
        num_samples: int
    ) -> Dict[str, np.ndarray]:
        """Adauga zgomot gaussian calibrat pentru DP."""
        sigma = self.noise_multiplier * self.max_grad_norm / num_samples
 
        noisy_gradients = {}
        for name, grad in gradients.items():
            noise = np.random.normal(0, sigma, grad.shape)
            noisy_gradients[name] = grad + noise
 
        return noisy_gradients
 
    def apply_dp(
        self,
        gradients: Dict[str, np.ndarray],
        num_samples: int
    ) -> Tuple[Dict[str, np.ndarray], float]:
        """Aplica differential privacy pe gradienti."""
        # Clipeaza gradientii
        clipped = self.clip_gradients(gradients)
 
        # Adauga zgomot
        noisy = self.add_noise(clipped, num_samples)
 
        # Calculeaza costul de confidentialitate pentru acest pas
        epsilon_step = self._compute_privacy_cost(num_samples)
        self.budget.consumed_epsilon += epsilon_step
 
        return noisy, epsilon_step
 
    def _compute_privacy_cost(self, num_samples: int) -> float:
        """Calculeaza costul epsilon folosind moments accountant."""
        # Contabilitate simplificata a confidentialitatii
        q = 1.0 / num_samples  # Probabilitatea de esantionare
        sigma = self.noise_multiplier
 
        # Conversie RDP la (epsilon, delta)-DP (simplificata)
        alpha = 2
        rdp = alpha * q ** 2 / (2 * sigma ** 2)
 
        epsilon = rdp + np.log(1 / self.budget.delta) / (alpha - 1)
        return epsilon
 
    def check_budget(self) -> bool:
        """Verifica daca bugetul de confidentialitate este epuizat."""
        return self.budget.consumed_epsilon < self.budget.epsilon
 
    def get_privacy_spent(self) -> Dict:
        """Obtine cheltuiala curenta de confidentialitate."""
        return {
            'epsilon_budget': self.budget.epsilon,
            'epsilon_spent': self.budget.consumed_epsilon,
            'epsilon_remaining': self.budget.epsilon - self.budget.consumed_epsilon,
            'delta': self.budget.delta
        }
 
class LocalDifferentialPrivacy:
    """Differential privacy locala pe partea clientului."""
 
    def __init__(self, epsilon: float):
        self.epsilon = epsilon
 
    def randomized_response(self, bit: bool) -> bool:
        """Mecanism de raspuns aleatoriu pentru un singur bit."""
        p = np.exp(self.epsilon) / (1 + np.exp(self.epsilon))
 
        if np.random.random() < p:
            return bit
        else:
            return not bit
 
    def laplace_mechanism(self, value: float, sensitivity: float) -> float:
        """Mecanism Laplace pentru valori numerice."""
        scale = sensitivity / self.epsilon
        noise = np.random.laplace(0, scale)
        return value + noise
 
    def privatize_vector(
        self,
        vector: np.ndarray,
        sensitivity: float
    ) -> np.ndarray:
        """Aplica LDP pe un vector."""
        scale = sensitivity / self.epsilon
        noise = np.random.laplace(0, scale, vector.shape)
        return vector + noise

Protocol de agregare securizata

Implementeaza agregare securizata pentru confidentialitate:

from dataclasses import dataclass
from typing import List, Dict, Tuple
import numpy as np
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
import secrets
 
@dataclass
class SecretShare:
    client_id: str
    share_index: int
    share_value: np.ndarray
    commitment: bytes
 
class SecureAggregation:
    def __init__(self, threshold: int, num_clients: int):
        self.threshold = threshold  # Numar minim de clienti necesari
        self.num_clients = num_clients
        self.client_keys: Dict[str, bytes] = {}
        self.shares: Dict[str, List[SecretShare]] = {}
 
    def generate_client_keys(self, client_id: str) -> Tuple[bytes, bytes]:
        """Genereaza pereche de chei pentru un client."""
        private_key = secrets.token_bytes(32)
        public_key = self._derive_public_key(private_key)
 
        self.client_keys[client_id] = public_key
        return private_key, public_key
 
    def _derive_public_key(self, private_key: bytes) -> bytes:
        """Deriveaza cheia publica din cheia privata."""
        hkdf = HKDF(
            algorithm=hashes.SHA256(),
            length=32,
            salt=None,
            info=b'public_key'
        )
        return hkdf.derive(private_key)
 
    def create_pairwise_masks(
        self,
        client_id: str,
        private_key: bytes,
        other_clients: List[str],
        vector_shape: Tuple
    ) -> Dict[str, np.ndarray]:
        """Creeaza masti pereche pentru agregare securizata."""
        masks = {}
 
        for other_id in other_clients:
            if other_id == client_id:
                continue
 
            other_public_key = self.client_keys.get(other_id)
            if not other_public_key:
                continue
 
            # Deriveaza secret partajat
            shared_secret = self._derive_shared_secret(private_key, other_public_key)
 
            # Genereaza masca determinista din secretul partajat
            mask = self._generate_mask(shared_secret, vector_shape, client_id, other_id)
            masks[other_id] = mask
 
        return masks
 
    def _derive_shared_secret(self, private_key: bytes, public_key: bytes) -> bytes:
        """Deriveaza secret partajat folosind mecanism similar ECDH."""
        combined = private_key + public_key
        hkdf = HKDF(
            algorithm=hashes.SHA256(),
            length=32,
            salt=None,
            info=b'shared_secret'
        )
        return hkdf.derive(combined)
 
    def _generate_mask(
        self,
        seed: bytes,
        shape: Tuple,
        client_a: str,
        client_b: str
    ) -> np.ndarray:
        """Genereaza masca bazata pe PRG."""
        # Ordonare determinista
        if client_a < client_b:
            sign = 1
        else:
            sign = -1
 
        # Foloseste seed-ul pentru initializarea RNG
        seed_int = int.from_bytes(seed[:4], 'big')
        rng = np.random.RandomState(seed_int)
 
        mask = rng.standard_normal(shape) * sign
        return mask
 
    def mask_update(
        self,
        update: np.ndarray,
        masks: Dict[str, np.ndarray]
    ) -> np.ndarray:
        """Mascheaza actualizarea cu masti pereche."""
        masked = update.copy()
 
        for mask in masks.values():
            masked += mask
 
        return masked
 
    def aggregate_masked_updates(
        self,
        masked_updates: Dict[str, np.ndarray],
        dropout_clients: List[str] = None
    ) -> np.ndarray:
        """Agrega actualizarile mascate - mastile se anuleaza reciproc."""
        if dropout_clients is None:
            dropout_clients = []
 
        # Daca nu sunt dropout-uri, mastile se anuleaza perfect
        aggregated = np.zeros_like(list(masked_updates.values())[0])
 
        for client_id, update in masked_updates.items():
            if client_id not in dropout_clients:
                aggregated += update
 
        return aggregated / len(masked_updates)
 
    def shamir_secret_share(
        self,
        secret: np.ndarray,
        num_shares: int,
        threshold: int
    ) -> List[Tuple[int, np.ndarray]]:
        """Creeaza fragmente de secret Shamir."""
        # Genereaza coeficienti aleatori
        coefficients = [secret]
        for _ in range(threshold - 1):
            coefficients.append(np.random.randn(*secret.shape))
 
        shares = []
        for i in range(1, num_shares + 1):
            share = np.zeros_like(secret)
            for j, coef in enumerate(coefficients):
                share += coef * (i ** j)
            shares.append((i, share))
 
        return shares
 
    def reconstruct_secret(
        self,
        shares: List[Tuple[int, np.ndarray]],
        threshold: int
    ) -> np.ndarray:
        """Reconstruieste secretul din fragmentele Shamir."""
        if len(shares) < threshold:
            raise ValueError(f"Need at least {threshold} shares")
 
        shares = shares[:threshold]
        secret = np.zeros_like(shares[0][1])
 
        for i, (xi, yi) in enumerate(shares):
            # Polinom de baza Lagrange
            numerator = 1.0
            denominator = 1.0
 
            for j, (xj, _) in enumerate(shares):
                if i != j:
                    numerator *= (0 - xj)
                    denominator *= (xi - xj)
 
            secret += yi * (numerator / denominator)
 
        return secret

Apararea impotriva model poisoning

Implementeaza aparari impotriva atacurilor de model poisoning:

from dataclasses import dataclass
from typing import List, Dict, Tuple
import numpy as np
from scipy import stats
 
@dataclass
class ClientAnalysis:
    client_id: str
    is_malicious: bool
    anomaly_score: float
    detection_method: str
 
class ModelPoisoningDefense:
    def __init__(self, detection_threshold: float = 2.0):
        self.detection_threshold = detection_threshold
        self.client_history: Dict[str, List[np.ndarray]] = {}
 
    def analyze_updates(
        self,
        updates: List[ModelUpdate],
        global_weights: Dict[str, np.ndarray]
    ) -> List[ClientAnalysis]:
        """Analizeaza actualizarile pentru potential poisoning."""
        analyses = []
 
        for update in updates:
            # Multiple metode de detectie
            krum_score = self._krum_score(update, updates)
            cosine_score = self._cosine_similarity_score(update, global_weights)
            magnitude_score = self._magnitude_score(update, global_weights)
            history_score = self._history_consistency_score(update)
 
            # Scor combinat de anomalie
            anomaly_score = (
                0.3 * krum_score +
                0.3 * cosine_score +
                0.2 * magnitude_score +
                0.2 * history_score
            )
 
            is_malicious = anomaly_score > self.detection_threshold
            detection_method = self._get_primary_detector(
                krum_score, cosine_score, magnitude_score, history_score
            )
 
            analyses.append(ClientAnalysis(
                client_id=update.client_id,
                is_malicious=is_malicious,
                anomaly_score=anomaly_score,
                detection_method=detection_method
            ))
 
            # Actualizeaza istoricul
            update_vector = self._flatten_weights(update.weights)
            if update.client_id not in self.client_history:
                self.client_history[update.client_id] = []
            self.client_history[update.client_id].append(update_vector)
 
        return analyses
 
    def _krum_score(self, target: ModelUpdate, all_updates: List[ModelUpdate]) -> float:
        """Scor de anomalie Multi-Krum."""
        target_vec = self._flatten_weights(target.weights)
 
        distances = []
        for update in all_updates:
            if update.client_id == target.client_id:
                continue
 
            vec = self._flatten_weights(update.weights)
            dist = np.linalg.norm(target_vec - vec)
            distances.append(dist)
 
        if not distances:
            return 0.0
 
        # Scor bazat pe suma vecinilor cei mai apropiati
        distances.sort()
        n_neighbors = min(len(distances), len(all_updates) - 2)
        krum_score = sum(distances[:n_neighbors])
 
        # Normalizare
        median_score = np.median([
            sum(sorted([
                np.linalg.norm(self._flatten_weights(u1.weights) - self._flatten_weights(u2.weights))
                for u2 in all_updates if u1.client_id != u2.client_id
            ])[:n_neighbors])
            for u1 in all_updates
        ])
 
        return krum_score / (median_score + 1e-6)
 
    def _cosine_similarity_score(
        self,
        update: ModelUpdate,
        global_weights: Dict[str, np.ndarray]
    ) -> float:
        """Scor bazat pe similitudinea cosinus cu directia modelului global."""
        update_vec = self._flatten_weights(update.weights)
        global_vec = self._flatten_weights(global_weights)
 
        # Directia de la global la actualizare
        direction = update_vec - global_vec
 
        # Directia asteptata (simplificata - foloseste media istorica)
        norm = np.linalg.norm(direction)
        if norm < 1e-6:
            return 0.0
 
        # Similaritate scazuta cu directia asteptata = anomalie
        cos_sim = np.dot(direction, global_vec) / (norm * np.linalg.norm(global_vec) + 1e-6)
 
        # Similaritate foarte negativa este suspecta
        return max(0, -cos_sim)
 
    def _magnitude_score(
        self,
        update: ModelUpdate,
        global_weights: Dict[str, np.ndarray]
    ) -> float:
        """Scor bazat pe magnitudinea actualizarii."""
        update_vec = self._flatten_weights(update.weights)
        global_vec = self._flatten_weights(global_weights)
 
        diff = update_vec - global_vec
        magnitude = np.linalg.norm(diff)
 
        # Z-score bazat pe magnitudinea asteptata
        expected_magnitude = np.linalg.norm(global_vec) * 0.1  # Euristica
        std_magnitude = expected_magnitude * 0.5
 
        z_score = abs(magnitude - expected_magnitude) / (std_magnitude + 1e-6)
        return z_score
 
    def _history_consistency_score(self, update: ModelUpdate) -> float:
        """Scor bazat pe consistenta cu istoricul clientului."""
        history = self.client_history.get(update.client_id, [])
 
        if len(history) < 2:
            return 0.0
 
        current = self._flatten_weights(update.weights)
 
        # Verifica daca actualizarea curenta este consistenta cu istoricul
        historical_changes = []
        for i in range(1, len(history)):
            change = np.linalg.norm(history[i] - history[i-1])
            historical_changes.append(change)
 
        if not historical_changes:
            return 0.0
 
        current_change = np.linalg.norm(current - history[-1])
 
        mean_change = np.mean(historical_changes)
        std_change = np.std(historical_changes) + 1e-6
 
        z_score = abs(current_change - mean_change) / std_change
        return z_score
 
    def _flatten_weights(self, weights: Dict[str, np.ndarray]) -> np.ndarray:
        """Aplatizeaza dictionarul de ponderi intr-un singur vector."""
        return np.concatenate([w.flatten() for w in weights.values()])
 
    def _get_primary_detector(self, *scores) -> str:
        """Obtine detectorul care a contribuit cel mai mult la anomalie."""
        methods = ['krum', 'cosine', 'magnitude', 'history']
        max_idx = np.argmax(scores)
        return methods[max_idx]
 
    def robust_aggregation(
        self,
        updates: List[ModelUpdate],
        global_weights: Dict[str, np.ndarray],
        method: str = 'trimmed_mean'
    ) -> Dict[str, np.ndarray]:
        """Efectueaza agregare robusta rezistenta la poisoning."""
        if method == 'trimmed_mean':
            return self._trimmed_mean_aggregation(updates, trim_ratio=0.1)
        elif method == 'median':
            return self._median_aggregation(updates)
        elif method == 'krum':
            return self._krum_aggregation(updates)
        else:
            raise ValueError(f"Unknown method: {method}")
 
    def _trimmed_mean_aggregation(
        self,
        updates: List[ModelUpdate],
        trim_ratio: float
    ) -> Dict[str, np.ndarray]:
        """Agregare cu medie trunchiera."""
        n_trim = int(len(updates) * trim_ratio)
 
        aggregated = {}
        for layer_name in updates[0].weights.keys():
            layer_updates = np.array([u.weights[layer_name] for u in updates])
 
            # Sorteaza si truncheaza pe axa clientilor
            sorted_updates = np.sort(layer_updates, axis=0)
            if n_trim > 0:
                trimmed = sorted_updates[n_trim:-n_trim]
            else:
                trimmed = sorted_updates
 
            aggregated[layer_name] = np.mean(trimmed, axis=0)
 
        return aggregated
 
    def _median_aggregation(self, updates: List[ModelUpdate]) -> Dict[str, np.ndarray]:
        """Agregare cu mediana pe coordonate."""
        aggregated = {}
        for layer_name in updates[0].weights.keys():
            layer_updates = np.array([u.weights[layer_name] for u in updates])
            aggregated[layer_name] = np.median(layer_updates, axis=0)
 
        return aggregated
 
    def _krum_aggregation(self, updates: List[ModelUpdate]) -> Dict[str, np.ndarray]:
        """Selecteaza actualizarea cu scorul Krum minim."""
        min_score = float('inf')
        best_update = None
 
        for update in updates:
            score = self._krum_score(update, updates)
            if score < min_score:
                min_score = score
                best_update = update
 
        return best_update.weights if best_update else updates[0].weights

Concluzii

Federated learning securizat necesita mai multe straturi de protectie, incluzand differential privacy, agregare securizata si aparari impotriva model poisoning. Implementeaza verificarea clientilor si sisteme de reputatie pentru a urmari increderea. Foloseste differential privacy pentru a oferi garantii formale de confidentialitate. Implementeaza protocoale de agregare securizata pentru a preveni ca serverul sa vada actualizarile individuale. Adauga metode de agregare robusta pentru a te apara impotriva atacurilor de poisoning. Tine minte ca confidentialitatea si securitatea in federated learning implica compromisuri cu utilitatea modelului - calibreaza atent parametrii in functie de cerintele tale specifice.

Ai nevoie de ajutor cu conformitatea EU AI Act sau securitatea AI?

Programeaza o consultatie gratuita de 30 de minute. Fara obligatii.

Programeaza un Apel

Weekly AI Security & Automation Digest

Get the latest on AI Security, workflow automation, secure integrations, and custom platform development delivered weekly.

No spam. Unsubscribe anytime.