Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP DO NOT MERGE] Learn2Code Feature Branch #233

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
14 changes: 8 additions & 6 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ on:
branches:
- main
- master
- learn2code
tags:
- "*"
pull_request:
branches:
- main
- master
- learn2code
workflow_dispatch:

permissions:
Expand Down Expand Up @@ -56,7 +58,7 @@ jobs:
run: |
set +e
has_updated=$(git diff --name-only '${{ github.event.pull_request.base.sha }}' | grep -E 'pyproject.toml|Cargo.toml')
is_main_or_release='${{ github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/tags/') }}'
is_main_or_release='${{ github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master' || github.ref == 'refs/heads/learn2code' || startsWith(github.ref, 'refs/tags/') }}'
if [[ -n "${has_updated}" ]] || [[ "${is_main_or_release}" == 'true' ]]; then
echo "should_build=true" >> $GITHUB_OUTPUT
else
Expand Down Expand Up @@ -202,7 +204,7 @@ jobs:
sudo apt-get update
sudo apt-get install --yes --upgrade build-essential cmake protobuf-compiler libssl-dev glibc-source musl-tools
- name: Upload wheels
uses: actions/[email protected]
uses: actions/[email protected].1
with:
overwrite: true
name: release-wheel-linux-${{ matrix.target }}-${{ github.run_id }}
Expand All @@ -228,7 +230,7 @@ jobs:
target: ${{ matrix.target }}
args: --release --out dist --find-interpreter
- name: Upload wheels
uses: actions/[email protected]
uses: actions/[email protected].1
with:
overwrite: true
name: release-wheel-windows-${{ matrix.target }}-${{ github.run_id }}
Expand All @@ -253,7 +255,7 @@ jobs:
target: ${{ matrix.target }}
args: --release --out dist --find-interpreter
- name: Upload wheels
uses: actions/[email protected]
uses: actions/[email protected].1
with:
overwrite: true
name: release-wheel-macos-${{ matrix.target }}-${{ github.run_id }}
Expand All @@ -271,7 +273,7 @@ jobs:
command: sdist
args: --out dist
- name: Upload sdist
uses: actions/[email protected]
uses: actions/[email protected].1
with:
overwrite: true
name: release-sdist-${{ github.run_id }}
Expand All @@ -281,7 +283,7 @@ jobs:
name: Release
runs-on: ubuntu-latest
if: "startsWith(github.ref, 'refs/tags/')"
needs: [build-linux, build-windows, build-macos, sdist]
needs: [build-linux, build-windows, build-macos, sdist, tests]
steps:
- uses: actions/download-artifact@v4
with:
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "dolma"
version = "1.1.1"
version = "1.2.0-dev7"
edition = "2021"
license = "Apache-2.0"

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "dolma"
version = "1.1.2"
version = "1.2.0.dev7"
description = "Data filters"
license = { text = "Apache-2.0" }
readme = "README.md"
Expand Down
2 changes: 2 additions & 0 deletions python/dolma/taggers/code/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
CodeSecretsTagger,
CodeStarCoderTaggers,
CodeStarCoderTaggers2,
Learn2CodeTaggers,
)

__all__ = [
Expand All @@ -12,4 +13,5 @@
"CodeRedPajamaTaggers",
"CodeStarCoderTaggers",
"CodeStarCoderTaggers2",
"Learn2CodeTaggers",
]
111 changes: 111 additions & 0 deletions python/dolma/taggers/code/code_taggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@
if CODE_DEPENDENCIES_AVAILABLE:
from .starcoder import get_nl_ratio
from .utils import (
b64_filter,
filter_html,
get_ext_to_lang_mapping,
get_line_stats,
get_proportion_alphabetic_chars,
get_secrets,
get_whitespace_regex,
hexadecimal_filter,
special_text_file_filter,
unicode_filter,
)


Expand Down Expand Up @@ -269,3 +275,108 @@ def predict(self, doc: DocumentWithMetadata) -> DocResult: # type: ignore
spans.append(Span(start=0, end=doc_length, type="code_to_text_ratio_html_doc", score=code_to_text_ratio))

return DocResult(doc=doc, spans=spans)


@TaggerRegistry.add("learn2code_taggers_v1")
class Learn2CodeTaggers(BaseTaggerWithMetadata):
"""
Based on a mix of filters from StarCoder and Granite
"""

def __init__(self) -> None:
check_code_dependencies()
self.ext_to_lang_mapping = get_ext_to_lang_mapping()
super().__init__()

def predict(self, doc: DocumentWithMetadata) -> DocResult: # type: ignore
spans: List[Span] = []
doc_length = len(doc.text)

num_github_stars = doc.metadata.get("max_stars_count", 0) or doc.metadata.get("star_events_count", 0) or 0
proportion_alpha = get_proportion_alphabetic_chars(doc.text)
has_xml_template = 1.0 if "<?xml version=" in doc.text[:100] else 0.0
line_stats = get_line_stats(doc.text)
b64_filter_results = b64_filter(doc.text)
hexadecimal_filter_results = hexadecimal_filter(doc.text)
unicode_filter_results = unicode_filter(doc.text)

try:
lang = self.ext_to_lang_mapping[doc.metadata.get("ext", "-no-lang")]
except KeyError:
lang = "-no-lang"

filename = doc.metadata.get("path", None)

try:
proportion_comments_doc = get_nl_ratio(doc.text, lang)
except: # pylint: disable=bare-except # noqa: E722
proportion_comments_doc = -1

