diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 1708201c..ba9aef5b 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -24,4 +24,4 @@ contact_links: url: https://g.co/vulnz about: > To report a security issue, please use https://g.co/vulnz. The Google Security Team will - respond within 5 working days of your report on https://g.co/vulnz. \ No newline at end of file + respond within 5 working days of your report on https://g.co/vulnz. diff --git a/.github/scripts/add-new-checks.sh b/.github/scripts/add-new-checks.sh new file mode 100755 index 00000000..1b4bf6fb --- /dev/null +++ b/.github/scripts/add-new-checks.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Script to add new required status checks to an existing branch protection rule. +# This preserves all your current settings and just adds the new checks + +echo "Adding new PR validation checks to existing branch protection..." + +# Add the new checks to existing ones +echo "Adding new checks: enforce, size, and protect-infrastructure..." +gh api repos/:owner/:repo/branches/main/protection/required_status_checks/contexts \ + --method POST \ + --input - <<< '["enforce", "size", "protect-infrastructure"]' + +echo "" +echo "โœ“ New checks added!" +echo "" +echo "Updated required status checks will include:" +echo "- test (3.10) [existing]" +echo "- test (3.11) [existing]" +echo "- test (3.12) [existing]" +echo "- Validate PR Template [existing]" +echo "- live-api-tests [existing]" +echo "- ollama-integration-test [existing]" +echo "- enforce [NEW - linked issue validation]" +echo "- size [NEW - PR size limit]" +echo "- protect-infrastructure [NEW - infrastructure file protection]" diff --git a/.github/scripts/add-size-labels.sh b/.github/scripts/add-size-labels.sh new file mode 100755 index 00000000..d3e6795c --- /dev/null +++ b/.github/scripts/add-size-labels.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Add size labels to PRs based on their change count + +echo "Adding size labels to PRs..." + +# Get all open PRs with their additions and deletions +gh pr list --limit 50 --json number,additions,deletions --jq '.[]' | while read -r pr_data; do + pr_number=$(echo "$pr_data" | jq -r '.number') + additions=$(echo "$pr_data" | jq -r '.additions') + deletions=$(echo "$pr_data" | jq -r '.deletions') + total_changes=$((additions + deletions)) + + # Determine size label + if [ $total_changes -lt 50 ]; then + size_label="size/XS" + elif [ $total_changes -lt 150 ]; then + size_label="size/S" + elif [ $total_changes -lt 600 ]; then + size_label="size/M" + elif [ $total_changes -lt 1000 ]; then + size_label="size/L" + else + size_label="size/XL" + fi + + echo "PR #$pr_number: $total_changes lines -> $size_label" + + # Remove any existing size labels first + existing_labels=$(gh pr view $pr_number --json labels --jq '.labels[].name' | grep "^size/" || true) + if [ ! -z "$existing_labels" ]; then + echo " Removing existing label: $existing_labels" + gh pr edit $pr_number --remove-label "$existing_labels" + fi + + # Add the new size label + gh pr edit $pr_number --add-label "$size_label" + + sleep 1 # Avoid rate limiting +done + +echo "Done adding size labels!" diff --git a/.github/scripts/revalidate-all-prs.sh b/.github/scripts/revalidate-all-prs.sh new file mode 100755 index 00000000..5bf85c69 --- /dev/null +++ b/.github/scripts/revalidate-all-prs.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Revalidate all open PRs + +echo "Fetching all open PRs..." +PR_NUMBERS=$(gh pr list --limit 50 --json number --jq '.[].number') +TOTAL=$(echo "$PR_NUMBERS" | wc -w | tr -d ' ') + +echo "Found $TOTAL open PRs" +echo "Starting revalidation..." +echo "" + +COUNT=0 +for pr in $PR_NUMBERS; do + COUNT=$((COUNT + 1)) + echo "[$COUNT/$TOTAL] Triggering revalidation for PR #$pr..." + gh workflow run revalidate-pr.yml -f pr_number=$pr + + # Small delay to avoid rate limiting + sleep 2 +done + +echo "" +echo "All workflows triggered!" +echo "" +echo "To monitor progress:" +echo " gh run list --workflow=revalidate-pr.yml --limit=$TOTAL" +echo "" +echo "To see results, check comments on each PR" diff --git a/.github/workflows/auto-update-pr.yaml b/.github/workflows/auto-update-pr.yaml new file mode 100644 index 00000000..ed5cee96 --- /dev/null +++ b/.github/workflows/auto-update-pr.yaml @@ -0,0 +1,166 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Auto Update PR + +on: + push: + branches: [main] + schedule: + # Run daily at 2 AM UTC to catch stale PRs + - cron: '0 2 * * *' + workflow_dispatch: + inputs: + pr_number: + description: 'PR number to update (optional, updates all if not specified)' + required: false + type: string + +permissions: + contents: write # Required for updateBranch API + pull-requests: write + issues: write + +jobs: + update-prs: + runs-on: ubuntu-latest + concurrency: + group: auto-update-pr-${{ github.event_name }} + cancel-in-progress: true + steps: + - name: Update PRs that are behind main + uses: actions/github-script@v7 + with: + script: | + const prNumber = context.payload.inputs?.pr_number; + + // Get list of open PRs + const prs = prNumber + ? [(await github.rest.pulls.get({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: parseInt(prNumber) + })).data] + : await github.paginate(github.rest.pulls.list, { + owner: context.repo.owner, + repo: context.repo.repo, + state: 'open', + sort: 'updated', + direction: 'desc' + }); + + console.log(`Found ${prs.length} open PRs to check`); + + // Constants for comment flood control + const UPDATE_COMMENT_COOLDOWN_DAYS = 7; + const COOLDOWN_MS = UPDATE_COMMENT_COOLDOWN_DAYS * 24 * 60 * 60 * 1000; + + for (const pr of prs) { + // Skip bot PRs and drafts + if (pr.user.login.includes('[bot]')) { + console.log(`Skipping bot PR #${pr.number} from ${pr.user.login}`); + continue; + } + if (pr.draft) { + console.log(`Skipping draft PR #${pr.number}`); + continue; + } + + try { + // Check if PR is behind main (base...head comparison) + const { data: comparison } = await github.rest.repos.compareCommits({ + owner: context.repo.owner, + repo: context.repo.repo, + base: pr.base.ref, // main branch + head: `${pr.head.repo.owner.login}:${pr.head.ref}` // Fully qualified ref for forks + }); + + if (comparison.behind_by > 0) { + console.log(`PR #${pr.number} is ${comparison.behind_by} commits behind ${pr.base.ref}`); + + // Check if the PR allows maintainer edits + if (pr.maintainer_can_modify) { + // Try to update the branch + try { + await github.rest.pulls.updateBranch({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: pr.number + }); + + console.log(`โœ… Updated PR #${pr.number}`); + + // Add a comment + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: pr.number, + body: `๐Ÿ”„ **Branch Updated**\n\nYour branch was ${comparison.behind_by} commits behind \`${pr.base.ref}\` and has been automatically updated. CI checks will re-run shortly.` + }); + } catch (updateError) { + console.log(`Could not auto-update PR #${pr.number}: ${updateError.message}`); + + // Determine the reason for failure + let failureReason = ''; + if (updateError.status === 409 || updateError.message.includes('merge conflict')) { + failureReason = '\n\n**Note:** Automatic update failed due to merge conflicts. Please resolve them manually.'; + } else if (updateError.status === 422) { + failureReason = '\n\n**Note:** Cannot push to fork. Please update manually.'; + } + + // Notify the contributor to update manually + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: pr.number, + body: `โš ๏ธ **Branch Update Required**\n\nYour branch is ${comparison.behind_by} commits behind \`${pr.base.ref}\`.${failureReason}\n\nPlease update your branch:\n\n\`\`\`bash\ngit fetch origin ${pr.base.ref}\ngit merge origin/${pr.base.ref}\ngit push\n\`\`\`\n\nOr use GitHub's "Update branch" button if available.` + }); + } + } else { + // Can't modify, just notify + console.log(`PR #${pr.number} doesn't allow maintainer edits`); + + // Check if we already commented recently (within last 7 days) + const { data: comments } = await github.rest.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: pr.number, + since: new Date(Date.now() - COOLDOWN_MS).toISOString() + }); + + const hasRecentUpdateComment = comments.some(c => + c.body?.includes('Branch Update Required') && + c.user?.login === 'github-actions[bot]' + ); + + if (!hasRecentUpdateComment) { + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: pr.number, + body: `โš ๏ธ **Branch Update Required**\n\nYour branch is ${comparison.behind_by} commits behind \`${pr.base.ref}\`. Please update your branch to ensure CI checks run with the latest code:\n\n\`\`\`bash\ngit fetch origin ${pr.base.ref}\ngit merge origin/${pr.base.ref}\ngit push\n\`\`\`\n\nNote: Enable "Allow edits by maintainers" to allow automatic updates.` + }); + } + } + } else { + console.log(`PR #${pr.number} is up to date`); + } + } catch (error) { + console.error(`Error processing PR #${pr.number}:`, error.message); + } + } + + // Log rate limit status + const { data: rateLimit } = await github.rest.rateLimit.get(); + console.log(`API rate limit remaining: ${rateLimit.rate.remaining}/${rateLimit.rate.limit}`); diff --git a/.github/workflows/check-infrastructure-changes.yml b/.github/workflows/check-infrastructure-changes.yml new file mode 100644 index 00000000..915d7e38 --- /dev/null +++ b/.github/workflows/check-infrastructure-changes.yml @@ -0,0 +1,100 @@ +name: Protect Infrastructure Files + +on: + pull_request_target: + types: [opened, synchronize, reopened] + workflow_dispatch: + +permissions: + contents: read + pull-requests: write + +jobs: + protect-infrastructure: + if: github.event_name == 'workflow_dispatch' || github.event.pull_request.draft == false + runs-on: ubuntu-latest + + steps: + - name: Check for infrastructure file changes + if: github.event_name == 'pull_request_target' + uses: actions/github-script@v7 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + // Get the PR author and check if they're a maintainer + const prAuthor = context.payload.pull_request.user.login; + const { data: authorPermission } = await github.rest.repos.getCollaboratorPermissionLevel({ + owner: context.repo.owner, + repo: context.repo.repo, + username: prAuthor + }); + + const isMaintainer = ['admin', 'maintain'].includes(authorPermission.permission); + + // Get list of files changed in the PR + const { data: files } = await github.rest.pulls.listFiles({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: context.payload.pull_request.number + }); + + // Check for infrastructure file changes + const infrastructureFiles = files.filter(file => + file.filename.startsWith('.github/') || + file.filename === 'pyproject.toml' || + file.filename === 'tox.ini' || + file.filename === '.pre-commit-config.yaml' || + file.filename === '.pylintrc' || + file.filename === 'Dockerfile' || + file.filename === 'autoformat.sh' || + file.filename === '.gitignore' || + file.filename === 'CONTRIBUTING.md' || + file.filename === 'LICENSE' || + file.filename === 'CITATION.cff' + ); + + if (infrastructureFiles.length > 0 && !isMaintainer) { + // Check if changes are only formatting/whitespace + let hasStructuralChanges = false; + for (const file of infrastructureFiles) { + const additions = file.additions || 0; + const deletions = file.deletions || 0; + const changes = file.changes || 0; + + // If file has significant changes (not just whitespace), consider it structural + if (additions > 5 || deletions > 5 || changes > 10) { + hasStructuralChanges = true; + break; + } + } + + const fileList = infrastructureFiles.map(f => ` - ${f.filename} (${f.changes} changes)`).join('\n'); + + // Post a comment explaining the issue + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `โŒ **Infrastructure File Protection**\n\n` + + `This PR modifies protected infrastructure files:\n\n${fileList}\n\n` + + `Only repository maintainers are allowed to modify infrastructure files (including \`.github/\`, build configuration, and repository documentation).\n\n` + + `**Note**: If these are only formatting changes, please:\n` + + `1. Revert changes to \`.github/\` files\n` + + `2. Use \`./autoformat.sh\` to format only source code directories\n` + + `3. Avoid running formatters on infrastructure files\n\n` + + `If structural changes are necessary:\n` + + `1. Open an issue describing the needed infrastructure changes\n` + + `2. A maintainer will review and implement the changes if approved\n\n` + + `For more information, see our [Contributing Guidelines](https://github.com/google/langextract/blob/main/CONTRIBUTING.md).` + }); + + core.setFailed( + `This PR modifies ${infrastructureFiles.length} protected infrastructure file(s). ` + + `Only maintainers can modify these files. ` + + `Use ./autoformat.sh to format code without touching infrastructure.` + ); + } else if (infrastructureFiles.length > 0 && isMaintainer) { + core.info(`PR modifies ${infrastructureFiles.length} infrastructure file(s) - allowed for maintainer ${prAuthor}`); + } else { + core.info('No infrastructure files modified'); + } diff --git a/.github/workflows/check-linked-issue.yml b/.github/workflows/check-linked-issue.yml new file mode 100644 index 00000000..916c338a --- /dev/null +++ b/.github/workflows/check-linked-issue.yml @@ -0,0 +1,90 @@ +name: Require linked issue with community support + +on: + pull_request_target: + types: [opened, edited, synchronize, reopened] + workflow_dispatch: + +permissions: + contents: read + issues: read + pull-requests: write + +jobs: + enforce: + if: github.event_name == 'workflow_dispatch' || github.event.pull_request.draft == false + runs-on: ubuntu-latest + + steps: + - name: Verify linked issue + if: github.event_name == 'pull_request_target' + uses: nearform-actions/github-action-check-linked-issues@v1.2.7 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + comment: true + exclude-branches: main + custom-body: | + No linked issues found. Please add "Fixes #" to your pull request description. + + Per our [Contributing Guidelines](https://github.com/google/langextract/blob/main/CONTRIBUTING.md#pull-request-guidelines), all PRs must: + - Reference an issue with "Fixes #123" or "Closes #123" + - The linked issue should have 5+ ๐Ÿ‘ reactions + - Include discussion demonstrating the importance of the change + + Use GitHub automation to close the issue when this PR is merged. + + - name: Check community support + if: github.event_name == 'pull_request_target' + uses: actions/github-script@v7 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + // Check if PR author is a maintainer + const prAuthor = context.payload.pull_request.user.login; + const { data: authorPermission } = await github.rest.repos.getCollaboratorPermissionLevel({ + owner: context.repo.owner, + repo: context.repo.repo, + username: prAuthor + }); + + const isMaintainer = ['admin', 'maintain'].includes(authorPermission.permission); + + const body = context.payload.pull_request.body || ''; + const match = body.match(/(?:Fixes|Closes|Resolves)\s+#(\d+)/i); + + if (!match) { + core.setFailed('No linked issue found'); + return; + } + + const issueNumber = Number(match[1]); + const { repository } = await github.graphql(` + query($owner: String!, $repo: String!, $number: Int!) { + repository(owner: $owner, name: $repo) { + issue(number: $number) { + reactionGroups { + content + users { + totalCount + } + } + } + } + } + `, { + owner: context.repo.owner, + repo: context.repo.repo, + number: issueNumber + }); + + const reactions = repository.issue.reactionGroups; + const thumbsUp = reactions.find(g => g.content === 'THUMBS_UP')?.users.totalCount || 0; + + core.info(`Issue #${issueNumber} has ${thumbsUp} ๐Ÿ‘ reactions`); + + const REQUIRED_THUMBS_UP = 5; + if (thumbsUp < REQUIRED_THUMBS_UP && !isMaintainer) { + core.setFailed(`Issue #${issueNumber} needs at least ${REQUIRED_THUMBS_UP} ๐Ÿ‘ reactions (currently has ${thumbsUp})`); + } else if (isMaintainer && thumbsUp < REQUIRED_THUMBS_UP) { + core.info(`Maintainer ${prAuthor} bypassing community support requirement (issue has ${thumbsUp} ๐Ÿ‘ reactions)`); + } diff --git a/.github/workflows/check-pr-size.yml b/.github/workflows/check-pr-size.yml new file mode 100644 index 00000000..d97e54bd --- /dev/null +++ b/.github/workflows/check-pr-size.yml @@ -0,0 +1,62 @@ +name: Check PR size + +on: + pull_request_target: + types: [opened, synchronize, reopened] + workflow_dispatch: + inputs: + pr_number: + description: 'PR number to check (optional)' + required: false + type: string + +permissions: + contents: read + pull-requests: write + +jobs: + size: + runs-on: ubuntu-latest + steps: + - name: Get PR data for manual trigger + if: github.event_name == 'workflow_dispatch' && github.event.inputs.pr_number + id: get_pr + uses: actions/github-script@v7 + with: + script: | + const { data: pr } = await github.rest.pulls.get({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: ${{ github.event.inputs.pr_number }} + }); + return pr; + + - name: Evaluate PR size + if: github.event_name == 'pull_request_target' || (github.event_name == 'workflow_dispatch' && github.event.inputs.pr_number) + uses: actions/github-script@v7 + with: + script: | + const pr = context.payload.pull_request || ${{ steps.get_pr.outputs.result || '{}' }}; + const totalChanges = pr.additions + pr.deletions; + + core.info(`PR contains ${pr.additions} additions and ${pr.deletions} deletions (${totalChanges} total)`); + + const sizeLabel = + totalChanges < 50 ? 'size/XS' : + totalChanges < 150 ? 'size/S' : + totalChanges < 600 ? 'size/M' : + totalChanges < 1000 ? 'size/L' : 'size/XL'; + + await github.rest.issues.addLabels({ + ...context.repo, + issue_number: pr.number, + labels: [sizeLabel] + }); + + const MAX_LINES = 1000; + if (totalChanges > MAX_LINES) { + core.setFailed( + `This PR contains ${totalChanges} lines of changes, which exceeds the maximum of ${MAX_LINES} lines. ` + + `Please split this into smaller, focused pull requests.` + ); + } diff --git a/.github/workflows/check-pr-up-to-date.yaml b/.github/workflows/check-pr-up-to-date.yaml new file mode 100644 index 00000000..59c259c0 --- /dev/null +++ b/.github/workflows/check-pr-up-to-date.yaml @@ -0,0 +1,87 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Check PR Up-to-Date + +on: + pull_request: + types: [opened, synchronize] + +permissions: + contents: read + pull-requests: write + +jobs: + check-up-to-date: + runs-on: ubuntu-latest + # Skip for bot PRs + if: ${{ !contains(github.actor, '[bot]') }} + concurrency: + group: check-pr-${{ github.event.pull_request.number }} + cancel-in-progress: true + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 2 # Sufficient for rev-list comparison + + - name: Check if PR is up-to-date with main + id: check + run: | + # Fetch the latest main branch + git fetch origin main + + # Check how many commits behind main + BEHIND=$(git rev-list --count HEAD..origin/main) + + echo "commits_behind=$BEHIND" >> $GITHUB_OUTPUT + + if [ "$BEHIND" -gt 0 ]; then + echo "::warning::PR is $BEHIND commits behind main" + exit 0 # Don't fail the check, just warn + else + echo "PR is up-to-date with main" + fi + + - name: Comment if PR needs update + if: ${{ steps.check.outputs.commits_behind != '0' }} + uses: actions/github-script@v7 + with: + script: | + const behind = ${{ steps.check.outputs.commits_behind }}; + const COMMENT_COOLDOWN_HOURS = 24; + const COOLDOWN_MS = COMMENT_COOLDOWN_HOURS * 60 * 60 * 1000; + + // Check for recent similar comments + const { data: comments } = await github.rest.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + per_page: 10 + }); + + const hasRecentComment = comments.some(c => + c.body?.includes('commits behind `main`') && + c.user?.login === 'github-actions[bot]' && + new Date(c.created_at) > new Date(Date.now() - COOLDOWN_MS) + ); + + if (!hasRecentComment) { + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `๐Ÿ“Š **PR Status**: ${behind} commits behind \`main\`\n\nConsider updating your branch for the most accurate CI results:\n\n**Option 1**: Use GitHub's "Update branch" button (if available)\n\n**Option 2**: Update locally:\n\`\`\`bash\ngit fetch origin main\ngit merge origin/main\ngit push\n\`\`\`\n\n*Note: If you use a different remote name (e.g., upstream), adjust the commands accordingly.*\n\nThis ensures your changes are tested against the latest code.` + }); + } diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index fc8a2a87..dc166fa9 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -15,20 +15,72 @@ name: CI on: + workflow_dispatch: push: branches: ["main"] pull_request: branches: ["main"] + pull_request_target: + types: [labeled] permissions: contents: read + issues: write + pull-requests: write jobs: + # Validates formatting on the PR branch directly (not the merge commit) + # to ensure code style compliance before merging + format-check: + runs-on: ubuntu-latest + if: github.event_name == 'pull_request' + steps: + - name: Checkout PR branch + uses: actions/checkout@v4 + with: + # Check the actual PR branch to catch formatting issues + ref: ${{ github.event.pull_request.head.sha }} + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + + - name: Install format tools + run: | + python -m pip install --upgrade pip + pip install tox + + - name: Check formatting + id: format-check + run: | + tox -e format + + - name: Comment on PR if formatting fails + if: always() && steps.format-check.outcome == 'failure' + uses: actions/github-script@v7 + with: + script: | + github.rest.issues.createComment({ + issue_number: context.payload.pull_request.number, + owner: context.repo.owner, + repo: context.repo.repo, + body: `โŒ **Formatting Check Failed** + + Your PR has formatting issues. Please run the following command locally and push the changes: + + \`\`\`bash + ./autoformat.sh + \`\`\` + + This will automatically fix all formatting issues using pyink (Google's Python formatter) and isort.` + }) + test: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10", "3.11"] + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 @@ -42,6 +94,203 @@ jobs: python -m pip install --upgrade pip pip install -e ".[dev,test]" - - name: Run tox (lint + tests) + - name: Run unit tests and linting + run: | + PY_VERSION=$(echo "${{ matrix.python-version }}" | tr -d '.') + # Format check is handled by separate job for better isolation + tox -e py${PY_VERSION},lint-src,lint-tests + + live-api-tests: + needs: test + runs-on: ubuntu-latest + if: | + github.event_name == 'push' || + (github.event_name == 'pull_request' && + github.event.pull_request.head.repo.full_name == github.repository) + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev,test]" + + - name: Run live API tests + env: + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + LANGEXTRACT_API_KEY: ${{ secrets.GEMINI_API_KEY }} # For backward compatibility + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: | + if [[ -z "$GEMINI_API_KEY" && -z "$OPENAI_API_KEY" ]]; then + echo "::notice::Live API tests skipped - no provider secrets configured" + exit 0 + fi + tox -e live-api + + ollama-integration-test: + needs: test + runs-on: ubuntu-latest + if: github.event_name == 'pull_request' + + steps: + - uses: actions/checkout@v4 + + - name: Detect file changes + id: changes + uses: tj-actions/changed-files@v46 + with: + files: | + langextract/inference.py + examples/ollama/** + tests/test_ollama_integration.py + .github/workflows/ci.yaml + + - name: Skip if no Ollama changes + if: steps.changes.outputs.any_changed == 'false' + run: | + echo "No Ollama-related changes detected โ€“ skipping job." + exit 0 + + - name: Set up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: "3.11" + + - name: Launch Ollama container + run: | + docker run -d --name ollama \ + -p 127.0.0.1:11434:11434 \ + -v ollama:/root/.ollama \ + ollama/ollama:0.5.4 + for i in {1..20}; do + curl -fs http://localhost:11434/api/version && break + sleep 3 + done + + - name: Pull gemma2 model + run: docker exec ollama ollama pull gemma2:2b || true + + - name: Install tox + run: | + python -m pip install --upgrade pip + pip install tox + + - name: Run Ollama integration tests + run: tox -e ollama-integration + + test-fork-pr: + runs-on: ubuntu-latest + # Triggered when a maintainer adds 'ready-to-merge' label + if: | + github.event_name == 'pull_request_target' && + github.event.action == 'labeled' && + contains(github.event.label.name, 'ready-to-merge') + + steps: + - name: Check if user is maintainer + uses: actions/github-script@v7 + with: + script: | + const { data: permission } = await github.rest.repos.getCollaboratorPermissionLevel({ + owner: context.repo.owner, + repo: context.repo.repo, + username: context.actor + }); + + const isMaintainer = ['admin', 'maintain'].includes(permission.permission); + if (!isMaintainer) { + throw new Error(`User ${context.actor} does not have maintainer permissions.`); + } + + - name: Checkout PR branch directly + uses: actions/checkout@v4 + with: + # Validate formatting on actual PR code before running expensive tests + ref: ${{ github.event.pull_request.head.sha }} + fetch-depth: 0 + + - name: Set up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev,test]" + + - name: Validate PR formatting + run: | + echo "Validating code formatting..." + tox -e format || { + echo "::error::Code formatting does not meet project standards. Please run ./autoformat.sh locally and push the changes." + exit 1 + } + + - name: Checkout main branch + uses: actions/checkout@v4 + with: + ref: main + fetch-depth: 0 + + - name: Merge PR safely for testing + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + # pull_request_target runs in base repo context, so this is safe + git fetch origin pull/${{ github.event.pull_request.number }}/head:pr-to-test + git merge pr-to-test --no-ff --no-edit + + - name: Add status comment + uses: actions/github-script@v7 + with: + script: | + github.rest.issues.createComment({ + issue_number: context.payload.pull_request.number, + owner: context.repo.owner, + repo: context.repo.repo, + body: 'Running live API tests... This will take a few minutes.' + }); + + - name: Run live API tests + env: + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + LANGEXTRACT_API_KEY: ${{ secrets.GEMINI_API_KEY }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | - tox \ No newline at end of file + if [[ -z "$GEMINI_API_KEY" && -z "$OPENAI_API_KEY" ]]; then + echo "::error::Live API tests skipped - no provider secrets configured" + exit 1 + fi + tox -e live-api + + - name: Report success + if: success() + uses: actions/github-script@v7 + with: + script: | + github.rest.issues.createComment({ + issue_number: context.payload.pull_request.number, + owner: context.repo.owner, + repo: context.repo.repo, + body: 'โœ… Live API tests passed! All endpoints are working correctly.' + }); + + - name: Report failure + if: failure() + uses: actions/github-script@v7 + with: + script: | + github.rest.issues.createComment({ + issue_number: context.payload.pull_request.number, + owner: context.repo.owner, + repo: context.repo.repo, + body: 'โŒ Live API tests failed. Please check the workflow logs for details.' + }); diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 00000000..cb3ff700 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,55 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Publish to PyPI + +on: + release: + types: [published] + +permissions: + contents: read + id-token: write + +jobs: + pypi-publish: + name: Publish to PyPI + runs-on: ubuntu-latest + environment: pypi + permissions: + id-token: write + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install build dependencies + run: | + python -m pip install --upgrade pip + pip install build + + - name: Build package + run: python -m build + + - name: Verify build artifacts + run: | + ls -la dist/ + pip install twine + twine check dist/* + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/revalidate-pr.yml b/.github/workflows/revalidate-pr.yml new file mode 100644 index 00000000..2b0deb8b --- /dev/null +++ b/.github/workflows/revalidate-pr.yml @@ -0,0 +1,157 @@ +name: Revalidate PR + +on: + workflow_dispatch: + inputs: + pr_number: + description: 'PR number to validate' + required: true + type: string + +permissions: + contents: read + pull-requests: write + issues: write + checks: write + statuses: write + +jobs: + revalidate: + runs-on: ubuntu-latest + steps: + - name: Get PR data + id: pr_data + uses: actions/github-script@v7 + with: + script: | + const { data: pr } = await github.rest.pulls.get({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: ${{ inputs.pr_number }} + }); + + core.info(`Validating PR #${pr.number}: ${pr.title}`); + core.info(`Author: ${pr.user.login}`); + core.info(`Changes: +${pr.additions} -${pr.deletions}`); + + // Store head SHA for creating status + core.setOutput('head_sha', pr.head.sha); + + return pr; + + - name: Create pending status + uses: actions/github-script@v7 + with: + script: | + await github.rest.repos.createCommitStatus({ + owner: context.repo.owner, + repo: context.repo.repo, + sha: '${{ steps.pr_data.outputs.head_sha }}', + state: 'pending', + context: 'Manual Validation', + description: 'Running validation checks...' + }); + + - name: Validate PR + id: validate + uses: actions/github-script@v7 + with: + script: | + const pr = ${{ steps.pr_data.outputs.result }}; + const errors = []; + let passed = true; + + // Check size + const totalChanges = pr.additions + pr.deletions; + const MAX_LINES = 1000; + if (totalChanges > MAX_LINES) { + errors.push(`PR size (${totalChanges} lines) exceeds ${MAX_LINES} line limit`); + passed = false; + } + + // Check template + const body = pr.body || ''; + const requiredSections = ["# Description", "Fixes #", "# How Has This Been Tested?", "# Checklist"]; + const missingSections = requiredSections.filter(section => !body.includes(section)); + + if (missingSections.length > 0) { + errors.push(`Missing PR template sections: ${missingSections.join(', ')}`); + passed = false; + } + + if (body.match(/Replace this with|Choose one:|Fixes #\[issue number\]/i)) { + errors.push('PR template contains unmodified placeholders'); + passed = false; + } + + // Check linked issue + const issueMatch = body.match(/(?:Fixes|Closes|Resolves)\s+#(\d+)/i); + if (!issueMatch) { + errors.push('No linked issue found'); + passed = false; + } + + // Store results + core.setOutput('passed', passed); + core.setOutput('errors', errors.join('; ')); + core.setOutput('totalChanges', totalChanges); + core.setOutput('hasTemplate', missingSections.length === 0); + core.setOutput('hasIssue', !!issueMatch); + + if (!passed) { + core.setFailed(errors.join('; ')); + } + + - name: Update commit status + if: always() + uses: actions/github-script@v7 + with: + script: | + const passed = ${{ steps.validate.outputs.passed }}; + const errors = '${{ steps.validate.outputs.errors }}'; + + await github.rest.repos.createCommitStatus({ + owner: context.repo.owner, + repo: context.repo.repo, + sha: '${{ steps.pr_data.outputs.head_sha }}', + state: passed ? 'success' : 'failure', + context: 'Manual Validation', + description: passed ? 'All validation checks passed' : errors.substring(0, 140), + target_url: `https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}` + }); + + - name: Add validation comment + if: always() + uses: actions/github-script@v7 + with: + script: | + const pr = ${{ steps.pr_data.outputs.result }}; + const passed = ${{ steps.validate.outputs.passed }}; + const totalChanges = ${{ steps.validate.outputs.totalChanges }}; + const hasTemplate = ${{ steps.validate.outputs.hasTemplate }}; + const hasIssue = ${{ steps.validate.outputs.hasIssue }}; + const errors = '${{ steps.validate.outputs.errors }}'.split('; ').filter(e => e); + + let body = `### Manual Validation Results\n\n`; + body += `**Status**: ${passed ? 'โœ… Passed' : 'โŒ Failed'}\n\n`; + body += `| Check | Status | Details |\n`; + body += `|-------|--------|----------|\n`; + body += `| PR Size | ${totalChanges <= 1000 ? 'โœ…' : 'โŒ'} | ${totalChanges} lines ${totalChanges > 1000 ? '(exceeds 1000 limit)' : ''} |\n`; + body += `| Template | ${hasTemplate ? 'โœ…' : 'โŒ'} | ${hasTemplate ? 'Complete' : 'Missing required sections'} |\n`; + body += `| Linked Issue | ${hasIssue ? 'โœ…' : 'โŒ'} | ${hasIssue ? 'Found' : 'Missing Fixes/Closes #XXX'} |\n`; + + if (errors.length > 0) { + body += `\n**Errors:**\n`; + errors.forEach(error => { + body += `- โŒ ${error}\n`; + }); + } + + body += `\n[View workflow run](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId})`; + + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: pr.number, + body: body + }); diff --git a/.github/workflows/validate_pr_template.yaml b/.github/workflows/validate_pr_template.yaml new file mode 100644 index 00000000..de17835a --- /dev/null +++ b/.github/workflows/validate_pr_template.yaml @@ -0,0 +1,43 @@ +name: Validate PR template + +on: + pull_request_target: + types: [opened, edited, synchronize, reopened] + workflow_dispatch: + +permissions: + contents: read + +jobs: + check: + if: github.event_name == 'workflow_dispatch' || github.event.pull_request.draft == false # drafts can save early + runs-on: ubuntu-latest + + steps: + - name: Fail if template untouched + if: github.event_name == 'pull_request_target' + env: + PR_BODY: ${{ github.event.pull_request.body }} + run: | + printf '%s\n' "$PR_BODY" | tr -d '\r' > body.txt + + # Required sections from the template + required=( "# Description" "Fixes #" "# How Has This Been Tested?" "# Checklist" ) + err=0 + + # Check for required sections + for h in "${required[@]}"; do + grep -Fq "$h" body.txt || { echo "::error::$h missing"; err=1; } + done + + # Check for placeholder text that should be replaced + grep -Eiq 'Replace this with|Choose one:' body.txt && { + echo "::error::Template placeholders still present"; err=1; + } + + # Also check for the unmodified issue number placeholder + grep -Fq 'Fixes #[issue number]' body.txt && { + echo "::error::Issue number placeholder not updated"; err=1; + } + + exit $err diff --git a/.gitignore b/.gitignore index 458f449d..fc93e588 100644 --- a/.gitignore +++ b/.gitignore @@ -51,4 +51,4 @@ docs/_build/ *.swp # OS-specific -.DS_Store \ No newline at end of file +.DS_Store diff --git a/.hgignore b/.hgignore index 4ef06c6c..3fb66f47 100644 --- a/.hgignore +++ b/.hgignore @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -gdm/codeai/codemind/cli/GEMINI.md \ No newline at end of file +gdm/codeai/codemind/cli/GEMINI.md diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..84410316 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,46 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Pre-commit hooks for LangExtract +# Install with: pre-commit install +# Run manually: pre-commit run --all-files + +repos: + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (import sorting) + # Configuration is in pyproject.toml + + - repo: https://github.com/google/pyink + rev: 24.3.0 + hooks: + - id: pyink + name: pyink (Google's Black fork) + args: ["--config", "pyproject.toml"] + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: end-of-file-fixer + exclude: \.gif$|\.svg$ + - id: trailing-whitespace + - id: check-yaml + - id: check-added-large-files + args: ['--maxkb=1000'] + - id: check-merge-conflict + - id: check-case-conflict + - id: mixed-line-ending + args: ['--fix=lf'] diff --git a/.pylintrc b/.pylintrc index 5709bc73..2e09c87f 100644 --- a/.pylintrc +++ b/.pylintrc @@ -14,10 +14,418 @@ [MASTER] +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use. +jobs=0 + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Pickle collected data for later comparisons. +persistent=yes + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +# Note: These plugins require Pylint >= 3.0 +load-plugins= + pylint.extensions.docparams, + pylint.extensions.typing + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + [MESSAGES CONTROL] -disable=all -enable=F + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time. +enable= + useless-suppression + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). +disable= + abstract-method, # Protocol/ABC classes often have abstract methods + too-few-public-methods, # Valid for data classes with minimal interface + fixme, # TODO/FIXME comments are useful for tracking work + # --- Code style and formatting --- + line-too-long, # Handled by pyink formatter + bad-indentation, # Pyink uses 2-space indentation + # --- Design complexity --- + too-many-positional-arguments, + too-many-locals, + too-many-arguments, + too-many-branches, + too-many-statements, + too-many-nested-blocks, + # --- Style preferences --- + no-else-return, + no-else-raise, + # --- Documentation --- + missing-function-docstring, + missing-class-docstring, + missing-raises-doc, + # --- Gradual improvements --- + deprecated-typing-alias, # For typing.Type etc. + unspecified-encoding, + unused-import + [REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. output-format=text -reports=no \ No newline at end of file + +# Tells whether to display a full report or only the messages +reports=no + +# Activate the evaluation score. +score=no + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo,bar,baz,toto,tutu,tata + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Good variable names which should always be accepted, separated by a comma. +good-names=i,j,k,ex,Run,_,id,ok + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Naming style matching correct variable names. +variable-naming-style=snake_case + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format=LF + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=2 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=" " + +# Maximum number of characters on a single line. +max-line-length=80 + +# Maximum number of lines in a module. +max-module-lines=2000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME,XXX,TODO + + +[SIMILARITIES] + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package.. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,dataclasses.InitVar,typing.Any + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules=dotenv,absl,more_itertools,pandas,requests,pydantic,yaml,IPython.display, + tqdm,numpy,google,langfun,typing_extensions + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=cls + + +[DESIGN] + +# Maximum number of arguments for function / method. +max-args=7 + +# Maximum number of attributes for a class (see R0902). +max-attributes=10 + +# Maximum number of boolean expressions in an if statement. +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=0 + + +[IMPORTS] + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=yes + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules=optparse,tkinter.tix + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled). +ext-import-graph= + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled). +import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant,numpy,pandas,torch,langfun,pyglove + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "BaseException, Exception". +overgeneral-exceptions=BaseException, + Exception diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 00000000..2eb3134a --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 +# +# This file contains citation metadata for LangExtract. +# For more information visit: https://citation-file-format.github.io/ + +cff-version: 1.2.0 +title: "LangExtract" +message: "If you use this software, please cite it as below." +type: software +authors: + - given-names: Akshay + family-names: Goel + email: goelak@google.com + affiliation: Google LLC +repository-code: "https://github.com/google/langextract" +url: "https://github.com/google/langextract" +repository: "https://github.com/google/langextract" +abstract: "LangExtract: A library for extracting structured data from language models" +keywords: + - language-models + - structured-data-extraction + - nlp + - machine-learning + - python +license: Apache-2.0 +version: 1.0.3 +date-released: 2025-07-30 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 724ff7f6..aa5038d4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -23,13 +23,117 @@ sign a new one. This project follows HAI-DEF's [Community guidelines](https://developers.google.com/health-ai-developer-foundations/community-guidelines) -## Contribution process +## Reporting Issues -### Code Reviews +If you encounter a bug or have a feature request, please open an issue on GitHub. +We have templates to help guide you: + +- **[Bug Report](.github/ISSUE_TEMPLATE/1-bug.md)**: For reporting bugs or unexpected behavior +- **[Feature Request](.github/ISSUE_TEMPLATE/2-feature-request.md)**: For suggesting new features or improvements + +When creating an issue, GitHub will prompt you to choose the appropriate template. +Please provide as much detail as possible to help us understand and address your concern. + +## Contribution Process + +### 1. Development Setup + +To get started, clone the repository and install the necessary dependencies for development and testing. Detailed instructions can be found in the [Installation from Source](https://github.com/google/langextract#from-source) section of the `README.md`. + +**Windows Users**: The formatting scripts use bash. Please use one of: +- Git Bash (comes with Git for Windows) +- WSL (Windows Subsystem for Linux) +- PowerShell with bash-compatible commands + +### 2. Code Style and Formatting + +This project uses automated tools to maintain a consistent code style. Before submitting a pull request, please format your code: + +```bash +# Run the auto-formatter +./autoformat.sh +``` + +This script uses: +- `isort` to organize imports with Google style (single-line imports) +- `pyink` (Google's fork of Black) to format code according to Google's Python Style Guide + +You can also run the formatters manually: +```bash +isort langextract tests +pyink langextract tests --config pyproject.toml +``` + +Note: The formatters target only `langextract` and `tests` directories by default to avoid +formatting virtual environments or other non-source directories. + +### 3. Pre-commit Hooks (Recommended) + +For automatic formatting checks before each commit: + +```bash +# Install pre-commit +pip install pre-commit + +# Install the git hooks +pre-commit install + +# Run manually on all files +pre-commit run --all-files +``` + +### 4. Linting and Testing + +All contributions must pass linting checks and unit tests. Please run these locally before submitting your changes: + +```bash +# Run linting with Pylint 3.x +pylint --rcfile=.pylintrc langextract tests + +# Run tests +pytest tests +``` + +**Note on Pylint Configuration**: We use a modern, minimal configuration that: +- Only disables truly noisy checks (not entire categories) +- Keeps critical error detection enabled +- Uses plugins for enhanced docstring and type checking +- Aligns with our pyink formatter (80-char lines, 2-space indents) + +For full testing across Python versions: +```bash +tox # runs pylint + pytest on Python 3.10 and 3.11 +``` + +### 5. Submit Your Pull Request All submissions, including submissions by project members, require review. We use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests) for this purpose. +When you create a pull request, GitHub will automatically populate it with our +[pull request template](.github/PULL_REQUEST_TEMPLATE/pull_request_template.md). +Please fill out all sections of the template to help reviewers understand your changes. + +#### Pull Request Guidelines + +- **Keep PRs focused and small**: Each PR should address a single issue and contain one cohesive change. PRs are automatically labeled by size to help reviewers: + - **size/XS**: < 50 lines โ€” Small fixes and documentation updates + - **size/S**: 50-150 lines โ€” Typical features or bug fixes + - **size/M**: 150-600 lines โ€” Larger features that remain well-scoped + - **size/L**: 600-1000 lines โ€” Consider splitting into smaller PRs if possible + - **size/XL**: > 1000 lines โ€” Requires strong justification and may need special review +- **Reference related issues**: All PRs must include "Fixes #123" or "Closes #123" in the description. The linked issue should have at least 5 ๐Ÿ‘ reactions from the community and include discussion that demonstrates the importance and need for the change. +- **No infrastructure changes**: Contributors cannot modify infrastructure files, build configuration, and core documentation. These files are protected and can only be changed by maintainers. Use `./autoformat.sh` to format code without affecting infrastructure files. In special circumstances, build configuration updates may be considered if they include discussion and evidence of robust testing, ideally with community support. +- **Single-change commits**: A PR should typically comprise a single git commit. Squash multiple commits before submitting. +- **Clear description**: Explain what your change does and why it's needed. +- **Ensure all tests pass**: Check that both formatting and tests are green before requesting review. +- **Respond to feedback promptly**: Address reviewer comments in a timely manner. + +If your change is large or complex, consider: +- Opening an issue first to discuss the approach +- Breaking it into multiple smaller PRs +- Clearly explaining in the PR description why a larger change is necessary + For more details, read HAI-DEF's [Contributing guidelines](https://developers.google.com/health-ai-developer-foundations/community-guidelines#contributing) diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..e8a74312 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,11 @@ +# Production Dockerfile for LangExtract +FROM python:3.10-slim + +# Set working directory +WORKDIR /app + +# Install LangExtract from PyPI +RUN pip install --no-cache-dir langextract + +# Set default command +CMD ["python"] diff --git a/README.md b/README.md index b49d4b6a..2cbe5820 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@

