Skip to content

Commit

Permalink
parsed out
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni committed Feb 6, 2024
1 parent a38ffb1 commit 91544e5
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 92 deletions.
88 changes: 0 additions & 88 deletions python/dolma/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import json
import warnings
from pathlib import Path
from typing import List, Optional, Union

import smart_open
import urllib3.util

# warning raised by pkg_resources used in a lot of google packages
warnings.filterwarnings("ignore", message=r".*declare_namespace\(\'.*google.*", category=DeprecationWarning)
Expand Down Expand Up @@ -59,86 +54,3 @@ def mixer(config: dict):
_dolma.mixer_entrypoint(json.dumps(config))
except RuntimeError as e:
raise DolmaRustPipelineError(f"Error running mixer: {e}") from e


class UrlBlocker:
"""
A class that provides URL blocking functionality based on a set of rules.
Args:
rules (List[str]): A list of rules to be used for blocking URLs.
Attributes:
engine: The underlying engine used for URL blocking.
Methods:
from_adblockplus_filepath: Create an instance of UrlBlocker from an AdBlock Plus file.
check_network_urls: Check if a given URL should be blocked based on the rules.
"""

def __init__(
self,
rules: List[str],
) -> None:
"""
Initialize the UrlBlocker instance.
Args:
rules (List[str]): A list of rules to be used for blocking URLs.
"""
self.engine = _dolma.UrlBlocker(rules=rules)

@classmethod
def from_adblockplus_filepath(
cls,
adblockplus_filepath: Union[str, Path],
) -> "UrlBlocker":
"""
Create an instance of UrlBlocker from an AdBlock Plus file.
Args:
adblockplus_filepath (str): The filepath of the AdBlock Plus file.
Returns:
UrlBlocker: An instance of UrlBlocker created from the AdBlock Plus file.
"""
with smart_open.open(adblockplus_filepath, "rt") as adb_file:
rules = [ln.strip() for ln in adb_file if not ln.startswith("!")]
return cls(rules)

def check_network_urls(
self,
url: str,
source_url: Optional[str] = None,
request_type: str = "",
) -> bool:
"""
Check if a given URL should be blocked based on the rules.
Args:
url (str): The URL to be checked.
source_url (str): The source URL of the request. If not provided, the host from the URL will be used.
request_type (str): The type of the request. For a list of valid request types, see the adblockplus
documentation: https://help.adblockplus.org/hc/en-us/articles/360062733293-How-to-write-filters
Returns:
bool: True if the URL should be blocked, False otherwise.
"""
parsed = urllib3.util.parse_url(url)
if parsed.scheme is None:
# if the URL does not have a scheme, we assume it is an HTTP URL
url = f"http://{url}"

if source_url is None:
# if the source URL is not provided, we use the host from the URL
source_url = parsed.host

return self.engine.check_network_urls(
url=url,
source_url=source_url,
request_type=request_type,
)
5 changes: 3 additions & 2 deletions python/dolma/taggers/url.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ..core.data_types import DocResult, DocumentWithMetadata, Span
from ..core.registry import TaggerRegistry
from ..core.taggers import BaseTaggerWithMetadata
from ..core.url_blocker import UrlBlocker


class BaseUrlTagger(BaseTaggerWithMetadata):
Expand Down Expand Up @@ -64,9 +65,9 @@ class DomainBlocklistPhishingTagger(BaseDomainTagger):

class AdbUrlTagger(BaseUrlTagger):
def __init__(self) -> None:
from dolma import UrlBlocker
# from dolma import UrlBlocker

self.engine = UrlBlocker.from_adblockplus_filepath(self.BLOCKLIST_PATH)
self.engine = UrlBlocker.from_adb_paths(self.BLOCKLIST_PATH)

def check_url(self, url: str) -> bool:
return self.engine.check_network_urls(url)
Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_urls.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
from unittest import TestCase

from dolma import UrlBlocker
from dolma.core.url_blocker import UrlBlocker

LOCAL_DATA = Path(__file__).parent.parent / "data"

Expand Down Expand Up @@ -29,7 +29,7 @@ def test_brave_adblocker(self):
self.assertFalse(engine.check_network_urls(not_to_block))

def test_load_from_file(self):
engine = UrlBlocker.from_adblockplus_filepath(LOCAL_DATA / "urls/easylist.txt.gz")
engine = UrlBlocker.from_adb_paths(LOCAL_DATA / "urls/easylist.txt.gz")

# global rules
self.assertTrue(engine.check_network_urls("berush.com"))
Expand Down

0 comments on commit 91544e5

Please sign in to comment.