Interpretable Bean Leaf Disease Screening with Integrated GradientsΒΆ

This project builds a lightweight computer-vision pipeline for screening bean leaf images and then audits the model with Integrated Gradients: it introduces the practical problem, trains a compact transfer-learning classifier, evaluates held-out performance, and checks whether the resulting saliency maps are stable and model-dependent.

Project goal. Classify bean leaf images into healthy, angular leaf spot, or bean rust, then explain one model prediction with an attribution map that highlights image regions most influential to the predicted class.

Why this matters. Leaf-disease classifiers can be useful as screening tools, but a high-level accuracy number is not enough. A model may appear accurate while relying on shortcuts such as image background, lighting, or camera artifacts. The interpretation workflow here asks whether the explanation focuses on leaf-relevant visual patterns and whether the attribution map changes when model parameters are disrupted.

Methods used. PyTorch, torchvision ResNet-18 transfer learning, HuggingFace beans dataset, Captum Integrated Gradients, confusion-matrix evaluation, attribution sensitivity analysis, and model-randomization sanity checks.

InΒ [1]:
%%capture
%pip install -q datasets captum scikit-learn scipy
InΒ [2]:
import contextlib
import copy
import json
import logging
import os
import random
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from IPython.display import display
from scipy.stats import spearmanr
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms

from captum.attr import IntegratedGradients
from datasets import load_dataset

warnings.filterwarnings("ignore")
logging.getLogger("huggingface_hub.utils._http").setLevel(logging.ERROR)

@contextlib.contextmanager
def suppress_output():
    with open(os.devnull, "w") as devnull:
        with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
            yield

IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

Reproducible Project SetupΒΆ

The constants below centralize the choices that control the experiment: random seed, image size, number of epochs, learning rate, and output folders. Keeping these settings in one place makes the notebook easier to rerun, adapt, and publish as a reproducible portfolio project.

InΒ [3]:
PROJECT_NAME = "bean_leaf_disease_integrated_gradients"
SEED = 479
IMAGE_SIZE = 128
N_EPOCHS = 3
LEARNING_RATE = 1e-3

OUTPUT_DIR = Path("portfolio_outputs")
FIGURE_DIR = OUTPUT_DIR / "figures"
TABLE_DIR = OUTPUT_DIR / "tables"
for directory in [OUTPUT_DIR, FIGURE_DIR, TABLE_DIR]:
    directory.mkdir(parents=True, exist_ok=True)

plt.rcParams.update({
    "figure.dpi": 120,
    "savefig.dpi": 180,
    "axes.spines.top": False,
    "axes.spines.right": False,
})

def sanitize_filename(text):
    safe = "".join(ch.lower() if ch.isalnum() else "_" for ch in str(text)).strip("_")
    return "_".join(part for part in safe.split("_") if part)[:80]

def save_table(df, filename):
    path = TABLE_DIR / filename
    df.to_csv(path, index=False)
    return path

def save_current_figure(filename):
    path = FIGURE_DIR / filename
    plt.gcf().savefig(path, bbox_inches="tight")
    return path

print(f"Project assets will be saved under: {OUTPUT_DIR.resolve()}")
Project assets will be saved under: /content/portfolio_outputs
InΒ [4]:
def set_seed(seed=479):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_beans_dataset():
    try:
        raw_dataset = load_dataset("beans")
    except Exception as err:
        raise RuntimeError(
            "Could not load the HuggingFace 'beans' dataset. "
            "Run this notebook with internet access the first time, or make sure the dataset is cached locally."
        ) from err
    label_column = "labels" if "labels" in raw_dataset["train"].column_names else "label"
    raw_label_names = raw_dataset["train"].features[label_column].names
    label_names = [name.replace("_", " ") for name in raw_label_names]
    return raw_dataset, label_column, label_names

def make_beans_transforms(image_size=128):
    train_transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=10),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN.flatten().tolist(), std=IMAGENET_STD.flatten().tolist()),
    ])
    eval_transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN.flatten().tolist(), std=IMAGENET_STD.flatten().tolist()),
    ])
    return train_transform, eval_transform

