# =============================================================================
# IMPORTS
# =============================================================================
import abc
import torch
# =============================================================================
# BASE CLASSES
# =============================================================================
[docs]
class BaseReadout(abc.ABC, torch.nn.Module):
"""Base class for readout function."""
[docs]
def __init__(self):
super(BaseReadout, self).__init__()
[docs]
@abc.abstractmethod
def forward(self, g, x=None, *args, **kwargs):
raise NotImplementedError
def _forward(self, g, x, *args, **kwargs):
raise NotImplementedError