Module contrib.classification.plot_logits
Generates plots demonstrating SAE feature specificity in image classification.
This module creates visualizations showing how different feature interventions affect class logits in a controlled manner. It plots the relationship between intervention magnitudes and their effects on class predictions, demonstrating that features are semantically meaningful and independent.
The main plotting function generates a three-panel figure: - Left panel: Effect of feature A intervention on classes A, B and C - Middle panel: Effect of feature B intervention on classes A, B and C - Right panel: Effect of feature C intervention on classes A, B and C
Classes
class Config (magnitude_range: tuple[float, float] = (-10.0, 10.0),
n_points: int = 50,
figsize: tuple[float, float] = (18.0, 5.0),
class_colors: dict[str, str] = <factory>,
show_confidence: bool = True,
dpi: int = 300)-
Config(magnitude_range: tuple[float, float] = (-10.0, 10.0), n_points: int = 50, figsize: tuple[float, float] = (18.0, 5.0), class_colors: dict[str, str] =
, show_confidence: bool = True, dpi: int = 300) Expand source code
@dataclasses.dataclass(frozen=True) class Config: magnitude_range: tuple[float, float] = (-10.0, 10.0) """Range for intervention magnitudes, from minimum to maximum value. Usually kept within [-10, 10] following Anthropic's work. Values outside this range may create artifacts.""" n_points: int = 50 """Number of evenly spaced points to sample within magnitude range. Higher values create smoother plots but increase computation time.""" figsize: tuple[float, float] = (18.0, 5.0) """Figure dimensions in inches (width, height). Default size is optimized for a three-panel figure.""" class_colors: dict[str, str] = dataclasses.field( default_factory=lambda: { "class_a": "#FF0000", # Red "class_b": "#4169E1", # Royal Blue } ) """Color mapping for different classes. Uses hex color codes for consistency across plotting backends.""" show_confidence: bool = True """Whether to show confidence intervals around trend lines.""" dpi: int = 300 """Dots per inch for saved figures. Higher values create larger files but better resolution."""
Class variables
var class_colors : dict[str, str]
-
Color mapping for different classes. Uses hex color codes for consistency across plotting backends.
var dpi : int
-
Dots per inch for saved figures. Higher values create larger files but better resolution.
var figsize : tuple[float, float]
-
Figure dimensions in inches (width, height). Default size is optimized for a three-panel figure.
var magnitude_range : tuple[float, float]
-
Range for intervention magnitudes, from minimum to maximum value. Usually kept within [-10, 10] following Anthropic's work. Values outside this range may create artifacts.
var n_points : int
-
Number of evenly spaced points to sample within magnitude range. Higher values create smoother plots but increase computation time.
var show_confidence : bool
-
Whether to show confidence intervals around trend lines.