class BeansTorchDataset(Dataset):
    def __init__(self, hf_split, label_column, transform=None):
        self.hf_split = hf_split
        self.label_column = label_column
        self.transform = transform

    def __len__(self):
        return len(self.hf_split)

    def __getitem__(self, idx):
        item = self.hf_split[int(idx)]
        image = item["image"].convert("RGB")
        label = int(item[self.label_column])
        if self.transform is not None:
            image = self.transform(image)
        return image, label

def summarize_split_counts(raw_dataset, label_column, label_names):
    rows = []
    for split_name, split_data in raw_dataset.items():
        labels = split_data[label_column]
        counts = pd.Series(labels).value_counts().sort_index()
        for label_id, label_name in enumerate(label_names):
            rows.append({"split": split_name, "class": label_name, "count": int(counts.get(label_id, 0))})
    return pd.DataFrame(rows)

def unnormalize_image(image_tensor):
    image = image_tensor.detach().cpu() * IMAGENET_STD + IMAGENET_MEAN
    image = image.clamp(0, 1)
    return image.permute(1, 2, 0).numpy()

def show_beans_examples(dataset, label_names, n_examples=6):
    n_classes = len(label_names)
    examples_per_class = max(1, int(np.ceil(n_examples / n_classes)))
    selected = {label_id: [] for label_id in range(n_classes)}

    for idx in range(len(dataset)):
        image, label = dataset[idx]
        if len(selected[label]) < examples_per_class:
            selected[label].append((image, label))
        if all(len(items) >= examples_per_class for items in selected.values()):
            break

    fig, axes = plt.subplots(n_classes, examples_per_class, figsize=(2.6 * examples_per_class, 2.5 * n_classes))
    if n_classes == 1 and examples_per_class == 1:
        axes = np.array([[axes]])
    elif n_classes == 1:
        axes = np.array([axes])
    elif examples_per_class == 1:
        axes = np.array([[ax] for ax in axes])

    for label_id in range(n_classes):
        for col in range(examples_per_class):
            ax = axes[label_id, col]
            if col < len(selected[label_id]):
                image, label = selected[label_id][col]
                ax.imshow(unnormalize_image(image))
                ax.set_title(label_names[label], fontsize=10)
            ax.axis("off")
    plt.suptitle("Example Beans images by class", y=1.02)
    plt.tight_layout()
    save_current_figure("01_example_beans_images.png")
    plt.show()

def make_data_loaders(train_dataset, val_dataset, test_dataset, device, train_batch_size=16, eval_batch_size=32):
    pin_memory = device.type == "cuda"
    train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=0, pin_memory=pin_memory)
    val_loader = DataLoader(val_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=0, pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=0, pin_memory=pin_memory)
    return train_loader, val_loader, test_loader

Problem Framing and DatasetΒΆ

Screening use caseΒΆ

The real-world task is lightweight bean leaf disease screening from images. Bean leaf disease can reduce crop yield, so a small image classifier could help flag images that appear healthy or show visible disease symptoms. This notebook frames the task as a three-class image classification problem:

$$ \text{bean leaf image} \mapsto \{\text{healthy}, \text{angular leaf spot}, \text{bean rust}\}. $$

The model should be treated as a prototype screening aid rather than a validated agricultural diagnostic system.

Explanation audienceΒΆ

The interpretation layer is intended for someone auditing the classifier before trusting its predictions: for example, a model developer, agronomist, extension worker, or farmer using the tool as a screening aid. The goal is to check whether the classifier relies on plausible leaf-local evidence instead of shortcuts such as lighting, background color, image borders, or camera artifacts.

Data source and preprocessingΒΆ

The data are loaded from the HuggingFace Beans dataset. Each example contains a bean leaf image and a class label. The preprocessing pipeline resizes images to a common resolution, converts them to tensors, and normalizes them with ImageNet mean and standard deviation because the target model uses an ImageNet-pretrained ResNet-18 backbone. The training transform also applies small random flips and rotations as lightweight data augmentation.

The full training split is used for model fitting, while the validation and test splits are left unchanged. This makes the later saliency analysis meaningful as a model-auditing step: the project is not only about predicting a disease class, but also about checking what image evidence the classifier appears to use.

