graphtoolbox.utils.visualizations

Functions

plot_all_graph_maps(graph_list, edge_index, ...)

Plot a grid of graph visualizations with varying edge weights.

plot_graph_map(edge_index, edge_weight, ...)

Plot a geographic graph with nodes and weighted edges using Basemap.

plot_losses(num_epochs, train_losses, val_losses)

Plot training and validation losses across epochs.

plot_nodes(true, pred, graph_dataset, **kwargs)

Plot true vs.

graphtoolbox.utils.visualizations.plot_losses(num_epochs, train_losses, val_losses)[source][source]

Plot training and validation losses across epochs.

Parameters:
  • num_epochs (int) – Total number of training epochs.

  • train_losses (list[float]) – List of training loss values per epoch.

  • val_losses (list[float]) – List of validation loss values per epoch.

Notes

The function displays the loss evolution and is typically used to diagnose convergence and potential overfitting.

Examples

>>> plot_losses(100, train_losses, val_losses)
graphtoolbox.utils.visualizations.plot_nodes(true, pred, graph_dataset, **kwargs)[source][source]

Plot true vs. predicted time series for all nodes in a graph.

Parameters:
  • true (torch.Tensor) – Ground-truth target values with shape [num_nodes, T].

  • pred (torch.Tensor) – Model predictions with shape [num_nodes, T].

  • graph_dataset (GraphDataset) – Dataset providing node metadata (names, coordinates, etc.).

  • nrows (int, optional) – Number of rows in subplot grid (default: 3).

  • ncols (int, optional) – Number of columns in subplot grid (default: 4).

  • figsize (tuple, optional) – Figure size in inches (default: (7*nrows, 3*ncols)).

Notes

  • Each subplot corresponds to a node’s time series.

  • Predictions and true values are plotted over time.

Examples

>>> plot_nodes(true, pred, dataset_val, nrows=2, ncols=3)
graphtoolbox.utils.visualizations.plot_graph_map(edge_index: Tensor, edge_weight: Tensor, df_pos: DataFrame, ax)[source][source]

Plot a geographic graph with nodes and weighted edges using Basemap.

Parameters:
  • edge_index (torch.Tensor) – Edge index tensor of shape [2, E].

  • edge_weight (torch.Tensor) – Edge weight tensor of shape [E].

  • df_pos (pandas.DataFrame) – DataFrame with node coordinates (LATITUDE, LONGITUDE).

  • ax (matplotlib.axes.Axes) – Axis on which to draw the map.

Notes

  • Nodes are placed using geographic coordinates.

  • Edge color intensity corresponds to connection weight.

  • Uses Basemap for cartographic rendering.

Examples

>>> fig, ax = plt.subplots(figsize=(8, 8))
>>> plot_graph_map(edge_index, edge_weight, df_pos, ax)
graphtoolbox.utils.visualizations.plot_all_graph_maps(graph_list, edge_index, df_pos, **kwargs)[source][source]

Plot a grid of graph visualizations with varying edge weights.

Parameters:
  • graph_list (list[torch.Tensor]) – List of edge weight tensors to visualize.

  • edge_index (torch.Tensor) – Edge index shared across graphs.

  • df_pos (pandas.DataFrame) – Node coordinates for plotting.

  • nrows (int, optional) – Number of rows in subplot grid (default: 2).

  • ncols (int, optional) – Number of columns in subplot grid (computed automatically).

  • figsize (tuple, optional) – Overall figure size.

Examples

>>> plot_all_graph_maps([W1, W2, W3], edge_index, df_pos, nrows=2)