diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 09b4dcac0..02d7c69e1 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -388,8 +388,10 @@ def create_sampling_config( ) -> SamplingConfig: """Creates a SamplingConfig with patched fanout. - Patches ``num_neighbors`` to zero-out label edge types, then creates - the SamplingConfig used by both colocated and graph store modes. + ``num_neighbors`` must cover the message-passing edge types; label edge types + are injected internally, are never sampled, and must not be specified (passing + one raises ``ValueError``). Then creates the SamplingConfig used by both + colocated and graph store modes. Args: num_neighbors: Fanout per hop. diff --git a/gigl/distributed/dist_neighbor_sampler.py b/gigl/distributed/dist_neighbor_sampler.py index a91737b5b..6b62c951f 100644 --- a/gigl/distributed/dist_neighbor_sampler.py +++ b/gigl/distributed/dist_neighbor_sampler.py @@ -12,6 +12,7 @@ from graphlearn_torch.utils import count_dict, merge_dict, reverse_edge_type from gigl.distributed.base_sampler import BaseDistNeighborSampler +from gigl.types.graph import is_label_edge_type class DistNeighborSampler(BaseDistNeighborSampler): @@ -72,6 +73,13 @@ async def _sample_from_nodes( nbr_dict: dict[EdgeType, list[torch.Tensor]] = {} edge_dict: dict[EdgeType, torch.Tensor] = {} for etype in self.edge_types: + if is_label_edge_type(etype): + # Label (positive/negative supervision) edges are injected + # into the graph for ABLP but must never be traversed during + # sampling (doing so would leak ground-truth targets). Skipping + # them here means they need no num_neighbors entry. Mirrors + # DistPPRNeighborSampler (dist_ppr_sampler.py). + continue req_num = self.num_neighbors[etype][i] if self.edge_dir == "in": srcs = src_dict.get(etype[-1], None) diff --git a/gigl/distributed/utils/neighborloader.py b/gigl/distributed/utils/neighborloader.py index 00d08c79d..876582208 100644 --- a/gigl/distributed/utils/neighborloader.py +++ b/gigl/distributed/utils/neighborloader.py @@ -53,23 +53,29 @@ def patch_fanout_for_sampling( num_neighbors: Union[list[int], dict[EdgeType, list[int]]], ) -> Union[list[int], dict[EdgeType, list[int]]]: """ - Sets up an approprirate fanout for sampling. + Normalizes the user-provided fanout into the per-edge-type form the samplers expect. - Does the following: - - For all label edge types, sets the fanout to be zero. - - For all other edge types, if the fanout is not specified, uses the original fanout. + For heterogeneous datasets, broadcasts a single fanout list to every message-passing + edge type, or validates a per-edge-type dict. - Note that if fanout is provided as a dict, the keys (edges) in the fanout must be in `edge_types`. + Label edge types (positive/negative supervision edges injected by ABLP) are never + sampled -- the samplers skip them during traversal (see + ``DistNeighborSampler._sample_from_nodes`` and ``DistPPRNeighborSampler``), so they + take no fanout. Callers must not specify them: any label edge type in ``num_neighbors`` + raises ``ValueError``. - We add this because the existing sampling logic (below) makes strict assumptions that we need to conform to. - https://github.com/alibaba/graphlearn-for-pytorch/blob/26fe3d4e050b081bc51a79dc9547f244f5d314da/graphlearn_torch/python/distributed/dist_neighbor_sampler.py#L317-L318 + For homogeneous datasets (``edge_types`` is None) the fanout list is returned unchanged. Args: edge_types (Optional[list[EdgeType]]): List of all edge types in the graph, is None for homogeneous datasets - num_neighbors (dict[EdgeType, list[int]]): Specified fanout by the user + num_neighbors (Union[list[int], dict[EdgeType, list[int]]]): Specified fanout by the user Returns: - Union[list[int], dict[EdgeType, list[int]]]: Modified fanout that is appropriate for sampling. Is a list[int] - if the dataset is homogeneous, otherwise is dict[EdgeType, list[int]] + Union[list[int], dict[EdgeType, list[int]]]: Normalized fanout. A list[int] for + homogeneous datasets, otherwise a dict[EdgeType, list[int]] over message-passing edges. + Raises: + ValueError: If a label edge type is supplied in ``num_neighbors``, if the dataset has + no message-passing edge types to fan out around, or on malformed fanout (extra or + missing edge types, inconsistent hop counts, or negative fanout). """ if edge_types is None: if isinstance(num_neighbors, abc.Mapping): @@ -79,32 +85,50 @@ def patch_fanout_for_sampling( if not all(hop >= 0 for hop in num_neighbors): raise ValueError(f"Hops provided must be non-negative, got {num_neighbors}") return num_neighbors + + message_passing_edge_types = [ + edge_type for edge_type in edge_types if not is_label_edge_type(edge_type) + ] + # A graph with only label edges and no message-passing edges cannot be sampled; + # fail explicitly rather than letting the validation tail raise StopIteration on an + # empty dict. + if not message_passing_edge_types: + raise ValueError( + f"No message-passing edge types found in dataset edge types {edge_types}; " + "cannot construct a fanout." + ) + if isinstance(num_neighbors, list): original_fanout = num_neighbors - should_broadcast_fanout = True - num_neighbors = {} + num_neighbors = { + edge_type: original_fanout for edge_type in message_passing_edge_types + } else: + # Label (positive/negative supervision) edges are injected internally and are + # never sampled, so callers must not specify them in the fanout. Reject them + # explicitly: they are part of the dataset edge types, so the extra-edge-types + # check below would not catch them. + provided_label_edge_types = { + edge_type for edge_type in num_neighbors if is_label_edge_type(edge_type) + } + if provided_label_edge_types: + raise ValueError( + f"Label edge types {provided_label_edge_types} were provided in num_neighbors. " + "Label (positive/negative supervision) edges are injected internally and are never " + "sampled, so they must not be specified in the fanout." + ) extra_edge_types = set(num_neighbors.keys()) - set(edge_types) if extra_edge_types: raise ValueError( f"Found extra edge types {extra_edge_types} in fanout which is not in dataset edge types {edge_types}." ) - original_fanout = next(iter(num_neighbors.values())) - should_broadcast_fanout = False num_neighbors = deepcopy(num_neighbors) - - num_hop = len(original_fanout) - zero_samples = [0 for _ in range(num_hop)] - for edge_type in edge_types: - # TODO(kmonte): stop setting fanout for positive/negative edges once GLT sampling correctly ignores those edges during fanout. - if is_label_edge_type(edge_type): - num_neighbors[edge_type] = zero_samples - elif should_broadcast_fanout and edge_type not in num_neighbors: - num_neighbors[edge_type] = original_fanout - elif not should_broadcast_fanout and edge_type not in num_neighbors: + missing_edge_types = set(message_passing_edge_types) - set(num_neighbors.keys()) + if missing_edge_types: raise ValueError( - f"Found non-labeled edge type in dataset {edge_type} which is not in the provided fanout {num_neighbors.keys()}. \ - If fanout is provided as a dict, all edges must be present." + f"Found non-labeled edge type(s) {missing_edge_types} in the dataset which are not in " + f"the provided fanout {set(num_neighbors.keys())}. If fanout is provided as a dict, " + "all message-passing edges must be present." ) hops = len(next(iter(num_neighbors.values()))) diff --git a/tests/unit/distributed/dist_ablp_neighborloader_test.py b/tests/unit/distributed/dist_ablp_neighborloader_test.py index 31d3d1cbc..2d8e4c626 100644 --- a/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -30,6 +30,7 @@ DEFAULT_HOMOGENEOUS_EDGE_TYPE, GraphPartitionData, PartitionOutput, + is_label_edge_type, message_passing_to_negative_label, message_passing_to_positive_label, to_heterogeneous_node, @@ -209,7 +210,14 @@ def _run_dblp_supervised( assert isinstance(dataset.train_node_ids, dict) assert isinstance(dataset.graph, dict) fanout = [2, 2] - num_neighbors = {edge_type: fanout for edge_type in dataset.graph.keys()} + # Label edge types must not be specified in the fanout (they are injected + # internally and never sampled), so build num_neighbors over message-passing + # edges only. + num_neighbors = { + edge_type: fanout + for edge_type in dataset.graph.keys() + if not is_label_edge_type(edge_type) + } create_test_process_group() loader = DistABLPLoader( dataset=dataset, diff --git a/tests/unit/distributed/utils/neighborloader_test.py b/tests/unit/distributed/utils/neighborloader_test.py index a9b9c023f..00c49fb55 100644 --- a/tests/unit/distributed/utils/neighborloader_test.py +++ b/tests/unit/distributed/utils/neighborloader_test.py @@ -74,26 +74,10 @@ def test_shard_nodes_by_process( _U2I_EDGE_TYPE: [2, 7], _I2U_EDGE_TYPE: [3, 4], }, + # Label edge types are excluded from the fanout (the samplers skip them). expected_num_neighbors={ _U2I_EDGE_TYPE: [2, 7], _I2U_EDGE_TYPE: [3, 4], - _LABELED_EDGE_TYPE: [0, 0], - }, - ), - param( - "Test patch_fanout_for_sampling on num_neighbors dict with labeled edge type in dataset and fanout", - edge_types=[_U2I_EDGE_TYPE, _I2U_EDGE_TYPE, _LABELED_EDGE_TYPE], - num_neighbors={ - _U2I_EDGE_TYPE: [2, 7], - _I2U_EDGE_TYPE: [3, 4], - # If labeled edge type fanout is provided by the user, we assume it was by accident, since users shouldn't be aware of this injected edge type, - # and still set the fanout of it to be 0. - _LABELED_EDGE_TYPE: [2, 2], - }, - expected_num_neighbors={ - _U2I_EDGE_TYPE: [2, 7], - _I2U_EDGE_TYPE: [3, 4], - _LABELED_EDGE_TYPE: [0, 0], }, ), param( @@ -103,7 +87,6 @@ def test_shard_nodes_by_process( expected_num_neighbors={ _U2I_EDGE_TYPE: [1, 3], _I2U_EDGE_TYPE: [1, 3], - _LABELED_EDGE_TYPE: [0, 0], }, ), param( @@ -114,7 +97,7 @@ def test_shard_nodes_by_process( ), ] ) - def test_patch_neighbors_with_zero_fanout( + def test_patch_fanout_for_sampling( self, _, edge_types: Optional[list[EdgeType]], @@ -165,6 +148,20 @@ def test_patch_neighbors_with_zero_fanout( _I2U_EDGE_TYPE: [3, 4], }, ), + param( + "Test that providing a label edge type in num_neighbors raises", + edge_types=[_U2I_EDGE_TYPE, _I2U_EDGE_TYPE, _LABELED_EDGE_TYPE], + num_neighbors={ + _U2I_EDGE_TYPE: [2, 7], + _I2U_EDGE_TYPE: [3, 4], + _LABELED_EDGE_TYPE: [2, 2], + }, + ), + param( + "Test that a dataset with no message-passing edge types raises", + edge_types=[_LABELED_EDGE_TYPE], + num_neighbors=[1, 3], + ), ] ) def test_patch_neighbors_failure(