graphtoolbox.interpretability.explain

Functions

ale_scalar_importance(ale[, counts, method])

Reduce ALE curve(s) to a scalar importance.

compute_ALE(group_values, group_contrib[, ...])

Compute accumulated local effects (ALE) for a single feature/group.

compute_ALE_avg_over_instants(group_values, ...)

Compute ALE curves averaged over a periodic cycle (e.g., 48 half-hours).

compute_ALE_per_node(group_values, group_contrib)

Compute per-node ALE curves.

compute_ALE_per_node_avg_over_instants(...)

Compute per-node ALE curves averaged over a periodic cycle.

compute_feature_importances_from_ALE(...[, ...])

Compute scalar feature importances from ALE for groups and features.

get_group_feature_mats(group_name, data, ...)

Build feature matrices per node for a given group on the TEST split.

plot_ALE(x, y, x_mid, ale[, label, color, ...])

Scatter + ALE curve plot for a single feature/group.

plot_ALE_avg(x_mid, ale_mean[, ale_std, ...])

Plot averaged ALE curve (and optional ±1 std band).

plot_ALE_nodes(x_mid, ale_mat[, counts, ...])

Plot small multiples of per-node ALE curves.

plot_explanation_graph(all_explanations, ...)

Visualize explanation graphs on a map for selected period(s).

plot_feature_importance_bar(df[, top_k, ...])

Plot a horizontal bar chart of feature importances.

Classes

VisualizationConfig(name, output_root, ...)

Configuration for graph explanation visualization.

class graphtoolbox.interpretability.explain.VisualizationConfig(name: str = 'default', output_root: str = 'interpretability', basemap: ~typing.Dict[str, ~typing.Any] | None = None, positions: ~typing.Dict[int, ~typing.Tuple[float, float]] | None = None, pos_df: ~pandas.core.frame.DataFrame | None = None, lon_col: str = 'LONGITUDE', lat_col: str = 'LATITUDE', node_id_col: str | None = None, grouping: ~typing.Dict[str, ~typing.Any] = <factory>, start_date_key: str = 'day_inf_test', date_freq: str = 'D', map_projection: str = 'merc', map_resolution: str = 'i', draw_coastlines: bool = True, draw_countries: bool = True, fillcontinents_color: str = 'gray', mapboundary_color: str = 'white', show_nodes: bool = True, show_labels: bool = True, node_size: int = 500, node_color: str = 'blue', node_alpha: float = 0.6, label_fontsize: int = 10, label_color: str = 'white', edge_cmap: str = 'rocket_r', edge_cmap_std: str = 'mako', edge_width_min: float = 1.0, edge_width_max: float = 5.0, edge_alpha: float = 1.0, connectionstyle: str = 'arc3, rad=0.1', edge_arrows: bool = False, arrowstyle: str = '-|>', arrowsize: int = 10, arrow_max_edges: int = 500, normalize_with_edge_weight: bool = True, save_dpi: int = 150, file_ext: str = 'pdf', subdir: str = 'explanation_graph', fontsize: int = 16, labelsize: int = 12)[source][source]

Bases: object

Configuration for graph explanation visualization.

Fields

namestr

Run/dataset name used in figure titles and output paths.

output_rootstr

Root directory where figures are saved.

basemapdict | None

Explicit Basemap kwargs (projection, bounds). If None, inferred.

positionsdict[int, tuple[float, float]] | None

Explicit node positions {node_id: (lon, lat)}.

pos_dfpandas.DataFrame | None

DataFrame holding lon/lat columns (and optionally node id).

lon_colstr

Longitude column name in pos_df.

lat_colstr

Latitude column name in pos_df.

node_id_colstr | None

Column in pos_df specifying node IDs; defaults to row order.

groupingdict

Time grouping specification (mode, ndays, labels, indices).

start_date_keystr

Key in data_kwargs for the test start date.

date_freqstr

Pandas frequency string to expand dates.

map_projectionstr

Basemap projection (default ‘merc’).

