Skip to content

SubgraphLoader for heterogeneous graph #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Chen-Cai-OSU opened this issue May 25, 2022 · 1 comment
Open

SubgraphLoader for heterogeneous graph #17

Chen-Cai-OSU opened this issue May 25, 2022 · 1 comment

Comments

@Chen-Cai-OSU
Copy link

Chen-Cai-OSU commented May 25, 2022

Hi,
I am trying to apply pyg_autoscale to heterogeneous graph and have to modify the compute_subgraph method in SubgraphLoader class. I was wondering would you like to elaborate on what offset, count are and what is relabel_fn doing?
My current understanding is that compute_subgraph is basically taking the sub-graph spanned by n_id. Is this understanding accurate?
Many thanks!

    def compute_subgraph(self, batches: List[Tuple[int, Tensor]]) -> SubData:
        batch_ids, n_ids = zip(*batches)
        n_id = torch.cat(n_ids, dim=0)
        batch_id = torch.tensor(batch_ids)

        # We collect the in-mini-batch size (`batch_size`), the offset of each
        # partition in the mini-batch (`offset`), and the number of nodes in
        # each partition (`count`)
        batch_size = n_id.numel()
        offset = self.ptr[batch_id]
        count = self.ptr[batch_id.add_(1)].sub_(offset)

        rowptr, col, value = self.data.adj_t.csr()
        rowptr, col, value, n_id = relabel_fn(rowptr, col, value, n_id,
                                              self.bipartite)

        adj_t = SparseTensor(rowptr=rowptr, col=col, value=value,
                             sparse_sizes=(rowptr.numel() - 1, n_id.numel()),
                             is_sorted=True)

        data = self.data.__class__(adj_t=adj_t)
        for k, v in self.data:
            if isinstance(v, Tensor) and v.size(0) == self.data.num_nodes:
                data[k] = v.index_select(0, n_id)

        return SubData(data, batch_size, n_id, offset, count)
@rusty1s
Copy link
Owner

rusty1s commented May 25, 2022

Yes, that is correct. Importantly, batches denotes a list of contiguous node indices grouped that we want to group into one single mini-batch/subgraph, for example: [[0, 1, 2], [5, 6, 7], [10, 11, 12, 13]] for which offset would be [0, 5, 10] and count would be [3, 3, 4]. relabel_fn then computes the induced subgraph of these chunks of nodes, and relabels their node indices to [0, ..., 9].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants