graphtoolbox.training.trainer

Classes

EarlyStopping([patience, min_delta])

Implements early stopping to terminate training when validation loss stops improving.

Trainer(model, dataset_train, dataset_val, ...)

Train, validate, and evaluate a graph neural network model on temporal graph datasets.

class graphtoolbox.training.trainer.EarlyStopping(patience=10, min_delta=0.0)[source][source]

Bases: object

Implements early stopping to terminate training when validation loss stops improving.

Parameters:
  • patience (int, default=10) – Number of epochs with no improvement after which training will be stopped.

  • min_delta (float, default=0.0) – Minimum change in validation loss to qualify as an improvement.

counter

Number of consecutive epochs without improvement.

Type:

int

best_loss

Lowest recorded validation loss.

Type:

float

early_stop

Whether the early stopping condition has been met.

Type:

bool

Examples

>>> stopper = EarlyStopping(patience=5, min_delta=0.01)
>>> for epoch in range(100):
...     val_loss = compute_validation_loss()
...     stopper(val_loss)
...     if stopper.early_stop:
...         print("Stopped early at epoch", epoch)
class graphtoolbox.training.trainer.Trainer(model, dataset_train, dataset_val, dataset_test, batch_size, model_kwargs: Dict | None = None, reconcile: bool = True, **kwargs)[source][source]

Bases: object

Train, validate, and evaluate a graph neural network model on temporal graph datasets.

This class handles the full training loop, including: - batched training with PyTorch Geometric loaders, - validation and early stopping, - checkpointing and loss tracking, - optional attention collection and saving, - optional per-group outputs for additive models (sum to final prediction), - inference and hierarchical reconciliation (MinT).

Parameters:
  • model (torch.nn.Module) – Graph neural network model to train.

  • dataset_train (GraphDataset) – Training dataset.

  • dataset_val (GraphDataset) – Validation dataset.

  • dataset_test (GraphDataset) – Test dataset.

  • batch_size (int) – Number of graph samples per batch.

  • model_kwargs (dict, optional) – Dictionary of model hyperparameters (loaded from config if None).

  • reconcile (bool, default=True) – Whether to apply MinT reconciliation to predictions.

  • **kwargs

    Optional keyword-only arguments:
    • edge_index (torch.Tensor[2, E])

    • edge_weight (torch.Tensor[E]) or None

    • return_attention (bool): collect attention during validation/test.

    • return_group_outputs (bool): ask additive models to return group-wise contributions.

    • lam_reg (float): graph smoothness regularizer weight (pairwise node prediction L2).

is_trained

Whether the model has been trained.

Type:

bool

train_loader, val_loader, test_loader

Dataloaders for training, validation, and test.

Type:

PyGDataLoader

saving_directory

Path to saved model checkpoints.

Type:

str

S, G, P

Matrices for hierarchical MinT reconciliation.

Type:

torch.Tensor

nodes

Node names (order used by the model).

Type:

list[str]

num_nodes

Number of base nodes.

Type:

int

Notes

When the model class name is ‘AdditiveGraphModel’ and return_group_outputs is not set, Trainer will auto-enable group output collection.

train(**kwargs) Tuple[source][source]

Train the model and optionally evaluate during training.

Supports early stopping, checkpoint saving, attention visualization and (for additive models) returning per-group outputs.

Parameters:
  • num_epochs (int, optional) – Number of epochs to train (default: from model_kwargs).

  • optimizer (torch.optim.Optimizer, optional) – Optimizer (default: Adam with model_kwargs[‘lr’]).

  • patience (int, default=20) – Early stopping patience.

  • min_delta (float, default=0.0) – Minimum delta to count as improvement.

  • force_training (bool, default=False) – If True, retrains even if checkpoint exists.

  • saving_directory (str, optional) – Folder to store model weights.

  • plot_loss (bool, optional) – If True, plots training/validation curves.

  • dynamic_graph (bool, optional) – Enable dynamic adjacency matrix updates per epoch.

  • save (bool, optional) – If True, saves attention maps.

Returns:

If return_attention is False and return_group_outputs is False:

(preds, targets, edge_index, edge_weight)

If return_attention is True:

(preds, targets, edge_index, attention_mats)

If return_group_outputs is True and return_attention is False:

(preds, targets, edge_index, edge_weight, group_outputs)

If both are True:

(preds, targets, edge_index, attention_mats, group_outputs)

Where: - preds : torch.Tensor[num_nodes, T] in original units (reconciled if enabled) - targets : torch.Tensor[num_nodes, T] in original units - attention_mats : dict[str, list[torch.Tensor]] if collected - group_outputs : dict[str, torch.Tensor[num_nodes, T]] if collected

Return type:

Tuple

Notes

  • Applies graph smoothness regularization weighted by lam_reg.

  • Uses early stopping with best-checkpoint saving.

evaluate(losses: List[str] | str = ['mape', 'rmse'])[source][source]

Evaluate trained model on test set using given metrics.

Parameters:

losses (str or list of str, default=['mape', 'rmse']) – Metrics to compute. Supported: ‘mape’, ‘rmse’.

Returns:

Prints evaluation metrics.

Return type:

None