graphtoolbox.training.trainer

Classes

EarlyStopping([patience, min_delta])

Implements early stopping to terminate training when validation loss stops improving.

RollingTrainer(dataset_train_0, ...[, ...])

Warm-start rolling retraining without future leakage.

Trainer(model, dataset_train, dataset_val, ...)

Train, validate, and evaluate a graph neural network model on temporal graph datasets.

class graphtoolbox.training.trainer.EarlyStopping(patience=10, min_delta=0.0)[source][source]

Bases: object

Implements early stopping to terminate training when validation loss stops improving.

Parameters:
  • patience (int, default=10) – Number of epochs with no improvement after which training will be stopped.

  • min_delta (float, default=0.0) – Minimum change in validation loss to qualify as an improvement.

counter

Number of consecutive epochs without improvement.

Type:

int

best_loss

Lowest recorded validation loss.

Type:

float

early_stop

Whether the early stopping condition has been met.

Type:

bool

Examples

>>> stopper = EarlyStopping(patience=5, min_delta=0.01)
>>> for epoch in range(100):
...     val_loss = compute_validation_loss()
...     stopper(val_loss)
...     if stopper.early_stop:
...         print("Stopped early at epoch", epoch)
class graphtoolbox.training.trainer.Trainer(model, dataset_train, dataset_val, dataset_test, batch_size, model_kwargs: Dict | None = None, reconcile: bool = True, top_forecasts_train: Tensor | None = None, top_forecasts_val: Tensor | None = None, top_forecasts_test: Tensor | None = None, top_level_model: str | None = 'ridge', **kwargs)[source][source]

Bases: object

Train, validate, and evaluate a graph neural network model on temporal graph datasets.

This class handles the full training loop, including: - batched training with PyTorch Geometric loaders, - validation and early stopping, - checkpointing and loss tracking, - optional attention collection and saving, - optional per-group outputs for additive models (sum to final prediction), - inference and hierarchical reconciliation (MinT).

Parameters:
  • model (torch.nn.Module) – Graph neural network model to train.

  • dataset_train (GraphDataset) – Training dataset.

  • dataset_val (GraphDataset) – Validation dataset.

  • dataset_test (GraphDataset) – Test dataset.

  • batch_size (int) – Number of graph samples per batch.

  • model_kwargs (dict, optional) – Dictionary of model hyperparameters (loaded from config if None).

  • reconcile (bool, default=True) – Whether to apply MinT reconciliation to predictions.

  • **kwargs

    Optional keyword-only arguments:
    • edge_index (torch.Tensor[2, E])

    • edge_weight (torch.Tensor[E]) or None

    • return_attention (bool): collect attention during validation/test.

    • return_group_outputs (bool): ask additive models to return group-wise contributions.

    • lam_reg (float): graph smoothness regularizer weight (pairwise node prediction L2).

is_trained

Whether the model has been trained.

Type:

bool

train_loader, val_loader, test_loader

Dataloaders for training, validation, and test.

Type:

PyGDataLoader

saving_directory

Path to saved model checkpoints.

Type:

str

S, G, P

Matrices for hierarchical MinT reconciliation.

Type:

torch.Tensor

nodes

Node names (order used by the model).

Type:

list[str]

num_nodes

Number of base nodes.

Type:

int

Notes

When the model class name is ‘AdditiveGraphModel’ and return_group_outputs is not set, Trainer will auto-enable group output collection.

train(**kwargs) Tuple[source][source]

Train the model and optionally evaluate during training.

Supports early stopping, checkpoint saving, attention visualization and (for additive models) returning per-group outputs.

Parameters:
  • num_epochs (int, optional) – Number of epochs to train (default: from model_kwargs).

  • optimizer (torch.optim.Optimizer, optional) – Optimizer (default: Adam with model_kwargs[‘lr’]).

  • patience (int, default=20) – Early stopping patience.

  • min_delta (float, default=0.0) – Minimum delta to count as improvement.

  • force_training (bool, default=False) – If True, retrains from scratch even if a checkpoint exists. If False (default), the trainer first looks for a completed checkpoint at the requested epoch count and loads it directly. When only a shorter checkpoint is found (e.g. 100 epochs exist but 200 are requested), training resumes from that checkpoint for the remaining epochs without restarting from scratch.

  • saving_directory (str, optional) – Folder to store model weights.

  • plot_loss (bool, optional) – If True, plots training/validation curves.

  • dynamic_graph (bool, optional) – Enable dynamic adjacency matrix updates per epoch.

  • save (bool, optional) – If True, saves attention maps.

Returns:

If return_attention is False and return_group_outputs is False:

(preds, targets, edge_index, edge_weight)

If return_attention is True:

(preds, targets, edge_index, attention_mats)

If return_group_outputs is True and return_attention is False:

(preds, targets, edge_index, edge_weight, group_outputs)

If both are True:

(preds, targets, edge_index, attention_mats, group_outputs)

Where: - preds : torch.Tensor[num_nodes, T] in original units (reconciled if enabled) - targets : torch.Tensor[num_nodes, T] in original units - attention_mats : dict[str, list[torch.Tensor]] if collected - group_outputs : dict[str, torch.Tensor[num_nodes, T]] if collected

Return type:

Tuple

Notes

  • Applies graph smoothness regularization weighted by lam_reg.

  • Uses early stopping with best-checkpoint saving.

evaluate(losses: List[str] | str = ['mape', 'rmse'])[source][source]

Evaluate trained model on test set using given metrics.

Parameters:

losses (str or list of str, default=['mape', 'rmse']) – Metrics to compute. Supported: ‘mape’, ‘rmse’.

Returns:

Prints evaluation metrics.

Return type:

None

class graphtoolbox.training.trainer.RollingTrainer(dataset_train_0: GraphDataset, dataset_val_0: GraphDataset, dataset_test_full: GraphDataset, model_class: Type[Module], model_kwargs: Dict[str, Any], window_size: int, step_size: int, batch_size: int = 32, reconcile: bool = True, trainer_kwargs: Dict[str, Any] | None = None, num_epochs_initial: int = 50, num_epochs_update: int = 5)[source][source]

Bases: object

Warm-start rolling retraining without future leakage.

Behavior:

  • Window 0:

    TRAIN = df_train_base VAL = df_val_base TEST = first test window

  • Window k >= 1:

    TRAIN = df_train_base + test_window_{k-1} (historical + past month) VAL = df_val_base TEST = test_window_k

  • Last window:

    If remaining timestamps < window_size, a final window is created with size = remaining timestamps.

All windows warm-start from the previous model.

run() List[Dict[str, Any]][source][source]

Run the rolling training + evaluation.

Returns:

{

“window_index”: int, “window_start”: Timestamp, “window_end”: Timestamp, “preds”: Tensor[num_nodes, T_eff], “targets”:Tensor[num_nodes, T_eff],

}

Return type:

List[Dict[str, Any]] with one entry per window