File size: 815 Bytes
1a1dde9
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Net(torch.nn.Module):
    def __init__(self, num_relations, num_classes, num_nodes=None, input_dim=None, hidden_dim=16, num_bases=30):
        super().__init__()
        assert num_nodes is not None or input_dim is not None, "Please provide input feature dimensionality or number of nodes"
        self.conv1 = RGCNConv(num_nodes if input_dim is None else input_dim, hidden_dim, num_relations,
                              num_bases)
        self.conv2 = RGCNConv(hidden_dim, num_classes, dataset.num_relations,
                              num_bases)

    def forward(self, x, edge_index, edge_type):
        # if x is None, uses an embedding based on num_nodes
        x = F.relu(self.conv1(x, edge_index, edge_type))
        x = self.conv2(x, edge_index, edge_type)
        return F.log_softmax(x, dim=1)