InΒ [5]:
set_seed(SEED)
device = get_device()

with suppress_output():
    raw_beans, label_column, label_names = load_beans_dataset()

train_transform, eval_transform = make_beans_transforms(image_size=IMAGE_SIZE)
train_dataset = BeansTorchDataset(raw_beans["train"], label_column, transform=train_transform)
val_dataset = BeansTorchDataset(raw_beans["validation"], label_column, transform=eval_transform)
test_dataset = BeansTorchDataset(raw_beans["test"], label_column, transform=eval_transform)

train_loader, val_loader, test_loader = make_data_loaders(train_dataset, val_dataset, test_dataset, device)
split_counts = summarize_split_counts(raw_beans, label_column, label_names)

display(split_counts)
save_table(split_counts, "01_split_counts.csv")
show_beans_examples(train_dataset, label_names, n_examples=6)
README.md: 0.00B [00:00, ?B/s]
data/train-00000-of-00001.parquet:   0%|          | 0.00/144M [00:00<?, ?B/s]
data/validation-00000-of-00001.parquet:   0%|          | 0.00/18.5M [00:00<?, ?B/s]
data/test-00000-of-00001.parquet:   0%|          | 0.00/17.7M [00:00<?, ?B/s]
Generating train split:   0%|          | 0/1034 [00:00<?, ? examples/s]
Generating validation split:   0%|          | 0/133 [00:00<?, ? examples/s]
Generating test split:   0%|          | 0/128 [00:00<?, ? examples/s]
split class count
0 train angular leaf spot 345
1 train bean rust 348
2 train healthy 341
3 validation angular leaf spot 44
4 validation bean rust 45
5 validation healthy 44
6 test angular leaf spot 43
7 test bean rust 43
8 test healthy 42
No description has been provided for this image

Figure 1. Dataset snapshot. The table gives split-by-class counts, and the image grid shows representative training examples for each class. The model is trained on the full training split, while validation and test are reserved for model selection and held-out evaluation.

InΒ [6]:
# ----------------------------
# Model architecture and training utilities
# ----------------------------
def build_resnet18_beans_classifier(num_classes=3, freeze_backbone=True):
    try:
        weights = models.ResNet18_Weights.DEFAULT
        model = models.resnet18(weights=weights)
        weights_loaded = True
    except Exception as err:
        print("Pretrained weights were unavailable; using randomly initialized weights.")
        print("Reason:", err)
        model = models.resnet18(weights=None)
        weights_loaded = False

    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    model.backbone_frozen = bool(freeze_backbone and weights_loaded)

    if model.backbone_frozen:
        for name, param in model.named_parameters():
            param.requires_grad = name.startswith("fc.")
    elif freeze_backbone and not weights_loaded:
        print("Backbone left trainable because frozen random features would not be useful.")
    return model

def freeze_batchnorm_modules(model):
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            module.eval()

def count_trainable_parameters(model):
    return sum(param.numel() for param in model.parameters() if param.requires_grad)

def evaluate_classifier(model, data_loader, criterion, device, return_predictions=False):
    model.eval()
    loss_sum, correct, total = 0.0, 0, 0
    y_true, y_pred = [], []
    with torch.inference_mode():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            logits = model(images)
            loss = criterion(logits, labels)
            predictions = logits.argmax(dim=1)
            batch_size = labels.size(0)
            loss_sum += loss.item() * batch_size
            correct += (predictions == labels).sum().item()
            total += batch_size
            y_true.extend(labels.cpu().tolist())
            y_pred.extend(predictions.cpu().tolist())
    metrics = {"loss": loss_sum / total, "accuracy": correct / total}
    if return_predictions:
        return metrics, np.array(y_true), np.array(y_pred)
    return metrics

def train_one_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    if getattr(model, "backbone_frozen", False):
        freeze_batchnorm_modules(model)
    loss_sum, correct, total = 0.0, 0, 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        predictions = logits.argmax(dim=1)
        batch_size = labels.size(0)
        loss_sum += loss.item() * batch_size
        correct += (predictions == labels).sum().item()
        total += batch_size
    return {"loss": loss_sum / total, "accuracy": correct / total}

