AI Model Explainability for Security: Understanding Black Box Decisions
As AI systems increasingly make critical decisions in security contexts - from threat detection to access control - the ability to explain these decisions becomes essential for trust, compliance, and debugging. This guide covers practical techniques for making AI models interpretable and auditable.
Why Explainability Matters for Security
In security applications, explainability serves multiple critical purposes:
- Compliance: Regulations like GDPR Article 22 require explanations for automated decisions
- Debugging: Understanding why a model flagged (or missed) a threat
- Trust: Security teams need confidence in AI recommendations
- Adversarial Defense: Identifying if models are being manipulated
- Audit Trail: Documenting decision rationale for investigations
SHAP (SHapley Additive exPlanations)
Implementing SHAP for Security Models
# shap_security_explanations.py
"""
SHAP-based explanations for security ML models.
"""
import shap
import numpy as np
import pandas as pd
from typing import Dict, List, Any, Optional
from dataclasses import dataclass
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
import json
@dataclass
class FeatureExplanation:
"""Explanation for a single feature's contribution."""
feature_name: str
feature_value: Any
shap_value: float
contribution_direction: str # 'increases_risk' or 'decreases_risk'
relative_importance: float
@dataclass
class PredictionExplanation:
"""Complete explanation for a prediction."""
prediction: str
confidence: float
base_value: float
feature_explanations: List[FeatureExplanation]
top_contributors: List[str]
summary: str
class SecurityModelExplainer:
"""
Generate human-readable explanations for security model predictions.
"""
def __init__(self, model, feature_names: List[str], class_names: List[str]):
self.model = model
self.feature_names = feature_names
self.class_names = class_names
self.explainer = None
self.background_data = None
def initialize_explainer(
self,
background_data: np.ndarray,
explainer_type: str = "tree"
):
"""
Initialize SHAP explainer with background data.
Args:
background_data: Representative sample for computing expectations
explainer_type: 'tree', 'kernel', 'deep', or 'linear'
"""
self.background_data = background_data
if explainer_type == "tree":
self.explainer = shap.TreeExplainer(self.model)
elif explainer_type == "kernel":
self.explainer = shap.KernelExplainer(
self.model.predict_proba,
shap.sample(background_data, 100)
)
elif explainer_type == "deep":
self.explainer = shap.DeepExplainer(self.model, background_data)
elif explainer_type == "linear":
self.explainer = shap.LinearExplainer(self.model, background_data)
else:
raise ValueError(f"Unknown explainer type: {explainer_type}")
def explain_prediction(
self,
instance: np.ndarray,
top_k: int = 5
) -> PredictionExplanation:
"""
Generate detailed explanation for a single prediction.
Args:
instance: Single instance to explain (1D array)
top_k: Number of top contributing features to highlight
"""
if self.explainer is None:
raise RuntimeError("Explainer not initialized. Call initialize_explainer first.")
# Get prediction
if hasattr(self.model, 'predict_proba'):
proba = self.model.predict_proba(instance.reshape(1, -1))[0]
predicted_class = np.argmax(proba)
confidence = proba[predicted_class]
else:
prediction = self.model.predict(instance.reshape(1, -1))[0]
predicted_class = int(prediction)
confidence = 1.0
# Get SHAP values
shap_values = self.explainer.shap_values(instance.reshape(1, -1))
# Handle multi-class output
if isinstance(shap_values, list):
# For multi-class, get SHAP values for predicted class
shap_vals = shap_values[predicted_class][0]
else:
shap_vals = shap_values[0]
# Get base value
if hasattr(self.explainer, 'expected_value'):
if isinstance(self.explainer.expected_value, np.ndarray):
base_value = self.explainer.expected_value[predicted_class]
else:
base_value = self.explainer.expected_value
else:
base_value = 0.5
# Create feature explanations
feature_explanations = []
abs_shap = np.abs(shap_vals)
max_abs = np.max(abs_shap) if np.max(abs_shap) > 0 else 1
for i, (name, value, shap_val) in enumerate(
zip(self.feature_names, instance, shap_vals)
):
feature_explanations.append(FeatureExplanation(
feature_name=name,
feature_value=value,
shap_value=float(shap_val),
contribution_direction='increases_risk' if shap_val > 0 else 'decreases_risk',
relative_importance=float(abs_shap[i] / max_abs)
))
# Sort by absolute SHAP value
feature_explanations.sort(key=lambda x: abs(x.shap_value), reverse=True)
# Get top contributors
top_contributors = [
f.feature_name for f in feature_explanations[:top_k]
]
# Generate summary
summary = self._generate_summary(
predicted_class,
confidence,
feature_explanations[:top_k]
)
return PredictionExplanation(
prediction=self.class_names[predicted_class],
confidence=float(confidence),
base_value=float(base_value),
feature_explanations=feature_explanations,
top_contributors=top_contributors,
summary=summary
)
def _generate_summary(
self,
predicted_class: int,
confidence: float,
top_features: List[FeatureExplanation]
) -> str:
"""Generate human-readable summary of the prediction."""
class_name = self.class_names[predicted_class]
# Build explanation text
if class_name.lower() in ['malicious', 'threat', 'anomaly', 'attack']:
summary = f"This instance was classified as {class_name} "
summary += f"with {confidence:.1%} confidence. "
else:
summary = f"This instance was classified as {class_name} "
summary += f"({confidence:.1%} confidence). "
if top_features:
increasing = [f for f in top_features if f.contribution_direction == 'increases_risk']
decreasing = [f for f in top_features if f.contribution_direction == 'decreases_risk']
if increasing:
factors = [f"'{f.feature_name}' = {f.feature_value}" for f in increasing[:3]]
summary += f"Key risk factors: {', '.join(factors)}. "
if decreasing:
factors = [f"'{f.feature_name}' = {f.feature_value}" for f in decreasing[:2]]
summary += f"Mitigating factors: {', '.join(factors)}."
return summary
def explain_batch(
self,
instances: np.ndarray,
output_format: str = "dataframe"
) -> Any:
"""
Explain multiple predictions.
Args:
instances: Multiple instances to explain
output_format: 'dataframe', 'json', or 'explanations'
"""
explanations = []
for instance in instances:
explanations.append(self.explain_prediction(instance))
if output_format == "explanations":
return explanations
elif output_format == "dataframe":
rows = []
for i, exp in enumerate(explanations):
row = {
'instance_id': i,
'prediction': exp.prediction,
'confidence': exp.confidence,
'summary': exp.summary
}
for j, contributor in enumerate(exp.top_contributors):
row[f'top_factor_{j+1}'] = contributor
rows.append(row)
return pd.DataFrame(rows)
elif output_format == "json":
return [self._explanation_to_dict(exp) for exp in explanations]
def _explanation_to_dict(self, exp: PredictionExplanation) -> dict:
"""Convert explanation to dictionary for JSON serialization."""
return {
'prediction': exp.prediction,
'confidence': exp.confidence,
'base_value': exp.base_value,
'summary': exp.summary,
'top_contributors': exp.top_contributors,
'feature_explanations': [
{
'feature': f.feature_name,
'value': f.feature_value if not isinstance(f.feature_value, np.floating) else float(f.feature_value),
'shap_value': f.shap_value,
'direction': f.contribution_direction,
'importance': f.relative_importance
}
for f in exp.feature_explanations
]
}
def generate_audit_report(
self,
instance: np.ndarray,
instance_id: str,
analyst_notes: Optional[str] = None
) -> dict:
"""
Generate a formal audit report for compliance purposes.
"""
explanation = self.explain_prediction(instance)
report = {
'report_type': 'AI_DECISION_EXPLANATION',
'instance_id': instance_id,
'generated_at': pd.Timestamp.now().isoformat(),
'model_info': {
'model_type': type(self.model).__name__,
'num_features': len(self.feature_names),
'num_classes': len(self.class_names)
},
'decision': {
'classification': explanation.prediction,
'confidence_score': explanation.confidence,
'base_probability': explanation.base_value
},
'explanation': {
'summary': explanation.summary,
'top_contributing_factors': [
{
'rank': i + 1,
'feature': f.feature_name,
'observed_value': f.feature_value if not isinstance(f.feature_value, np.floating) else float(f.feature_value),
'contribution': f.shap_value,
'direction': f.contribution_direction
}
for i, f in enumerate(explanation.feature_explanations[:10])
]
},
'analyst_notes': analyst_notes,
'compliance_statement': (
"This explanation was generated using SHAP (SHapley Additive exPlanations) "
"to provide transparency into the AI model's decision-making process, "
"in accordance with GDPR Article 22 right to explanation requirements."
)
}
return report
class ThreatDetectionExplainer(SecurityModelExplainer):
"""
Specialized explainer for threat detection models.
"""
def __init__(self, model, feature_names: List[str]):
super().__init__(
model,
feature_names,
class_names=['Benign', 'Malicious']
)
# Define feature categories for better explanations
self.feature_categories = {
'network': ['src_port', 'dst_port', 'protocol', 'packet_size', 'flow_duration'],
'behavioral': ['request_rate', 'unique_destinations', 'failed_attempts'],
'content': ['payload_entropy', 'suspicious_patterns', 'encoding_type'],
'temporal': ['hour_of_day', 'day_of_week', 'time_since_last']
}
def explain_threat(
self,
instance: np.ndarray,
include_recommendations: bool = True
) -> dict:
"""
Generate threat-specific explanation with recommendations.
"""
explanation = self.explain_prediction(instance)
result = {
'threat_assessment': {
'is_threat': explanation.prediction == 'Malicious',
'confidence': explanation.confidence,
'severity': self._calculate_severity(explanation)
},
'explanation': explanation.summary,
'contributing_factors': []
}
# Categorize contributing factors
for feat_exp in explanation.feature_explanations[:10]:
category = self._get_feature_category(feat_exp.feature_name)
result['contributing_factors'].append({
'factor': feat_exp.feature_name,
'category': category,
'value': feat_exp.feature_value,
'impact': 'High' if feat_exp.relative_importance > 0.7 else
'Medium' if feat_exp.relative_importance > 0.3 else 'Low',
'direction': feat_exp.contribution_direction
})
if include_recommendations and explanation.prediction == 'Malicious':
result['recommendations'] = self._generate_recommendations(
explanation.feature_explanations[:5]
)
return result
def _calculate_severity(self, explanation: PredictionExplanation) -> str:
"""Calculate threat severity based on confidence and factors."""
if explanation.prediction == 'Benign':
return 'None'
if explanation.confidence > 0.9:
return 'Critical'
elif explanation.confidence > 0.7:
return 'High'
elif explanation.confidence > 0.5:
return 'Medium'
else:
return 'Low'
def _get_feature_category(self, feature_name: str) -> str:
"""Get category for a feature."""
for category, features in self.feature_categories.items():
if any(f in feature_name.lower() for f in features):
return category
return 'other'
def _generate_recommendations(
self,
top_factors: List[FeatureExplanation]
) -> List[str]:
"""Generate security recommendations based on threat factors."""
recommendations = []
for factor in top_factors:
if factor.contribution_direction != 'increases_risk':
continue
name = factor.feature_name.lower()
if 'port' in name:
recommendations.append(
f"Review firewall rules for port {factor.feature_value}"
)
elif 'rate' in name or 'frequency' in name:
recommendations.append(
"Consider implementing rate limiting"
)
elif 'entropy' in name:
recommendations.append(
"Inspect payload for potential obfuscation or encryption"
)
elif 'failed' in name:
recommendations.append(
"Review authentication logs for brute force attempts"
)
elif 'destination' in name:
recommendations.append(
"Investigate communication with unusual destinations"
)
return recommendations[:5] # Limit to top 5 recommendationsLIME (Local Interpretable Model-agnostic Explanations)
# lime_security_explanations.py
"""
LIME-based explanations for security models.
"""
import lime
import lime.lime_tabular
import lime.lime_text
import numpy as np
from typing import List, Dict, Callable, Optional
import re
class SecurityTextExplainer:
"""
LIME-based explanations for text-based security models
(e.g., phishing detection, malware classification).
"""
def __init__(
self,
classifier_fn: Callable,
class_names: List[str]
):
"""
Args:
classifier_fn: Function that takes list of texts and returns probabilities
class_names: Names of classification classes
"""
self.classifier_fn = classifier_fn
self.class_names = class_names
self.explainer = lime.lime_text.LimeTextExplainer(
class_names=class_names,
split_expression=r'\W+',
bow=True
)
def explain_text(
self,
text: str,
num_features: int = 10,
num_samples: int = 5000
) -> dict:
"""
Explain a text classification prediction.
Args:
text: Text to explain
num_features: Number of features in explanation
num_samples: Number of perturbed samples to generate
"""
explanation = self.explainer.explain_instance(
text,
self.classifier_fn,
num_features=num_features,
num_samples=num_samples
)
# Get predicted class
proba = self.classifier_fn([text])[0]
predicted_class = np.argmax(proba)
# Extract feature weights
feature_weights = explanation.as_list(label=predicted_class)
# Highlight important words in text
highlighted_text = self._highlight_text(text, feature_weights)
return {
'text': text,
'prediction': self.class_names[predicted_class],
'confidence': float(proba[predicted_class]),
'explanation': {
'important_words': [
{
'word': word,
'weight': float(weight),
'direction': 'suspicious' if weight > 0 else 'benign'
}
for word, weight in feature_weights
],
'highlighted_text': highlighted_text
},
'summary': self._generate_text_summary(
self.class_names[predicted_class],
feature_weights
)
}
def _highlight_text(
self,
text: str,
feature_weights: List[tuple]
) -> str:
"""Add HTML highlighting to important words."""
highlighted = text
for word, weight in sorted(feature_weights, key=lambda x: -abs(x[1])):
if weight > 0:
color = 'red'
else:
color = 'green'
pattern = re.compile(re.escape(word), re.IGNORECASE)
highlighted = pattern.sub(
f'<span style="background-color: {color}; padding: 2px;">{word}</span>',
highlighted
)
return highlighted
def _generate_text_summary(
self,
prediction: str,
feature_weights: List[tuple]
) -> str:
"""Generate natural language summary."""
suspicious_words = [w for w, weight in feature_weights if weight > 0][:3]
benign_words = [w for w, weight in feature_weights if weight < 0][:3]
summary = f"The text was classified as '{prediction}'. "
if suspicious_words:
summary += f"Suspicious indicators: {', '.join(suspicious_words)}. "
if benign_words:
summary += f"Benign indicators: {', '.join(benign_words)}."
return summary
class SecurityTabularExplainer:
"""
LIME-based explanations for tabular security data.
"""
def __init__(
self,
training_data: np.ndarray,
feature_names: List[str],
class_names: List[str],
categorical_features: Optional[List[int]] = None
):
self.feature_names = feature_names
self.class_names = class_names
self.explainer = lime.lime_tabular.LimeTabularExplainer(
training_data,
feature_names=feature_names,
class_names=class_names,
categorical_features=categorical_features or [],
mode='classification'
)
def explain_instance(
self,
instance: np.ndarray,
predict_fn: Callable,
num_features: int = 10
) -> dict:
"""
Explain a single prediction.
"""
explanation = self.explainer.explain_instance(
instance,
predict_fn,
num_features=num_features
)
# Get prediction
proba = predict_fn(instance.reshape(1, -1))[0]
predicted_class = np.argmax(proba)
# Extract feature contributions
feature_weights = explanation.as_list(label=predicted_class)
return {
'prediction': self.class_names[predicted_class],
'confidence': float(proba[predicted_class]),
'local_prediction': explanation.local_pred[predicted_class],
'feature_contributions': [
{
'feature': feature,
'condition': feature,
'weight': float(weight),
'direction': 'increases_risk' if weight > 0 else 'decreases_risk'
}
for feature, weight in feature_weights
],
'intercept': float(explanation.intercept[predicted_class])
}Attention Visualization for Transformers
# attention_visualization.py
"""
Attention visualization for transformer-based security models.
"""
import torch
import numpy as np
from typing import List, Dict, Tuple, Optional
from transformers import AutoTokenizer, AutoModel
import matplotlib.pyplot as plt
import seaborn as sns
class AttentionExplainer:
"""
Visualize and explain attention patterns in transformer models.
"""
def __init__(
self,
model_name: str,
device: str = 'cpu'
):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(
model_name,
output_attentions=True
).to(device)
self.device = device
def get_attention_weights(
self,
text: str
) -> Tuple[List[str], np.ndarray]:
"""
Extract attention weights for input text.
Returns:
tokens: List of tokens
attention: Attention weights [num_layers, num_heads, seq_len, seq_len]
"""
inputs = self.tokenizer(
text,
return_tensors='pt',
truncation=True,
max_length=512
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
# Get attention weights
attentions = outputs.attentions # tuple of tensors
# Stack all layers
attention_tensor = torch.stack(attentions) # [layers, batch, heads, seq, seq]
attention_np = attention_tensor.squeeze(1).cpu().numpy()
# Get tokens
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
return tokens, attention_np
def explain_with_attention(
self,
text: str,
layer: int = -1,
head: Optional[int] = None
) -> dict:
"""
Generate explanation based on attention patterns.
Args:
text: Input text
layer: Which layer to analyze (-1 for last)
head: Which attention head (None for average)
"""
tokens, attention = self.get_attention_weights(text)
# Select layer
layer_attention = attention[layer] # [heads, seq, seq]
# Average or select head
if head is None:
token_attention = layer_attention.mean(axis=0) # Average heads
else:
token_attention = layer_attention[head]
# Get attention to CLS token (for classification)
cls_attention = token_attention[0, 1:-1] # Skip CLS and SEP
relevant_tokens = tokens[1:-1]
# Normalize
if cls_attention.sum() > 0:
cls_attention = cls_attention / cls_attention.sum()
# Find most attended tokens
top_indices = np.argsort(cls_attention)[::-1][:10]
return {
'tokens': relevant_tokens,
'attention_weights': cls_attention.tolist(),
'top_attended_tokens': [
{
'token': relevant_tokens[i],
'attention': float(cls_attention[i]),
'position': int(i)
}
for i in top_indices
],
'attention_matrix': token_attention.tolist()
}
def visualize_attention(
self,
text: str,
layer: int = -1,
head: Optional[int] = None,
save_path: Optional[str] = None
):
"""
Create attention heatmap visualization.
"""
tokens, attention = self.get_attention_weights(text)
if head is None:
attn_matrix = attention[layer].mean(axis=0)
title = f"Averaged Attention (Layer {layer})"
else:
attn_matrix = attention[layer, head]
title = f"Attention Head {head} (Layer {layer})"
# Truncate for visualization
max_tokens = 30
if len(tokens) > max_tokens:
tokens = tokens[:max_tokens]
attn_matrix = attn_matrix[:max_tokens, :max_tokens]
# Create heatmap
fig, ax = plt.subplots(figsize=(12, 10))
sns.heatmap(
attn_matrix,
xticklabels=tokens,
yticklabels=tokens,
cmap='YlOrRd',
ax=ax
)
ax.set_title(title)
ax.set_xlabel('Key Tokens')
ax.set_ylabel('Query Tokens')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
else:
plt.show()
def get_token_importance(
self,
text: str,
method: str = 'attention_rollout'
) -> Dict[str, float]:
"""
Calculate token importance using attention-based methods.
Args:
method: 'attention_rollout', 'attention_flow', or 'last_layer'
"""
tokens, attention = self.get_attention_weights(text)
if method == 'last_layer':
# Simple: use last layer attention from CLS
importance = attention[-1].mean(axis=0)[0]
elif method == 'attention_rollout':
# Attention rollout: multiply attention matrices
rollout = np.eye(attention.shape[-1])
for layer_attn in attention:
# Average heads
layer_avg = layer_attn.mean(axis=0)
# Add residual connection
layer_avg = 0.5 * layer_avg + 0.5 * np.eye(layer_avg.shape[0])
# Normalize
layer_avg = layer_avg / layer_avg.sum(axis=-1, keepdims=True)
# Multiply
rollout = rollout @ layer_avg
importance = rollout[0]
elif method == 'attention_flow':
# Simplified attention flow
importance = np.zeros(attention.shape[-1])
for layer_attn in attention:
importance += layer_attn.mean(axis=0)[0]
importance = importance / len(attention)
else:
raise ValueError(f"Unknown method: {method}")
# Create token-importance mapping
return {
token: float(imp)
for token, imp in zip(tokens, importance)
}Counterfactual Explanations
# counterfactual_explanations.py
"""
Counterfactual explanations for security decisions.
"""
import numpy as np
from typing import List, Dict, Optional, Callable, Tuple
from dataclasses import dataclass
from scipy.optimize import minimize
import copy
@dataclass
class Counterfactual:
"""A counterfactual explanation."""
original_instance: np.ndarray
counterfactual_instance: np.ndarray
original_prediction: str
counterfactual_prediction: str
changes: List[Dict]
distance: float
validity: bool
class CounterfactualExplainer:
"""
Generate counterfactual explanations showing minimal changes
to flip a prediction.
"""
def __init__(
self,
predict_fn: Callable,
feature_names: List[str],
feature_ranges: Dict[str, Tuple[float, float]],
categorical_features: Optional[List[str]] = None,
immutable_features: Optional[List[str]] = None
):
"""
Args:
predict_fn: Function returning predicted class
feature_names: Names of features
feature_ranges: Valid ranges for each feature
categorical_features: Features that are categorical
immutable_features: Features that cannot be changed
"""
self.predict_fn = predict_fn
self.feature_names = feature_names
self.feature_ranges = feature_ranges
self.categorical_features = set(categorical_features or [])
self.immutable_features = set(immutable_features or [])
# Feature indices
self.feature_idx = {name: i for i, name in enumerate(feature_names)}
def generate_counterfactual(
self,
instance: np.ndarray,
target_class: int,
max_changes: int = 5,
distance_weight: float = 0.1
) -> Optional[Counterfactual]:
"""
Find minimal changes to achieve target prediction.
"""
original_pred = self.predict_fn(instance.reshape(1, -1))[0]
# If already target class, no counterfactual needed
if original_pred == target_class:
return None
# Optimization to find counterfactual
def objective(x):
# Get prediction probability for target class
proba = self._get_proba(x, target_class)
# Distance from original
distance = self._calculate_distance(instance, x)
# Penalty for too many changes
num_changes = np.sum(np.abs(x - instance) > 1e-6)
change_penalty = max(0, num_changes - max_changes) * 10
# Combine objectives
return -proba + distance_weight * distance + change_penalty
# Generate multiple starting points
best_cf = None
best_score = float('inf')
for _ in range(10): # Multiple random restarts
# Start with perturbed version of original
x0 = instance + np.random.randn(len(instance)) * 0.1
# Clip to valid ranges
x0 = self._clip_to_ranges(x0)
# Optimize
result = minimize(
objective,
x0,
method='L-BFGS-B',
bounds=self._get_bounds()
)
# Check if valid counterfactual
cf_pred = self.predict_fn(result.x.reshape(1, -1))[0]
if cf_pred == target_class and result.fun < best_score:
best_cf = result.x
best_score = result.fun
if best_cf is None:
return None
# Create counterfactual object
changes = self._identify_changes(instance, best_cf)
return Counterfactual(
original_instance=instance.copy(),
counterfactual_instance=best_cf,
original_prediction=str(original_pred),
counterfactual_prediction=str(target_class),
changes=changes,
distance=self._calculate_distance(instance, best_cf),
validity=True
)
def _get_proba(self, x: np.ndarray, target_class: int) -> float:
"""Get prediction probability for target class."""
try:
proba = self.predict_fn(x.reshape(1, -1))
if hasattr(proba, '__len__') and len(proba.shape) > 1:
return proba[0, target_class]
return float(proba[0] == target_class)
except:
return 0.0
def _calculate_distance(
self,
x1: np.ndarray,
x2: np.ndarray
) -> float:
"""Calculate normalized distance between instances."""
distance = 0.0
for i, name in enumerate(self.feature_names):
if name in self.immutable_features:
continue
min_val, max_val = self.feature_ranges.get(name, (0, 1))
range_size = max_val - min_val
if range_size > 0:
# Normalized absolute difference
distance += abs(x1[i] - x2[i]) / range_size
else:
distance += abs(x1[i] - x2[i])
return distance / len(self.feature_names)
def _clip_to_ranges(self, x: np.ndarray) -> np.ndarray:
"""Clip values to valid ranges."""
clipped = x.copy()
for i, name in enumerate(self.feature_names):
min_val, max_val = self.feature_ranges.get(name, (-np.inf, np.inf))
clipped[i] = np.clip(clipped[i], min_val, max_val)
return clipped
def _get_bounds(self) -> List[Tuple[float, float]]:
"""Get optimization bounds for each feature."""
bounds = []
for name in self.feature_names:
if name in self.immutable_features:
# Very tight bounds for immutable features
bounds.append((None, None))
else:
bounds.append(self.feature_ranges.get(name, (None, None)))
return bounds
def _identify_changes(
self,
original: np.ndarray,
counterfactual: np.ndarray,
threshold: float = 1e-6
) -> List[Dict]:
"""Identify which features changed."""
changes = []
for i, name in enumerate(self.feature_names):
diff = abs(original[i] - counterfactual[i])
if diff > threshold:
changes.append({
'feature': name,
'original_value': float(original[i]),
'counterfactual_value': float(counterfactual[i]),
'change': float(counterfactual[i] - original[i]),
'change_percentage': float(diff / abs(original[i])) * 100 if original[i] != 0 else float('inf')
})
# Sort by magnitude of change
changes.sort(key=lambda x: abs(x['change']), reverse=True)
return changes
def generate_explanation(
self,
instance: np.ndarray,
target_class: int = 0
) -> str:
"""Generate natural language counterfactual explanation."""
cf = self.generate_counterfactual(instance, target_class)
if cf is None:
return "No counterfactual found - the instance may already be classified as the target class."
explanation = f"To change the prediction from '{cf.original_prediction}' to '{cf.counterfactual_prediction}', "
if len(cf.changes) == 0:
explanation += "no changes are needed."
elif len(cf.changes) == 1:
change = cf.changes[0]
explanation += f"change '{change['feature']}' from {change['original_value']:.2f} to {change['counterfactual_value']:.2f}."
else:
explanation += "make the following changes: "
change_strs = []
for change in cf.changes[:5]: # Limit to top 5
change_strs.append(
f"'{change['feature']}' from {change['original_value']:.2f} to {change['counterfactual_value']:.2f}"
)
explanation += "; ".join(change_strs) + "."
return explanationConclusion
AI explainability is essential for security applications where understanding model decisions can mean the difference between catching a threat and missing an attack. Key techniques covered include:
- SHAP for feature-level importance with game-theoretic foundations
- LIME for local, model-agnostic explanations
- Attention visualization for transformer-based models
- Counterfactual explanations showing what changes would alter predictions
These techniques enable security teams to trust, debug, and comply with regulations when deploying AI systems in critical security contexts.