Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions gigl/distributed/base_dist_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,10 @@ def create_sampling_config(
) -> SamplingConfig:
"""Creates a SamplingConfig with patched fanout.

Patches ``num_neighbors`` to zero-out label edge types, then creates
the SamplingConfig used by both colocated and graph store modes.
``num_neighbors`` must cover the message-passing edge types; label edge types
are injected internally, are never sampled, and must not be specified (passing
one raises ``ValueError``). Then creates the SamplingConfig used by both
colocated and graph store modes.

Args:
num_neighbors: Fanout per hop.
Expand Down
8 changes: 8 additions & 0 deletions gigl/distributed/dist_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from graphlearn_torch.utils import count_dict, merge_dict, reverse_edge_type

from gigl.distributed.base_sampler import BaseDistNeighborSampler
from gigl.types.graph import is_label_edge_type


class DistNeighborSampler(BaseDistNeighborSampler):
Expand Down Expand Up @@ -72,6 +73,13 @@ async def _sample_from_nodes(
nbr_dict: dict[EdgeType, list[torch.Tensor]] = {}
edge_dict: dict[EdgeType, torch.Tensor] = {}
for etype in self.edge_types:
if is_label_edge_type(etype):
# Label (positive/negative supervision) edges are injected
# into the graph for ABLP but must never be traversed during
# sampling (doing so would leak ground-truth targets). Skipping
# them here means they need no num_neighbors entry. Mirrors
# DistPPRNeighborSampler (dist_ppr_sampler.py).
continue
req_num = self.num_neighbors[etype][i]
if self.edge_dir == "in":
srcs = src_dict.get(etype[-1], None)
Expand Down
76 changes: 50 additions & 26 deletions gigl/distributed/utils/neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,29 @@ def patch_fanout_for_sampling(
num_neighbors: Union[list[int], dict[EdgeType, list[int]]],
) -> Union[list[int], dict[EdgeType, list[int]]]:
"""
Sets up an approprirate fanout for sampling.
Normalizes the user-provided fanout into the per-edge-type form the samplers expect.

Does the following:
- For all label edge types, sets the fanout to be zero.
- For all other edge types, if the fanout is not specified, uses the original fanout.
For heterogeneous datasets, broadcasts a single fanout list to every message-passing
edge type, or validates a per-edge-type dict.

Note that if fanout is provided as a dict, the keys (edges) in the fanout must be in `edge_types`.
Label edge types (positive/negative supervision edges injected by ABLP) are never
sampled -- the samplers skip them during traversal (see
``DistNeighborSampler._sample_from_nodes`` and ``DistPPRNeighborSampler``), so they
take no fanout. Callers must not specify them: any label edge type in ``num_neighbors``
raises ``ValueError``.

We add this because the existing sampling logic (below) makes strict assumptions that we need to conform to.
https://github.com/alibaba/graphlearn-for-pytorch/blob/26fe3d4e050b081bc51a79dc9547f244f5d314da/graphlearn_torch/python/distributed/dist_neighbor_sampler.py#L317-L318
For homogeneous datasets (``edge_types`` is None) the fanout list is returned unchanged.

