diff --git a/nemo_curator/stages/text/filters/heuristic/string.py b/nemo_curator/stages/text/filters/heuristic/string.py index ce1a0d2b4e..6ddf54a297 100644 --- a/nemo_curator/stages/text/filters/heuristic/string.py +++ b/nemo_curator/stages/text/filters/heuristic/string.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from typing import Literal from nemo_curator.stages.text.filters.doc_filter import DocumentFilter @@ -111,15 +112,28 @@ def keep_document(self, score: float) -> bool: class UrlsFilter(DocumentFilter): """ If more than 20% of the document is comprised of URLs, then discard. + + Args: + max_url_to_text_ratio: Maximum ratio of URL characters to total + characters before the document is dropped. + url_regex: Optional URL regex (compiled pattern or string). When + ``None`` the default project-wide regex is used. Useful when the + default is too strict or too permissive for a particular corpus — + e.g. ``r"https?://[^\\s]+"`` or ``r"https?://[^\\s\\\"'<>]+"``. """ - def __init__(self, max_url_to_text_ratio: float = 0.2): + def __init__( + self, + max_url_to_text_ratio: float = 0.2, + url_regex: re.Pattern | str | None = None, + ): super().__init__() self._cutoff = max_url_to_text_ratio self._name = "urls_ratio" + self._url_regex = re.compile(url_regex) if isinstance(url_regex, str) else (url_regex or regex_url) def score_document(self, text: str) -> float: - all_urls = regex_url.findall(text) + all_urls = self._url_regex.findall(text) url_chars = sum([len(url) for url in all_urls]) nchar = len(text) # Remove if the document is empty @@ -431,13 +445,18 @@ def keep_document(self, score: float) -> bool: class PornographicUrlsFilter(DocumentFilter): """ Check if any of the URLs within the document point to pornography. + + Args: + url_regex: Optional URL regex (compiled pattern or string). When + ``None`` the default project-wide regex is used. """ - def __init__(self): + def __init__(self, url_regex: re.Pattern | str | None = None): super().__init__() + self._url_regex = re.compile(url_regex) if isinstance(url_regex, str) else (url_regex or regex_url) def score_document(self, text: str) -> int: - all_urls = regex_url.findall(text) + all_urls = self._url_regex.findall(text) for url in all_urls: if "porn" in url: return 1 diff --git a/nemo_curator/stages/text/utils/constants.py b/nemo_curator/stages/text/utils/constants.py index b07c81b51f..ad4a1148fa 100644 --- a/nemo_curator/stages/text/utils/constants.py +++ b/nemo_curator/stages/text/utils/constants.py @@ -72,6 +72,10 @@ regex_alpha = regex.compile("[[:alpha:]]") regex_digit = regex.compile("[[:digit:]]") regex_alphanum = re.compile("[a-zA-Z0-9\n?!,.]") -regex_url = re.compile("http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+") +# NOTE: the `-` inside the character class is escaped on purpose. With +# `[$-_…]` the regex engine treats `$-_` as a *range* spanning U+0024 +# through U+005F, which silently includes `<`, `>`, `;`, `:`, etc., so +# matches bleed past the actual URL into surrounding HTML/punctuation. +regex_url = re.compile(r"http[s]?://(?:[a-zA-Z]|[0-9]|[$\-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+") regex_paren = re.compile(r"{|}|⟨|⟩|\[|\]|\(|\)") regex_hash = re.compile("#+") diff --git a/tests/stages/text/modules/test_filters.py b/tests/stages/text/modules/test_filters.py index 3ed3c96484..1b12aa340d 100644 --- a/tests/stages/text/modules/test_filters.py +++ b/tests/stages/text/modules/test_filters.py @@ -13,11 +13,13 @@ # limitations under the License. import os +import re import numpy as np import pandas as pd import pytest +from nemo_curator.stages.text.utils.constants import regex_url from nemo_curator.stages.text.filters import DocumentFilter, Filter, Score, ScoreFilter from nemo_curator.stages.text.filters.heuristic import ( BoilerPlateStringFilter, @@ -555,6 +557,53 @@ def test_urls(self) -> None: ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" + def test_url_regex_does_not_swallow_html_tags(self) -> None: + # Regression for #1601: previously the `[$-_…]` range matched `<`, + # `>`, `;`, `:`, etc., so a URL match bled past the URL into the + # surrounding HTML/punctuation. + assert regex_url.findall("see http://x.com for details") == ["http://x.com"] + assert regex_url.findall("click http://example.com;next") == ["http://example.com"] + + def test_url_regex_still_matches_allowed_characters(self) -> None: + # The fix must not regress on characters the original character + # class intended to allow: letters, digits, `$`, `_`, `@`, `.`, + # `&`, `+`, `-`, `!`, `*`, `(`, `)`, `,`, and percent-encoded escapes. + text = "ref https://A.B-C_D+E&f!*(g),h%2F end" + + assert regex_url.findall(text) == ["https://A.B-C_D+E&f!*(g),h%2F"] + + def test_urls_filter_accepts_custom_regex(self) -> None: + # Per the discussion on #1601, the URL regex should be + # customizable on the filter so callers can swap in a stricter or + # looser pattern (e.g. `r"https?://[^\s]+"`). + dataset = list_to_dataset( + [ + "ftp://files.example.com/archive.tar.gz", + "no urls here!", + "https://www.nvidia.com/en-us/", + ] + ) + # Custom regex matches `ftp://` URLs that the default does not. + filters = ScoreFilter(UrlsFilter(url_regex=r"ftp://[^\s]+")) + + filtered_data = filters.process(dataset) + + expected_data = DocumentBatch( + data=pd.DataFrame({"text": ["no urls here!", "https://www.nvidia.com/en-us/"]}), + task_id="batch_1_urls_ratio", + dataset_name="test_1", + ) + assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" + + def test_urls_filter_accepts_compiled_pattern(self) -> None: + # The custom regex argument should also accept a pre-compiled + # `re.Pattern` instance, not just a string. + compiled = re.compile(r"https?://[^\s]+") + urls_filter = UrlsFilter(url_regex=compiled) + + # The constructor stores the same compiled object, not a re-compile. + assert urls_filter._url_regex is compiled + def test_bullets(self) -> None: dataset = list_to_dataset( [