map_resolutionstr

Basemap resolution code.

draw_coastlinesbool

Draw coastlines if True.

draw_countriesbool

Draw country borders if True.

fillcontinents_colorstr

Color used to fill continents.

mapboundary_colorstr

Map boundary fill color.

show_nodesbool

Display nodes.

show_labelsbool

Display node labels.

node_sizeint

Node marker size.

node_colorstr

Node color.

node_alphafloat

Node transparency.

label_fontsizeint

Node label font size.

label_colorstr

Node label font color.

edge_cmapstr

Colormap name for importance (non-std modes).

edge_cmap_stdstr

Colormap name when vis_mode == ‘std’.

edge_width_minfloat

Minimum edge line width.

edge_width_maxfloat

Maximum edge line width.

edge_alphafloat

Edge transparency.

connectionstylestr

Matplotlib connection style for edges.

edge_arrowsbool

Draw directed edge arrows if True.

arrowstylestr

Arrow style passed to NetworkX.

arrowsizeint

Arrow size.

arrow_max_edgesint

Max edges allowed for arrow rendering (performance guard).

normalize_with_edge_weightbool

Modulate context importance by normalized edge weights.

save_dpiint

DPI for saved figures.

file_extstr

Output file extension.

subdirstr

Subdirectory under output_root.

fontsizeint

Global title font size.

labelsizeint

Axis tick label size.

name: str = 'default'
output_root: str = 'interpretability'
basemap: Dict[str, Any] | None = None
positions: Dict[int, Tuple[float, float]] | None = None
pos_df: DataFrame | None = None
lon_col: str = 'LONGITUDE'
lat_col: str = 'LATITUDE'
node_id_col: str | None = None
grouping: Dict[str, Any]
start_date_key: str = 'day_inf_test'
date_freq: str = 'D'
map_projection: str = 'merc'
map_resolution: str = 'i'
draw_coastlines: bool = True
draw_countries: bool = True
fillcontinents_color: str = 'gray'
mapboundary_color: str = 'white'
show_nodes: bool = True
show_labels: bool = True
node_size: int = 500
node_color: str = 'blue'
node_alpha: float = 0.6
label_fontsize: int = 10
label_color: str = 'white'
edge_cmap: str = 'rocket_r'
edge_cmap_std: str = 'mako'
edge_width_min: float = 1.0
edge_width_max: float = 5.0
edge_alpha: float = 1.0
connectionstyle: str = 'arc3,rad=0.1'
edge_arrows: bool = False
arrowstyle: str = '-|>'
arrowsize: int = 10
arrow_max_edges: int = 500
normalize_with_edge_weight: bool = True
save_dpi: int = 150
file_ext: str = 'pdf'
subdir: str = 'explanation_graph'
fontsize: int = 16
labelsize: int = 12
graphtoolbox.interpretability.explain.plot_explanation_graph(all_explanations: dict, graph_dataset_test: GraphDataset, data_kwargs: dict, dataset: str = 'default', vis_mode: str = 'std', months_to_plot: list[str] = ['ALL'], edge_keep_ratio: float = 0.1, df_pos: DataFrame | None = None, viz_cfg: VisualizationConfig | dict | None = None)[source][source]

Visualize explanation graphs on a map for selected period(s).

