graphtoolbox.models.gnn

Classes

AdditiveGraphModel(feature_group_dims, ...)

Additive Graph Model (GAM-like GNN).

ConvAdapter(conv_class, in_dim, hidden_channels)

GCNEncoder(in_channels, out_channels)

Simple 2-layer Graph Convolutional Network encoder.

VariationalGNNEncoder(in_channels, out_channels)

Variational Graph Encoder producing mean and log-variance embeddings for Variational Graph Autoencoders (VGAE) or graph-based latent models.

myGNN(in_channels, num_layers, ...[, ...])

class graphtoolbox.models.gnn.ConvAdapter(conv_class, in_dim, hidden_channels, heads=1, base_kwargs=None)[source][source]

Bases: Module

forward(x, edge_index, x0=None, **tensors)[source][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class graphtoolbox.models.gnn.myGNN(in_channels: int, num_layers: int, hidden_channels: int, out_channels: int, conv_class=<class 'torch_geometric.nn.conv.gat_conv.GATConv'>, conv_kwargs=None, heads=1, **kwargs)[source][source]

Bases: Module

forward(x, edge_index, edge_weight=None, edge_attr=None, return_attention=False, **kwargs)[source][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class graphtoolbox.models.gnn.AdditiveGraphModel(feature_group_dims: dict, num_layers: int, hidden_channels: int, out_channels: int, conv_class=None, conv_kwargs=None, **kwargs)[source][source]

Bases: Module

Additive Graph Model (GAM-like GNN).

This model builds per-feature-group encoders and then applies one shared GNN backbone over a block-diagonal expansion of the original graph (one disjoint copy per feature group). The final per-node prediction is the average of group-wise outputs plus a learnable bias.

Mathematical form

\(h_g = \mathrm{Encoder}_g(x_g)\), \(y_g = \mathrm{GNN}(h_g, \mathcal{G})\), \(\hat{y} = b + \frac{1}{|G|} \sum_{g \in G} y_g\).

Key points

  • Single shared backbone over disjoint copies improves stability.

  • Avoids repeated backbone calls (prevents exploding activations).

  • Provides additive interpretability via group_outputs.

param feature_group_dims:

Mapping group name -> input feature dimension.

type feature_group_dims:

dict[str, int]

param num_layers:

Number of GNN layers in the shared backbone.

type num_layers:

int

param hidden_channels:

Hidden dimension used by encoders and backbone.

type hidden_channels:

int

param out_channels:

Output dimension per node.

type out_channels:

int

param conv_class:

PyG convolution class (default: GCNConv).

type conv_class:

type, optional

param conv_kwargs:

Extra keyword arguments forwarded to the convolution constructor.

type conv_kwargs:

dict, optional

group_names

Ordered list of feature group names.

Type:

list[str]

num_groups

Number of feature groups.

Type:

int

encoders

Per-group MLP encoders producing hidden representations.

Type:

nn.ModuleDict

gnn_branch

Shared GNN backbone applied to concatenated group embeddings.

Type:

myGNN

bias

Learnable scalar bias added to predictions.

Type:

torch.nn.Parameter

group_index_map

Slices locating each group inside the concatenated feature tensor.

Type:

dict[str, slice]

forward(x, edge_index, edge_weight=None, mask=None, return_attention: bool = False, return_group_outputs: bool = False, **kwargs)[source][source]

Forward pass.

Fast path (return_attention=False) constructs a block-diagonal graph and performs one shared GNN pass. Slow path (return_attention=True) iterates per group to collect attention maps.

Parameters:
  • x (torch.Tensor) – Node features concatenated by group, shape [N, F_total].

  • edge_index (torch.Tensor) – Graph connectivity [2, E].

  • edge_weight (torch.Tensor | None) – Optional edge weights [E].

  • mask (Any | None) – Unused placeholder (trainer compatibility).

  • return_attention (bool) – If True, returns per-group attention statistics.

  • return_group_outputs (bool) – If True, also return per-group node outputs.

Returns:

  • y_hat: Tensor [N, out_channels] (always)

  • group_outputs (dict[str, Tensor]) if return_group_outputs is True

  • attention_per_group (dict[str, dict]) if return_attention is True

Return type:

torch.Tensor | tuple

class graphtoolbox.models.gnn.GCNEncoder(in_channels, out_channels)[source][source]

Bases: Module

Simple 2-layer Graph Convolutional Network encoder.

This encoder maps node features into a compact latent space using two GCNConv layers and ReLU activation.

Parameters:
  • in_channels (int) – Number of input node features.

  • out_channels (int) – Dimension of the latent embedding space.

Examples

>>> enc = GCNEncoder(in_channels=32, out_channels=16)
>>> z = enc(x, edge_index)
>>> z.shape
torch.Size([N, 16])
forward(x, edge_index)[source][source]

Forward pass through the GCN encoder.

Parameters:
  • x (torch.Tensor) – Node feature matrix of shape [N, F].

  • edge_index (torch.LongTensor) – Graph connectivity in COO format.

Returns:

Latent node representations of shape [N, out_channels].

Return type:

torch.Tensor

class graphtoolbox.models.gnn.VariationalGNNEncoder(in_channels, out_channels, conv='gcn')[source][source]

Bases: Module

Variational Graph Encoder producing mean and log-variance embeddings for Variational Graph Autoencoders (VGAE) or graph-based latent models.

Supports both GCN and GraphSAGE convolutions.

Parameters:
  • in_channels (int) – Number of input node features.

  • out_channels (int) – Latent embedding dimension.

  • conv ({'gcn', 'sage'}, default='gcn') – Type of graph convolution to use.

conv1

First convolution layer (shared for both mu/logstd branches).

Type:

torch_geometric.nn.MessagePassing

conv_mu

Convolution layer producing mean embeddings.

Type:

torch_geometric.nn.MessagePassing

conv_logstd

Convolution layer producing log-variance embeddings.

Type:

torch_geometric.nn.MessagePassing

Examples

>>> enc = VariationalGNNEncoder(in_channels=32, out_channels=16, conv='sage')
>>> mu, logstd = enc(x, edge_index)
>>> mu.shape, logstd.shape
(torch.Size([N, 16]), torch.Size([N, 16]))
forward(x, edge_index)[source][source]

Compute latent mean and log-variance representations.

Parameters:
  • x (torch.Tensor) – Node feature matrix of shape [N, F].

  • edge_index (torch.LongTensor) – Graph connectivity in COO format.

Returns:

Mean and log-variance tensors, each of shape [N, out_channels].

Return type:

tuple of torch.Tensor