Source code for easygraph.nn.convs.common

import torch
import torch.nn as nn


[docs]class MultiHeadWrapper(nn.Module): r"""A wrapper to apply multiple heads to a given layer. Args: ``num_heads`` (``int``): The number of heads. ``readout`` (``bool``): The readout method. Can be ``"mean"``, ``"max"``, ``"sum"``, or ``"concat"``. ``layer`` (``nn.Module``): The layer to apply multiple heads. ``**kwargs``: The keyword arguments for the layer. """ def __init__( self, num_heads: int, readout: str, layer: nn.Module, **kwargs ) -> None: super().__init__() self.layers = nn.ModuleList() for _ in range(num_heads): self.layers.append(layer(**kwargs)) self.num_heads = num_heads self.readout = readout
[docs] def forward(self, **kwargs) -> torch.Tensor: r"""The forward function. .. note:: You must explicitly pass the keyword arguments to the layer. For example, if the layer is ``GATConv``, you must pass ``X=X`` and ``g=g``. """ if self.readout == "concat": return torch.cat([layer(**kwargs) for layer in self.layers], dim=-1) else: outs = torch.stack([layer(**kwargs) for layer in self.layers]) if self.readout == "mean": return outs.mean(dim=0) elif self.readout == "max": return outs.max(dim=0)[0] elif self.readout == "sum": return outs.sum(dim=0) else: raise ValueError("Unknown readout type")