def train_classifier(model, train_loader, val_loader, device, n_epochs=3, lr=1e-3):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam([param for param in model.parameters() if param.requires_grad], lr=lr)
    history = []
    for epoch in range(1, n_epochs + 1):
        train_metrics = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_metrics = evaluate_classifier(model, val_loader, criterion, device)
        history.append({
            "epoch": epoch,
            "train_loss": train_metrics["loss"],
            "train_accuracy": train_metrics["accuracy"],
            "val_loss": val_metrics["loss"],
            "val_accuracy": val_metrics["accuracy"],
        })
    return pd.DataFrame(history)

def plot_training_history(history_df):
    plt.figure(figsize=(6, 4))
    plt.plot(history_df["epoch"], history_df["train_accuracy"], marker="o", label="train")
    plt.plot(history_df["epoch"], history_df["val_accuracy"], marker="o", label="validation")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Training history")
    plt.legend()
    plt.grid(alpha=0.3)
    plt.tight_layout()
    save_current_figure("02_training_history.png")
    plt.show()

def plot_confusion_matrix(y_true, y_pred, label_names):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(label_names))))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=label_names)
    fig, ax = plt.subplots(figsize=(6, 5))
    disp.plot(ax=ax, xticks_rotation=25, values_format="d", colorbar=False)
    ax.set_title("Test confusion matrix")
    plt.tight_layout()
    save_current_figure("03_test_confusion_matrix.png")
    plt.show()
    return cm

Model Design: Lightweight Transfer LearningΒΆ

ArchitectureΒΆ

The target model is a compact transfer-learning classifier. It uses a ResNet-18 convolutional backbone and replaces the original ImageNet output layer with a new linear layer for the three Beans classes:

$$ \text{image} \rightarrow \text{ResNet-18 convolutional features} \rightarrow \text{global average pooling} \rightarrow \text{Linear}(512, 3). $$

The output layer returns three logits, one for each class: angular leaf spot, bean rust, and healthy. This model does not use attention mechanisms. The ImageNet-pretrained backbone weights are frozen, and only the final fc classifier head is trained. BatchNorm running statistics are also kept fixed, so the pretrained feature extractor is treated as stable and Beans-specific optimization happens in the final classifier layer.

This project uses 128Γ—128 images instead of the larger ImageNet default size to keep runtime manageable on CPU. ResNet-18 can still accept this resolution because its convolutional layers and global average pooling do not require a fixed 224Γ—224 input, though larger images could improve accuracy in a more compute-intensive version.

How explanations connect to the modelΒΆ

Because this project uses saliency maps rather than concept activation vectors or sparse autoencoders, there is no internal concept layer to select. Integrated Gradients is computed directly with respect to input pixels for one selected target class.

The later sanity check still analyzes the model layer by layer. The randomization cascade starts from the output layer and moves backward through high-level ResNet blocks, testing whether the saliency map depends on learned model parameters rather than only on image structure.

InΒ [7]:
set_seed(SEED)
with suppress_output():
    model = build_resnet18_beans_classifier(num_classes=len(label_names), freeze_backbone=True)
model = model.to(device)

model_summary = pd.DataFrame([{
    "architecture": "ImageNet-pretrained ResNet-18 + Linear(512, 3)",
    "trainable_parameters": count_trainable_parameters(model),
    "backbone_weights_frozen": getattr(model, "backbone_frozen", False),
    "batchnorm_running_stats_fixed": True,
}])
display(model_summary)
save_table(model_summary, "02_model_summary.csv")

history = train_classifier(model, train_loader, val_loader, device, n_epochs=N_EPOCHS, lr=LEARNING_RATE)
display(history)
save_table(history, "03_training_history.csv")
plot_training_history(history)
architecture trainable_parameters backbone_weights_frozen batchnorm_running_stats_fixed
0 ImageNet-pretrained ResNet-18 + Linear(512, 3) 1539 True True
epoch train_loss train_accuracy val_loss val_accuracy
0 1 0.859978 0.590909 0.679890 0.699248
1 2 0.618878 0.741779 0.572208 0.729323
2 3 0.543465 0.785300 0.468062 0.819549
No description has been provided for this image