Parameters:
  • all_explanations (dict[str, dict]) – Mapping from model name to a dict with explanation artifacts. Expected keys include: - “mean”: node or edge importance array/tensor - “std”: optional uncertainty array/tensor (same shape as “mean”) - “edge_masks”: optional edge-level importance array/tensor

  • graph_dataset_test (GraphDataset | list[torch_geometric.data.Data]) – Test graphs (PyG Data objects). Must contain at least one element and be aligned with the explanation arrays.

  • data_kwargs (dict) – Auxiliary information for plotting. Must contain the key “day_inf_test” (start date of the test set); additional keys are accepted.

  • dataset (str) – Backward compatibility only. If viz_cfg is provided, its name is used instead.

  • vis_mode (str) – Visualization mode: “model”, “context”, or “std”. - “model”: display model-driven importances - “context”: display context-driven importances if available - “std”: display uncertainty (std) when provided

  • months_to_plot (list[str]) – Backward compatibility. When viz_cfg.grouping is provided, it takes precedence over this argument.

  • edge_keep_ratio (float) – Fraction of highest-importance edges to keep in the rendering (e.g., 0.10 keeps top 10% edges).

  • df_pos (pandas.DataFrame | None) – Node positions as a DataFrame (backward compatibility). Prefer passing positions via viz_cfg.positions or viz_cfg.pos_df.

  • viz_cfg (VisualizationConfig | dict | None) – Visualization configuration. May be a dataclass or a dict. Typical fields include: - name: run name to use in titles/outputs - positions / pos_df: node positions - grouping: list of period keys to render (overrides months_to_plot) - map kwargs: styling, CRS, background, etc.

Returns:

A Matplotlib Figure, or a dictionary of Figures keyed by period/group, depending on the requested grouping. May return None when running non-interactive workflows that save plots to disk.

Return type:

matplotlib.figure.Figure | dict[str, matplotlib.figure.Figure] | None

Raises:
  • KeyError – If required keys (e.g., “day_inf_test”) are missing in data_kwargs.

  • ValueError – If inputs are inconsistent (e.g., empty dataset, invalid edge_keep_ratio).

  • RuntimeError – If a map background cannot be created or positions are unresolved.

Notes

  • viz_cfg.grouping supersedes months_to_plot when provided.

  • Edge filtering is applied independently per plot based on importance.

  • This function does not compute explanations; it only renders precomputed artifacts contained in all_explanations.

graphtoolbox.interpretability.explain.get_group_feature_mats(group_name, data, graph_dataset_train, graph_dataset_test, expand_dummies=True)[source][source]

Build feature matrices per node for a given group on the TEST split.

Returns a mapping {feature_name: X} where X has shape [num_nodes, T] and columns are aligned to graph_dataset_test.nodes order. If a listed feature is missing but dummy-encoded columns exist (e.g., ‘feat_X_*’), they are included.

Parameters:
  • group_name (str) – Name of the feature group to extract.

  • data (Any) – Data container with df_test and node identifiers (data.node_var).

  • graph_dataset_train (Any) – Training graph dataset (used to read feature_groups).

  • graph_dataset_test (Any) – Test graph dataset (used for node order).

  • expand_dummies (bool) – Whether to include dummy one-hot expanded columns if base column is missing.

Returns:

Mapping feature_name -> ndarray [num_nodes, T] of unscaled values.

Return type:

dict[str, numpy.ndarray]

Raises:
  • ValueError – If group_name is not present in feature_groups.

  • RuntimeWarning – If requested feature is missing in df_test (when not using dummies).

graphtoolbox.interpretability.explain.compute_ALE_avg_over_instants(group_values, group_contrib, n_bins=20, period=48, align='start')[source][source]

Compute ALE curves averaged over a periodic cycle (e.g., 48 half-hours).

Steps: - Align/truncate to a multiple of period and (optionally) window start/end. - For each instant in the period, compute binned mean effects and center them. - Average centered curves across instants; report mean and std envelopes.

Parameters:
  • group_values (torch.Tensor) – Unscaled feature values, shape [num_nodes, T].

  • group_contrib (torch.Tensor) – Group contributions/effects, shape [num_nodes, T].

  • n_bins (int) – Number of bins for ALE discretization.

  • period (int) – Period length (e.g., 48 for half-hourly daily cycle).

  • align (str) – ‘start’ to use the first Tm samples, ‘end’ to use the last Tm.

Returns:

(x_mid, ale_mean, ale_std) arrays of length n_bins.

Return type:

tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]

Raises:

ValueError – If not enough timesteps or no finite values.

