graphtoolbox.utils.visualizations¶
Functions
|
Plot a grid of graph visualizations with varying edge weights. |
|
Plot a geographic graph with nodes and weighted edges using Basemap. |
|
Plot training and validation losses across epochs. |
|
Plot true vs. |
- graphtoolbox.utils.visualizations.plot_losses(num_epochs, train_losses, val_losses)[source][source]¶
Plot training and validation losses across epochs.
- Parameters:
num_epochs (int) – Total number of training epochs.
train_losses (list[float]) – List of training loss values per epoch.
val_losses (list[float]) – List of validation loss values per epoch.
Notes
The function displays the loss evolution and is typically used to diagnose convergence and potential overfitting.
Examples
>>> plot_losses(100, train_losses, val_losses)
- graphtoolbox.utils.visualizations.plot_nodes(true, pred, graph_dataset, **kwargs)[source][source]¶
Plot true vs. predicted time series for all nodes in a graph.
- Parameters:
true (torch.Tensor) – Ground-truth target values with shape
[num_nodes, T].pred (torch.Tensor) – Model predictions with shape
[num_nodes, T].graph_dataset (GraphDataset) – Dataset providing node metadata (names, coordinates, etc.).
nrows (int, optional) – Number of rows in subplot grid (default: 3).
ncols (int, optional) – Number of columns in subplot grid (default: 4).
figsize (tuple, optional) – Figure size in inches (default: (7*nrows, 3*ncols)).
Notes
Each subplot corresponds to a node’s time series.
Predictions and true values are plotted over time.
Examples
>>> plot_nodes(true, pred, dataset_val, nrows=2, ncols=3)
- graphtoolbox.utils.visualizations.plot_graph_map(edge_index: Tensor, edge_weight: Tensor, df_pos: DataFrame, ax)[source][source]¶
Plot a geographic graph with nodes and weighted edges using Basemap.
- Parameters:
edge_index (torch.Tensor) – Edge index tensor of shape
[2, E].edge_weight (torch.Tensor) – Edge weight tensor of shape
[E].df_pos (pandas.DataFrame) – DataFrame with node coordinates (LATITUDE, LONGITUDE).
ax (matplotlib.axes.Axes) – Axis on which to draw the map.
Notes
Nodes are placed using geographic coordinates.
Edge color intensity corresponds to connection weight.
Uses Basemap for cartographic rendering.
Examples
>>> fig, ax = plt.subplots(figsize=(8, 8)) >>> plot_graph_map(edge_index, edge_weight, df_pos, ax)
- graphtoolbox.utils.visualizations.plot_all_graph_maps(graph_list, edge_index, df_pos, **kwargs)[source][source]¶
Plot a grid of graph visualizations with varying edge weights.
- Parameters:
graph_list (list[torch.Tensor]) – List of edge weight tensors to visualize.
edge_index (torch.Tensor) – Edge index shared across graphs.
df_pos (pandas.DataFrame) – Node coordinates for plotting.
nrows (int, optional) – Number of rows in subplot grid (default: 2).
ncols (int, optional) – Number of columns in subplot grid (computed automatically).
figsize (tuple, optional) – Overall figure size.
Examples
>>> plot_all_graph_maps([W1, W2, W3], edge_index, df_pos, nrows=2)