Figure 2. Training history. The plot and table summarize how training and validation accuracy change over the training run. This provides an audit trail for the lightweight transfer-learning setup and shows whether the classifier learned useful Beans-specific signal without heavy GPU training.

InΒ [8]:
criterion = nn.CrossEntropyLoss()
test_metrics, test_true, test_pred = evaluate_classifier(model, test_loader, criterion, device, return_predictions=True)

test_metrics_df = pd.DataFrame([test_metrics])
display(test_metrics_df)
save_table(test_metrics_df, "04_test_metrics.csv")
test_cm = plot_confusion_matrix(test_true, test_pred, label_names)


confusion_matrix_df = pd.DataFrame(test_cm, index=label_names, columns=label_names)
save_table(confusion_matrix_df.reset_index().rename(columns={"index": "true_label"}), "05_confusion_matrix.csv")
loss accuracy
0 0.557868 0.765625
No description has been provided for this image
Out[8]:
PosixPath('portfolio_outputs/tables/05_confusion_matrix.csv')

Figure 3. Held-out test performance. The metrics table reports final test-set performance, and the confusion matrix shows which classes are predicted correctly versus confused. The saliency maps below explain this trained model as an interpretable machine-learning artifact, not as a validated agricultural diagnostic system.

InΒ [9]:
# ----------------------------
# Integrated Gradients explanation utilities
# ----------------------------
def predict_image(model, image_tensor, device):
    model.eval()
    with torch.inference_mode():
        logits = model(image_tensor.unsqueeze(0).to(device))
        probs = torch.softmax(logits, dim=1).squeeze(0).cpu()
    return int(probs.argmax()), probs

def select_test_example(model, dataset, device, prefer_correct=True, max_search=100):
    fallback = None
    for idx in range(min(max_search, len(dataset))):
        image, true_label = dataset[idx]
        pred_label, probs = predict_image(model, image, device)
        if fallback is None:
            fallback = (idx, image, true_label, pred_label, probs)
        if (not prefer_correct) or (pred_label == true_label):
            return idx, image, true_label, pred_label, probs
    return fallback

def black_image_baseline_like(image_tensor):
    black = torch.zeros_like(image_tensor)
    normalized_black = (black - IMAGENET_MEAN) / IMAGENET_STD
    return normalized_black

def compute_integrated_gradients(model, image_tensor, target_class, device, n_steps=32):
    model.eval()
    ig = IntegratedGradients(model)
    input_tensor = image_tensor.unsqueeze(0).to(device).requires_grad_(True)
    baseline = black_image_baseline_like(image_tensor).unsqueeze(0).to(device)
    attr = ig.attribute(input_tensor, baselines=baseline, target=int(target_class), n_steps=int(n_steps))
    return attr.detach().cpu().squeeze(0)

def attribution_heatmap(attr_tensor):
    heatmap = attr_tensor.abs().mean(dim=0).numpy()
    heatmap = heatmap - heatmap.min()
    denom = heatmap.max()
    if denom > 0:
        heatmap = heatmap / denom
    return heatmap

def plot_saliency(image_tensor, attr_tensor, true_label, pred_label, label_names, title):
    image = unnormalize_image(image_tensor)
    heatmap = attribution_heatmap(attr_tensor)
    fig, axes = plt.subplots(1, 2, figsize=(8.5, 4))
    axes[0].imshow(image)
    axes[0].set_title(f"Image\ntrue: {label_names[true_label]}")
    axes[0].axis("off")
    axes[1].imshow(image)
    im = axes[1].imshow(heatmap, alpha=0.55, vmin=0, vmax=1)
    axes[1].set_title(f"Integrated Gradients\npred: {label_names[pred_label]}")
    axes[1].axis("off")
    cbar = fig.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
    cbar.set_label("Normalized |IG| attribution magnitude")
    fig.suptitle(title)
    plt.tight_layout()
    save_current_figure(f"04_saliency_{sanitize_filename(title)}.png")
    plt.show()