graphtoolbox.interpretability.explain.plot_ALE_avg(x_mid, ale_mean, ale_std=None, label='feature', color='C0', smooth=True)[source][source]

Plot averaged ALE curve (and optional ±1 std band).

Parameters:
  • x_mid (numpy.ndarray) – Bin midpoints.

  • ale_mean (numpy.ndarray) – Mean centered ALE per bin.

  • ale_std (numpy.ndarray | None) – Optional std per bin across instants.

  • label (str) – Legend label and x-axis label for the feature.

  • color (str) – Base color for line and band.

  • smooth (bool) – Use spline smoothing if enough valid points.

Returns:

None. Displays a Matplotlib figure.

Return type:

None

graphtoolbox.interpretability.explain.compute_ALE(group_values, group_contrib, n_bins=20)[source][source]

Compute accumulated local effects (ALE) for a single feature/group.

Parameters:
  • group_values (torch.Tensor) – Unscaled feature values, shape [num_nodes, T].

  • group_contrib (torch.Tensor) – Group contributions/effects, shape [num_nodes, T].

  • n_bins (int) – Number of bins for ALE discretization.

Returns:

Flattened x values, y effects, bin midpoints, and centered ALE per bin.

Return type:

tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray]

Raises:

ValueError – If there are no finite values or all bins are empty.

graphtoolbox.interpretability.explain.plot_ALE(x, y, x_mid, ale, label='feature', color='C0', align_to_mean=True, demean_scatter=False)[source][source]

Scatter + ALE curve plot for a single feature/group.

Parameters:
  • x (numpy.ndarray) – Flattened feature values.

  • y (numpy.ndarray) – Flattened effects/contributions.

  • x_mid (numpy.ndarray) – Bin midpoints.

  • ale (numpy.ndarray) – Centered ALE values per bin.

  • label (str) – Legend label and x-axis label for the feature.

  • color (str) – Base color for scatter and lines.

  • align_to_mean (bool) – Shift ALE curve by mean(y) to overlay the scatter trend.

  • demean_scatter (bool) – Center scatter around 0 by subtracting mean(y).

Returns:

None. Displays a Matplotlib figure.

Return type:

None

graphtoolbox.interpretability.explain.compute_ALE_per_node(group_values, group_contrib, n_bins=20, use_global_bins=True)[source][source]

Compute per-node ALE curves.

Parameters:
  • group_values (torch.Tensor) – Unscaled feature values per node, shape [num_nodes, T].

  • group_contrib (torch.Tensor) – Group contributions per node, shape [num_nodes, T].

  • n_bins (int) – Number of bins per node.

  • use_global_bins (bool) – Use shared global bins across nodes (recommended).

Returns:

(x_mid, ale_mat, counts) where: - x_mid: midpoints (global if use_global_bins) - ale_mat: per-node centered ALE, shape [num_nodes, n_bins] - counts: samples per bin, shape [num_nodes, n_bins]

Return type:

tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]

Raises:

ValueError – If no finite feature values.

graphtoolbox.interpretability.explain.compute_ALE_per_node_avg_over_instants(group_values, group_contrib, n_bins=20, period=48, align='start', use_global_bins=True)[source][source]

Compute per-node ALE curves averaged over a periodic cycle.

Parameters:
  • group_values (torch.Tensor) – Unscaled feature values per node, shape [num_nodes, T].

  • group_contrib (torch.Tensor) – Group contributions per node, shape [num_nodes, T].

  • n_bins (int) – Number of bins per node.

  • period (int) – Period length (e.g., 48 for half-hourly daily cycle).

  • align (str) – ‘start’ or ‘end’ alignment when truncating to multiple of period.

  • use_global_bins (bool) – Only global bins are supported in this averaged mode.

Returns:

(x_mid, ale_nodes, counts_nodes) where: - x_mid: bin midpoints (global) - ale_nodes: averaged centered ALE per node [N, n_bins] - counts_nodes: total sample counts per node/bin [N, n_bins]

Return type:

tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]

