graphtoolbox.interpretability.explain¶
Functions
|
Reduce ALE curve(s) to a scalar importance. |
|
Compute accumulated local effects (ALE) for a single feature/group. |
|
Compute ALE curves averaged over a periodic cycle (e.g., 48 half-hours). |
|
Compute per-node ALE curves. |
Compute per-node ALE curves averaged over a periodic cycle. |
|
|
Compute scalar feature importances from ALE for groups and features. |
|
Build feature matrices per node for a given group on the TEST split. |
|
Scatter + ALE curve plot for a single feature/group. |
|
Plot averaged ALE curve (and optional ±1 std band). |
|
Plot small multiples of per-node ALE curves. |
|
Visualize explanation graphs on a map for selected period(s). |
|
Plot a horizontal bar chart of feature importances. |
Classes
|
Configuration for graph explanation visualization. |
- class graphtoolbox.interpretability.explain.VisualizationConfig(name: str = 'default', output_root: str = 'interpretability', basemap: ~typing.Dict[str, ~typing.Any] | None = None, positions: ~typing.Dict[int, ~typing.Tuple[float, float]] | None = None, pos_df: ~pandas.core.frame.DataFrame | None = None, lon_col: str = 'LONGITUDE', lat_col: str = 'LATITUDE', node_id_col: str | None = None, grouping: ~typing.Dict[str, ~typing.Any] = <factory>, start_date_key: str = 'day_inf_test', date_freq: str = 'D', map_projection: str = 'merc', map_resolution: str = 'i', draw_coastlines: bool = True, draw_countries: bool = True, fillcontinents_color: str = 'gray', mapboundary_color: str = 'white', show_nodes: bool = True, show_labels: bool = True, node_size: int = 500, node_color: str = 'blue', node_alpha: float = 0.6, label_fontsize: int = 10, label_color: str = 'white', edge_cmap: str = 'rocket_r', edge_cmap_std: str = 'mako', edge_width_min: float = 1.0, edge_width_max: float = 5.0, edge_alpha: float = 1.0, connectionstyle: str = 'arc3, rad=0.1', edge_arrows: bool = False, arrowstyle: str = '-|>', arrowsize: int = 10, arrow_max_edges: int = 500, normalize_with_edge_weight: bool = True, save_dpi: int = 150, file_ext: str = 'pdf', subdir: str = 'explanation_graph', fontsize: int = 16, labelsize: int = 12)[source][source]¶
Bases:
objectConfiguration for graph explanation visualization.
Fields¶
- namestr
Run/dataset name used in figure titles and output paths.
- output_rootstr
Root directory where figures are saved.
- basemapdict | None
Explicit Basemap kwargs (projection, bounds). If None, inferred.
- positionsdict[int, tuple[float, float]] | None
Explicit node positions {node_id: (lon, lat)}.
- pos_dfpandas.DataFrame | None
DataFrame holding lon/lat columns (and optionally node id).
- lon_colstr
Longitude column name in pos_df.
- lat_colstr
Latitude column name in pos_df.
- node_id_colstr | None
Column in pos_df specifying node IDs; defaults to row order.
- groupingdict
Time grouping specification (mode, ndays, labels, indices).
- start_date_keystr
Key in data_kwargs for the test start date.
- date_freqstr
Pandas frequency string to expand dates.
- map_projectionstr
Basemap projection (default ‘merc’).
- map_resolutionstr
Basemap resolution code.
- draw_coastlinesbool
Draw coastlines if True.
- draw_countriesbool
Draw country borders if True.
- fillcontinents_colorstr
Color used to fill continents.
- mapboundary_colorstr
Map boundary fill color.
- show_nodesbool
Display nodes.
- show_labelsbool
Display node labels.
- node_sizeint
Node marker size.
- node_colorstr
Node color.
- node_alphafloat
Node transparency.
- label_fontsizeint
Node label font size.
- label_colorstr
Node label font color.
- edge_cmapstr
Colormap name for importance (non-std modes).
- edge_cmap_stdstr
Colormap name when vis_mode == ‘std’.
- edge_width_minfloat
Minimum edge line width.
- edge_width_maxfloat
Maximum edge line width.
- edge_alphafloat
Edge transparency.
- connectionstylestr
Matplotlib connection style for edges.
- edge_arrowsbool
Draw directed edge arrows if True.
- arrowstylestr
Arrow style passed to NetworkX.
- arrowsizeint
Arrow size.
- arrow_max_edgesint
Max edges allowed for arrow rendering (performance guard).
- normalize_with_edge_weightbool
Modulate context importance by normalized edge weights.
- save_dpiint
DPI for saved figures.
- file_extstr
Output file extension.
- subdirstr
Subdirectory under output_root.
- fontsizeint
Global title font size.
- labelsizeint
Axis tick label size.
- name: str = 'default'¶
- output_root: str = 'interpretability'¶
- basemap: Dict[str, Any] | None = None¶
- positions: Dict[int, Tuple[float, float]] | None = None¶
- pos_df: DataFrame | None = None¶
- lon_col: str = 'LONGITUDE'¶
- lat_col: str = 'LATITUDE'¶
- node_id_col: str | None = None¶
- grouping: Dict[str, Any]¶
- start_date_key: str = 'day_inf_test'¶
- date_freq: str = 'D'¶
- map_projection: str = 'merc'¶
- map_resolution: str = 'i'¶
- draw_coastlines: bool = True¶
- draw_countries: bool = True¶
- fillcontinents_color: str = 'gray'¶
- mapboundary_color: str = 'white'¶
- show_nodes: bool = True¶
- show_labels: bool = True¶
- node_size: int = 500¶
- node_color: str = 'blue'¶
- node_alpha: float = 0.6¶
- label_fontsize: int = 10¶
- label_color: str = 'white'¶
- edge_cmap: str = 'rocket_r'¶
- edge_cmap_std: str = 'mako'¶
- edge_width_min: float = 1.0¶
- edge_width_max: float = 5.0¶
- edge_alpha: float = 1.0¶
- connectionstyle: str = 'arc3,rad=0.1'¶
- edge_arrows: bool = False¶
- arrowstyle: str = '-|>'¶
- arrowsize: int = 10¶
- arrow_max_edges: int = 500¶
- normalize_with_edge_weight: bool = True¶
- save_dpi: int = 150¶
- file_ext: str = 'pdf'¶
- subdir: str = 'explanation_graph'¶
- fontsize: int = 16¶
- labelsize: int = 12¶
- graphtoolbox.interpretability.explain.plot_explanation_graph(all_explanations: dict, graph_dataset_test: GraphDataset, data_kwargs: dict, dataset: str = 'default', vis_mode: str = 'std', months_to_plot: list[str] = ['ALL'], edge_keep_ratio: float = 0.1, df_pos: DataFrame | None = None, viz_cfg: VisualizationConfig | dict | None = None)[source][source]¶
Visualize explanation graphs on a map for selected period(s).
- Parameters:
all_explanations (dict[str, dict]) – Mapping from model name to a dict with explanation artifacts. Expected keys include: - “mean”: node or edge importance array/tensor - “std”: optional uncertainty array/tensor (same shape as “mean”) - “edge_masks”: optional edge-level importance array/tensor
graph_dataset_test (GraphDataset | list[torch_geometric.data.Data]) – Test graphs (PyG Data objects). Must contain at least one element and be aligned with the explanation arrays.
data_kwargs (dict) – Auxiliary information for plotting. Must contain the key “day_inf_test” (start date of the test set); additional keys are accepted.
dataset (str) – Backward compatibility only. If
viz_cfgis provided, itsnameis used instead.vis_mode (str) – Visualization mode: “model”, “context”, or “std”. - “model”: display model-driven importances - “context”: display context-driven importances if available - “std”: display uncertainty (std) when provided
months_to_plot (list[str]) – Backward compatibility. When
viz_cfg.groupingis provided, it takes precedence over this argument.edge_keep_ratio (float) – Fraction of highest-importance edges to keep in the rendering (e.g., 0.10 keeps top 10% edges).
df_pos (pandas.DataFrame | None) – Node positions as a DataFrame (backward compatibility). Prefer passing positions via
viz_cfg.positionsorviz_cfg.pos_df.viz_cfg (VisualizationConfig | dict | None) – Visualization configuration. May be a dataclass or a dict. Typical fields include: - name: run name to use in titles/outputs - positions / pos_df: node positions - grouping: list of period keys to render (overrides months_to_plot) - map kwargs: styling, CRS, background, etc.
- Returns:
A Matplotlib Figure, or a dictionary of Figures keyed by period/group, depending on the requested grouping. May return None when running non-interactive workflows that save plots to disk.
- Return type:
matplotlib.figure.Figure | dict[str, matplotlib.figure.Figure] | None
- Raises:
KeyError – If required keys (e.g., “day_inf_test”) are missing in
data_kwargs.ValueError – If inputs are inconsistent (e.g., empty dataset, invalid
edge_keep_ratio).RuntimeError – If a map background cannot be created or positions are unresolved.
Notes
viz_cfg.groupingsupersedesmonths_to_plotwhen provided.Edge filtering is applied independently per plot based on importance.
This function does not compute explanations; it only renders precomputed artifacts contained in
all_explanations.
- graphtoolbox.interpretability.explain.get_group_feature_mats(group_name, data, graph_dataset_train, graph_dataset_test, expand_dummies=True)[source][source]¶
Build feature matrices per node for a given group on the TEST split.
Returns a mapping {feature_name: X} where X has shape [num_nodes, T] and columns are aligned to graph_dataset_test.nodes order. If a listed feature is missing but dummy-encoded columns exist (e.g., ‘feat_X_*’), they are included.
- Parameters:
group_name (str) – Name of the feature group to extract.
data (Any) – Data container with df_test and node identifiers (data.node_var).
graph_dataset_train (Any) – Training graph dataset (used to read feature_groups).
graph_dataset_test (Any) – Test graph dataset (used for node order).
expand_dummies (bool) – Whether to include dummy one-hot expanded columns if base column is missing.
- Returns:
Mapping feature_name -> ndarray [num_nodes, T] of unscaled values.
- Return type:
dict[str, numpy.ndarray]
- Raises:
ValueError – If group_name is not present in feature_groups.
RuntimeWarning – If requested feature is missing in df_test (when not using dummies).
- graphtoolbox.interpretability.explain.compute_ALE_avg_over_instants(group_values, group_contrib, n_bins=20, period=48, align='start')[source][source]¶
Compute ALE curves averaged over a periodic cycle (e.g., 48 half-hours).
Steps: - Align/truncate to a multiple of period and (optionally) window start/end. - For each instant in the period, compute binned mean effects and center them. - Average centered curves across instants; report mean and std envelopes.
- Parameters:
group_values (torch.Tensor) – Unscaled feature values, shape [num_nodes, T].
group_contrib (torch.Tensor) – Group contributions/effects, shape [num_nodes, T].
n_bins (int) – Number of bins for ALE discretization.
period (int) – Period length (e.g., 48 for half-hourly daily cycle).
align (str) – ‘start’ to use the first Tm samples, ‘end’ to use the last Tm.
- Returns:
(x_mid, ale_mean, ale_std) arrays of length n_bins.
- Return type:
tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]
- Raises:
ValueError – If not enough timesteps or no finite values.
- graphtoolbox.interpretability.explain.plot_ALE_avg(x_mid, ale_mean, ale_std=None, label='feature', color='C0', smooth=True)[source][source]¶
Plot averaged ALE curve (and optional ±1 std band).
- Parameters:
x_mid (numpy.ndarray) – Bin midpoints.
ale_mean (numpy.ndarray) – Mean centered ALE per bin.
ale_std (numpy.ndarray | None) – Optional std per bin across instants.
label (str) – Legend label and x-axis label for the feature.
color (str) – Base color for line and band.
smooth (bool) – Use spline smoothing if enough valid points.
- Returns:
None. Displays a Matplotlib figure.
- Return type:
None
- graphtoolbox.interpretability.explain.compute_ALE(group_values, group_contrib, n_bins=20)[source][source]¶
Compute accumulated local effects (ALE) for a single feature/group.
- Parameters:
group_values (torch.Tensor) – Unscaled feature values, shape [num_nodes, T].
group_contrib (torch.Tensor) – Group contributions/effects, shape [num_nodes, T].
n_bins (int) – Number of bins for ALE discretization.
- Returns:
Flattened x values, y effects, bin midpoints, and centered ALE per bin.
- Return type:
tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray]
- Raises:
ValueError – If there are no finite values or all bins are empty.
- graphtoolbox.interpretability.explain.plot_ALE(x, y, x_mid, ale, label='feature', color='C0', align_to_mean=True, demean_scatter=False)[source][source]¶
Scatter + ALE curve plot for a single feature/group.
- Parameters:
x (numpy.ndarray) – Flattened feature values.
y (numpy.ndarray) – Flattened effects/contributions.
x_mid (numpy.ndarray) – Bin midpoints.
ale (numpy.ndarray) – Centered ALE values per bin.
label (str) – Legend label and x-axis label for the feature.
color (str) – Base color for scatter and lines.
align_to_mean (bool) – Shift ALE curve by mean(y) to overlay the scatter trend.
demean_scatter (bool) – Center scatter around 0 by subtracting mean(y).
- Returns:
None. Displays a Matplotlib figure.
- Return type:
None
- graphtoolbox.interpretability.explain.compute_ALE_per_node(group_values, group_contrib, n_bins=20, use_global_bins=True)[source][source]¶
Compute per-node ALE curves.
- Parameters:
group_values (torch.Tensor) – Unscaled feature values per node, shape [num_nodes, T].
group_contrib (torch.Tensor) – Group contributions per node, shape [num_nodes, T].
n_bins (int) – Number of bins per node.
use_global_bins (bool) – Use shared global bins across nodes (recommended).
- Returns:
(x_mid, ale_mat, counts) where: - x_mid: midpoints (global if use_global_bins) - ale_mat: per-node centered ALE, shape [num_nodes, n_bins] - counts: samples per bin, shape [num_nodes, n_bins]
- Return type:
tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]
- Raises:
ValueError – If no finite feature values.
- graphtoolbox.interpretability.explain.compute_ALE_per_node_avg_over_instants(group_values, group_contrib, n_bins=20, period=48, align='start', use_global_bins=True)[source][source]¶
Compute per-node ALE curves averaged over a periodic cycle.
- Parameters:
group_values (torch.Tensor) – Unscaled feature values per node, shape [num_nodes, T].
group_contrib (torch.Tensor) – Group contributions per node, shape [num_nodes, T].
n_bins (int) – Number of bins per node.
period (int) – Period length (e.g., 48 for half-hourly daily cycle).
align (str) – ‘start’ or ‘end’ alignment when truncating to multiple of period.
use_global_bins (bool) – Only global bins are supported in this averaged mode.
- Returns:
(x_mid, ale_nodes, counts_nodes) where: - x_mid: bin midpoints (global) - ale_nodes: averaged centered ALE per node [N, n_bins] - counts_nodes: total sample counts per node/bin [N, n_bins]
- Return type:
tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]
- Raises:
ValueError – If not enough timesteps or no finite values.
NotImplementedError – If use_global_bins is False.
- graphtoolbox.interpretability.explain.plot_ALE_nodes(x_mid, ale_mat, counts=None, max_cols=4, smooth=True, min_points_spline=4, title_prefix='ALE', node_labels=None)[source][source]¶
Plot small multiples of per-node ALE curves.
- Parameters:
x_mid (numpy.ndarray) – Bin midpoints (global).
ale_mat (numpy.ndarray) – Per-node centered ALE values, shape [num_nodes, n_bins].
counts (numpy.ndarray | None) – Optional per-node sample counts per bin, same shape as ale_mat.
max_cols (int) – Maximum number of subplot columns.
smooth (bool) – Use spline smoothing when enough valid points.
min_points_spline (int) – Minimum valid points to enable smoothing.
title_prefix (str) – Title prefix for the figure.
node_labels (list[str] | None) – Optional labels for nodes; defaults to “Node i”.
- Returns:
Matplotlib Figure with the grid of ALE plots.
- Return type:
matplotlib.figure.Figure
- graphtoolbox.interpretability.explain.ale_scalar_importance(ale, counts=None, method='rms')[source][source]¶
Reduce ALE curve(s) to a scalar importance.
- - 'rms'
root-mean-square (optionally weighted by counts)
- - 'range'
max-min range
- - 'tv'
total variation (L1 sum of absolute differences)
- Parameters:
ale (numpy.ndarray) – ALE curve(s), shape (n_bins,) or (n_nodes, n_bins).
counts (numpy.ndarray | None) – Optional weights per bin (same shape as ale or per row).
method (str) – Reduction method (‘rms’|’range’|’tv’).
- Returns:
Scalar importance.
- Return type:
float
- Raises:
ValueError – If method is unknown or input rank is unsupported.
- graphtoolbox.interpretability.explain.compute_feature_importances_from_ALE(group_outputs, data, graph_dataset_train, graph_dataset_test, n_bins=20, mode='avg48', period=48, align='start', method='rms')[source][source]¶
Compute scalar feature importances from ALE for groups and features.
- Modes:
‘global’: ALE over all nodes/times (compute_ALE)
‘avg48’: average ALE over a periodic cycle (compute_ALE_avg_over_instants)
‘per_node_avg48’: per-node ALE averaged over instants; aggregate across nodes
- Parameters:
group_outputs (dict[str, torch.Tensor]) – Mapping group_name -> contributions tensor [N, T_pred].
data (Any) – Data container providing df_test, node_var, and date index.
graph_dataset_train (Any) – Training dataset with dataset_kwargs[‘feature_groups’].
graph_dataset_test (Any) – Test dataset providing node ordering.
n_bins (int) – Number of ALE bins.
mode (str) – One of {‘global’,’avg48’,’per_node_avg48’}.
period (int) – Cycle length for *_avg_over_instants modes.
align (str) – ‘start’ or ‘end’ alignment for windowing.
method (str) – Scalar reduction method for ALE (‘rms’|’range’|’tv’).
- Returns:
DataFrame with columns [‘feature’,’group’,’importance’] sorted desc.
- Return type:
pandas.DataFrame
- graphtoolbox.interpretability.explain.plot_feature_importance_bar(df, top_k=None, normalize='sum', figsize=None, color='C0')[source][source]¶
Plot a horizontal bar chart of feature importances.
- Parameters:
df (pandas.DataFrame) – DataFrame with columns [‘feature’,’importance’].
top_k (int | None) – Optional number of top features to display.
normalize (str | None) – ‘sum’ to normalize by total, ‘max’ by maximum, or None.
figsize (tuple[float, float] | None) – Figure size (width, height) in inches.
color (str) – Matplotlib color for bars.
- Returns:
The DataFrame used for plotting (after normalization/filtering), or None if empty.
- Return type:
pandas.DataFrame | None