Local Explanation with Integrated GradientsΒΆ

Method intuitionΒΆ

Integrated Gradients is a local explanation method: each attribution map explains one image and one target class by assigning an attribution value to each input pixel. A global summary could be made by averaging many local maps, but the core method used here is image-specific.

For image $x^*$, target class $k$, model $f_\theta$, and baseline image $x_0$, Integrated Gradients attributes pixels by accumulating gradients along a path from the baseline to the input:

$$ (x^* - x_0) \odot \int_0^1 \nabla_x f_{\theta,k}(x_0 + \alpha(x^* - x_0))\, d\alpha. $$

Here $f_{\theta,k}$ is the class-$k$ logit, $x_0$ is a black-image baseline after applying the same input normalization, and n_steps controls the number of interpolation points used to approximate the path integral.

Visualization designΒΆ

The visualization below shows the original bean leaf image next to an Integrated Gradients saliency map for the predicted class. The heatmap uses absolute attribution magnitude averaged over color channels, so brighter regions indicate larger attribution magnitude, not necessarily positive evidence for the predicted class.

Interpretation guidanceΒΆ

For the selected example, the most useful reading is whether high-attribution regions align with the leaf blade and visible disease-relevant texture rather than image borders or background. If attribution is concentrated mostly on leaf-local regions, the explanation is more consistent with a plausible screening model. If the map emphasizes background or acquisition artifacts, that would be evidence of shortcut behavior.

A saliency map should not be treated as proof that the model has learned plant pathology. It shows sensitivity for one trained model, one image, and one target class. Bright regions may correspond to disease symptoms, but they can also reflect edges, texture, contrast, or correlated image artifacts.

InΒ [10]:
example_idx, example_image, true_label, pred_label, probs = select_test_example(model, test_dataset, device, prefer_correct=True)

selected_example_summary = pd.DataFrame([{
    "test_index": example_idx,
    "true_label": label_names[true_label],
    "predicted_label": label_names[pred_label],
    "predicted_probability": float(probs[pred_label]),
}])
display(selected_example_summary)
save_table(selected_example_summary, "06_selected_example_summary.csv")

prediction_table = pd.DataFrame({"class": label_names, "predicted_probability": probs.numpy()})
display(prediction_table)
save_table(prediction_table, "07_selected_example_probabilities.csv")

attr_ig = compute_integrated_gradients(model, example_image, target_class=pred_label, device=device, n_steps=24)
plot_saliency(example_image, attr_ig, true_label, pred_label, label_names, title="Local explanation for one Beans image")
test_index true_label predicted_label predicted_probability
0 1 angular leaf spot angular leaf spot 0.75268
class predicted_probability
0 angular leaf spot 0.752680
1 bean rust 0.210218
2 healthy 0.037102
No description has been provided for this image

Figure 4. Local Integrated Gradients explanation. The left panel shows the original image; the right panel overlays normalized absolute Integrated Gradients attribution magnitude for the predicted-class logit. Brighter regions indicate larger attribution magnitude, not necessarily positive evidence for the class. The prediction table above gives the model's class probabilities for this selected image.

InΒ [11]:
# ----------------------------
# Integrated Gradients sensitivity utilities
# ----------------------------
def flatten_attribution(attr_tensor):
    return attr_tensor.abs().mean(dim=0).numpy().ravel()

def safe_spearman(a, b):
    statistic = spearmanr(a, b).statistic
    if np.isnan(statistic):
        return 0.0
    return float(statistic)

def integrated_gradients_sensitivity(model, image_tensor, target_class, device, n_steps_grid):
    attrs = {}
    for n_steps in n_steps_grid:
        attrs[int(n_steps)] = compute_integrated_gradients(model, image_tensor, target_class=target_class, device=device, n_steps=int(n_steps))
    reference_steps = max(n_steps_grid)
    reference_flat = flatten_attribution(attrs[reference_steps])
    rows = []
    for n_steps in n_steps_grid:
        current_flat = flatten_attribution(attrs[int(n_steps)])
        rows.append({"n_steps": int(n_steps), "spearman_vs_largest_n_steps": safe_spearman(current_flat, reference_flat)})
    return attrs, pd.DataFrame(rows)