Raises:
  • ValueError – If not enough timesteps or no finite values.

  • NotImplementedError – If use_global_bins is False.

graphtoolbox.interpretability.explain.plot_ALE_nodes(x_mid, ale_mat, counts=None, max_cols=4, smooth=True, min_points_spline=4, title_prefix='ALE', node_labels=None)[source][source]

Plot small multiples of per-node ALE curves.

Parameters:
  • x_mid (numpy.ndarray) – Bin midpoints (global).

  • ale_mat (numpy.ndarray) – Per-node centered ALE values, shape [num_nodes, n_bins].

  • counts (numpy.ndarray | None) – Optional per-node sample counts per bin, same shape as ale_mat.

  • max_cols (int) – Maximum number of subplot columns.

  • smooth (bool) – Use spline smoothing when enough valid points.

  • min_points_spline (int) – Minimum valid points to enable smoothing.

  • title_prefix (str) – Title prefix for the figure.

  • node_labels (list[str] | None) – Optional labels for nodes; defaults to “Node i”.

Returns:

Matplotlib Figure with the grid of ALE plots.

Return type:

matplotlib.figure.Figure

graphtoolbox.interpretability.explain.ale_scalar_importance(ale, counts=None, method='rms')[source][source]

Reduce ALE curve(s) to a scalar importance.

- 'rms'

root-mean-square (optionally weighted by counts)

- 'range'

max-min range

- 'tv'

total variation (L1 sum of absolute differences)

Parameters:
  • ale (numpy.ndarray) – ALE curve(s), shape (n_bins,) or (n_nodes, n_bins).

  • counts (numpy.ndarray | None) – Optional weights per bin (same shape as ale or per row).

  • method (str) – Reduction method (‘rms’|’range’|’tv’).

Returns:

Scalar importance.

Return type:

float

Raises:

ValueError – If method is unknown or input rank is unsupported.

graphtoolbox.interpretability.explain.compute_feature_importances_from_ALE(group_outputs, data, graph_dataset_train, graph_dataset_test, n_bins=20, mode='avg48', period=48, align='start', method='rms')[source][source]

Compute scalar feature importances from ALE for groups and features.

Modes:
  • ‘global’: ALE over all nodes/times (compute_ALE)

  • ‘avg48’: average ALE over a periodic cycle (compute_ALE_avg_over_instants)

  • ‘per_node_avg48’: per-node ALE averaged over instants; aggregate across nodes

Parameters:
  • group_outputs (dict[str, torch.Tensor]) – Mapping group_name -> contributions tensor [N, T_pred].

  • data (Any) – Data container providing df_test, node_var, and date index.

  • graph_dataset_train (Any) – Training dataset with dataset_kwargs[‘feature_groups’].

  • graph_dataset_test (Any) – Test dataset providing node ordering.

  • n_bins (int) – Number of ALE bins.

  • mode (str) – One of {‘global’,’avg48’,’per_node_avg48’}.

  • period (int) – Cycle length for *_avg_over_instants modes.

  • align (str) – ‘start’ or ‘end’ alignment for windowing.

  • method (str) – Scalar reduction method for ALE (‘rms’|’range’|’tv’).

Returns:

DataFrame with columns [‘feature’,’group’,’importance’] sorted desc.

Return type:

pandas.DataFrame

graphtoolbox.interpretability.explain.plot_feature_importance_bar(df, top_k=None, normalize='sum', figsize=None, color='C0')[source][source]

Plot a horizontal bar chart of feature importances.

Parameters:
  • df (pandas.DataFrame) – DataFrame with columns [‘feature’,’importance’].

  • top_k (int | None) – Optional number of top features to display.

  • normalize (str | None) – ‘sum’ to normalize by total, ‘max’ by maximum, or None.

  • figsize (tuple[float, float] | None) – Figure size (width, height) in inches.

  • color (str) – Matplotlib color for bars.

Returns:

The DataFrame used for plotting (after normalization/filtering), or None if empty.

Return type:

pandas.DataFrame | None