Args:
edge_types (Optional[list[EdgeType]]): List of all edge types in the graph, is None for homogeneous datasets
num_neighbors (dict[EdgeType, list[int]]): Specified fanout by the user
num_neighbors (Union[list[int], dict[EdgeType, list[int]]]): Specified fanout by the user
Returns:
Union[list[int], dict[EdgeType, list[int]]]: Modified fanout that is appropriate for sampling. Is a list[int]
if the dataset is homogeneous, otherwise is dict[EdgeType, list[int]]
Union[list[int], dict[EdgeType, list[int]]]: Normalized fanout. A list[int] for
homogeneous datasets, otherwise a dict[EdgeType, list[int]] over message-passing edges.
Raises:
ValueError: If a label edge type is supplied in ``num_neighbors``, if the dataset has
no message-passing edge types to fan out around, or on malformed fanout (extra or
missing edge types, inconsistent hop counts, or negative fanout).
"""
if edge_types is None:
if isinstance(num_neighbors, abc.Mapping):
Expand All @@ -79,32 +85,50 @@ def patch_fanout_for_sampling(
if not all(hop >= 0 for hop in num_neighbors):
raise ValueError(f"Hops provided must be non-negative, got {num_neighbors}")
return num_neighbors

message_passing_edge_types = [
edge_type for edge_type in edge_types if not is_label_edge_type(edge_type)
]
Comment thread
kmontemayor2-sc marked this conversation as resolved.
# A graph with only label edges and no message-passing edges cannot be sampled;
# fail explicitly rather than letting the validation tail raise StopIteration on an
# empty dict.
if not message_passing_edge_types:
raise ValueError(
f"No message-passing edge types found in dataset edge types {edge_types}; "
"cannot construct a fanout."
)

if isinstance(num_neighbors, list):
original_fanout = num_neighbors
should_broadcast_fanout = True
num_neighbors = {}
num_neighbors = {
edge_type: original_fanout for edge_type in message_passing_edge_types
}
else:
# Label (positive/negative supervision) edges are injected internally and are
# never sampled, so callers must not specify them in the fanout. Reject them
# explicitly: they are part of the dataset edge types, so the extra-edge-types
# check below would not catch them.
provided_label_edge_types = {
edge_type for edge_type in num_neighbors if is_label_edge_type(edge_type)
}
if provided_label_edge_types:
raise ValueError(
f"Label edge types {provided_label_edge_types} were provided in num_neighbors. "
"Label (positive/negative supervision) edges are injected internally and are never "
"sampled, so they must not be specified in the fanout."
)
extra_edge_types = set(num_neighbors.keys()) - set(edge_types)
if extra_edge_types:
raise ValueError(
f"Found extra edge types {extra_edge_types} in fanout which is not in dataset edge types {edge_types}."
)
original_fanout = next(iter(num_neighbors.values()))
should_broadcast_fanout = False
num_neighbors = deepcopy(num_neighbors)

num_hop = len(original_fanout)
zero_samples = [0 for _ in range(num_hop)]
for edge_type in edge_types:
# TODO(kmonte): stop setting fanout for positive/negative edges once GLT sampling correctly ignores those edges during fanout.
if is_label_edge_type(edge_type):
num_neighbors[edge_type] = zero_samples
elif should_broadcast_fanout and edge_type not in num_neighbors:
num_neighbors[edge_type] = original_fanout
elif not should_broadcast_fanout and edge_type not in num_neighbors:
missing_edge_types = set(message_passing_edge_types) - set(num_neighbors.keys())
if missing_edge_types:
raise ValueError(
f"Found non-labeled edge type in dataset {edge_type} which is not in the provided fanout {num_neighbors.keys()}. \
If fanout is provided as a dict, all edges must be present."
f"Found non-labeled edge type(s) {missing_edge_types} in the dataset which are not in "
f"the provided fanout {set(num_neighbors.keys())}. If fanout is provided as a dict, "
"all message-passing edges must be present."
)

hops = len(next(iter(num_neighbors.values())))
Expand Down
10 changes: 9 additions & 1 deletion tests/unit/distributed/dist_ablp_neighborloader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
DEFAULT_HOMOGENEOUS_EDGE_TYPE,
GraphPartitionData,
PartitionOutput,
is_label_edge_type,
message_passing_to_negative_label,
message_passing_to_positive_label,
to_heterogeneous_node,
Expand Down Expand Up @@ -209,7 +210,14 @@ def _run_dblp_supervised(
assert isinstance(dataset.train_node_ids, dict)
assert isinstance(dataset.graph, dict)
fanout = [2, 2]
num_neighbors = {edge_type: fanout for edge_type in dataset.graph.keys()}
# Label edge types must not be specified in the fanout (they are injected
# internally and never sampled), so build num_neighbors over message-passing
# edges only.
num_neighbors = {
edge_type: fanout
for edge_type in dataset.graph.keys()
if not is_label_edge_type(edge_type)
}
create_test_process_group()
loader = DistABLPLoader(
dataset=dataset,
Expand Down
35 changes: 16 additions & 19 deletions tests/unit/distributed/utils/neighborloader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,10 @@ def test_shard_nodes_by_process(
_U2I_EDGE_TYPE: [2, 7],
_I2U_EDGE_TYPE: [3, 4],
},
# Label edge types are excluded from the fanout (the samplers skip them).
expected_num_neighbors={
_U2I_EDGE_TYPE: [2, 7],
_I2U_EDGE_TYPE: [3, 4],
_LABELED_EDGE_TYPE: [0, 0],
},
),
param(
"Test patch_fanout_for_sampling on num_neighbors dict with labeled edge type in dataset and fanout",
edge_types=[_U2I_EDGE_TYPE, _I2U_EDGE_TYPE, _LABELED_EDGE_TYPE],
num_neighbors={
_U2I_EDGE_TYPE: [2, 7],
_I2U_EDGE_TYPE: [3, 4],
# If labeled edge type fanout is provided by the user, we assume it was by accident, since users shouldn't be aware of this injected edge type,
# and still set the fanout of it to be 0.
_LABELED_EDGE_TYPE: [2, 2],
},
expected_num_neighbors={
_U2I_EDGE_TYPE: [2, 7],
_I2U_EDGE_TYPE: [3, 4],
_LABELED_EDGE_TYPE: [0, 0],
},
),
param(
Expand All @@ -103,7 +87,6 @@ def test_shard_nodes_by_process(
expected_num_neighbors={
_U2I_EDGE_TYPE: [1, 3],
_I2U_EDGE_TYPE: [1, 3],
_LABELED_EDGE_TYPE: [0, 0],
},
),
param(
Expand All @@ -114,7 +97,7 @@ def test_shard_nodes_by_process(
),
]
)
def test_patch_neighbors_with_zero_fanout(
def test_patch_fanout_for_sampling(
self,
_,
edge_types: Optional[list[EdgeType]],
Expand Down Expand Up @@ -165,6 +148,20 @@ def test_patch_neighbors_with_zero_fanout(
_I2U_EDGE_TYPE: [3, 4],
},
),
param(
"Test that providing a label edge type in num_neighbors raises",
edge_types=[_U2I_EDGE_TYPE, _I2U_EDGE_TYPE, _LABELED_EDGE_TYPE],
num_neighbors={
_U2I_EDGE_TYPE: [2, 7],
_I2U_EDGE_TYPE: [3, 4],
_LABELED_EDGE_TYPE: [2, 2],
},
),
param(
"Test that a dataset with no message-passing edge types raises",
edge_types=[_LABELED_EDGE_TYPE],
num_neighbors=[1, 3],
),
]
)
def test_patch_neighbors_failure(
Expand Down