- LangExtract Logo + LangExtract Logo

# LangExtract -[![PyPI version](https://badge.fury.io/py/langextract.svg)](https://badge.fury.io/py/langextract) +[![PyPI version](https://img.shields.io/pypi/v/langextract.svg)](https://pypi.org/project/langextract/) [![GitHub stars](https://img.shields.io/github/stars/google/langextract.svg?style=social&label=Star)](https://github.com/google/langextract) ![Tests](https://github.com/google/langextract/actions/workflows/ci.yaml/badge.svg) @@ -17,6 +17,8 @@ - [Quick Start](#quick-start) - [Installation](#installation) - [API Key Setup for Cloud Models](#api-key-setup-for-cloud-models) +- [Using OpenAI Models](#using-openai-models) +- [Using Local LLMs with Ollama](#using-local-llms-with-ollama) - [More Examples](#more-examples) - [*Romeo and Juliet* Full Text Extraction](#romeo-and-juliet-full-text-extraction) - [Medication Extraction](#medication-extraction) @@ -111,7 +113,7 @@ The extractions can be saved to a `.jsonl` file, a popular format for working wi ```python # Save the results to a JSONL file -lx.io.save_annotated_documents([result], output_name="extraction_results.jsonl") +lx.io.save_annotated_documents([result], output_name="extraction_results.jsonl", output_dir=".") # Generate the visualization from the file html_content = lx.visualize("extraction_results.jsonl") @@ -121,7 +123,7 @@ with open("visualization.html", "w") as f: This creates an animated and interactive HTML file: -![Romeo and Juliet Basic Visualization ](docs/_static/romeo_juliet_basic.gif) +![Romeo and Juliet Basic Visualization ](https://raw.githubusercontent.com/google/langextract/main/docs/_static/romeo_juliet_basic.gif) > **Note on LLM Knowledge Utilization:** This example demonstrates extractions that stay close to the text evidence - extracting "longing" for Lady Juliet's emotional state and identifying "yearning" from "gazed longingly at the stars." The task could be modified to generate attributes that draw more heavily from the LLM's world knowledge (e.g., adding `"identity": "Capulet family daughter"` or `"literary_context": "tragic heroine"`). The balance between text-evidence and knowledge-inference is controlled by your prompt instructions and example attributes. @@ -142,7 +144,7 @@ result = lx.extract( ) ``` -This approach can extract hundreds of entities from full novels while maintaining high accuracy. The interactive visualization seamlessly handles large result sets, making it easy to explore hundreds of entities from the output JSONL file. **[See the full *Romeo and Juliet* extraction example โ†’](docs/examples/longer_text_example.md)** for detailed results and performance insights. +This approach can extract hundreds of entities from full novels while maintaining high accuracy. The interactive visualization seamlessly handles large result sets, making it easy to explore hundreds of entities from the output JSONL file. **[See the full *Romeo and Juliet* extraction example โ†’](https://github.com/google/langextract/blob/main/docs/examples/longer_text_example.md)** for detailed results and performance insights. ## Installation @@ -181,10 +183,16 @@ pip install -e ".[dev]" pip install -e ".[test]" ``` +### Docker + +```bash +docker build -t langextract . +docker run --rm -e LANGEXTRACT_API_KEY="your-api-key" langextract python your_script.py +``` ## API Key Setup for Cloud Models -When using LangExtract with cloud-hosted models (like Gemini), you'll need to +When using LangExtract with cloud-hosted models (like Gemini or OpenAI), you'll need to set up an API key. On-device models don't require an API key. For developers using local LLMs, LangExtract offers built-in support for Ollama and can be extended to other third-party APIs by updating the inference endpoints. @@ -195,6 +203,7 @@ Get API keys from: * [AI Studio](https://aistudio.google.com/app/apikey) for Gemini models * [Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/sdks/overview) for enterprise use +* [OpenAI Platform](https://platform.openai.com/api-keys) for OpenAI models ### Setting up API key in your environment @@ -244,6 +253,50 @@ result = lx.extract( ) ``` +## Using OpenAI Models + +LangExtract also supports OpenAI models. Example OpenAI configuration: + +```python +import langextract as lx + +result = lx.extract( + text_or_documents=input_text, + prompt_description=prompt, + examples=examples, + language_model_type=lx.inference.OpenAILanguageModel, + model_id="gpt-4o", + api_key=os.environ.get('OPENAI_API_KEY'), + fence_output=True, + use_schema_constraints=False +) +``` + +Note: OpenAI models require `fence_output=True` and `use_schema_constraints=False` because LangExtract doesn't implement schema constraints for OpenAI yet. + +## Using Local LLMs with Ollama + +LangExtract supports local inference using Ollama, allowing you to run models without API keys: + +```python +import langextract as lx + +result = lx.extract( + text_or_documents=input_text, + prompt_description=prompt, + examples=examples, + language_model_type=lx.inference.OllamaLanguageModel, + model_id="gemma2:2b", # or any Ollama model + model_url="http://localhost:11434", + fence_output=False, + use_schema_constraints=False +) +``` + +**Quick setup:** Install Ollama from [ollama.com](https://ollama.com/), run `ollama pull gemma2:2b`, then `ollama serve`. + +For detailed installation, Docker setup, and examples, see [`examples/ollama/`](examples/ollama/). + ## More Examples Additional examples of LangExtract in action: @@ -252,7 +305,7 @@ Additional examples of LangExtract in action: LangExtract can process complete documents directly from URLs. This example demonstrates extraction from the full text of *Romeo and Juliet* from Project Gutenberg (147,843 characters), showing parallel processing, sequential extraction passes, and performance optimization for long document processing. -**[View *Romeo and Juliet* Full Text Example โ†’](docs/examples/longer_text_example.md)** +**[View *Romeo and Juliet* Full Text Example โ†’](https://github.com/google/langextract/blob/main/docs/examples/longer_text_example.md)** ### Medication Extraction @@ -260,7 +313,7 @@ LangExtract can process complete documents directly from URLs. This example demo LangExtract excels at extracting structured medical information from clinical text. These examples demonstrate both basic entity recognition (medication names, dosages, routes) and relationship extraction (connecting medications to their attributes), showing LangExtract's effectiveness for healthcare applications. -**[View Medication Examples โ†’](docs/examples/medication_examples.md)** +**[View Medication Examples โ†’](https://github.com/google/langextract/blob/main/docs/examples/medication_examples.md)** ### Radiology Report Structuring: RadExtract @@ -270,7 +323,7 @@ Explore RadExtract, a live interactive demo on HuggingFace Spaces that shows how ## Contributing -Contributions are welcome! See [CONTRIBUTING.md](CONTRIBUTING.md) to get started +Contributions are welcome! See [CONTRIBUTING.md](https://github.com/google/langextract/blob/main/CONTRIBUTING.md) to get started with development, testing, and pull requests. You must sign a [Contributor License Agreement](https://cla.developers.google.com/about) before submitting patches. @@ -297,14 +350,58 @@ Or reproduce the full CI matrix locally with tox: tox # runs pylint + pytest on Python 3.10 and 3.11 ``` +### Ollama Integration Testing + +If you have Ollama installed locally, you can run integration tests: + +```bash +# Test Ollama integration (requires Ollama running with gemma2:2b model) +tox -e ollama-integration +``` + +This test will automatically detect if Ollama is available and run real inference tests. + +## Development + +### Code Formatting + +This project uses automated formatting tools to maintain consistent code style: + +```bash +# Auto-format all code +./autoformat.sh + +# Or run formatters separately +isort langextract tests --profile google --line-length 80 +pyink langextract tests --config pyproject.toml +``` + +### Pre-commit Hooks + +For automatic formatting checks: +```bash +pre-commit install # One-time setup +pre-commit run --all-files # Manual run +``` + +### Linting + +Run linting before submitting PRs: + +```bash +pylint --rcfile=.pylintrc langextract tests +``` + +See [CONTRIBUTING.md](CONTRIBUTING.md) for full development guidelines. + ## Disclaimer This is not an officially supported Google product. If you use LangExtract in production or publications, please cite accordingly and -acknowledge usage. Use is subject to the [Apache 2.0 License](LICENSE). +acknowledge usage. Use is subject to the [Apache 2.0 License](https://github.com/google/langextract/blob/main/LICENSE). For health-related applications, use of LangExtract is also subject to the [Health AI Developer Foundations Terms of Use](https://developers.google.com/health-ai-developer-foundations/terms). --- -**Happy Extracting!** \ No newline at end of file +**Happy Extracting!** diff --git a/autoformat.sh b/autoformat.sh new file mode 100755 index 00000000..5b7b1897 --- /dev/null +++ b/autoformat.sh @@ -0,0 +1,125 @@ +#!/bin/bash +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Autoformat LangExtract codebase +# +# Usage: ./autoformat.sh [target_directory ...] +# If no target is specified, formats the current directory +# +# This script runs: +# 1. isort for import sorting +# 2. pyink (Google's Black fork) for code formatting +# 3. pre-commit hooks for additional formatting (trailing whitespace, end-of-file, etc.) + +set -e + +echo "LangExtract Auto-formatter" +echo "==========================" +echo + +# Check for required tools +check_tool() { + if ! command -v "$1" &> /dev/null; then + echo "Error: $1 not found. Please install with: pip install $1" + exit 1 + fi +} + +check_tool "isort" +check_tool "pyink" +check_tool "pre-commit" + +# Parse command line arguments +show_usage() { + echo "Usage: $0 [target_directory ...]" + echo + echo "Formats Python code using isort and pyink according to Google style." + echo + echo "Arguments:" + echo " target_directory One or more directories to format (default: langextract tests)" + echo + echo "Examples:" + echo " $0 # Format langextract and tests directories" + echo " $0 langextract # Format only langextract directory" + echo " $0 src tests # Format multiple specific directories" +} + +# Check for help flag +if [ "$1" = "-h" ] || [ "$1" = "--help" ]; then + show_usage + exit 0 +fi + +# Determine target directories +if [ $# -eq 0 ]; then + TARGETS="langextract tests" + echo "No target specified. Formatting default directories: langextract tests" +else + TARGETS="$@" + echo "Formatting targets: $TARGETS" +fi + +# Find pyproject.toml relative to script location +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +CONFIG_FILE="${SCRIPT_DIR}/pyproject.toml" + +if [ ! -f "$CONFIG_FILE" ]; then + echo "Warning: pyproject.toml not found at ${CONFIG_FILE}" + echo "Using default configuration." + CONFIG_ARG="" +else + CONFIG_ARG="--config $CONFIG_FILE" +fi + +echo + +# Run isort +echo "Running isort to organize imports..." +if isort $TARGETS; then + echo "Import sorting complete" +else + echo "Import sorting failed" + exit 1 +fi + +echo + +# Run pyink +echo "Running pyink to format code (Google style, 80 chars)..." +if pyink $TARGETS $CONFIG_ARG; then + echo "Code formatting complete" +else + echo "Code formatting failed" + exit 1 +fi + +echo + +# Run pre-commit hooks for additional formatting +echo "Running pre-commit hooks for additional formatting..." +if pre-commit run --all-files; then + echo "Pre-commit hooks passed" +else + echo "Pre-commit hooks made changes - please review" + # Exit with success since formatting was applied + exit 0 +fi + +echo +echo "All formatting complete!" +echo +echo "Next steps:" +echo " - Run: pylint --rcfile=${SCRIPT_DIR}/.pylintrc $TARGETS" +echo " - Commit your changes" diff --git a/commit_message.txt b/commit_message.txt new file mode 100644 index 00000000..9070dadb --- /dev/null +++ b/commit_message.txt @@ -0,0 +1 @@ +Fix: Resolve merge conflict and update docstrings in inference.py \ No newline at end of file diff --git a/docs/examples/longer_text_example.md b/docs/examples/longer_text_example.md index 62d1ff39..5adb4c06 100644 --- a/docs/examples/longer_text_example.md +++ b/docs/examples/longer_text_example.md @@ -76,7 +76,7 @@ result = lx.extract( print(f"Extracted {len(result.extractions)} entities from {len(result.text):,} characters") # Save and visualize the results -lx.io.save_annotated_documents([result], output_name="romeo_juliet_extractions.jsonl") +lx.io.save_annotated_documents([result], output_name="romeo_juliet_extractions.jsonl", output_dir=".") # Generate the interactive visualization html_content = lx.visualize("romeo_juliet_extractions.jsonl") @@ -171,4 +171,4 @@ LangExtract combines precise text positioning with world knowledge enrichment, e --- -ยน Models like Gemini 1.5 Pro show strong performance on many benchmarks, but [needle-in-a-haystack tests](https://cloud.google.com/blog/products/ai-machine-learning/the-needle-in-the-haystack-test-and-how-gemini-pro-solves-it) across million-token contexts indicate that performance can vary in multi-fact retrieval scenarios. This demonstrates how LangExtract's smaller context windows approach ensures consistently high quality across entire documents by avoiding the complexity and potential degradation of massive single-context processing. \ No newline at end of file +ยน Models like Gemini 1.5 Pro show strong performance on many benchmarks, but [needle-in-a-haystack tests](https://cloud.google.com/blog/products/ai-machine-learning/the-needle-in-the-haystack-test-and-how-gemini-pro-solves-it) across million-token contexts indicate that performance can vary in multi-fact retrieval scenarios. This demonstrates how LangExtract's smaller context windows approach ensures consistently high quality across entire documents by avoiding the complexity and potential degradation of massive single-context processing. diff --git a/docs/examples/medication_examples.md b/docs/examples/medication_examples.md index 7fb27b11..d6474964 100644 --- a/docs/examples/medication_examples.md +++ b/docs/examples/medication_examples.md @@ -62,7 +62,7 @@ for entity in result.extractions: print(f"โ€ข {entity.extraction_class.capitalize()}: {entity.extraction_text}{position_info}") # Save and visualize the results -lx.io.save_annotated_documents([result], output_name="medical_ner_extraction.jsonl") +lx.io.save_annotated_documents([result], output_name="medical_ner_extraction.jsonl", output_dir=".") # Generate the interactive visualization html_content = lx.visualize("medical_ner_extraction.jsonl") @@ -193,7 +193,11 @@ for med_name, extractions in medication_groups.items(): print(f" โ€ข {extraction.extraction_class.capitalize()}: {extraction.extraction_text}{position_info}") # Save and visualize the results -lx.io.save_annotated_documents([result], output_name="medical_relationship_extraction.jsonl") +lx.io.save_annotated_documents( + [result], + output_name="medical_ner_extraction.jsonl", + output_dir="." +) # Generate the interactive visualization html_content = lx.visualize("medical_relationship_extraction.jsonl") @@ -239,4 +243,4 @@ This example demonstrates how attributes enable efficient relationship extractio - **Relationship Extraction**: Groups related entities using attributes - **Position Tracking**: Records exact positions of extracted entities in the source text - **Structured Output**: Organizes information in a format suitable for healthcare applications -- **Interactive Visualization**: Generates HTML visualizations for exploring complex medical extractions with entity groupings and relationships clearly displayed \ No newline at end of file +- **Interactive Visualization**: Generates HTML visualizations for exploring complex medical extractions with entity groupings and relationships clearly displayed diff --git a/examples/ollama/.dockerignore b/examples/ollama/.dockerignore new file mode 100644 index 00000000..77374252 --- /dev/null +++ b/examples/ollama/.dockerignore @@ -0,0 +1,35 @@ +# Ignore Python cache +__pycache__/ +*.pyc +*.pyo +*.pyd +.Python + +# Ignore version control +.git/ +.gitignore + +# Ignore OS files +.DS_Store +Thumbs.db + +# Ignore virtual environments +venv/ +env/ +.venv/ + +# Ignore IDE files +.vscode/ +.idea/ +*.swp +*.swo + +# Ignore test artifacts +.pytest_cache/ +.coverage +htmlcov/ + +# Ignore build artifacts +build/ +dist/ +*.egg-info/ diff --git a/examples/ollama/Dockerfile b/examples/ollama/Dockerfile new file mode 100644 index 00000000..48690a6a --- /dev/null +++ b/examples/ollama/Dockerfile @@ -0,0 +1,23 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM python:3.11-slim-bookworm + +WORKDIR /app + +RUN pip install langextract + +COPY quickstart.py . + +CMD ["python", "quickstart.py"] diff --git a/examples/ollama/README.md b/examples/ollama/README.md new file mode 100644 index 00000000..2fff8593 --- /dev/null +++ b/examples/ollama/README.md @@ -0,0 +1,32 @@ +# Ollama Examples + +This directory contains examples for using LangExtract with Ollama for local LLM inference. + +For setup instructions and documentation, see the [main README's Ollama section](../../README.md#using-local-llms-with-ollama). + +## Quick Reference + +**Local setup:** +```bash +ollama pull gemma2:2b +python quickstart.py +``` + +**Docker setup:** +```bash +docker-compose up +``` + +## Files + +- `quickstart.py` - Basic extraction example with configurable model +- `docker-compose.yml` - Production-ready Docker setup with health checks +- `Dockerfile` - Container definition for LangExtract + +## Model License + +Ollama models come with their own licenses. For example: +- Gemma models: [Gemma Terms of Use](https://ai.google.dev/gemma/terms) +- Llama models: [Meta Llama License](https://llama.meta.com/llama-downloads/) + +Please review the license for any model you use. diff --git a/examples/ollama/docker-compose.yml b/examples/ollama/docker-compose.yml new file mode 100644 index 00000000..431765ea --- /dev/null +++ b/examples/ollama/docker-compose.yml @@ -0,0 +1,42 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +services: + ollama: + image: ollama/ollama:0.5.4 + ports: + - "127.0.0.1:11434:11434" # Bind only to localhost for security + volumes: + - ollama-data:/root/.ollama # Cross-platform support + command: serve + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:11434/api/version"] + interval: 5s + timeout: 3s + retries: 5 + start_period: 10s + + langextract: + build: . + depends_on: + ollama: + condition: service_healthy + environment: + - OLLAMA_HOST=http://ollama:11434 + volumes: + - .:/app + command: python quickstart.py + +volumes: + ollama-data: diff --git a/examples/ollama/quickstart.py b/examples/ollama/quickstart.py new file mode 100644 index 00000000..ed578412 --- /dev/null +++ b/examples/ollama/quickstart.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quick-start example for using Ollama with langextract.""" + +import argparse +import os + +import langextract as lx + + +def run_extraction(model_id="gemma2:2b", temperature=0.3): + """Run a simple extraction example using Ollama.""" + input_text = "Isaac Asimov was a prolific science fiction writer." + + prompt = "Extract the author's full name and their primary literary genre." + + examples = [ + lx.data.ExampleData( + text=( + "J.R.R. Tolkien was an English writer, best known for" + " high-fantasy." + ), + extractions=[ + lx.data.Extraction( + extraction_class="author_details", + # extraction_text includes full context with ellipsis for clarity + extraction_text="J.R.R. Tolkien was an English writer...", + attributes={ + "name": "J.R.R. Tolkien", + "genre": "high-fantasy", + }, + ) + ], + ) + ] + + result = lx.extract( + text_or_documents=input_text, + prompt_description=prompt, + examples=examples, + language_model_type=lx.inference.OllamaLanguageModel, + model_id=model_id, + model_url=os.getenv("OLLAMA_HOST", "http://localhost:11434"), + temperature=temperature, + fence_output=False, + use_schema_constraints=False, + ) + + return result + + +def main(): + """Main function to run the quick-start example.""" + parser = argparse.ArgumentParser(description="Run Ollama extraction example") + parser.add_argument( + "--model-id", + default=os.getenv("MODEL_ID", "gemma2:2b"), + help="Ollama model ID (default: gemma2:2b or MODEL_ID env var)", + ) + parser.add_argument( + "--temperature", + type=float, + default=float(os.getenv("TEMPERATURE", "0.3")), + help="Model temperature (default: 0.3 or TEMPERATURE env var)", + ) + args = parser.parse_args() + + print(f"๐Ÿš€ Running Ollama quick-start example with {args.model_id}...") + print("-" * 50) + + try: + result = run_extraction( + model_id=args.model_id, temperature=args.temperature + ) + + for extraction in result.extractions: + print(f"Class: {extraction.extraction_class}") + print(f"Text: {extraction.extraction_text}") + print(f"Attributes: {extraction.attributes}") + + print("\nโœ… SUCCESS! Ollama is working with langextract") + return True + + except ConnectionError as e: + print(f"\nConnectionError: {e}") + print("Make sure Ollama is running: 'ollama serve'") + return False + except Exception as e: + print(f"\nError: {type(e).__name__}: {e}") + return False + + +if __name__ == "__main__": + success = main() + exit(0 if success else 1) diff --git a/exceptions.py b/exceptions.py new file mode 100644 index 00000000..0199da56 --- /dev/null +++ b/exceptions.py @@ -0,0 +1,30 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base exceptions for LangExtract. + +This module defines the base exception class that all LangExtract exceptions +inherit from. Individual modules define their own specific exceptions. +""" + +__all__ = ["LangExtractError"] + + +class LangExtractError(Exception): + """Base exception for all LangExtract errors. + + All exceptions raised by LangExtract should inherit from this class. + This allows users to catch all LangExtract-specific errors with a single + except clause. + """ diff --git a/kokoro/presubmit.cfg b/kokoro/presubmit.cfg index 6d821424..746c6d14 100644 --- a/kokoro/presubmit.cfg +++ b/kokoro/presubmit.cfg @@ -28,4 +28,4 @@ container_properties { xunit_test_results { target_name: "pytest_results" result_xml_path: "git/repo/pytest_results/test.xml" -} \ No newline at end of file +} diff --git a/kokoro/test.sh b/kokoro/test.sh index ba75ace2..87134817 100644 --- a/kokoro/test.sh +++ b/kokoro/test.sh @@ -103,4 +103,4 @@ deactivate echo "=========================================" echo "Kokoro test script for langextract finished successfully." -echo "=========================================" \ No newline at end of file +echo "=========================================" diff --git a/langextract/__init__.py b/langextract/__init__.py index 817e04f2..0eac21d0 100644 --- a/langextract/__init__.py +++ b/langextract/__init__.py @@ -18,13 +18,14 @@ from collections.abc import Iterable, Sequence import os -from typing import Any, Type, TypeVar, cast +from typing import Any, cast, Type, TypeVar import warnings import dotenv from langextract import annotation from langextract import data +from langextract import exceptions from langextract import inference from langextract import io from langextract import prompting @@ -32,6 +33,19 @@ from langextract import schema from langextract import visualization +__all__ = [ + "extract", + "visualize", + "annotation", + "data", + "exceptions", + "inference", + "io", + "prompting", + "resolver", + "schema", + "visualization", +] LanguageModelT = TypeVar("LanguageModelT", bound=inference.BaseLanguageModel) @@ -241,4 +255,4 @@ def extract( batch_length=batch_length, debug=debug, extraction_passes=extraction_passes, - ) + ) \ No newline at end of file diff --git a/langextract/annotation.py b/langextract/annotation.py index fe3b5a54..a370be9e 100644 --- a/langextract/annotation.py +++ b/langextract/annotation.py @@ -31,6 +31,7 @@ from langextract import chunking from langextract import data +from langextract import exceptions from langextract import inference from langextract import progress from langextract import prompting @@ -39,7 +40,7 @@ ATTRIBUTE_SUFFIX = "_attributes" -class DocumentRepeatError(Exception): +class DocumentRepeatError(exceptions.LangExtractError): """Exception raised when identical document ids are present.""" diff --git a/langextract/chunking.py b/langextract/chunking.py index 3625d7a1..abddf8f7 100644 --- a/langextract/chunking.py +++ b/langextract/chunking.py @@ -28,10 +28,11 @@ import more_itertools from langextract import data +from langextract import exceptions from langextract import tokenizer -class TokenUtilError(Exception): +class TokenUtilError(exceptions.LangExtractError): """Error raised when token_util returns unexpected values.""" @@ -450,7 +451,9 @@ def __next__(self) -> TextChunk: curr_chunk.start_index, token_index + 1 ) if self._tokens_exceed_buffer(test_chunk): - if start_of_new_line > 0: + # Only break at newline if: 1) newline exists (> 0) and + # 2) it's after chunk start (prevents empty intervals) + if start_of_new_line > 0 and start_of_new_line > curr_chunk.start_index: # Terminate the curr_chunk at the start of the most recent newline. curr_chunk = create_token_interval( curr_chunk.start_index, start_of_new_line diff --git a/langextract/exceptions.py b/langextract/exceptions.py new file mode 100644 index 00000000..b3103ab7 --- /dev/null +++ b/langextract/exceptions.py @@ -0,0 +1,26 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base exceptions for LangExtract.""" + +__all__ = ["LangExtractError"] + + +class LangExtractError(Exception): + """Base exception for all LangExtract errors. + + All exceptions raised by LangExtract should inherit from this class. + This allows users to catch all LangExtract-specific errors with a single + except clause. + """ diff --git a/langextract/inference.py b/langextract/inference.py index 822661fd..b2f2785e 100644 --- a/langextract/inference.py +++ b/langextract/inference.py @@ -15,25 +15,23 @@ """Simple library for performing language model inference.""" import abc -from collections.abc import Iterator, Mapping, Sequence import concurrent.futures import dataclasses import enum import json import textwrap +from collections.abc import Iterator, Mapping, Sequence from typing import Any +from urllib.parse import urljoin, urlparse -from google import genai import langfun as lf +import openai import requests -from typing_extensions import override import yaml +from google import genai +from typing_extensions import override - - -from langextract import data -from langextract import schema - +from langextract import data, exceptions, schema _OLLAMA_DEFAULT_MODEL_URL = 'http://localhost:11434' @@ -52,7 +50,7 @@ def __str__(self) -> str: return f'Score: {self.score:.2f}\nOutput:\n{formatted_lines}' -class InferenceOutputError(Exception): +class InferenceOutputError(exceptions.LangExtractError): """Exception raised when no scored outputs are available from the language model.""" def __init__(self, message: str): @@ -64,7 +62,7 @@ class BaseLanguageModel(abc.ABC): """An abstract inference class for managing LLM inference. Attributes: - _constraint: A `Constraint` object specifying constraints for model output. + _constraint: A Constraint object specifying constraints for model output. """ def __init__(self, constraint: schema.Constraint = schema.Constraint()): @@ -99,49 +97,6 @@ class InferenceType(enum.Enum): MULTIPROCESS = 'multiprocess' -# TODO: Add support for llm options. -@dataclasses.dataclass(init=False) -class LangFunLanguageModel(BaseLanguageModel): - """Language model inference class using LangFun language class. - - See https://github.com/google/langfun for more details on LangFun. - """ - - _lm: lf.core.language_model.LanguageModel # underlying LangFun model - _constraint: schema.Constraint = dataclasses.field( - default_factory=schema.Constraint, repr=False, compare=False - ) - _extra_kwargs: dict[str, Any] = dataclasses.field( - default_factory=dict, repr=False, compare=False - ) - - def __init__( - self, - language_model: lf.core.language_model.LanguageModel, - constraint: schema.Constraint = schema.Constraint(), - **kwargs, - ) -> None: - self._lm = language_model - self._constraint = constraint - - # Preserve any unused kwargs for debugging / future use - self._extra_kwargs = kwargs or {} - super().__init__(constraint=constraint) - - @override - def infer( - self, batch_prompts: Sequence[str], **kwargs - ) -> Iterator[Sequence[ScoredOutput]]: - responses = self._lm.sample(prompts=batch_prompts) - for a_response in responses: - for sample in a_response.samples: - yield [ - ScoredOutput( - score=sample.response.score, output=sample.response.text - ) - ] - - @dataclasses.dataclass(init=False) class OllamaLanguageModel(BaseLanguageModel): """Language model inference class using Ollama based host.""" @@ -158,13 +113,13 @@ class OllamaLanguageModel(BaseLanguageModel): def __init__( self, - model: str, + model_id: str, model_url: str = _OLLAMA_DEFAULT_MODEL_URL, structured_output_format: str = 'json', constraint: schema.Constraint = schema.Constraint(), **kwargs, ) -> None: - self._model = model + self._model = model_id self._model_url = model_url self._structured_output_format = structured_output_format self._constraint = constraint @@ -204,7 +159,7 @@ def _ollama_query( ) -> Mapping[str, Any]: """Sends a prompt to an Ollama model and returns the generated response. - This function makes an HTTP POST request to the `/api/generate` endpoint of + This function makes an HTTP POST request to the /api/generate endpoint of an Ollama server. It can optionally load the specified model first, generate a response (with or without streaming), then return a parsed JSON response. @@ -229,12 +184,12 @@ def _ollama_query( generation completes. num_threads: Number of CPU threads to use. If None, Ollama uses a default heuristic. - num_ctx: Number of context tokens allowed. If None, uses modelโ€™s default + num_ctx: Number of context tokens allowed. If None, uses model's default or config. Returns: - A mapping (dictionary-like) containing the serverโ€™s JSON response. For - non-streaming calls, the `"response"` key typically contains the entire + A mapping (dictionary-like) containing the server's JSON response. For + non-streaming calls, the "response" key typically contains the entire generated text. Raises: @@ -250,12 +205,20 @@ def _ollama_query( if top_k: options['top_k'] = top_k if num_threads: - options['num_thread'] = num_threads + options['num_threads'] = num_threads if max_output_tokens: options['num_predict'] = max_output_tokens if num_ctx: options['num_ctx'] = num_ctx - model_url = model_url + '/api/generate' + + # Properly construct the API endpoint URL + # Validate URL to prevent SSRF attacks + parsed_url = urlparse(model_url) + if not parsed_url.scheme in ['http', 'https']: + raise ValueError(f"Invalid URL scheme: {parsed_url.scheme}. Only http and https are allowed.") + if not parsed_url.netloc: + raise ValueError(f"Invalid URL: {model_url}. Missing hostname.") + api_endpoint = urljoin(model_url, '/api/generate') payload = { 'model': model, @@ -268,7 +231,7 @@ def _ollama_query( } try: response = requests.post( - model_url, + api_endpoint, headers={ 'Content-Type': 'application/json', 'Accept': 'application/json', @@ -290,7 +253,7 @@ def _ollama_query( return response.json() if response.status_code == 404: raise ValueError( - f"Can't find Ollama {model}. Try launching `ollama run {model}`" + f"Can't find Ollama {model}. Try launching ollama run {model}" ' from command line.' ) else: @@ -333,7 +296,7 @@ def __init__( temperature: Sampling temperature. max_workers: Maximum number of parallel API calls. **kwargs: Ignored extra parameters so callers can pass a superset of - arguments shared across back-ends without raising ``TypeError``. + arguments shared across back-ends without raising `TypeError. """ self.model_id = model_id self.api_key = api_key @@ -429,7 +392,11 @@ def infer( yield [result] def parse_output(self, output: str) -> Any: - """Parses Gemini output as JSON or YAML.""" + """Parses Gemini output as JSON or YAML. + + Note: This expects raw JSON/YAML without code fences. + Code fence extraction is handled by resolver.py. + """ try: if self.format_type == data.FormatType.JSON: return json.loads(output) @@ -439,3 +406,172 @@ def parse_output(self, output: str) -> Any: raise ValueError( f'Failed to parse output as {self.format_type.name}: {str(e)}' ) from e + + +@dataclasses.dataclass(init=False) +class OpenAILanguageModel(BaseLanguageModel): + """Language model inference using OpenAI's API with structured output.""" + + model_id: str = 'gpt-4o-mini' + api_key: str | None = None + base_url: str | None = None + organization: str | None = None + format_type: data.FormatType = data.FormatType.JSON + temperature: float = 0.0 + max_workers: int = 10 + _client: openai.OpenAI | None = dataclasses.field( + default=None, repr=False, compare=False + ) + _extra_kwargs: dict[str, Any] = dataclasses.field( + default_factory=dict, repr=False, compare=False + ) + + def __init__( + self, + model_id: str = 'gpt-4o-mini', + api_key: str | None = None, + base_url: str | None = None, + organization: str | None = None, + format_type: data.FormatType = data.FormatType.JSON, + temperature: float = 0.0, + max_workers: int = 10, + **kwargs, + ) -> None: + """Initialize the OpenAI language model. + + Args: + model_id: The OpenAI model ID to use (e.g., 'gpt-4o-mini', 'gpt-4o'). + api_key: API key for OpenAI service. + base_url: Base URL for OpenAI service. + organization: Optional OpenAI organization ID. + format_type: Output format (JSON or YAML). + temperature: Sampling temperature. + max_workers: Maximum number of parallel API calls. + **kwargs: Ignored extra parameters so callers can pass a superset of + arguments shared across back-ends without raising `TypeError. + """ + self.model_id = model_id + self.api_key = api_key + self.base_url = base_url + self.organization = organization + self.format_type = format_type + self.temperature = temperature + self.max_workers = max_workers + self._extra_kwargs = kwargs or {} + + if not self.api_key: + raise ValueError('API key not provided.') + + # Initialize the OpenAI client + self._client = openai.OpenAI( + api_key=self.api_key, + base_url=self.base_url, + organization=self.organization, + ) + + super().__init__( + constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE) + ) + + def _process_single_prompt(self, prompt: str, config: dict) -> ScoredOutput: + """Process a single prompt and return a ScoredOutput.""" + try: + # Prepare the system message for structured output + system_message = '' + if self.format_type == data.FormatType.JSON: + system_message = ( + 'You are a helpful assistant that responds in JSON format.' + ) + elif self.format_type == data.FormatType.YAML: + system_message = ( + 'You are a helpful assistant that responds in YAML format.' + ) + + # Create the chat completion using the v1.x client API + response = self._client.chat.completions.create( + model=self.model_id, + messages=[ + {'role': 'system', 'content': system_message}, + {'role': 'user', 'content': prompt}, + ], + temperature=config.get('temperature', self.temperature), + max_tokens=config.get('max_output_tokens'), + top_p=config.get('top_p'), + n=1, + ) + + # Extract the response text using the v1.x response format + output_text = response.choices[0].message.content + + return ScoredOutput(score=1.0, output=output_text) + + except Exception as e: + raise InferenceOutputError(f'OpenAI API error: {str(e)}') from e + + def infer( + self, batch_prompts: Sequence[str], **kwargs + ) -> Iterator[Sequence[ScoredOutput]]: + """Runs inference on a list of prompts via OpenAI's API. + + Args: + batch_prompts: A list of string prompts. + **kwargs: Additional generation params (temperature, top_p, etc.) + + Yields: + Lists of ScoredOutputs. + """ + config = { + 'temperature': kwargs.get('temperature', self.temperature), + } + if 'max_output_tokens' in kwargs: + config['max_output_tokens'] = kwargs['max_output_tokens'] + if 'top_p' in kwargs: + config['top_p'] = kwargs['top_p'] + + # Use parallel processing for batches larger than 1 + if len(batch_prompts) > 1 and self.max_workers > 1: + with concurrent.futures.ThreadPoolExecutor( + max_workers=min(self.max_workers, len(batch_prompts)) + ) as executor: + future_to_index = { + executor.submit( + self._process_single_prompt, prompt, config.copy() + ): i + for i, prompt in enumerate(batch_prompts) + } + + results: list[ScoredOutput | None] = [None] * len(batch_prompts) + for future in concurrent.futures.as_completed(future_to_index): + index = future_to_index[future] + try: + results[index] = future.result() + except Exception as e: + raise InferenceOutputError( + f'Parallel inference error: {str(e)}' + ) from e + + for result in results: + if result is None: + raise InferenceOutputError('Failed to process one or more prompts') + yield [result] + else: + # Sequential processing for single prompt or worker + for prompt in batch_prompts: + result = self._process_single_prompt(prompt, config.copy()) + yield [result] + + def parse_output(self, output: str) -> Any: + """Parses OpenAI output as JSON or YAML. + + Note: This expects raw JSON/YAML without code fences. + Code fence extraction is handled by resolver.py. + """ + try: + if self.format_type == data.FormatType.JSON: + return json.loads(output) + else: + return yaml.safe_load(output) + except Exception as e: + raise ValueError( + f'Failed to parse output as {self.format_type.name}: {str(e)}' + ) from e \ No newline at end of file diff --git a/langextract/io.py b/langextract/io.py index 7f94a193..59dead7a 100644 --- a/langextract/io.py +++ b/langextract/io.py @@ -18,23 +18,21 @@ import dataclasses import json import os +import pathlib from typing import Any, Iterator import pandas as pd import requests -import os -import pathlib -import os -import pathlib from langextract import data from langextract import data_lib +from langextract import exceptions from langextract import progress DEFAULT_TIMEOUT_SECONDS = 30 -class InvalidDatasetError(Exception): +class InvalidDatasetError(exceptions.LangExtractError): """Error raised when Dataset is empty or invalid.""" @@ -83,7 +81,7 @@ def load(self, delimiter: str = ',') -> Iterator[data.Document]: def save_annotated_documents( annotated_documents: Iterator[data.AnnotatedDocument], - output_dir: pathlib.Path | None = None, + output_dir: pathlib.Path | str | None = None, output_name: str = 'data.jsonl', show_progress: bool = True, ) -> None: @@ -92,7 +90,7 @@ def save_annotated_documents( Args: annotated_documents: Iterator over AnnotatedDocument objects to save. output_dir: The directory to which the JSONL file should be written. - Defaults to 'test_output/' if None. + Can be a Path object or a string. Defaults to 'test_output/' if None. output_name: File name for the JSONL file. show_progress: Whether to show a progress bar during saving. @@ -102,6 +100,8 @@ def save_annotated_documents( """ if output_dir is None: output_dir = pathlib.Path('test_output') + else: + output_dir = pathlib.Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) diff --git a/langextract/progress.py b/langextract/progress.py index a79b9126..41c4f3b8 100644 --- a/langextract/progress.py +++ b/langextract/progress.py @@ -16,6 +16,7 @@ from typing import Any import urllib.parse + import tqdm # ANSI color codes for terminal output diff --git a/langextract/prompting.py b/langextract/prompting.py index 5d6623b1..4484273b 100644 --- a/langextract/prompting.py +++ b/langextract/prompting.py @@ -16,17 +16,18 @@ import dataclasses import json +import os +import pathlib import pydantic import yaml -import os -import pathlib from langextract import data +from langextract import exceptions from langextract import schema -class PromptBuilderError(Exception): +class PromptBuilderError(exceptions.LangExtractError): """Failure to build prompt.""" diff --git a/langextract/resolver.py b/langextract/resolver.py index e9085f16..c6496b82 100644 --- a/langextract/resolver.py +++ b/langextract/resolver.py @@ -31,6 +31,7 @@ import yaml from langextract import data +from langextract import exceptions from langextract import schema from langextract import tokenizer @@ -151,7 +152,7 @@ def align( ExtractionValueType = str | int | float | dict | list | None -class ResolverParsingError(Exception): +class ResolverParsingError(exceptions.LangExtractError): """Error raised when content cannot be parsed as the given format.""" diff --git a/langextract/schema.py b/langextract/schema.py index 2c02baac..dd553bdc 100644 --- a/langextract/schema.py +++ b/langextract/schema.py @@ -22,7 +22,6 @@ import enum from typing import Any - from langextract import data diff --git a/langextract/tokenizer.py b/langextract/tokenizer.py index f4036f36..5028fb0f 100644 --- a/langextract/tokenizer.py +++ b/langextract/tokenizer.py @@ -30,8 +30,10 @@ from absl import logging +from langextract import exceptions -class BaseTokenizerError(Exception): + +class BaseTokenizerError(exceptions.LangExtractError): """Base class for all tokenizer-related errors.""" diff --git a/langextract/visualization.py b/langextract/visualization.py index 513cfa58..fc9be321 100644 --- a/langextract/visualization.py +++ b/langextract/visualization.py @@ -28,19 +28,38 @@ import html import itertools import json +import pathlib import textwrap -import os -import pathlib from langextract import data as _data from langextract import io as _io # Fallback if IPython is not present try: - from IPython.display import HTML # type: ignore -except Exception: + from IPython import get_ipython # type: ignore[import-not-found] + from IPython.display import HTML # type: ignore[import-not-found] +except ImportError: + + def get_ipython(): # type: ignore[no-redef] + return None + HTML = None # pytype: disable=annotation-type-mismatch + +def _is_jupyter() -> bool: + """Check if we're in a Jupyter/IPython environment that can display HTML.""" + try: + if get_ipython is None: + return False + ip = get_ipython() + if ip is None: + return False + # Simple check: if we're in IPython and NOT in a plain terminal + return ip.__class__.__name__ != 'TerminalInteractiveShell' + except Exception: + return False + + _PALETTE: list[str] = [ '#D2E3FC', # Light Blue (Primary Container) '#C8E6C9', # Light Green (Tertiary Container) @@ -119,9 +138,7 @@ padding: 8px 10px; margin-top: 8px; font-size: 13px; } .lx-current-highlight { - text-decoration: underline; - text-decoration-color: #ff4444; - text-decoration-thickness: 3px; + border-bottom: 4px solid #ff4444; font-weight: bold; animation: lx-pulse 1s ease-in-out; } @@ -130,9 +147,9 @@ 50% { text-decoration-color: #ff0000; } 100% { text-decoration-color: #ff4444; } } - .lx-legend { - font-size: 12px; margin-bottom: 8px; - padding-bottom: 8px; border-bottom: 1px solid #e0e0e0; + .lx-legend { + font-size: 12px; margin-bottom: 8px; + padding-bottom: 8px; border-bottom: 1px solid #e0e0e0; } .lx-label { display: inline-block; @@ -456,12 +473,12 @@ def _extraction_sort_key(extraction):
-
- Entity 1/{len(extractions)} | + Entity 1/{len(extractions)} | Pos {pos_info_str}
@@ -540,7 +557,7 @@ def visualize( animation_speed: float = 1.0, show_legend: bool = True, gif_optimized: bool = True, -) -> 'HTML | str': +) -> HTML | str: """Visualises extraction data as animated highlighted HTML. Args: @@ -584,7 +601,9 @@ def visualize( ' animate.

' ) full_html = _VISUALIZATION_CSS + empty_html - return HTML(full_html) if HTML is not None else full_html + if HTML is not None and _is_jupyter(): + return HTML(full_html) + return full_html color_map = _assign_colors(valid_extractions) @@ -605,4 +624,6 @@ def visualize( 'class="lx-animated-wrapper lx-gif-optimized"', ) - return HTML(full_html) if HTML is not None else full_html + if HTML is not None and _is_jupyter(): + return HTML(full_html) + return full_html diff --git a/pyproject.toml b/pyproject.toml index a2dfd19c..95f6da26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ build-backend = "setuptools.build_meta" [project] name = "langextract" -version = "0.1.0" +version = "1.0.4" description = "LangExtract: A library for extracting structured data from language models" readme = "README.md" requires-python = ">=3.10" @@ -32,32 +32,41 @@ dependencies = [ "async_timeout>=4.0.0", "exceptiongroup>=1.1.0", "google-genai>=0.1.0", - "langfun>=0.1.0", "ml-collections>=0.1.0", "more-itertools>=8.0.0", "numpy>=1.20.0", - "openai>=0.27.0", + "openai>=1.50.0", "pandas>=1.3.0", "pydantic>=1.8.0", "python-dotenv>=0.19.0", - "python-magic>=0.4.27", + "PyYAML>=6.0", "requests>=2.25.0", "tqdm>=4.64.0", "typing-extensions>=4.0.0" ] +[project.urls] +"Homepage" = "https://github.com/google/langextract" +"Repository" = "https://github.com/google/langextract" +"Documentation" = "https://github.com/google/langextract/blob/main/README.md" +"Bug Tracker" = "https://github.com/google/langextract/issues" + [project.optional-dependencies] dev = [ - "black>=23.7.0", - "pylint>=2.17.5", - "pytest>=7.4.0", + "pyink~=24.3.0", + "isort>=5.13.0", + "pylint>=3.0.0", "pytype>=2024.10.11", - "tox>=4.0.0", + "tox>=4.0.0" ] test = [ "pytest>=7.4.0", "tomli>=2.0.0" ] +notebook = [ + "ipython>=7.0.0", + "notebook>=6.0.0" +] [tool.setuptools] packages = ["langextract"] @@ -72,6 +81,28 @@ include-package-data = false "*.svg", ] -[tool.pytest] +[tool.pytest.ini_options] testpaths = ["tests"] -python_files = "*_test.py" \ No newline at end of file +python_files = "*_test.py" +python_classes = "Test*" +python_functions = "test_*" +# Show extra test summary info +addopts = "-ra" +markers = [ + "live_api: marks tests as requiring live API access", +] + +[tool.pyink] +# Configuration for Google's style guide +line-length = 80 +unstable = true +pyink-indentation = 2 +pyink-use-majority-quotes = true + +[tool.isort] +# Configuration for Google's style guide +profile = "google" +line_length = 80 +force_sort_within_sections = true +# Allow multiple imports on one line for these modules +single_line_exclusions = ["typing", "typing_extensions", "collections.abc"] diff --git a/tests/.pylintrc b/tests/.pylintrc new file mode 100644 index 00000000..4b06ddd5 --- /dev/null +++ b/tests/.pylintrc @@ -0,0 +1,52 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Test-specific Pylint configuration +# Inherits from parent ../.pylintrc and adds test-specific relaxations + +[MASTER] +# Python will merge with parent; no need to repeat plugins. + +[MESSAGES CONTROL] +# Additional disables for test code only +disable= + # --- Test-specific relaxations --- + duplicate-code, # Test fixtures often have similar patterns + too-many-lines, # Large test files are common + missing-module-docstring, # Tests don't need module docs + missing-class-docstring, # Test classes are self-explanatory + missing-function-docstring, # Test method names describe intent + line-too-long, # Golden strings and test data + invalid-name, # setUp, tearDown, maxDiff, etc. + protected-access, # Tests often access private members + use-dict-literal, # Parametrized tests benefit from dict() + bad-indentation, # pyink 2-space style conflicts with pylint + unused-argument, # Mock callbacks often have unused args + import-error, # Test dependencies may not be installed + unused-import, # Some imports are for test fixtures + too-many-positional-arguments # Test methods can have many args + +[DESIGN] +# Relax complexity limits for tests +max-args = 10 # Fixtures often take many params +max-locals = 25 # Complex test setups +max-statements = 75 # Detailed test scenarios +max-branches = 15 # Multiple test conditions + +[BASIC] +# Allow common test naming patterns +good-names=i,j,k,ex,Run,_,id,ok,fd,fp,maxDiff,setUp,tearDown + +# Include test-specific naming patterns +method-rgx=[a-z_][a-z0-9_]{2,50}$|test[A-Z_][a-zA-Z0-9]*$|assert[A-Z][a-zA-Z0-9]*$ diff --git a/tests/annotation_test.py b/tests/annotation_test.py index bfa87c09..a5540e4e 100644 --- a/tests/annotation_test.py +++ b/tests/annotation_test.py @@ -20,6 +20,7 @@ from absl.testing import absltest from absl.testing import parameterized + from langextract import annotation from langextract import data from langextract import inference @@ -34,7 +35,7 @@ class AnnotatorTest(absltest.TestCase): def setUp(self): super().setUp() self.mock_language_model = self.enter_context( - mock.patch.object(inference, "LangFunLanguageModel", autospec=True) + mock.patch.object(inference, "GeminiLanguageModel", autospec=True) ) self.annotator = annotation.Annotator( language_model=self.mock_language_model, @@ -687,7 +688,7 @@ def test_annotate_documents( batch_length: int = 1, ): mock_language_model = self.enter_context( - mock.patch.object(inference, "LangFunLanguageModel", autospec=True) + mock.patch.object(inference, "GeminiLanguageModel", autospec=True) ) # Define a side effect function so return length based on batch length. @@ -760,7 +761,7 @@ def test_annotate_documents_exceptions( batch_length: int = 1, ): mock_language_model = self.enter_context( - mock.patch.object(inference, "LangFunLanguageModel", autospec=True) + mock.patch.object(inference, "GeminiLanguageModel", autospec=True) ) mock_language_model.infer.return_value = [ [ @@ -797,7 +798,7 @@ class AnnotatorMultiPassTest(absltest.TestCase): def setUp(self): super().setUp() self.mock_language_model = self.enter_context( - mock.patch.object(inference, "LangFunLanguageModel", autospec=True) + mock.patch.object(inference, "GeminiLanguageModel", autospec=True) ) self.annotator = annotation.Annotator( language_model=self.mock_language_model, diff --git a/tests/chunking_test.py b/tests/chunking_test.py index ad4f17b5..1eddd98a 100644 --- a/tests/chunking_test.py +++ b/tests/chunking_test.py @@ -14,11 +14,12 @@ import textwrap +from absl.testing import absltest +from absl.testing import parameterized + from langextract import chunking from langextract import data from langextract import tokenizer -from absl.testing import absltest -from absl.testing import parameterized class SentenceIterTest(absltest.TestCase): @@ -94,7 +95,6 @@ def test_sentence_with_multiple_newlines_and_right_interval(self): + "Mr\n\nBond\n\nasks why?" ) tokenized_text = tokenizer.tokenize(text) - # To take the whole text chunk_interval = tokenizer.TokenInterval( start_index=0, end_index=len(tokenized_text.tokens) ) @@ -191,6 +191,33 @@ def test_long_token_gets_own_chunk(self): with self.assertRaises(StopIteration): next(chunk_iter) + def test_newline_at_chunk_boundary_does_not_create_empty_interval(self): + """Test that newlines at chunk boundaries don't create empty token intervals. + + When a newline occurs exactly at a chunk boundary, the chunking algorithm + should not attempt to create an empty interval (where start_index == end_index). + This was causing a ValueError in create_token_interval(). + """ + text = "First sentence.\nSecond sentence that is longer.\nThird sentence." + tokenized_text = tokenizer.tokenize(text) + + chunk_iter = chunking.ChunkIterator(tokenized_text, max_char_buffer=20) + chunks = list(chunk_iter) + + for chunk in chunks: + self.assertLess( + chunk.token_interval.start_index, + chunk.token_interval.end_index, + "Chunk should have non-empty interval", + ) + + expected_intervals = [(0, 3), (3, 6), (6, 9), (9, 12)] + actual_intervals = [ + (chunk.token_interval.start_index, chunk.token_interval.end_index) + for chunk in chunks + ] + self.assertEqual(actual_intervals, expected_intervals) + def test_chunk_unicode_text(self): text = textwrap.dedent("""\ Chief Complaint: @@ -352,7 +379,7 @@ def test_make_batches_of_textchunk( self.assertListEqual( actual_batches, expected_batches, - "Batches do not match expected", + "Batched chunks should match expected structure", ) @@ -368,7 +395,9 @@ def test_string_output(self): )""") document = data.Document(text=text, document_id="test_doc_123") tokenized_text = tokenizer.tokenize(text) - chunk_iter = chunking.ChunkIterator(tokenized_text, max_char_buffer=7, document=document) + chunk_iter = chunking.ChunkIterator( + tokenized_text, max_char_buffer=7, document=document + ) text_chunk = next(chunk_iter) self.assertEqual(str(text_chunk), expected) @@ -407,7 +436,7 @@ def test_multiple_chunks_with_additional_context(self): ) chunks = list(chunk_iter) self.assertGreater( - len(chunks), 1, "Expected multiple chunks due to max_char_buffer limit" + len(chunks), 1, "Should create multiple chunks with small buffer" ) additional_contexts = [chunk.additional_context for chunk in chunks] expected_additional_contexts = [self._ADDITIONAL_CONTEXT] * len(chunks) diff --git a/tests/data_lib_test.py b/tests/data_lib_test.py index 0eed51cc..e1cbdeb0 100644 --- a/tests/data_lib_test.py +++ b/tests/data_lib_test.py @@ -14,13 +14,13 @@ import json +from absl.testing import absltest +from absl.testing import parameterized import numpy as np from langextract import data from langextract import data_lib from langextract import tokenizer -from absl.testing import absltest -from absl.testing import parameterized class DataLibToDictParameterizedTest(parameterized.TestCase): diff --git a/tests/inference_test.py b/tests/inference_test.py index d9cf6b57..2bb91f13 100644 --- a/tests/inference_test.py +++ b/tests/inference_test.py @@ -13,56 +13,12 @@ # limitations under the License. from unittest import mock -import langfun as lf -from absl.testing import absltest -from langextract import inference - - -class TestLangFunLanguageModel(absltest.TestCase): - @mock.patch.object( - inference.lf.core.language_model, "LanguageModel", autospec=True - ) - def test_langfun_infer(self, mock_lf_model): - mock_client_instance = mock_lf_model.return_value - metadata = { - "score": -0.004259720362824737, - "logprobs": None, - "is_cached": False, - } - source = lf.UserMessage( - text="What's heart in Italian?.", - sender="User", - metadata={"formatted_text": "What's heart in Italian?."}, - tags=["lm-input"], - ) - sample = lf.LMSample( - response=lf.AIMessage( - text="Cuore", - sender="AI", - metadata=metadata, - source=source, - tags=["lm-response"], - ), - score=-0.004259720362824737, - ) - actual_response = lf.LMSamplingResult( - samples=[sample], - ) - - # Mock the sample response. - mock_client_instance.sample.return_value = [actual_response] - model = inference.LangFunLanguageModel(language_model=mock_client_instance) - - batch_prompts = ["What's heart in Italian?"] - expected_results = [ - [inference.ScoredOutput(score=-0.004259720362824737, output="Cuore")] - ] - - results = list(model.infer(batch_prompts)) +from absl.testing import absltest +from absl.testing import parameterized - mock_client_instance.sample.assert_called_once_with(prompts=batch_prompts) - self.assertEqual(results, expected_results) +from langextract import data +from langextract import inference class TestOllamaLanguageModel(absltest.TestCase): @@ -118,7 +74,7 @@ def test_ollama_infer(self, mock_ollama_query): } mock_ollama_query.return_value = gemma_response model = inference.OllamaLanguageModel( - model="gemma2:latest", + model_id="gemma2:latest", model_url="http://localhost:11434", structured_output_format="json", ) @@ -139,5 +95,132 @@ def test_ollama_infer(self, mock_ollama_query): self.assertEqual(results, expected_results) +class TestOpenAILanguageModelInference(parameterized.TestCase): + + @parameterized.named_parameters( + ("without", "test-api-key", None, "gpt-4o-mini", 0.5), + ("with", "test-api-key", "http://127.0.0.1:9001/v1", "gpt-4o-mini", 0.5), + ) + @mock.patch("openai.OpenAI") + def test_openai_infer_with_parameters( + self, api_key, base_url, model_id, temperature, mock_openai_class + ): + # Mock the OpenAI client and chat completion response + mock_client = mock.Mock() + mock_openai_class.return_value = mock_client + + # Mock response structure for v1.x API + mock_response = mock.Mock() + mock_response.choices = [ + mock.Mock(message=mock.Mock(content='{"name": "John", "age": 30}')) + ] + mock_client.chat.completions.create.return_value = mock_response + + # Create model instance + model = inference.OpenAILanguageModel( + model_id=model_id, + api_key=api_key, + base_url=base_url, + temperature=temperature, + ) + + # Test inference + batch_prompts = ["Extract name and age from: John is 30 years old"] + results = list(model.infer(batch_prompts)) + + # Verify API was called correctly + mock_client.chat.completions.create.assert_called_once_with( + model="gpt-4o-mini", + messages=[ + { + "role": "system", + "content": ( + "You are a helpful assistant that responds in JSON format." + ), + }, + { + "role": "user", + "content": "Extract name and age from: John is 30 years old", + }, + ], + temperature=temperature, + max_tokens=None, + top_p=None, + n=1, + ) + + # Check results + expected_results = [[ + inference.ScoredOutput(score=1.0, output='{"name": "John", "age": 30}') + ]] + self.assertEqual(results, expected_results) + + +class TestOpenAILanguageModel(absltest.TestCase): + + def test_openai_parse_output_json(self): + model = inference.OpenAILanguageModel( + api_key="test-key", format_type=data.FormatType.JSON + ) + + # Test valid JSON parsing + output = '{"key": "value", "number": 42}' + parsed = model.parse_output(output) + self.assertEqual(parsed, {"key": "value", "number": 42}) + + # Test invalid JSON + with self.assertRaises(ValueError) as context: + model.parse_output("invalid json") + self.assertIn("Failed to parse output as JSON", str(context.exception)) + + def test_openai_parse_output_yaml(self): + model = inference.OpenAILanguageModel( + api_key="test-key", format_type=data.FormatType.YAML + ) + + # Test valid YAML parsing + output = "key: value\nnumber: 42" + parsed = model.parse_output(output) + self.assertEqual(parsed, {"key": "value", "number": 42}) + + # Test invalid YAML + with self.assertRaises(ValueError) as context: + model.parse_output("invalid: yaml: bad") + self.assertIn("Failed to parse output as YAML", str(context.exception)) + + def test_openai_no_api_key_raises_error(self): + with self.assertRaises(ValueError) as context: + inference.OpenAILanguageModel(api_key=None) + self.assertEqual(str(context.exception), "API key not provided.") + + @mock.patch("openai.OpenAI") + def test_openai_temperature_zero(self, mock_openai_class): + # Test that temperature=0.0 is properly passed through + mock_client = mock.Mock() + mock_openai_class.return_value = mock_client + + mock_response = mock.Mock() + mock_response.choices = [ + mock.Mock(message=mock.Mock(content='{"result": "test"}')) + ] + mock_client.chat.completions.create.return_value = mock_response + + model = inference.OpenAILanguageModel( + api_key="test-key", temperature=0.0 # Testing zero temperature + ) + + list(model.infer(["test prompt"])) + + # Verify temperature=0.0 was passed to the API + mock_client.chat.completions.create.assert_called_with( + model="gpt-4o-mini", + messages=mock.ANY, + temperature=0.0, + max_tokens=None, + top_p=None, + n=1, + ) + + if __name__ == "__main__": absltest.main() diff --git a/tests/init_test.py b/tests/init_test.py index b68371f7..d79a07f4 100644 --- a/tests/init_test.py +++ b/tests/init_test.py @@ -18,11 +18,12 @@ from unittest import mock from absl.testing import absltest -import langextract as lx + from langextract import data from langextract import inference from langextract import prompting from langextract import schema +import langextract as lx class InitTest(absltest.TestCase): @@ -142,5 +143,6 @@ def test_lang_extract_as_lx_extract( self.assertDataclassEqual(expected_result, actual_result) + if __name__ == "__main__": absltest.main() diff --git a/tests/prompting_test.py b/tests/prompting_test.py index 93712121..5449139b 100644 --- a/tests/prompting_test.py +++ b/tests/prompting_test.py @@ -16,6 +16,7 @@ from absl.testing import absltest from absl.testing import parameterized + from langextract import data from langextract import prompting from langextract import schema diff --git a/tests/resolver_test.py b/tests/resolver_test.py index 61d2a5e6..b96270ee 100644 --- a/tests/resolver_test.py +++ b/tests/resolver_test.py @@ -17,6 +17,7 @@ from absl.testing import absltest from absl.testing import parameterized + from langextract import chunking from langextract import data from langextract import resolver as resolver_lib diff --git a/tests/schema_test.py b/tests/schema_test.py index 4664da08..d4b067b5 100644 --- a/tests/schema_test.py +++ b/tests/schema_test.py @@ -16,11 +16,9 @@ import textwrap from unittest import mock - - - from absl.testing import absltest from absl.testing import parameterized + from langextract import data from langextract import schema diff --git a/tests/test_live_api.py b/tests/test_live_api.py new file mode 100644 index 00000000..8d9801eb --- /dev/null +++ b/tests/test_live_api.py @@ -0,0 +1,585 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Live API integration tests that require real API keys. + +These tests are skipped if API keys are not available in the environment. +They should run in CI after all other tests pass. +""" + +from functools import wraps +import os +import re +import textwrap +import time +import unittest + +from dotenv import load_dotenv +import pytest + +import langextract as lx +from langextract.inference import OpenAILanguageModel + +load_dotenv() + +DEFAULT_GEMINI_MODEL = "gemini-2.5-flash" +DEFAULT_OPENAI_MODEL = "gpt-4o" + +GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get( + "LANGEXTRACT_API_KEY" +) +OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") + +skip_if_no_gemini = pytest.mark.skipif( + not GEMINI_API_KEY, + reason=( + "Gemini API key not available (set GEMINI_API_KEY or" + " LANGEXTRACT_API_KEY)" + ), +) +skip_if_no_openai = pytest.mark.skipif( + not OPENAI_API_KEY, + reason="OpenAI API key not available (set OPENAI_API_KEY)", +) + +live_api = pytest.mark.live_api + +GEMINI_MODEL_PARAMS = { + "temperature": 0.0, + "top_p": 0.0, + "max_output_tokens": 256, +} + +OPENAI_MODEL_PARAMS = { + "temperature": 0.0, +} + + +INITIAL_RETRY_DELAY = 1.0 +MAX_RETRY_DELAY = 8.0 + + +def retry_on_transient_errors(max_retries=3, backoff_factor=2.0): + """Decorator to retry tests on transient API errors with exponential backoff. + + Args: + max_retries: Maximum number of retry attempts + backoff_factor: Multiplier for exponential backoff (e.g., 2.0 = 1s, 2s, 4s) + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + last_exception = None + delay = INITIAL_RETRY_DELAY + + for attempt in range(max_retries + 1): + try: + return func(*args, **kwargs) + except ( + lx.exceptions.LangExtractError, + ConnectionError, + TimeoutError, + OSError, + RuntimeError, + ) as e: + last_exception = e + error_str = str(e).lower() + error_type = type(e).__name__ + + transient_errors = [ + "503", + "service unavailable", + "temporarily unavailable", + "rate limit", + "429", + "too many requests", + "connection reset", + "timeout", + "deadline exceeded", + ] + + is_transient = any( + err in error_str for err in transient_errors + ) or error_type in ["ServiceUnavailable", "RateLimitError", "Timeout"] + + if is_transient and attempt < max_retries: + print( + f"\nTransient error ({error_type}) on attempt" + f" {attempt + 1}/{max_retries + 1}: {e}" + ) + print(f"Retrying in {delay} seconds...") + time.sleep(delay) + delay = min(delay * backoff_factor, MAX_RETRY_DELAY) + else: + raise + + raise last_exception + + return wrapper + + return decorator + + +@pytest.fixture(autouse=True) +def add_delay_between_tests(): + """Add a small delay between tests to avoid rate limiting.""" + yield + time.sleep(0.5) + + +def get_basic_medication_examples(): + """Get example data for basic medication extraction.""" + return [ + lx.data.ExampleData( + text="Patient was given 250 mg IV Cefazolin TID for one week.", + extractions=[ + lx.data.Extraction( + extraction_class="dosage", extraction_text="250 mg" + ), + lx.data.Extraction( + extraction_class="route", extraction_text="IV" + ), + lx.data.Extraction( + extraction_class="medication", extraction_text="Cefazolin" + ), + lx.data.Extraction( + extraction_class="frequency", + extraction_text="TID", # TID = three times a day + ), + lx.data.Extraction( + extraction_class="duration", extraction_text="for one week" + ), + ], + ) + ] + + +def get_relationship_examples(): + """Get example data for medication relationship extraction.""" + return [ + lx.data.ExampleData( + text=( + "Patient takes Aspirin 100mg daily for heart health and" + " Simvastatin 20mg at bedtime." + ), + extractions=[ + # First medication group + lx.data.Extraction( + extraction_class="medication", + extraction_text="Aspirin", + attributes={"medication_group": "Aspirin"}, + ), + lx.data.Extraction( + extraction_class="dosage", + extraction_text="100mg", + attributes={"medication_group": "Aspirin"}, + ), + lx.data.Extraction( + extraction_class="frequency", + extraction_text="daily", + attributes={"medication_group": "Aspirin"}, + ), + lx.data.Extraction( + extraction_class="condition", + extraction_text="heart health", + attributes={"medication_group": "Aspirin"}, + ), + # Second medication group + lx.data.Extraction( + extraction_class="medication", + extraction_text="Simvastatin", + attributes={"medication_group": "Simvastatin"}, + ), + lx.data.Extraction( + extraction_class="dosage", + extraction_text="20mg", + attributes={"medication_group": "Simvastatin"}, + ), + lx.data.Extraction( + extraction_class="frequency", + extraction_text="at bedtime", + attributes={"medication_group": "Simvastatin"}, + ), + ], + ) + ] + + +def extract_by_class(result, extraction_class): + """Helper to extract entities by class. + + Returns a set of extraction texts for the given class. + """ + return { + e.extraction_text + for e in result.extractions + if e.extraction_class == extraction_class + } + + +def assert_extractions_contain(test_case, result, expected_classes): + """Assert that result contains all expected extraction classes. + + Uses unittest assertions for richer error messages. + """ + actual_classes = {e.extraction_class for e in result.extractions} + missing_classes = expected_classes - actual_classes + test_case.assertFalse( + missing_classes, + f"Missing expected classes: {missing_classes}. Found extractions:" + f" {[f'{e.extraction_class}:{e.extraction_text}' for e in result.extractions]}", + ) + + +def assert_valid_char_intervals(test_case, result): + """Assert that all extractions have valid char intervals and alignment status.""" + for extraction in result.extractions: + test_case.assertIsNotNone( + extraction.char_interval, + f"Missing char_interval for extraction: {extraction.extraction_text}", + ) + test_case.assertIsNotNone( + extraction.alignment_status, + "Missing alignment_status for extraction:" + f" {extraction.extraction_text}", + ) + if hasattr(result, "text") and result.text: + text_length = len(result.text) + test_case.assertGreaterEqual( + extraction.char_interval.start_pos, + 0, + f"Invalid start_pos for extraction: {extraction.extraction_text}", + ) + test_case.assertLessEqual( + extraction.char_interval.end_pos, + text_length, + f"Invalid end_pos for extraction: {extraction.extraction_text}", + ) + + +class TestLiveAPIGemini(unittest.TestCase): + """Tests using real Gemini API.""" + + @skip_if_no_gemini + @live_api + @retry_on_transient_errors(max_retries=2) + def test_medication_extraction(self): + """Test medication extraction with entities in order.""" + prompt = textwrap.dedent("""\ + Extract medication information including medication name, dosage, route, frequency, + and duration in the order they appear in the text.""") + + examples = get_basic_medication_examples() + input_text = "Patient took 400 mg PO Ibuprofen q4h for two days." + + result = lx.extract( + text_or_documents=input_text, + prompt_description=prompt, + examples=examples, + model_id=DEFAULT_GEMINI_MODEL, + api_key=GEMINI_API_KEY, + language_model_params=GEMINI_MODEL_PARAMS, + ) + + assert result is not None + assert hasattr(result, "extractions") + assert len(result.extractions) > 0 + + expected_classes = { + "dosage", + "route", + "medication", + "frequency", + "duration", + } + assert_extractions_contain(self, result, expected_classes) + assert_valid_char_intervals(self, result) + + # Using regex for precise matching to avoid false positives + medication_texts = extract_by_class(result, "medication") + self.assertTrue( + any( + re.search(r"\bIbuprofen\b", text, re.IGNORECASE) + for text in medication_texts + ), + f"No Ibuprofen found in: {medication_texts}", + ) + + dosage_texts = extract_by_class(result, "dosage") + self.assertTrue( + any( + re.search(r"\b400\s*mg\b", text, re.IGNORECASE) + for text in dosage_texts + ), + f"No 400mg dosage found in: {dosage_texts}", + ) + + route_texts = extract_by_class(result, "route") + self.assertTrue( + any( + re.search(r"\b(PO|oral)\b", text, re.IGNORECASE) + for text in route_texts + ), + f"No PO/oral route found in: {route_texts}", + ) + + @skip_if_no_gemini + @live_api + @retry_on_transient_errors(max_retries=2) + @pytest.mark.xfail( + reason=( + "Known tokenizer issue with non-Latin characters - see GitHub" + " issue #13" + ), + strict=True, + ) + def test_multilingual_medication_extraction(self): + """Test medication extraction with Japanese text.""" + text = ( # "The patient takes 10 mg of medication daily." + "ๆ‚ฃ่€…ใฏๆฏŽๆ—ฅ10mgใฎ่–ฌใ‚’ๆœ็”จใ—ใพใ™ใ€‚" + ) + + prompt = "Extract medication information including dosage and frequency." + + examples = [ + lx.data.ExampleData( + text="The patient takes 20mg of aspirin twice daily.", + extractions=[ + lx.data.Extraction( + extraction_class="medication", + extraction_text="aspirin", + attributes={"dosage": "20mg", "frequency": "twice daily"}, + ), + ], + ) + ] + + result = lx.extract( + text_or_documents=text, + prompt_description=prompt, + examples=examples, + model_id=DEFAULT_GEMINI_MODEL, + api_key=GEMINI_API_KEY, + language_model_params=GEMINI_MODEL_PARAMS, + ) + + assert result is not None + assert hasattr(result, "extractions") + assert len(result.extractions) > 0 + + medication_extractions = [ + e for e in result.extractions if e.extraction_class == "medication" + ] + assert ( + len(medication_extractions) > 0 + ), "No medication entities found in Japanese text" + assert_valid_char_intervals(self, result) + + @skip_if_no_gemini + @live_api + @retry_on_transient_errors(max_retries=2) + def test_medication_relationship_extraction(self): + """Test relationship extraction for medications with Gemini.""" + input_text = """ + The patient was prescribed Lisinopril and Metformin last month. + He takes the Lisinopril 10mg daily for hypertension, but often misses + his Metformin 500mg dose which should be taken twice daily for diabetes. + """ + + prompt = textwrap.dedent(""" + Extract medications with their details, using attributes to group related information: + + 1. Extract entities in the order they appear in the text + 2. Each entity must have a 'medication_group' attribute linking it to its medication + 3. All details about a medication should share the same medication_group value + """) + + examples = get_relationship_examples() + + result = lx.extract( + text_or_documents=input_text, + prompt_description=prompt, + examples=examples, + model_id=DEFAULT_GEMINI_MODEL, + api_key=GEMINI_API_KEY, + language_model_params=GEMINI_MODEL_PARAMS, + ) + + assert result is not None + assert len(result.extractions) > 0 + assert_valid_char_intervals(self, result) + + medication_groups = {} + for extraction in result.extractions: + assert ( + extraction.attributes is not None + ), f"Missing attributes for {extraction.extraction_text}" + assert ( + "medication_group" in extraction.attributes + ), f"Missing medication_group for {extraction.extraction_text}" + + group_name = extraction.attributes["medication_group"] + medication_groups.setdefault(group_name, []).append(extraction) + + assert ( + len(medication_groups) >= 2 + ), f"Expected at least 2 medications, found {len(medication_groups)}" + + # Allow flexible matching for dosage field (could be "dosage" or "dose") + for med_name, extractions in medication_groups.items(): + extraction_classes = {e.extraction_class for e in extractions} + # At minimum, each group should have the medication itself + assert ( + "medication" in extraction_classes + ), f"{med_name} group missing medication entity" + # Dosage is expected but might be formatted differently + assert any( + c in extraction_classes for c in ["dosage", "dose"] + ), f"{med_name} group missing dosage" + + +class TestLiveAPIOpenAI(unittest.TestCase): + """Tests using real OpenAI API.""" + + @skip_if_no_openai + @live_api + @retry_on_transient_errors(max_retries=2) + def test_medication_extraction(self): + """Test medication extraction with OpenAI models.""" + prompt = textwrap.dedent("""\ + Extract medication information including medication name, dosage, route, frequency, + and duration in the order they appear in the text.""") + + examples = get_basic_medication_examples() + input_text = "Patient took 400 mg PO Ibuprofen q4h for two days." + + result = lx.extract( + text_or_documents=input_text, + prompt_description=prompt, + examples=examples, + language_model_type=OpenAILanguageModel, + model_id=DEFAULT_OPENAI_MODEL, + api_key=OPENAI_API_KEY, + fence_output=True, + use_schema_constraints=False, + language_model_params=OPENAI_MODEL_PARAMS, + ) + + assert result is not None + assert hasattr(result, "extractions") + assert len(result.extractions) > 0 + + expected_classes = { + "dosage", + "route", + "medication", + "frequency", + "duration", + } + assert_extractions_contain(self, result, expected_classes) + assert_valid_char_intervals(self, result) + + # Using regex for precise matching to avoid false positives + medication_texts = extract_by_class(result, "medication") + self.assertTrue( + any( + re.search(r"\bIbuprofen\b", text, re.IGNORECASE) + for text in medication_texts + ), + f"No Ibuprofen found in: {medication_texts}", + ) + + dosage_texts = extract_by_class(result, "dosage") + self.assertTrue( + any( + re.search(r"\b400\s*mg\b", text, re.IGNORECASE) + for text in dosage_texts + ), + f"No 400mg dosage found in: {dosage_texts}", + ) + + route_texts = extract_by_class(result, "route") + self.assertTrue( + any( + re.search(r"\b(PO|oral)\b", text, re.IGNORECASE) + for text in route_texts + ), + f"No PO/oral route found in: {route_texts}", + ) + + @skip_if_no_openai + @live_api + @retry_on_transient_errors(max_retries=2) + def test_medication_relationship_extraction(self): + """Test relationship extraction for medications with OpenAI.""" + input_text = """ + The patient was prescribed Lisinopril and Metformin last month. + He takes the Lisinopril 10mg daily for hypertension, but often misses + his Metformin 500mg dose which should be taken twice daily for diabetes. + """ + + prompt = textwrap.dedent(""" + Extract medications with their details, using attributes to group related information: + + 1. Extract entities in the order they appear in the text + 2. Each entity must have a 'medication_group' attribute linking it to its medication + 3. All details about a medication should share the same medication_group value + """) + + examples = get_relationship_examples() + + result = lx.extract( + text_or_documents=input_text, + prompt_description=prompt, + examples=examples, + language_model_type=OpenAILanguageModel, + model_id=DEFAULT_OPENAI_MODEL, + api_key=OPENAI_API_KEY, + fence_output=True, + use_schema_constraints=False, + language_model_params=OPENAI_MODEL_PARAMS, + ) + + assert result is not None + assert len(result.extractions) > 0 + assert_valid_char_intervals(self, result) + + medication_groups = {} + for extraction in result.extractions: + assert ( + extraction.attributes is not None + ), f"Missing attributes for {extraction.extraction_text}" + assert ( + "medication_group" in extraction.attributes + ), f"Missing medication_group for {extraction.extraction_text}" + + group_name = extraction.attributes["medication_group"] + medication_groups.setdefault(group_name, []).append(extraction) + + assert ( + len(medication_groups) >= 2 + ), f"Expected at least 2 medications, found {len(medication_groups)}" + + # Allow flexible matching for dosage field (could be "dosage" or "dose") + for med_name, extractions in medication_groups.items(): + extraction_classes = {e.extraction_class for e in extractions} + # At minimum, each group should have the medication itself + assert ( + "medication" in extraction_classes + ), f"{med_name} group missing medication entity" + # Dosage is expected but might be formatted differently + assert any( + c in extraction_classes for c in ["dosage", "dose"] + ), f"{med_name} group missing dosage" diff --git a/tests/test_ollama_integration.py b/tests/test_ollama_integration.py new file mode 100644 index 00000000..5ab4397d --- /dev/null +++ b/tests/test_ollama_integration.py @@ -0,0 +1,80 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for Ollama functionality.""" +import socket + +import pytest +import requests + +import langextract as lx + + +def _ollama_available(): + """Check if Ollama is running on localhost:11434.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + result = sock.connect_ex(("localhost", 11434)) + return result == 0 + + +@pytest.mark.skipif(not _ollama_available(), reason="Ollama not running") +def test_ollama_extraction(): + """Test extraction using Ollama when available.""" + input_text = "Isaac Asimov was a prolific science fiction writer." + prompt = "Extract the author's full name and their primary literary genre." + + examples = [ + lx.data.ExampleData( + text=( + "J.R.R. Tolkien was an English writer, best known for" + " high-fantasy." + ), + extractions=[ + lx.data.Extraction( + extraction_class="author_details", + extraction_text="J.R.R. Tolkien was an English writer...", + attributes={ + "name": "J.R.R. Tolkien", + "genre": "high-fantasy", + }, + ) + ], + ) + ] + + model_id = "gemma2:2b" + + try: + result = lx.extract( + text_or_documents=input_text, + prompt_description=prompt, + examples=examples, + language_model_type=lx.inference.OllamaLanguageModel, + model_id=model_id, + model_url="http://localhost:11434", + temperature=0.3, + fence_output=False, + use_schema_constraints=False, + ) + + assert len(result.extractions) > 0 + extraction = result.extractions[0] + assert extraction.extraction_class == "author_details" + if extraction.attributes: + assert "asimov" in extraction.attributes.get("name", "").lower() + + except ValueError as e: + if "Can't find Ollama" in str(e): + pytest.skip(f"Ollama model {model_id} not available") + raise diff --git a/tests/tokenizer_test.py b/tests/tokenizer_test.py index 9d296978..021f802a 100644 --- a/tests/tokenizer_test.py +++ b/tests/tokenizer_test.py @@ -14,10 +14,11 @@ import textwrap -from langextract import tokenizer from absl.testing import absltest from absl.testing import parameterized +from langextract import tokenizer + class TokenizerTest(parameterized.TestCase): diff --git a/tests/visualization_test.py b/tests/visualization_test.py index 0cb7fbe2..647107f9 100644 --- a/tests/visualization_test.py +++ b/tests/visualization_test.py @@ -17,6 +17,7 @@ from unittest import mock from absl.testing import absltest + from langextract import data as lx_data from langextract import visualization diff --git a/tox.ini b/tox.ini index e8988af7..7abd98a0 100644 --- a/tox.ini +++ b/tox.ini @@ -13,7 +13,7 @@ # limitations under the License. [tox] -envlist = py310, py311 +envlist = py310, py311, py312, format, lint-src, lint-tests skip_missing_interpreters = True [testenv] @@ -22,5 +22,41 @@ setenv = deps = .[dev,test] commands = - pylint --rcfile=.pylintrc --score n langextract tests - pytest -q \ No newline at end of file + pytest -ra -m "not live_api" + +[testenv:format] +skip_install = true +deps = + isort>=5.13.2 + pyink~=24.3.0 +commands = + isort langextract tests --check-only --diff + pyink langextract tests --check --diff --config pyproject.toml + +[testenv:lint-src] +deps = + pylint>=3.0.0 +commands = + pylint --rcfile=.pylintrc langextract + +[testenv:lint-tests] +deps = + pylint>=3.0.0 +commands = + pylint --rcfile=tests/.pylintrc tests + +[testenv:live-api] +basepython = python3.11 +passenv = + GEMINI_API_KEY + LANGEXTRACT_API_KEY + OPENAI_API_KEY +deps = {[testenv]deps} +commands = + pytest tests/test_live_api.py -v -m live_api --maxfail=1 + +[testenv:ollama-integration] +basepython = python3.11 +deps = {[testenv]deps} +commands = + pytest tests/test_ollama_integration.py -v --tb=short