Module contrib.classification.plot_logits

Generates plots demonstrating SAE feature specificity in image classification.

This module creates visualizations showing how different feature interventions affect class logits in a controlled manner. It plots the relationship between intervention magnitudes and their effects on class predictions, demonstrating that features are semantically meaningful and independent.

The main plotting function generates a three-panel figure: - Left panel: Effect of feature A intervention on classes A, B and C - Middle panel: Effect of feature B intervention on classes A, B and C - Right panel: Effect of feature C intervention on classes A, B and C

Classes

class Config (magnitude_range: tuple[float, float] = (-10.0, 10.0),
n_points: int = 50,
figsize: tuple[float, float] = (18.0, 5.0),
class_colors: dict[str, str] = <factory>,
show_confidence: bool = True,
dpi: int = 300)

Config(magnitude_range: tuple[float, float] = (-10.0, 10.0), n_points: int = 50, figsize: tuple[float, float] = (18.0, 5.0), class_colors: dict[str, str] = , show_confidence: bool = True, dpi: int = 300)

Expand source code
@dataclasses.dataclass(frozen=True)
class Config:
    magnitude_range: tuple[float, float] = (-10.0, 10.0)
    """Range for intervention magnitudes, from minimum to maximum value. Usually kept within [-10, 10] following Anthropic's work. Values outside this range may create artifacts."""

    n_points: int = 50
    """Number of evenly spaced points to sample within magnitude range. Higher values create smoother plots but increase computation time."""

    figsize: tuple[float, float] = (18.0, 5.0)
    """Figure dimensions in inches (width, height). Default size is optimized for a three-panel figure."""

    class_colors: dict[str, str] = dataclasses.field(
        default_factory=lambda: {
            "class_a": "#FF0000",  # Red
            "class_b": "#4169E1",  # Royal Blue
        }
    )
    """Color mapping for different classes. Uses hex color codes for consistency across plotting backends."""
    show_confidence: bool = True
    """Whether to show confidence intervals around trend lines."""
    dpi: int = 300
    """Dots per inch for saved figures. Higher values create larger files but better resolution."""

Class variables

var class_colors : dict[str, str]

Color mapping for different classes. Uses hex color codes for consistency across plotting backends.

var dpi : int

Dots per inch for saved figures. Higher values create larger files but better resolution.

var figsize : tuple[float, float]

Figure dimensions in inches (width, height). Default size is optimized for a three-panel figure.

var magnitude_range : tuple[float, float]

Range for intervention magnitudes, from minimum to maximum value. Usually kept within [-10, 10] following Anthropic's work. Values outside this range may create artifacts.

var n_points : int

Number of evenly spaced points to sample within magnitude range. Higher values create smoother plots but increase computation time.

var show_confidence : bool

Whether to show confidence intervals around trend lines.