# Not relevant for non-html code
if lang == "html":
try:
proportion_text_in_html = filter_html(doc.text)
except: # pylint: disable=bare-except # noqa: E722
proportion_text_in_html = -1.0
else:
proportion_text_in_html = 1.0

is_special_text_file = 1 if special_text_file_filter(filename, lang) else 0

# document-level scores
spans.append(Span(start=0, end=doc_length, type="num_chars_doc", score=float(doc_length)))
spans.append(Span(start=0, end=doc_length, type="num_github_stars_doc", score=float(num_github_stars)))
spans.append(Span(start=0, end=doc_length, type="proportion_alpha_doc", score=proportion_alpha))
spans.append(Span(start=0, end=doc_length, type="has_xml_template_doc", score=has_xml_template))
spans.append(Span(start=0, end=doc_length, type="num_lines_doc", score=float(line_stats.total_count)))
spans.append(Span(start=0, end=doc_length, type="mean_line_length_doc", score=line_stats.mean_length))
spans.append(Span(start=0, end=doc_length, type="max_line_length_doc", score=float(line_stats.max_length)))
spans.append(
Span(
start=0, end=doc_length, type="longest_seq_b64_doc", score=float(b64_filter_results.longest_match)
)
)
spans.append(
Span(start=0, end=doc_length, type="proportion_b64_doc", score=b64_filter_results.proportion_match)
)
spans.append(
Span(
start=0,
end=doc_length,
type="longest_seq_hexadecimal_doc",
score=float(hexadecimal_filter_results.longest_match),
)
)
spans.append(
Span(
start=0,
end=doc_length,
type="proportion_hexadecimal_doc",
score=hexadecimal_filter_results.proportion_match,
)
)
spans.append(
Span(
start=0,
end=doc_length,
type="longest_seq_unicode_doc",
score=float(unicode_filter_results.longest_match),
)
)
spans.append(
Span(
start=0,
end=doc_length,
type="proportion_unicode_doc",
score=unicode_filter_results.proportion_match,
)
)
spans.append(Span(start=0, end=doc_length, type="proportion_comments_doc", score=proportion_comments_doc))
spans.append(
Span(start=0, end=doc_length, type="proportion_text_in_html_doc", score=proportion_text_in_html)
)
spans.append(
Span(start=0, end=doc_length, type="is_special_text_file_doc", score=float(is_special_text_file))
)

return DocResult(doc=doc, spans=spans)
82 changes: 82 additions & 0 deletions python/dolma/taggers/code/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import json
import logging
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Generator

Expand Down Expand Up @@ -54,6 +56,73 @@ def get_secrets(code: str):
return secrets


@dataclass
class StarCoderRegexFilterResults:
longest_match: int
proportion_match: float


def regex_match(regex_string: str, text: str) -> StarCoderRegexFilterResults:
all_matches = re.findall(regex_string, text)

match_lengths = [len(match) for match in all_matches]
longest_match = max(match_lengths) if match_lengths else 0
proportion_match = sum(match_lengths) / len(text)

return StarCoderRegexFilterResults(longest_match=longest_match, proportion_match=proportion_match)


def b64_filter(text: str) -> StarCoderRegexFilterResults:
"""
Taken from the StarCoder2 paper.
"""
regex = r"[a-zA-Z0-9+/\n=]{64,}"
return regex_match(regex, text)


def hexadecimal_filter(text: str) -> StarCoderRegexFilterResults:
"""
Taken from StarCoder2 paper.
The escaped literal case, e.g. "\\x48\\x31\\xc0\\x50\\x68\\x2f\\x2f\\x73\\x68",
is a bit broken, because it'll always drop the first byte in the sequence due to
how \b is interpreted in that context.
"""
regex = r"(?:\b(?:0x|\\x)?[0-9a-fA-F]{2}(?:,|\b\s*)){8,}"
return regex_match(regex, text)


def unicode_filter(text: str) -> StarCoderRegexFilterResults:
"""
Taken from the StarCoder2 paper.
"""
regex = r"(?:\\u[0-9a-fA-F]{4}){8,}"
return regex_match(regex, text)


def get_proportion_alphabetic_chars(text: str) -> float:
"""Calculates the proportion of characters in passed text that are alphabetic"""
nonalpha = re.sub(r"[^A-Za-z]", "", text)
return len(nonalpha) / len(text)


@dataclass
class LineStats:
total_count: int
mean_length: float
max_length: int


def get_line_stats(text: str) -> LineStats:
"""Finds some summary stats about the lines in the passed text"""

lines = text.split("\n")
line_lengths = [len(line) for line in lines]

return LineStats(
total_count=len(lines), mean_length=sum(line_lengths) / len(lines), max_length=max(line_lengths)
)


def filter_html(html: str) -> float:
"""Filter HTML files based on displayed text VS code ratio"""
try:
Expand All @@ -80,3 +149,16 @@ def get_ext_to_lang_mapping() -> Dict[str, str]:
path = Path(__file__).parent / "../../data/ext_to_lang_mapping.json"
with smart_open.open(path, "r") as f:
return json.load(f)


def special_text_file_filter(filepath: str, lang: str) -> bool:
if lang == "text": # TODO: include markdown as well?
filename = Path(filepath).stem.lower()

if "requirement" in filename:
return True

if filename in {"readme", "todo", "description", "cmakelists"}:
return True

return False
Loading
Loading