From 1d766d9047172ec2db4b6f221e697623f870a6ad Mon Sep 17 00:00:00 2001 From: Kyle Montemayor Date: Mon, 22 Jun 2026 18:41:59 +0000 Subject: [PATCH 1/2] Remove positive/negative-edge zero-fanout workaround in k-hop sampler DistNeighborSampler now skips label (positive/negative supervision) edges directly in its k-hop loop, mirroring DistPPRNeighborSampler. As a result patch_fanout_for_sampling no longer injects zero fanout for label edges; it excludes them from num_neighbors instead, resolving the long-standing TODO at neighborloader.py. Behavior-preserving (zero fanout and skipping produce identical subgraphs) and covers both colocated and graph-store modes. - dist_neighbor_sampler.py: skip is_label_edge_type edges before indexing num_neighbors - neighborloader.py: patch_fanout_for_sampling excludes label edges; add empty-edges guard - base_dist_loader.py: fix stale create_sampling_config docstring - neighborloader_test.py: update expectations (labels excluded, not zeroed) Co-Authored-By: Claude Opus 4.8 (1M context) --- gigl/distributed/base_dist_loader.py | 5 +- gigl/distributed/dist_neighbor_sampler.py | 8 +++ gigl/distributed/utils/neighborloader.py | 66 +++++++++++-------- .../distributed/utils/neighborloader_test.py | 11 ++-- 4 files changed, 55 insertions(+), 35 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 09b4dcac0..74dd631b3 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -388,8 +388,9 @@ def create_sampling_config( ) -> SamplingConfig: """Creates a SamplingConfig with patched fanout. - Patches ``num_neighbors`` to zero-out label edge types, then creates - the SamplingConfig used by both colocated and graph store modes. + Excludes label edge types from ``num_neighbors`` (the samplers skip them + during traversal), then creates the SamplingConfig used by both colocated + and graph store modes. Args: num_neighbors: Fanout per hop. diff --git a/gigl/distributed/dist_neighbor_sampler.py b/gigl/distributed/dist_neighbor_sampler.py index a91737b5b..6b62c951f 100644 --- a/gigl/distributed/dist_neighbor_sampler.py +++ b/gigl/distributed/dist_neighbor_sampler.py @@ -12,6 +12,7 @@ from graphlearn_torch.utils import count_dict, merge_dict, reverse_edge_type from gigl.distributed.base_sampler import BaseDistNeighborSampler +from gigl.types.graph import is_label_edge_type class DistNeighborSampler(BaseDistNeighborSampler): @@ -72,6 +73,13 @@ async def _sample_from_nodes( nbr_dict: dict[EdgeType, list[torch.Tensor]] = {} edge_dict: dict[EdgeType, torch.Tensor] = {} for etype in self.edge_types: + if is_label_edge_type(etype): + # Label (positive/negative supervision) edges are injected + # into the graph for ABLP but must never be traversed during + # sampling (doing so would leak ground-truth targets). Skipping + # them here means they need no num_neighbors entry. Mirrors + # DistPPRNeighborSampler (dist_ppr_sampler.py). + continue req_num = self.num_neighbors[etype][i] if self.edge_dir == "in": srcs = src_dict.get(etype[-1], None) diff --git a/gigl/distributed/utils/neighborloader.py b/gigl/distributed/utils/neighborloader.py index 00d08c79d..2d4b61102 100644 --- a/gigl/distributed/utils/neighborloader.py +++ b/gigl/distributed/utils/neighborloader.py @@ -53,23 +53,24 @@ def patch_fanout_for_sampling( num_neighbors: Union[list[int], dict[EdgeType, list[int]]], ) -> Union[list[int], dict[EdgeType, list[int]]]: """ - Sets up an approprirate fanout for sampling. + Normalizes the user-provided fanout into the per-edge-type form the samplers expect. - Does the following: - - For all label edge types, sets the fanout to be zero. - - For all other edge types, if the fanout is not specified, uses the original fanout. + For heterogeneous datasets, broadcasts a single fanout list to every message-passing + edge type, or validates a per-edge-type dict. - Note that if fanout is provided as a dict, the keys (edges) in the fanout must be in `edge_types`. + Label edge types (positive/negative supervision edges injected by ABLP) are excluded + from the returned fanout: the samplers skip them during traversal (see + ``DistNeighborSampler._sample_from_nodes`` and ``DistPPRNeighborSampler``), so they + need no fanout entry. Any label edge type a caller supplies by accident is dropped. - We add this because the existing sampling logic (below) makes strict assumptions that we need to conform to. - https://github.com/alibaba/graphlearn-for-pytorch/blob/26fe3d4e050b081bc51a79dc9547f244f5d314da/graphlearn_torch/python/distributed/dist_neighbor_sampler.py#L317-L318 + For homogeneous datasets (``edge_types`` is None) the fanout list is returned unchanged. Args: edge_types (Optional[list[EdgeType]]): List of all edge types in the graph, is None for homogeneous datasets - num_neighbors (dict[EdgeType, list[int]]): Specified fanout by the user + num_neighbors (Union[list[int], dict[EdgeType, list[int]]]): Specified fanout by the user Returns: - Union[list[int], dict[EdgeType, list[int]]]: Modified fanout that is appropriate for sampling. Is a list[int] - if the dataset is homogeneous, otherwise is dict[EdgeType, list[int]] + Union[list[int], dict[EdgeType, list[int]]]: Normalized fanout. A list[int] for + homogeneous datasets, otherwise a dict[EdgeType, list[int]] over message-passing edges. """ if edge_types is None: if isinstance(num_neighbors, abc.Mapping): @@ -79,32 +80,43 @@ def patch_fanout_for_sampling( if not all(hop >= 0 for hop in num_neighbors): raise ValueError(f"Hops provided must be non-negative, got {num_neighbors}") return num_neighbors + + message_passing_edge_types = [ + edge_type for edge_type in edge_types if not is_label_edge_type(edge_type) + ] + # A graph with only label edges and no message-passing edges cannot be sampled; + # fail explicitly rather than letting the validation tail raise StopIteration on an + # empty dict. + if not message_passing_edge_types: + raise ValueError( + f"No message-passing edge types found in dataset edge types {edge_types}; " + "cannot construct a fanout." + ) + if isinstance(num_neighbors, list): original_fanout = num_neighbors - should_broadcast_fanout = True - num_neighbors = {} + num_neighbors = { + edge_type: original_fanout for edge_type in message_passing_edge_types + } else: extra_edge_types = set(num_neighbors.keys()) - set(edge_types) if extra_edge_types: raise ValueError( f"Found extra edge types {extra_edge_types} in fanout which is not in dataset edge types {edge_types}." ) - original_fanout = next(iter(num_neighbors.values())) - should_broadcast_fanout = False - num_neighbors = deepcopy(num_neighbors) - - num_hop = len(original_fanout) - zero_samples = [0 for _ in range(num_hop)] - for edge_type in edge_types: - # TODO(kmonte): stop setting fanout for positive/negative edges once GLT sampling correctly ignores those edges during fanout. - if is_label_edge_type(edge_type): - num_neighbors[edge_type] = zero_samples - elif should_broadcast_fanout and edge_type not in num_neighbors: - num_neighbors[edge_type] = original_fanout - elif not should_broadcast_fanout and edge_type not in num_neighbors: + # Label edges are injected internally and are never sampled; drop any the + # caller supplied so the fanout dict only contains message-passing edges. + num_neighbors = { + edge_type: deepcopy(fanout) + for edge_type, fanout in num_neighbors.items() + if not is_label_edge_type(edge_type) + } + missing_edge_types = set(message_passing_edge_types) - set(num_neighbors.keys()) + if missing_edge_types: raise ValueError( - f"Found non-labeled edge type in dataset {edge_type} which is not in the provided fanout {num_neighbors.keys()}. \ - If fanout is provided as a dict, all edges must be present." + f"Found non-labeled edge type(s) {missing_edge_types} in the dataset which are not in " + f"the provided fanout {set(num_neighbors.keys())}. If fanout is provided as a dict, " + "all message-passing edges must be present." ) hops = len(next(iter(num_neighbors.values()))) diff --git a/tests/unit/distributed/utils/neighborloader_test.py b/tests/unit/distributed/utils/neighborloader_test.py index a9b9c023f..a2eb07bc2 100644 --- a/tests/unit/distributed/utils/neighborloader_test.py +++ b/tests/unit/distributed/utils/neighborloader_test.py @@ -74,10 +74,10 @@ def test_shard_nodes_by_process( _U2I_EDGE_TYPE: [2, 7], _I2U_EDGE_TYPE: [3, 4], }, + # Label edge types are excluded from the fanout (the samplers skip them). expected_num_neighbors={ _U2I_EDGE_TYPE: [2, 7], _I2U_EDGE_TYPE: [3, 4], - _LABELED_EDGE_TYPE: [0, 0], }, ), param( @@ -86,14 +86,14 @@ def test_shard_nodes_by_process( num_neighbors={ _U2I_EDGE_TYPE: [2, 7], _I2U_EDGE_TYPE: [3, 4], - # If labeled edge type fanout is provided by the user, we assume it was by accident, since users shouldn't be aware of this injected edge type, - # and still set the fanout of it to be 0. + # A label edge type supplied by the caller is assumed accidental + # (callers should not be aware of this injected edge type) and is + # dropped from the fanout rather than zeroed. _LABELED_EDGE_TYPE: [2, 2], }, expected_num_neighbors={ _U2I_EDGE_TYPE: [2, 7], _I2U_EDGE_TYPE: [3, 4], - _LABELED_EDGE_TYPE: [0, 0], }, ), param( @@ -103,7 +103,6 @@ def test_shard_nodes_by_process( expected_num_neighbors={ _U2I_EDGE_TYPE: [1, 3], _I2U_EDGE_TYPE: [1, 3], - _LABELED_EDGE_TYPE: [0, 0], }, ), param( @@ -114,7 +113,7 @@ def test_shard_nodes_by_process( ), ] ) - def test_patch_neighbors_with_zero_fanout( + def test_patch_fanout_for_sampling( self, _, edge_types: Optional[list[EdgeType]], From ae8b4623fcae588e7f4823a1213d473a4ff731a1 Mon Sep 17 00:00:00 2001 From: Kyle Montemayor Date: Tue, 23 Jun 2026 17:08:12 +0000 Subject: [PATCH 2/2] Throw on label edges and missing message-passing edge types in fanout MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit patch_fanout_for_sampling now raises ValueError if a caller passes any label (positive/negative supervision) edge type in num_neighbors — previously these were silently dropped. The existing guard that raises when there are no message-passing edge types to fan out around is retained. Fail fast on invalid fanout input rather than silently correcting it. - neighborloader.py: dict branch rejects caller-provided label edge types - base_dist_loader.py: update create_sampling_config docstring - neighborloader_test.py: label-edge and no-message-passing-edges cases now assert raises - dist_ablp_neighborloader_test.py: build num_neighbors over message-passing edges only Co-Authored-By: Claude Opus 4.8 (1M context) --- gigl/distributed/base_dist_loader.py | 7 ++-- gigl/distributed/utils/neighborloader.py | 32 +++++++++++++------ .../dist_ablp_neighborloader_test.py | 10 +++++- .../distributed/utils/neighborloader_test.py | 30 ++++++++--------- 4 files changed, 49 insertions(+), 30 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 74dd631b3..02d7c69e1 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -388,9 +388,10 @@ def create_sampling_config( ) -> SamplingConfig: """Creates a SamplingConfig with patched fanout. - Excludes label edge types from ``num_neighbors`` (the samplers skip them - during traversal), then creates the SamplingConfig used by both colocated - and graph store modes. + ``num_neighbors`` must cover the message-passing edge types; label edge types + are injected internally, are never sampled, and must not be specified (passing + one raises ``ValueError``). Then creates the SamplingConfig used by both + colocated and graph store modes. Args: num_neighbors: Fanout per hop. diff --git a/gigl/distributed/utils/neighborloader.py b/gigl/distributed/utils/neighborloader.py index 2d4b61102..876582208 100644 --- a/gigl/distributed/utils/neighborloader.py +++ b/gigl/distributed/utils/neighborloader.py @@ -58,10 +58,11 @@ def patch_fanout_for_sampling( For heterogeneous datasets, broadcasts a single fanout list to every message-passing edge type, or validates a per-edge-type dict. - Label edge types (positive/negative supervision edges injected by ABLP) are excluded - from the returned fanout: the samplers skip them during traversal (see + Label edge types (positive/negative supervision edges injected by ABLP) are never + sampled -- the samplers skip them during traversal (see ``DistNeighborSampler._sample_from_nodes`` and ``DistPPRNeighborSampler``), so they - need no fanout entry. Any label edge type a caller supplies by accident is dropped. + take no fanout. Callers must not specify them: any label edge type in ``num_neighbors`` + raises ``ValueError``. For homogeneous datasets (``edge_types`` is None) the fanout list is returned unchanged. @@ -71,6 +72,10 @@ def patch_fanout_for_sampling( Returns: Union[list[int], dict[EdgeType, list[int]]]: Normalized fanout. A list[int] for homogeneous datasets, otherwise a dict[EdgeType, list[int]] over message-passing edges. + Raises: + ValueError: If a label edge type is supplied in ``num_neighbors``, if the dataset has + no message-passing edge types to fan out around, or on malformed fanout (extra or + missing edge types, inconsistent hop counts, or negative fanout). """ if edge_types is None: if isinstance(num_neighbors, abc.Mapping): @@ -99,18 +104,25 @@ def patch_fanout_for_sampling( edge_type: original_fanout for edge_type in message_passing_edge_types } else: + # Label (positive/negative supervision) edges are injected internally and are + # never sampled, so callers must not specify them in the fanout. Reject them + # explicitly: they are part of the dataset edge types, so the extra-edge-types + # check below would not catch them. + provided_label_edge_types = { + edge_type for edge_type in num_neighbors if is_label_edge_type(edge_type) + } + if provided_label_edge_types: + raise ValueError( + f"Label edge types {provided_label_edge_types} were provided in num_neighbors. " + "Label (positive/negative supervision) edges are injected internally and are never " + "sampled, so they must not be specified in the fanout." + ) extra_edge_types = set(num_neighbors.keys()) - set(edge_types) if extra_edge_types: raise ValueError( f"Found extra edge types {extra_edge_types} in fanout which is not in dataset edge types {edge_types}." ) - # Label edges are injected internally and are never sampled; drop any the - # caller supplied so the fanout dict only contains message-passing edges. - num_neighbors = { - edge_type: deepcopy(fanout) - for edge_type, fanout in num_neighbors.items() - if not is_label_edge_type(edge_type) - } + num_neighbors = deepcopy(num_neighbors) missing_edge_types = set(message_passing_edge_types) - set(num_neighbors.keys()) if missing_edge_types: raise ValueError( diff --git a/tests/unit/distributed/dist_ablp_neighborloader_test.py b/tests/unit/distributed/dist_ablp_neighborloader_test.py index 31d3d1cbc..2d8e4c626 100644 --- a/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -30,6 +30,7 @@ DEFAULT_HOMOGENEOUS_EDGE_TYPE, GraphPartitionData, PartitionOutput, + is_label_edge_type, message_passing_to_negative_label, message_passing_to_positive_label, to_heterogeneous_node, @@ -209,7 +210,14 @@ def _run_dblp_supervised( assert isinstance(dataset.train_node_ids, dict) assert isinstance(dataset.graph, dict) fanout = [2, 2] - num_neighbors = {edge_type: fanout for edge_type in dataset.graph.keys()} + # Label edge types must not be specified in the fanout (they are injected + # internally and never sampled), so build num_neighbors over message-passing + # edges only. + num_neighbors = { + edge_type: fanout + for edge_type in dataset.graph.keys() + if not is_label_edge_type(edge_type) + } create_test_process_group() loader = DistABLPLoader( dataset=dataset, diff --git a/tests/unit/distributed/utils/neighborloader_test.py b/tests/unit/distributed/utils/neighborloader_test.py index a2eb07bc2..00c49fb55 100644 --- a/tests/unit/distributed/utils/neighborloader_test.py +++ b/tests/unit/distributed/utils/neighborloader_test.py @@ -80,22 +80,6 @@ def test_shard_nodes_by_process( _I2U_EDGE_TYPE: [3, 4], }, ), - param( - "Test patch_fanout_for_sampling on num_neighbors dict with labeled edge type in dataset and fanout", - edge_types=[_U2I_EDGE_TYPE, _I2U_EDGE_TYPE, _LABELED_EDGE_TYPE], - num_neighbors={ - _U2I_EDGE_TYPE: [2, 7], - _I2U_EDGE_TYPE: [3, 4], - # A label edge type supplied by the caller is assumed accidental - # (callers should not be aware of this injected edge type) and is - # dropped from the fanout rather than zeroed. - _LABELED_EDGE_TYPE: [2, 2], - }, - expected_num_neighbors={ - _U2I_EDGE_TYPE: [2, 7], - _I2U_EDGE_TYPE: [3, 4], - }, - ), param( "Test patch_fanout_for_sampling on num_neighbors list", edge_types=[_U2I_EDGE_TYPE, _I2U_EDGE_TYPE, _LABELED_EDGE_TYPE], @@ -164,6 +148,20 @@ def test_patch_fanout_for_sampling( _I2U_EDGE_TYPE: [3, 4], }, ), + param( + "Test that providing a label edge type in num_neighbors raises", + edge_types=[_U2I_EDGE_TYPE, _I2U_EDGE_TYPE, _LABELED_EDGE_TYPE], + num_neighbors={ + _U2I_EDGE_TYPE: [2, 7], + _I2U_EDGE_TYPE: [3, 4], + _LABELED_EDGE_TYPE: [2, 2], + }, + ), + param( + "Test that a dataset with no message-passing edge types raises", + edge_types=[_LABELED_EDGE_TYPE], + num_neighbors=[1, 3], + ), ] ) def test_patch_neighbors_failure(