def plot_nsteps_sensitivity(sensitivity_df):
    plt.figure(figsize=(6, 4))
    plt.plot(sensitivity_df["n_steps"], sensitivity_df["spearman_vs_largest_n_steps"], marker="o")
    plt.xlabel("Integrated Gradients n_steps")
    plt.ylabel("Spearman correlation vs. largest n_steps")
    plt.title("Integrated Gradients hyperparameter sensitivity")
    plt.ylim(0, 1.05)
    plt.grid(alpha=0.3)
    plt.tight_layout()
    save_current_figure("05_ig_nsteps_sensitivity.png")
    plt.show()

def plot_nsteps_examples(image_tensor, attrs_by_steps, true_label, pred_label, label_names):
    for n_steps in [min(attrs_by_steps), max(attrs_by_steps)]:
        plot_saliency(image_tensor, attrs_by_steps[n_steps], true_label, pred_label, label_names, title=f"Integrated Gradients with n_steps={n_steps}")

Explanation Robustness: Integration-Step SensitivityΒΆ

Integrated Gradients has implementation choices that affect the output. The main numerical hyperparameter checked here is n_steps, which controls the number of interpolation points used to approximate the path integral from the baseline image to the actual image. A smaller value is faster but gives a rougher approximation; a larger value is slower but usually gives a more stable attribution map.

Other important choices include the baseline image, target class, and whether signed or absolute attribution values are visualized. This notebook uses a zero/black-image baseline, explains the predicted class, and visualizes absolute attribution magnitude averaged across color channels.

The code below recomputes the same explanation with several n_steps values and compares each attribution map with the largest-step map using Spearman rank correlation. This check asks whether the main attribution pattern is reasonably stable under a simple numerical approximation change.

InΒ [12]:
n_steps_grid = [8, 16, 24, 48]
ig_attrs_by_steps, sensitivity_df = integrated_gradients_sensitivity(model, example_image, target_class=pred_label, device=device, n_steps_grid=n_steps_grid)

display(sensitivity_df)
save_table(sensitivity_df, "08_ig_nsteps_sensitivity.csv")
plot_nsteps_sensitivity(sensitivity_df)
n_steps spearman_vs_largest_n_steps
0 8 0.957191
1 16 0.984672
2 24 0.992039
3 48 1.000000
No description has been provided for this image

Figure 5. Integration-step sensitivity. The plot reports Spearman rank correlation between each attribution map and the largest-n_steps reference map. The table above gives the exact correlations. Higher correlation means the attribution ranking is more stable relative to the largest-step reference, while lower correlation suggests the explanation is more sensitive to the numerical integration grid.

InΒ [13]:
plot_nsteps_examples(example_image, ig_attrs_by_steps, true_label, pred_label, label_names)
No description has been provided for this image
No description has been provided for this image

Figure 6. Visual comparison of small and large n_steps maps. The smallest-step and largest-step maps provide a visual check of whether the main interpretation changes with the numerical integration grid. This tests sensitivity to n_steps; it does not test sensitivity to the baseline image or to signed versus absolute visualization.

InΒ [14]:
# ----------------------------
# Model-randomization sanity-check utilities
# ----------------------------
def flatten_signed_attribution(attr_tensor):
    return attr_tensor.mean(dim=0).numpy().ravel()

def reset_module_weights(module, seed=479):
    torch.manual_seed(seed)
    for child in module.modules():
        if hasattr(child, "reset_parameters"):
            child.reset_parameters()

def model_randomization_check(model, image_tensor, target_class, device, layer_names, n_steps=16):
    original_attr = compute_integrated_gradients(model, image_tensor, target_class=target_class, device=device, n_steps=n_steps)
    original_abs_flat = flatten_attribution(original_attr)
    original_signed_flat = flatten_signed_attribution(original_attr)
    randomized_model = copy.deepcopy(model).to(device)
    attrs = [("Original", original_attr)]
    rows = [{"randomized_through": "Original", "spearman_abs_vs_original": 1.0, "spearman_signed_vs_original": 1.0}]
    for step, layer_name in enumerate(layer_names, start=1):
        layer = randomized_model.get_submodule(layer_name)
        reset_module_weights(layer, seed=479 + step)
        randomized_model.eval()
        current_attr = compute_integrated_gradients(randomized_model, image_tensor, target_class=target_class, device=device, n_steps=n_steps)
        current_abs_flat = flatten_attribution(current_attr)
        current_signed_flat = flatten_signed_attribution(current_attr)
        attrs.append((f"through {layer_name}", current_attr))
        rows.append({
            "randomized_through": f"through {layer_name}",
            "spearman_abs_vs_original": safe_spearman(current_abs_flat, original_abs_flat),
            "spearman_signed_vs_original": safe_spearman(current_signed_flat, original_signed_flat),
        })
    return pd.DataFrame(rows), attrs

def plot_randomization_correlation(randomization_df):
    plt.figure(figsize=(8, 4))
    plt.plot(randomization_df["randomized_through"], randomization_df["spearman_abs_vs_original"], marker="o", label="absolute attribution")
    plt.plot(randomization_df["randomized_through"], randomization_df["spearman_signed_vs_original"], marker="o", label="signed attribution")
    plt.xticks(rotation=30, ha="right")
    plt.xlabel("Randomization stage")
    plt.ylabel("Spearman correlation vs. original map")
    plt.title("Model randomization sanity check")
    plt.ylim(-1.05, 1.05)
    plt.legend()
    plt.grid(alpha=0.3)
    plt.tight_layout()
    save_current_figure("07_model_randomization_sanity_check.png")
    plt.show()

Explanation Sanity Check: Model RandomizationΒΆ

The final audit uses a model-randomization sanity check for saliency maps. The purpose is to test whether the explanation depends on the trained model's learned weights rather than only on the input image. If the saliency map remains almost unchanged after model weights are randomized, then the method may be mostly reflecting input structure instead of what the model learned.

The check randomizes layers from the output side toward the input side:

$$ \texttt{fc} \rightarrow \texttt{layer4} \rightarrow \texttt{layer3} \rightarrow \texttt{layer2} \rightarrow \texttt{layer1} \rightarrow \texttt{conv1}. $$

After each randomization step, Integrated Gradients is recomputed for the same image and target class. The randomized map is compared with the original map using Spearman correlation. A useful sanity-check result should generally show correlation deterioration as more learned layers are randomized. The deterioration does not need to be perfectly monotone after every layer; the main evidence is a large drop from the original map to the randomized maps.

Because absolute attribution magnitude can preserve image-shaped structure even when attribution signs change, this notebook reports Spearman correlations for both absolute-magnitude maps and signed channel-averaged maps.

InΒ [15]:
cascade_layers = ["fc", "layer4", "layer3", "layer2", "layer1", "conv1"]
randomization_df, randomized_attrs = model_randomization_check(model, example_image, target_class=pred_label, device=device, layer_names=cascade_layers, n_steps=24)

display(randomization_df)
save_table(randomization_df, "09_model_randomization_check.csv")
plot_randomization_correlation(randomization_df)
randomized_through spearman_abs_vs_original spearman_signed_vs_original
0 Original 1.000000 1.000000
1 through fc 0.428613 0.329991
2 through layer4 0.234856 0.026298
3 through layer3 0.190487 0.014381
4 through layer2 0.220071 0.007313
5 through layer1 0.187073 0.011533
6 through conv1 0.163436 -0.006403
No description has been provided for this image

Figure 7. Model-randomization sanity check. The plot reports Spearman correlation between each randomized Integrated Gradients map and the original trained-model map. The table above gives correlations for both absolute-magnitude maps and signed channel-averaged maps. A large drop after randomization is evidence that the explanation depends on model parameters rather than only on image structure.

Sanity-check interpretationΒΆ

This is the decision point for the explanation audit. If Spearman correlation deteriorates sharply after randomizing the classifier head and stays low through the cascade, the selected-example explanation passes the model-randomization check. If correlation remains high even after substantial randomization, the explanation should be treated cautiously because it may be dominated by image structure rather than learned model behavior.