diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 1708201c..ba9aef5b 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -24,4 +24,4 @@ contact_links: url: https://g.co/vulnz about: > To report a security issue, please use https://g.co/vulnz. The Google Security Team will - respond within 5 working days of your report on https://g.co/vulnz. \ No newline at end of file + respond within 5 working days of your report on https://g.co/vulnz. diff --git a/.github/workflows/check-infrastructure-changes.yml b/.github/workflows/check-infrastructure-changes.yml new file mode 100644 index 00000000..c8c24c65 --- /dev/null +++ b/.github/workflows/check-infrastructure-changes.yml @@ -0,0 +1,100 @@ +name: Protect Infrastructure Files + +on: + pull_request_target: + types: [opened, synchronize, reopened] + workflow_dispatch: + +permissions: + contents: read + pull-requests: write + +jobs: + protect-infrastructure: + if: github.event_name == 'workflow_dispatch' || github.event.pull_request.draft == false + runs-on: ubuntu-latest + + steps: + - name: Check for infrastructure file changes + if: github.event_name == 'pull_request' + uses: actions/github-script@v7 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + // Get the PR author and check if they're a maintainer + const prAuthor = context.payload.pull_request.user.login; + const { data: authorPermission } = await github.rest.repos.getCollaboratorPermissionLevel({ + owner: context.repo.owner, + repo: context.repo.repo, + username: prAuthor + }); + + const isMaintainer = ['admin', 'maintain'].includes(authorPermission.permission); + + // Get list of files changed in the PR + const { data: files } = await github.rest.pulls.listFiles({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: context.payload.pull_request.number + }); + + // Check for infrastructure file changes + const infrastructureFiles = files.filter(file => + file.filename.startsWith('.github/') || + file.filename === 'pyproject.toml' || + file.filename === 'tox.ini' || + file.filename === '.pre-commit-config.yaml' || + file.filename === '.pylintrc' || + file.filename === 'Dockerfile' || + file.filename === 'autoformat.sh' || + file.filename === '.gitignore' || + file.filename === 'CONTRIBUTING.md' || + file.filename === 'LICENSE' || + file.filename === 'CITATION.cff' + ); + + if (infrastructureFiles.length > 0 && !isMaintainer) { + // Check if changes are only formatting/whitespace + let hasStructuralChanges = false; + for (const file of infrastructureFiles) { + const additions = file.additions || 0; + const deletions = file.deletions || 0; + const changes = file.changes || 0; + + // If file has significant changes (not just whitespace), consider it structural + if (additions > 5 || deletions > 5 || changes > 10) { + hasStructuralChanges = true; + break; + } + } + + const fileList = infrastructureFiles.map(f => ` - ${f.filename} (${f.changes} changes)`).join('\n'); + + // Post a comment explaining the issue + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `โŒ **Infrastructure File Protection**\n\n` + + `This PR modifies protected infrastructure files:\n\n${fileList}\n\n` + + `Only repository maintainers are allowed to modify infrastructure files (including \`.github/\`, build configuration, and repository documentation).\n\n` + + `**Note**: If these are only formatting changes, please:\n` + + `1. Revert changes to \`.github/\` files\n` + + `2. Use \`./autoformat.sh\` to format only source code directories\n` + + `3. Avoid running formatters on infrastructure files\n\n` + + `If structural changes are necessary:\n` + + `1. Open an issue describing the needed infrastructure changes\n` + + `2. A maintainer will review and implement the changes if approved\n\n` + + `For more information, see our [Contributing Guidelines](https://github.com/google/langextract/blob/main/CONTRIBUTING.md).` + }); + + core.setFailed( + `This PR modifies ${infrastructureFiles.length} protected infrastructure file(s). ` + + `Only maintainers can modify these files. ` + + `Use ./autoformat.sh to format code without touching infrastructure.` + ); + } else if (infrastructureFiles.length > 0 && isMaintainer) { + core.info(`PR modifies ${infrastructureFiles.length} infrastructure file(s) - allowed for maintainer ${prAuthor}`); + } else { + core.info('No infrastructure files modified'); + } diff --git a/.github/workflows/check-linked-issue.yml b/.github/workflows/check-linked-issue.yml new file mode 100644 index 00000000..b6ce7e62 --- /dev/null +++ b/.github/workflows/check-linked-issue.yml @@ -0,0 +1,90 @@ +name: Require linked issue with community support + +on: + pull_request_target: + types: [opened, edited, synchronize, reopened] + workflow_dispatch: + +permissions: + contents: read + issues: read + pull-requests: write + +jobs: + enforce: + if: github.event_name == 'workflow_dispatch' || github.event.pull_request.draft == false + runs-on: ubuntu-latest + + steps: + - name: Verify linked issue + if: github.event_name == 'pull_request' + uses: nearform-actions/github-action-check-linked-issues@v1.2.7 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + comment: true + exclude-branches: main + custom-body: | + No linked issues found. Please add "Fixes #" to your pull request description. + + Per our [Contributing Guidelines](https://github.com/google/langextract/blob/main/CONTRIBUTING.md#pull-request-guidelines), all PRs must: + - Reference an issue with "Fixes #123" or "Closes #123" + - The linked issue should have 5+ ๐Ÿ‘ reactions + - Include discussion demonstrating the importance of the change + + Use GitHub automation to close the issue when this PR is merged. + + - name: Check community support + if: github.event_name == 'pull_request' + uses: actions/github-script@v7 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + // Check if PR author is a maintainer + const prAuthor = context.payload.pull_request.user.login; + const { data: authorPermission } = await github.rest.repos.getCollaboratorPermissionLevel({ + owner: context.repo.owner, + repo: context.repo.repo, + username: prAuthor + }); + + const isMaintainer = ['admin', 'maintain'].includes(authorPermission.permission); + + const body = context.payload.pull_request.body || ''; + const match = body.match(/(?:Fixes|Closes|Resolves)\s+#(\d+)/i); + + if (!match) { + core.setFailed('No linked issue found'); + return; + } + + const issueNumber = Number(match[1]); + const { repository } = await github.graphql(` + query($owner: String!, $repo: String!, $number: Int!) { + repository(owner: $owner, name: $repo) { + issue(number: $number) { + reactionGroups { + content + users { + totalCount + } + } + } + } + } + `, { + owner: context.repo.owner, + repo: context.repo.repo, + number: issueNumber + }); + + const reactions = repository.issue.reactionGroups; + const thumbsUp = reactions.find(g => g.content === 'THUMBS_UP')?.users.totalCount || 0; + + core.info(`Issue #${issueNumber} has ${thumbsUp} ๐Ÿ‘ reactions`); + + const REQUIRED_THUMBS_UP = 5; + if (thumbsUp < REQUIRED_THUMBS_UP && !isMaintainer) { + core.setFailed(`Issue #${issueNumber} needs at least ${REQUIRED_THUMBS_UP} ๐Ÿ‘ reactions (currently has ${thumbsUp})`); + } else if (isMaintainer && thumbsUp < REQUIRED_THUMBS_UP) { + core.info(`Maintainer ${prAuthor} bypassing community support requirement (issue has ${thumbsUp} ๐Ÿ‘ reactions)`); + } \ No newline at end of file diff --git a/.github/workflows/check-pr-size.yml b/.github/workflows/check-pr-size.yml new file mode 100644 index 00000000..276b5d68 --- /dev/null +++ b/.github/workflows/check-pr-size.yml @@ -0,0 +1,44 @@ +name: Check PR size + +on: + pull_request_target: + types: [opened, synchronize, reopened] + workflow_dispatch: + +permissions: + contents: read + pull-requests: write + +jobs: + size: + runs-on: ubuntu-latest + steps: + - name: Evaluate PR size + if: github.event_name == 'pull_request' + uses: actions/github-script@v7 + with: + script: | + const pr = context.payload.pull_request; + const totalChanges = pr.additions + pr.deletions; + + core.info(`PR contains ${pr.additions} additions and ${pr.deletions} deletions (${totalChanges} total)`); + + const sizeLabel = + totalChanges < 50 ? 'size/XS' : + totalChanges < 150 ? 'size/S' : + totalChanges < 600 ? 'size/M' : + totalChanges < 1000 ? 'size/L' : 'size/XL'; + + await github.rest.issues.addLabels({ + ...context.repo, + issue_number: pr.number, + labels: [sizeLabel] + }); + + const MAX_LINES = 1000; + if (totalChanges > MAX_LINES) { + core.setFailed( + `This PR contains ${totalChanges} lines of changes, which exceeds the maximum of ${MAX_LINES} lines. ` + + `Please split this into smaller, focused pull requests.` + ); + } diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index fc8a2a87..4b060a20 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -15,6 +15,7 @@ name: CI on: + workflow_dispatch: push: branches: ["main"] pull_request: @@ -28,7 +29,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10", "3.11"] + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 @@ -42,6 +43,91 @@ jobs: python -m pip install --upgrade pip pip install -e ".[dev,test]" - - name: Run tox (lint + tests) + - name: Run unit tests and linting run: | - tox \ No newline at end of file + PY_VERSION=$(echo "${{ matrix.python-version }}" | tr -d '.') + tox -e py${PY_VERSION},format,lint-src,lint-tests + + live-api-tests: + needs: test + runs-on: ubuntu-latest + if: | + github.event_name == 'push' || + (github.event_name == 'pull_request' && + github.event.pull_request.head.repo.full_name == github.repository) + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev,test]" + + - name: Run live API tests + env: + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + LANGEXTRACT_API_KEY: ${{ secrets.GEMINI_API_KEY }} # For backward compatibility + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: | + if [[ -z "$GEMINI_API_KEY" && -z "$OPENAI_API_KEY" ]]; then + echo "::notice::Live API tests skipped - no provider secrets configured" + exit 0 + fi + tox -e live-api + + ollama-integration-test: + needs: test + runs-on: ubuntu-latest + if: github.event_name == 'pull_request' + + steps: + - uses: actions/checkout@v4 + + - name: Detect file changes + id: changes + uses: tj-actions/changed-files@v46 + with: + files: | + langextract/inference.py + examples/ollama/** + tests/test_ollama_integration.py + .github/workflows/ci.yaml + + - name: Skip if no Ollama changes + if: steps.changes.outputs.any_changed == 'false' + run: | + echo "No Ollama-related changes detected โ€“ skipping job." + exit 0 + + - name: Set up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: "3.11" + + - name: Launch Ollama container + run: | + docker run -d --name ollama \ + -p 127.0.0.1:11434:11434 \ + -v ollama:/root/.ollama \ + ollama/ollama:0.5.4 + for i in {1..20}; do + curl -fs http://localhost:11434/api/version && break + sleep 3 + done + + - name: Pull gemma2 model + run: docker exec ollama ollama pull gemma2:2b || true + + - name: Install tox + run: | + python -m pip install --upgrade pip + pip install tox + + - name: Run Ollama integration tests + run: tox -e ollama-integration diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 00000000..cb3ff700 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,55 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Publish to PyPI + +on: + release: + types: [published] + +permissions: + contents: read + id-token: write + +jobs: + pypi-publish: + name: Publish to PyPI + runs-on: ubuntu-latest + environment: pypi + permissions: + id-token: write + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install build dependencies + run: | + python -m pip install --upgrade pip + pip install build + + - name: Build package + run: python -m build + + - name: Verify build artifacts + run: | + ls -la dist/ + pip install twine + twine check dist/* + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/validate_pr_template.yaml b/.github/workflows/validate_pr_template.yaml new file mode 100644 index 00000000..26621213 --- /dev/null +++ b/.github/workflows/validate_pr_template.yaml @@ -0,0 +1,43 @@ +name: Validate PR template + +on: + pull_request_target: + types: [opened, edited, synchronize, reopened] + workflow_dispatch: + +permissions: + contents: read + +jobs: + check: + if: github.event_name == 'workflow_dispatch' || github.event.pull_request.draft == false # drafts can save early + runs-on: ubuntu-latest + + steps: + - name: Fail if template untouched + if: github.event_name == 'pull_request' + env: + PR_BODY: ${{ github.event.pull_request.body }} + run: | + printf '%s\n' "$PR_BODY" | tr -d '\r' > body.txt + + # Required sections from the template + required=( "# Description" "Fixes #" "# How Has This Been Tested?" "# Checklist" ) + err=0 + + # Check for required sections + for h in "${required[@]}"; do + grep -Fq "$h" body.txt || { echo "::error::$h missing"; err=1; } + done + + # Check for placeholder text that should be replaced + grep -Eiq 'Replace this with|Choose one:' body.txt && { + echo "::error::Template placeholders still present"; err=1; + } + + # Also check for the unmodified issue number placeholder + grep -Fq 'Fixes #[issue number]' body.txt && { + echo "::error::Issue number placeholder not updated"; err=1; + } + + exit $err diff --git a/.gitignore b/.gitignore index 458f449d..fc93e588 100644 --- a/.gitignore +++ b/.gitignore @@ -51,4 +51,4 @@ docs/_build/ *.swp # OS-specific -.DS_Store \ No newline at end of file +.DS_Store diff --git a/.hgignore b/.hgignore index 4ef06c6c..3fb66f47 100644 --- a/.hgignore +++ b/.hgignore @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -gdm/codeai/codemind/cli/GEMINI.md \ No newline at end of file +gdm/codeai/codemind/cli/GEMINI.md diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..84410316 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,46 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Pre-commit hooks for LangExtract +# Install with: pre-commit install +# Run manually: pre-commit run --all-files + +repos: + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (import sorting) + # Configuration is in pyproject.toml + + - repo: https://github.com/google/pyink + rev: 24.3.0 + hooks: + - id: pyink + name: pyink (Google's Black fork) + args: ["--config", "pyproject.toml"] + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: end-of-file-fixer + exclude: \.gif$|\.svg$ + - id: trailing-whitespace + - id: check-yaml + - id: check-added-large-files + args: ['--maxkb=1000'] + - id: check-merge-conflict + - id: check-case-conflict + - id: mixed-line-ending + args: ['--fix=lf'] diff --git a/.pylintrc b/.pylintrc index 5709bc73..2e09c87f 100644 --- a/.pylintrc +++ b/.pylintrc @@ -14,10 +14,418 @@ [MASTER] +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use. +jobs=0 + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Pickle collected data for later comparisons. +persistent=yes + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +# Note: These plugins require Pylint >= 3.0 +load-plugins= + pylint.extensions.docparams, + pylint.extensions.typing + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + [MESSAGES CONTROL] -disable=all -enable=F + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time. +enable= + useless-suppression + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). +disable= + abstract-method, # Protocol/ABC classes often have abstract methods + too-few-public-methods, # Valid for data classes with minimal interface + fixme, # TODO/FIXME comments are useful for tracking work + # --- Code style and formatting --- + line-too-long, # Handled by pyink formatter + bad-indentation, # Pyink uses 2-space indentation + # --- Design complexity --- + too-many-positional-arguments, + too-many-locals, + too-many-arguments, + too-many-branches, + too-many-statements, + too-many-nested-blocks, + # --- Style preferences --- + no-else-return, + no-else-raise, + # --- Documentation --- + missing-function-docstring, + missing-class-docstring, + missing-raises-doc, + # --- Gradual improvements --- + deprecated-typing-alias, # For typing.Type etc. + unspecified-encoding, + unused-import + [REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. output-format=text -reports=no \ No newline at end of file + +# Tells whether to display a full report or only the messages +reports=no + +# Activate the evaluation score. +score=no + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo,bar,baz,toto,tutu,tata + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Good variable names which should always be accepted, separated by a comma. +good-names=i,j,k,ex,Run,_,id,ok + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Naming style matching correct variable names. +variable-naming-style=snake_case + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format=LF + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=2 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=" " + +# Maximum number of characters on a single line. +max-line-length=80 + +# Maximum number of lines in a module. +max-module-lines=2000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME,XXX,TODO + + +[SIMILARITIES] + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package.. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,dataclasses.InitVar,typing.Any + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules=dotenv,absl,more_itertools,pandas,requests,pydantic,yaml,IPython.display, + tqdm,numpy,google,langfun,typing_extensions + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=cls + + +[DESIGN] + +# Maximum number of arguments for function / method. +max-args=7 + +# Maximum number of attributes for a class (see R0902). +max-attributes=10 + +# Maximum number of boolean expressions in an if statement. +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=0 + + +[IMPORTS] + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=yes + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules=optparse,tkinter.tix + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled). +ext-import-graph= + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled). +import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant,numpy,pandas,torch,langfun,pyglove + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "BaseException, Exception". +overgeneral-exceptions=BaseException, + Exception diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 00000000..2eb3134a --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 +# +# This file contains citation metadata for LangExtract. +# For more information visit: https://citation-file-format.github.io/ + +cff-version: 1.2.0 +title: "LangExtract" +message: "If you use this software, please cite it as below." +type: software +authors: + - given-names: Akshay + family-names: Goel + email: goelak@google.com + affiliation: Google LLC +repository-code: "https://github.com/google/langextract" +url: "https://github.com/google/langextract" +repository: "https://github.com/google/langextract" +abstract: "LangExtract: A library for extracting structured data from language models" +keywords: + - language-models + - structured-data-extraction + - nlp + - machine-learning + - python +license: Apache-2.0 +version: 1.0.3 +date-released: 2025-07-30 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 724ff7f6..aa5038d4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -23,13 +23,117 @@ sign a new one. This project follows HAI-DEF's [Community guidelines](https://developers.google.com/health-ai-developer-foundations/community-guidelines) -## Contribution process +## Reporting Issues -### Code Reviews +If you encounter a bug or have a feature request, please open an issue on GitHub. +We have templates to help guide you: + +- **[Bug Report](.github/ISSUE_TEMPLATE/1-bug.md)**: For reporting bugs or unexpected behavior +- **[Feature Request](.github/ISSUE_TEMPLATE/2-feature-request.md)**: For suggesting new features or improvements + +When creating an issue, GitHub will prompt you to choose the appropriate template. +Please provide as much detail as possible to help us understand and address your concern. + +## Contribution Process + +### 1. Development Setup + +To get started, clone the repository and install the necessary dependencies for development and testing. Detailed instructions can be found in the [Installation from Source](https://github.com/google/langextract#from-source) section of the `README.md`. + +**Windows Users**: The formatting scripts use bash. Please use one of: +- Git Bash (comes with Git for Windows) +- WSL (Windows Subsystem for Linux) +- PowerShell with bash-compatible commands + +### 2. Code Style and Formatting + +This project uses automated tools to maintain a consistent code style. Before submitting a pull request, please format your code: + +```bash +# Run the auto-formatter +./autoformat.sh +``` + +This script uses: +- `isort` to organize imports with Google style (single-line imports) +- `pyink` (Google's fork of Black) to format code according to Google's Python Style Guide + +You can also run the formatters manually: +```bash +isort langextract tests +pyink langextract tests --config pyproject.toml +``` + +Note: The formatters target only `langextract` and `tests` directories by default to avoid +formatting virtual environments or other non-source directories. + +### 3. Pre-commit Hooks (Recommended) + +For automatic formatting checks before each commit: + +```bash +# Install pre-commit +pip install pre-commit + +# Install the git hooks +pre-commit install + +# Run manually on all files +pre-commit run --all-files +``` + +### 4. Linting and Testing + +All contributions must pass linting checks and unit tests. Please run these locally before submitting your changes: + +```bash +# Run linting with Pylint 3.x +pylint --rcfile=.pylintrc langextract tests + +# Run tests +pytest tests +``` + +**Note on Pylint Configuration**: We use a modern, minimal configuration that: +- Only disables truly noisy checks (not entire categories) +- Keeps critical error detection enabled +- Uses plugins for enhanced docstring and type checking +- Aligns with our pyink formatter (80-char lines, 2-space indents) + +For full testing across Python versions: +```bash +tox # runs pylint + pytest on Python 3.10 and 3.11 +``` + +### 5. Submit Your Pull Request All submissions, including submissions by project members, require review. We use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests) for this purpose. +When you create a pull request, GitHub will automatically populate it with our +[pull request template](.github/PULL_REQUEST_TEMPLATE/pull_request_template.md). +Please fill out all sections of the template to help reviewers understand your changes. + +#### Pull Request Guidelines + +- **Keep PRs focused and small**: Each PR should address a single issue and contain one cohesive change. PRs are automatically labeled by size to help reviewers: + - **size/XS**: < 50 lines โ€” Small fixes and documentation updates + - **size/S**: 50-150 lines โ€” Typical features or bug fixes + - **size/M**: 150-600 lines โ€” Larger features that remain well-scoped + - **size/L**: 600-1000 lines โ€” Consider splitting into smaller PRs if possible + - **size/XL**: > 1000 lines โ€” Requires strong justification and may need special review +- **Reference related issues**: All PRs must include "Fixes #123" or "Closes #123" in the description. The linked issue should have at least 5 ๐Ÿ‘ reactions from the community and include discussion that demonstrates the importance and need for the change. +- **No infrastructure changes**: Contributors cannot modify infrastructure files, build configuration, and core documentation. These files are protected and can only be changed by maintainers. Use `./autoformat.sh` to format code without affecting infrastructure files. In special circumstances, build configuration updates may be considered if they include discussion and evidence of robust testing, ideally with community support. +- **Single-change commits**: A PR should typically comprise a single git commit. Squash multiple commits before submitting. +- **Clear description**: Explain what your change does and why it's needed. +- **Ensure all tests pass**: Check that both formatting and tests are green before requesting review. +- **Respond to feedback promptly**: Address reviewer comments in a timely manner. + +If your change is large or complex, consider: +- Opening an issue first to discuss the approach +- Breaking it into multiple smaller PRs +- Clearly explaining in the PR description why a larger change is necessary + For more details, read HAI-DEF's [Contributing guidelines](https://developers.google.com/health-ai-developer-foundations/community-guidelines#contributing) diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..e8a74312 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,11 @@ +# Production Dockerfile for LangExtract +FROM python:3.10-slim + +# Set working directory +WORKDIR /app + +# Install LangExtract from PyPI +RUN pip install --no-cache-dir langextract + +# Set default command +CMD ["python"] diff --git a/README.md b/README.md index b49d4b6a..2cbe5820 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@

- LangExtract Logo + LangExtract Logo

# LangExtract -[![PyPI version](https://badge.fury.io/py/langextract.svg)](https://badge.fury.io/py/langextract) +[![PyPI version](https://img.shields.io/pypi/v/langextract.svg)](https://pypi.org/project/langextract/) [![GitHub stars](https://img.shields.io/github/stars/google/langextract.svg?style=social&label=Star)](https://github.com/google/langextract) ![Tests](https://github.com/google/langextract/actions/workflows/ci.yaml/badge.svg) @@ -17,6 +17,8 @@ - [Quick Start](#quick-start) - [Installation](#installation) - [API Key Setup for Cloud Models](#api-key-setup-for-cloud-models) +- [Using OpenAI Models](#using-openai-models) +- [Using Local LLMs with Ollama](#using-local-llms-with-ollama) - [More Examples](#more-examples) - [*Romeo and Juliet* Full Text Extraction](#romeo-and-juliet-full-text-extraction) - [Medication Extraction](#medication-extraction) @@ -111,7 +113,7 @@ The extractions can be saved to a `.jsonl` file, a popular format for working wi ```python # Save the results to a JSONL file -lx.io.save_annotated_documents([result], output_name="extraction_results.jsonl") +lx.io.save_annotated_documents([result], output_name="extraction_results.jsonl", output_dir=".") # Generate the visualization from the file html_content = lx.visualize("extraction_results.jsonl") @@ -121,7 +123,7 @@ with open("visualization.html", "w") as f: This creates an animated and interactive HTML file: -![Romeo and Juliet Basic Visualization ](docs/_static/romeo_juliet_basic.gif) +![Romeo and Juliet Basic Visualization ](https://raw.githubusercontent.com/google/langextract/main/docs/_static/romeo_juliet_basic.gif) > **Note on LLM Knowledge Utilization:** This example demonstrates extractions that stay close to the text evidence - extracting "longing" for Lady Juliet's emotional state and identifying "yearning" from "gazed longingly at the stars." The task could be modified to generate attributes that draw more heavily from the LLM's world knowledge (e.g., adding `"identity": "Capulet family daughter"` or `"literary_context": "tragic heroine"`). The balance between text-evidence and knowledge-inference is controlled by your prompt instructions and example attributes. @@ -142,7 +144,7 @@ result = lx.extract( ) ``` -This approach can extract hundreds of entities from full novels while maintaining high accuracy. The interactive visualization seamlessly handles large result sets, making it easy to explore hundreds of entities from the output JSONL file. **[See the full *Romeo and Juliet* extraction example โ†’](docs/examples/longer_text_example.md)** for detailed results and performance insights. +This approach can extract hundreds of entities from full novels while maintaining high accuracy. The interactive visualization seamlessly handles large result sets, making it easy to explore hundreds of entities from the output JSONL file. **[See the full *Romeo and Juliet* extraction example โ†’](https://github.com/google/langextract/blob/main/docs/examples/longer_text_example.md)** for detailed results and performance insights. ## Installation @@ -181,10 +183,16 @@ pip install -e ".[dev]" pip install -e ".[test]" ``` +### Docker + +```bash +docker build -t langextract . +docker run --rm -e LANGEXTRACT_API_KEY="your-api-key" langextract python your_script.py +``` ## API Key Setup for Cloud Models -When using LangExtract with cloud-hosted models (like Gemini), you'll need to +When using LangExtract with cloud-hosted models (like Gemini or OpenAI), you'll need to set up an API key. On-device models don't require an API key. For developers using local LLMs, LangExtract offers built-in support for Ollama and can be extended to other third-party APIs by updating the inference endpoints. @@ -195,6 +203,7 @@ Get API keys from: * [AI Studio](https://aistudio.google.com/app/apikey) for Gemini models * [Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/sdks/overview) for enterprise use +* [OpenAI Platform](https://platform.openai.com/api-keys) for OpenAI models ### Setting up API key in your environment @@ -244,6 +253,50 @@ result = lx.extract( ) ``` +## Using OpenAI Models + +LangExtract also supports OpenAI models. Example OpenAI configuration: + +```python +import langextract as lx + +result = lx.extract( + text_or_documents=input_text, + prompt_description=prompt, + examples=examples, + language_model_type=lx.inference.OpenAILanguageModel, + model_id="gpt-4o", + api_key=os.environ.get('OPENAI_API_KEY'), + fence_output=True, + use_schema_constraints=False +) +``` + +Note: OpenAI models require `fence_output=True` and `use_schema_constraints=False` because LangExtract doesn't implement schema constraints for OpenAI yet. + +## Using Local LLMs with Ollama + +LangExtract supports local inference using Ollama, allowing you to run models without API keys: + +```python +import langextract as lx + +result = lx.extract( + text_or_documents=input_text, + prompt_description=prompt, + examples=examples, + language_model_type=lx.inference.OllamaLanguageModel, + model_id="gemma2:2b", # or any Ollama model + model_url="http://localhost:11434", + fence_output=False, + use_schema_constraints=False +) +``` + +**Quick setup:** Install Ollama from [ollama.com](https://ollama.com/), run `ollama pull gemma2:2b`, then `ollama serve`. + +For detailed installation, Docker setup, and examples, see [`examples/ollama/`](examples/ollama/). + ## More Examples Additional examples of LangExtract in action: @@ -252,7 +305,7 @@ Additional examples of LangExtract in action: LangExtract can process complete documents directly from URLs. This example demonstrates extraction from the full text of *Romeo and Juliet* from Project Gutenberg (147,843 characters), showing parallel processing, sequential extraction passes, and performance optimization for long document processing. -**[View *Romeo and Juliet* Full Text Example โ†’](docs/examples/longer_text_example.md)** +**[View *Romeo and Juliet* Full Text Example โ†’](https://github.com/google/langextract/blob/main/docs/examples/longer_text_example.md)** ### Medication Extraction @@ -260,7 +313,7 @@ LangExtract can process complete documents directly from URLs. This example demo LangExtract excels at extracting structured medical information from clinical text. These examples demonstrate both basic entity recognition (medication names, dosages, routes) and relationship extraction (connecting medications to their attributes), showing LangExtract's effectiveness for healthcare applications. -**[View Medication Examples โ†’](docs/examples/medication_examples.md)** +**[View Medication Examples โ†’](https://github.com/google/langextract/blob/main/docs/examples/medication_examples.md)** ### Radiology Report Structuring: RadExtract @@ -270,7 +323,7 @@ Explore RadExtract, a live interactive demo on HuggingFace Spaces that shows how ## Contributing -Contributions are welcome! See [CONTRIBUTING.md](CONTRIBUTING.md) to get started +Contributions are welcome! See [CONTRIBUTING.md](https://github.com/google/langextract/blob/main/CONTRIBUTING.md) to get started with development, testing, and pull requests. You must sign a [Contributor License Agreement](https://cla.developers.google.com/about) before submitting patches. @@ -297,14 +350,58 @@ Or reproduce the full CI matrix locally with tox: tox # runs pylint + pytest on Python 3.10 and 3.11 ``` +### Ollama Integration Testing + +If you have Ollama installed locally, you can run integration tests: + +```bash +# Test Ollama integration (requires Ollama running with gemma2:2b model) +tox -e ollama-integration +``` + +This test will automatically detect if Ollama is available and run real inference tests. + +## Development + +### Code Formatting + +This project uses automated formatting tools to maintain consistent code style: + +```bash +# Auto-format all code +./autoformat.sh + +# Or run formatters separately +isort langextract tests --profile google --line-length 80 +pyink langextract tests --config pyproject.toml +``` + +### Pre-commit Hooks + +For automatic formatting checks: +```bash +pre-commit install # One-time setup +pre-commit run --all-files # Manual run +``` + +### Linting + +Run linting before submitting PRs: + +```bash +pylint --rcfile=.pylintrc langextract tests +``` + +See [CONTRIBUTING.md](CONTRIBUTING.md) for full development guidelines. + ## Disclaimer This is not an officially supported Google product. If you use LangExtract in production or publications, please cite accordingly and -acknowledge usage. Use is subject to the [Apache 2.0 License](LICENSE). +acknowledge usage. Use is subject to the [Apache 2.0 License](https://github.com/google/langextract/blob/main/LICENSE). For health-related applications, use of LangExtract is also subject to the [Health AI Developer Foundations Terms of Use](https://developers.google.com/health-ai-developer-foundations/terms). --- -**Happy Extracting!** \ No newline at end of file +**Happy Extracting!** diff --git a/VERTEX_AI_INTEGRATION.md b/VERTEX_AI_INTEGRATION.md new file mode 100644 index 00000000..281fba52 --- /dev/null +++ b/VERTEX_AI_INTEGRATION.md @@ -0,0 +1,238 @@ +# Gemini Vertex AI Integration for LangExtract + +This document describes the new Gemini Vertex AI integration added to LangExtract, which allows you to use Google Cloud Vertex AI instead of API keys for authentication. + +## Overview + +The new `GeminiVertexLanguageModel` class provides: +- **Vertex AI Authentication**: Use Google Cloud project and location instead of API keys +- **Thinking Budget Control**: Configure reasoning capabilities for supported models +- **Full Feature Compatibility**: All existing features work with Vertex AI +- **Enhanced Security**: Leverage Google Cloud IAM and service accounts + +## Quick Start + +### Basic Usage + +```python +import langextract as lx + +# Define your examples +examples = [ + lx.data.ExampleData( + text="Patient takes Aspirin 100mg daily for heart health.", + extractions=[ + lx.data.Extraction( + extraction_class="medication", + extraction_text="Aspirin" + ), + lx.data.Extraction( + extraction_class="dosage", + extraction_text="100mg" + ), + ] + ) +] + +# Extract using Vertex AI +result = lx.extract( + text_or_documents="Patient was prescribed Lisinopril 10mg daily for hypertension.", + prompt_description="Extract medication information", + examples=examples, + project="your-gcp-project-id", # Your GCP project + location="global", # Vertex AI location + language_model_type=lx.inference.GeminiVertexLanguageModel +) +``` + +### Advanced Configuration + +```python +# Advanced Vertex AI configuration +result = lx.extract( + text_or_documents=text, + prompt_description=prompt, + examples=examples, + model_id="gemini-2.5-flash", + project="your-gcp-project-id", + location="us-central1", # Specific region + thinking_budget=1000, # Enable reasoning (0 = no thinking) + language_model_type=lx.inference.GeminiVertexLanguageModel, + temperature=0.1, + max_workers=5, + language_model_params={ + "safety_settings": [ + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, + ] + } +) +``` + +## Key Features + +### 1. Vertex AI Authentication +- **Project-based**: Use your Google Cloud project ID instead of API keys +- **Location-aware**: Specify the region for your Vertex AI deployment +- **IAM Integration**: Leverage Google Cloud's identity and access management + +### 2. Thinking Budget +- **Reasoning Control**: Set `thinking_budget` to control model reasoning steps +- **Performance Tuning**: Higher values allow more complex reasoning +- **Cost Management**: Set to 0 to disable thinking and reduce costs + +### 3. Full Compatibility +- **Schema Constraints**: Full support for structured outputs +- **Parallel Processing**: Multi-worker support for batch processing +- **Safety Settings**: Configure content filtering and safety thresholds + +## Parameters + +### New Parameters in `lx.extract()` + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `project` | `str \| None` | `None` | Google Cloud project ID for Vertex AI | +| `location` | `str` | `"global"` | Vertex AI location/region | +| `thinking_budget` | `int` | `0` | Reasoning budget (0 = no thinking) | + +### GeminiVertexLanguageModel Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `project` | `str \| None` | `None` | **Required** - GCP project ID | +| `location` | `str` | `"global"` | Vertex AI region | +| `model_id` | `str` | `"gemini-2.5-flash"` | Model identifier | +| `thinking_budget` | `int` | `0` | Reasoning steps allowed | +| `temperature` | `float` | `0.0` | Sampling temperature | +| `max_workers` | `int` | `10` | Parallel processing workers | + +## Authentication Setup + +### Prerequisites +1. **Google Cloud Project**: Active GCP project with Vertex AI enabled +2. **Authentication**: One of the following: + - Application Default Credentials (ADC) + - Service Account Key + - Google Cloud SDK authentication + +### Setup Steps + +1. **Install Google Cloud SDK** (if not already installed): + ```bash + # Follow instructions at: https://cloud.google.com/sdk/docs/install + ``` + +2. **Authenticate**: + ```bash + # Option 1: User authentication + gcloud auth application-default login + + # Option 2: Service account (recommended for production) + export GOOGLE_APPLICATION_CREDENTIALS="path/to/service-account-key.json" + ``` + +3. **Enable Vertex AI API**: + ```bash + gcloud services enable aiplatform.googleapis.com --project=your-project-id + ``` + +## Migration from API Key + +### Before (API Key) +```python +result = lx.extract( + text_or_documents=text, + prompt_description=prompt, + examples=examples, + api_key="your-api-key", + language_model_type=lx.inference.GeminiLanguageModel +) +``` + +### After (Vertex AI) +```python +result = lx.extract( + text_or_documents=text, + prompt_description=prompt, + examples=examples, + project="your-gcp-project-id", + location="global", + language_model_type=lx.inference.GeminiVertexLanguageModel +) +``` + +## Error Handling + +### Common Errors and Solutions + +1. **Missing Project ID**: + ``` + ValueError: Project ID not provided for Vertex AI + ``` + **Solution**: Provide the `project` parameter + +2. **Authentication Error**: + ``` + google.auth.exceptions.DefaultCredentialsError + ``` + **Solution**: Set up authentication (see Authentication Setup) + +3. **Mutually Exclusive Parameters**: + ``` + ValueError: Both api_key and project parameters are provided + ``` + **Solution**: Use either `api_key` OR `project`, not both + +## Performance Considerations + +### Thinking Budget +- **Low values (0-100)**: Fast responses, basic reasoning +- **Medium values (100-1000)**: Balanced performance and reasoning +- **High values (1000+)**: Deep reasoning, slower responses + +### Regional Deployment +- **Global**: Default, automatically routed +- **Regional**: Lower latency for specific regions +- **Multi-region**: Consider for high availability + +## Examples + +See the complete example in `examples/vertex_ai_example.py` for: +- Basic Vertex AI usage +- Advanced configuration +- Error handling +- Comparison with API key approach + +## Troubleshooting + +### Debugging Authentication +```python +# Test authentication +from google.auth import default +credentials, project = default() +print(f"Authenticated project: {project}") +``` + +### Verbose Logging +```python +import logging +logging.basicConfig(level=logging.DEBUG) +``` + +## Benefits of Vertex AI + +1. **Enterprise Security**: Leverage Google Cloud's security model +2. **Cost Management**: Better cost tracking and budgeting +3. **Scalability**: Enterprise-grade scaling and reliability +4. **Integration**: Seamless integration with other GCP services +5. **Compliance**: Meet enterprise compliance requirements + +## Next Steps + +1. **Try the Example**: Run `examples/vertex_ai_example.py` +2. **Set Up Authentication**: Configure your GCP credentials +3. **Test Integration**: Use the test script to verify setup +4. **Migrate Gradually**: Move from API keys to Vertex AI incrementally + +For more information, see the [Google Cloud Vertex AI documentation](https://cloud.google.com/vertex-ai/docs). diff --git a/autoformat.sh b/autoformat.sh new file mode 100755 index 00000000..5b7b1897 --- /dev/null +++ b/autoformat.sh @@ -0,0 +1,125 @@ +#!/bin/bash +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Autoformat LangExtract codebase +# +# Usage: ./autoformat.sh [target_directory ...] +# If no target is specified, formats the current directory +# +# This script runs: +# 1. isort for import sorting +# 2. pyink (Google's Black fork) for code formatting +# 3. pre-commit hooks for additional formatting (trailing whitespace, end-of-file, etc.) + +set -e + +echo "LangExtract Auto-formatter" +echo "==========================" +echo + +# Check for required tools +check_tool() { + if ! command -v "$1" &> /dev/null; then + echo "Error: $1 not found. Please install with: pip install $1" + exit 1 + fi +} + +check_tool "isort" +check_tool "pyink" +check_tool "pre-commit" + +# Parse command line arguments +show_usage() { + echo "Usage: $0 [target_directory ...]" + echo + echo "Formats Python code using isort and pyink according to Google style." + echo + echo "Arguments:" + echo " target_directory One or more directories to format (default: langextract tests)" + echo + echo "Examples:" + echo " $0 # Format langextract and tests directories" + echo " $0 langextract # Format only langextract directory" + echo " $0 src tests # Format multiple specific directories" +} + +# Check for help flag +if [ "$1" = "-h" ] || [ "$1" = "--help" ]; then + show_usage + exit 0 +fi + +# Determine target directories +if [ $# -eq 0 ]; then + TARGETS="langextract tests" + echo "No target specified. Formatting default directories: langextract tests" +else + TARGETS="$@" + echo "Formatting targets: $TARGETS" +fi + +# Find pyproject.toml relative to script location +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +CONFIG_FILE="${SCRIPT_DIR}/pyproject.toml" + +if [ ! -f "$CONFIG_FILE" ]; then + echo "Warning: pyproject.toml not found at ${CONFIG_FILE}" + echo "Using default configuration." + CONFIG_ARG="" +else + CONFIG_ARG="--config $CONFIG_FILE" +fi + +echo + +# Run isort +echo "Running isort to organize imports..." +if isort $TARGETS; then + echo "Import sorting complete" +else + echo "Import sorting failed" + exit 1 +fi + +echo + +# Run pyink +echo "Running pyink to format code (Google style, 80 chars)..." +if pyink $TARGETS $CONFIG_ARG; then + echo "Code formatting complete" +else + echo "Code formatting failed" + exit 1 +fi + +echo + +# Run pre-commit hooks for additional formatting +echo "Running pre-commit hooks for additional formatting..." +if pre-commit run --all-files; then + echo "Pre-commit hooks passed" +else + echo "Pre-commit hooks made changes - please review" + # Exit with success since formatting was applied + exit 0 +fi + +echo +echo "All formatting complete!" +echo +echo "Next steps:" +echo " - Run: pylint --rcfile=${SCRIPT_DIR}/.pylintrc $TARGETS" +echo " - Commit your changes" diff --git a/docs/examples/longer_text_example.md b/docs/examples/longer_text_example.md index 62d1ff39..5adb4c06 100644 --- a/docs/examples/longer_text_example.md +++ b/docs/examples/longer_text_example.md @@ -76,7 +76,7 @@ result = lx.extract( print(f"Extracted {len(result.extractions)} entities from {len(result.text):,} characters") # Save and visualize the results -lx.io.save_annotated_documents([result], output_name="romeo_juliet_extractions.jsonl") +lx.io.save_annotated_documents([result], output_name="romeo_juliet_extractions.jsonl", output_dir=".") # Generate the interactive visualization html_content = lx.visualize("romeo_juliet_extractions.jsonl") @@ -171,4 +171,4 @@ LangExtract combines precise text positioning with world knowledge enrichment, e --- -ยน Models like Gemini 1.5 Pro show strong performance on many benchmarks, but [needle-in-a-haystack tests](https://cloud.google.com/blog/products/ai-machine-learning/the-needle-in-the-haystack-test-and-how-gemini-pro-solves-it) across million-token contexts indicate that performance can vary in multi-fact retrieval scenarios. This demonstrates how LangExtract's smaller context windows approach ensures consistently high quality across entire documents by avoiding the complexity and potential degradation of massive single-context processing. \ No newline at end of file +ยน Models like Gemini 1.5 Pro show strong performance on many benchmarks, but [needle-in-a-haystack tests](https://cloud.google.com/blog/products/ai-machine-learning/the-needle-in-the-haystack-test-and-how-gemini-pro-solves-it) across million-token contexts indicate that performance can vary in multi-fact retrieval scenarios. This demonstrates how LangExtract's smaller context windows approach ensures consistently high quality across entire documents by avoiding the complexity and potential degradation of massive single-context processing. diff --git a/docs/examples/medication_examples.md b/docs/examples/medication_examples.md index 7fb27b11..d6474964 100644 --- a/docs/examples/medication_examples.md +++ b/docs/examples/medication_examples.md @@ -62,7 +62,7 @@ for entity in result.extractions: print(f"โ€ข {entity.extraction_class.capitalize()}: {entity.extraction_text}{position_info}") # Save and visualize the results -lx.io.save_annotated_documents([result], output_name="medical_ner_extraction.jsonl") +lx.io.save_annotated_documents([result], output_name="medical_ner_extraction.jsonl", output_dir=".") # Generate the interactive visualization html_content = lx.visualize("medical_ner_extraction.jsonl") @@ -193,7 +193,11 @@ for med_name, extractions in medication_groups.items(): print(f" โ€ข {extraction.extraction_class.capitalize()}: {extraction.extraction_text}{position_info}") # Save and visualize the results -lx.io.save_annotated_documents([result], output_name="medical_relationship_extraction.jsonl") +lx.io.save_annotated_documents( + [result], + output_name="medical_ner_extraction.jsonl", + output_dir="." +) # Generate the interactive visualization html_content = lx.visualize("medical_relationship_extraction.jsonl") @@ -239,4 +243,4 @@ This example demonstrates how attributes enable efficient relationship extractio - **Relationship Extraction**: Groups related entities using attributes - **Position Tracking**: Records exact positions of extracted entities in the source text - **Structured Output**: Organizes information in a format suitable for healthcare applications -- **Interactive Visualization**: Generates HTML visualizations for exploring complex medical extractions with entity groupings and relationships clearly displayed \ No newline at end of file +- **Interactive Visualization**: Generates HTML visualizations for exploring complex medical extractions with entity groupings and relationships clearly displayed diff --git a/examples/ollama/.dockerignore b/examples/ollama/.dockerignore new file mode 100644 index 00000000..77374252 --- /dev/null +++ b/examples/ollama/.dockerignore @@ -0,0 +1,35 @@ +# Ignore Python cache +__pycache__/ +*.pyc +*.pyo +*.pyd +.Python + +# Ignore version control +.git/ +.gitignore + +# Ignore OS files +.DS_Store +Thumbs.db + +# Ignore virtual environments +venv/ +env/ +.venv/ + +# Ignore IDE files +.vscode/ +.idea/ +*.swp +*.swo + +# Ignore test artifacts +.pytest_cache/ +.coverage +htmlcov/ + +# Ignore build artifacts +build/ +dist/ +*.egg-info/ diff --git a/examples/ollama/Dockerfile b/examples/ollama/Dockerfile new file mode 100644 index 00000000..48690a6a --- /dev/null +++ b/examples/ollama/Dockerfile @@ -0,0 +1,23 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM python:3.11-slim-bookworm + +WORKDIR /app + +RUN pip install langextract + +COPY quickstart.py . + +CMD ["python", "quickstart.py"] diff --git a/examples/ollama/README.md b/examples/ollama/README.md new file mode 100644 index 00000000..2fff8593 --- /dev/null +++ b/examples/ollama/README.md @@ -0,0 +1,32 @@ +# Ollama Examples + +This directory contains examples for using LangExtract with Ollama for local LLM inference. + +For setup instructions and documentation, see the [main README's Ollama section](../../README.md#using-local-llms-with-ollama). + +## Quick Reference + +**Local setup:** +```bash +ollama pull gemma2:2b +python quickstart.py +``` + +**Docker setup:** +```bash +docker-compose up +``` + +## Files + +- `quickstart.py` - Basic extraction example with configurable model +- `docker-compose.yml` - Production-ready Docker setup with health checks +- `Dockerfile` - Container definition for LangExtract + +## Model License + +Ollama models come with their own licenses. For example: +- Gemma models: [Gemma Terms of Use](https://ai.google.dev/gemma/terms) +- Llama models: [Meta Llama License](https://llama.meta.com/llama-downloads/) + +Please review the license for any model you use. diff --git a/examples/ollama/docker-compose.yml b/examples/ollama/docker-compose.yml new file mode 100644 index 00000000..431765ea --- /dev/null +++ b/examples/ollama/docker-compose.yml @@ -0,0 +1,42 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +services: + ollama: + image: ollama/ollama:0.5.4 + ports: + - "127.0.0.1:11434:11434" # Bind only to localhost for security + volumes: + - ollama-data:/root/.ollama # Cross-platform support + command: serve + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:11434/api/version"] + interval: 5s + timeout: 3s + retries: 5 + start_period: 10s + + langextract: + build: . + depends_on: + ollama: + condition: service_healthy + environment: + - OLLAMA_HOST=http://ollama:11434 + volumes: + - .:/app + command: python quickstart.py + +volumes: + ollama-data: diff --git a/examples/ollama/quickstart.py b/examples/ollama/quickstart.py new file mode 100644 index 00000000..ed578412 --- /dev/null +++ b/examples/ollama/quickstart.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quick-start example for using Ollama with langextract.""" + +import argparse +import os + +import langextract as lx + + +def run_extraction(model_id="gemma2:2b", temperature=0.3): + """Run a simple extraction example using Ollama.""" + input_text = "Isaac Asimov was a prolific science fiction writer." + + prompt = "Extract the author's full name and their primary literary genre." + + examples = [ + lx.data.ExampleData( + text=( + "J.R.R. Tolkien was an English writer, best known for" + " high-fantasy." + ), + extractions=[ + lx.data.Extraction( + extraction_class="author_details", + # extraction_text includes full context with ellipsis for clarity + extraction_text="J.R.R. Tolkien was an English writer...", + attributes={ + "name": "J.R.R. Tolkien", + "genre": "high-fantasy", + }, + ) + ], + ) + ] + + result = lx.extract( + text_or_documents=input_text, + prompt_description=prompt, + examples=examples, + language_model_type=lx.inference.OllamaLanguageModel, + model_id=model_id, + model_url=os.getenv("OLLAMA_HOST", "http://localhost:11434"), + temperature=temperature, + fence_output=False, + use_schema_constraints=False, + ) + + return result + + +def main(): + """Main function to run the quick-start example.""" + parser = argparse.ArgumentParser(description="Run Ollama extraction example") + parser.add_argument( + "--model-id", + default=os.getenv("MODEL_ID", "gemma2:2b"), + help="Ollama model ID (default: gemma2:2b or MODEL_ID env var)", + ) + parser.add_argument( + "--temperature", + type=float, + default=float(os.getenv("TEMPERATURE", "0.3")), + help="Model temperature (default: 0.3 or TEMPERATURE env var)", + ) + args = parser.parse_args() + + print(f"๐Ÿš€ Running Ollama quick-start example with {args.model_id}...") + print("-" * 50) + + try: + result = run_extraction( + model_id=args.model_id, temperature=args.temperature + ) + + for extraction in result.extractions: + print(f"Class: {extraction.extraction_class}") + print(f"Text: {extraction.extraction_text}") + print(f"Attributes: {extraction.attributes}") + + print("\nโœ… SUCCESS! Ollama is working with langextract") + return True + + except ConnectionError as e: + print(f"\nConnectionError: {e}") + print("Make sure Ollama is running: 'ollama serve'") + return False + except Exception as e: + print(f"\nError: {type(e).__name__}: {e}") + return False + + +if __name__ == "__main__": + success = main() + exit(0 if success else 1) diff --git a/examples/vertex_ai_example.py b/examples/vertex_ai_example.py new file mode 100644 index 00000000..139ef495 --- /dev/null +++ b/examples/vertex_ai_example.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +""" +Example demonstrating Gemini Vertex AI integration with langextract. + +This example shows how to use the new GeminiVertexLanguageModel for +extraction tasks using Google Cloud Vertex AI instead of API keys. +""" + +import langextract as lx + +def main(): + # Example text for extraction + text = """ + Patient was prescribed Lisinopril 10mg daily for hypertension and + Metformin 500mg twice daily for diabetes management. The patient + should take Lisinopril in the morning and Metformin with meals. + """ + + # Define examples for medication extraction + examples = [ + lx.data.ExampleData( + text="Patient takes Aspirin 100mg daily for heart health.", + extractions=[ + lx.data.Extraction( + extraction_class="medication", + extraction_text="Aspirin", + attributes={"medication_group": "Aspirin"} + ), + lx.data.Extraction( + extraction_class="dosage", + extraction_text="100mg", + attributes={"medication_group": "Aspirin"} + ), + lx.data.Extraction( + extraction_class="frequency", + extraction_text="daily", + attributes={"medication_group": "Aspirin"} + ), + lx.data.Extraction( + extraction_class="condition", + extraction_text="heart health", + attributes={"medication_group": "Aspirin"} + ), + ], + ) + ] + + prompt = """ + Extract medication information including: + - medication name + - dosage + - frequency + - condition being treated + + Group related information using the medication_group attribute. + """ + + print("=== Vertex AI Example ===") + print("Using GeminiVertexLanguageModel with project and location") + + try: + # Example using Vertex AI (replace with your project details) + result = lx.extract( + text_or_documents=text, + prompt_description=prompt, + examples=examples, + model_id="gemini-2.5-flash", + project="your-project-id", # Replace with your GCP project ID + location="global", # or your preferred region like "us-central1" + thinking_budget=0, # Set to higher values for more reasoning + language_model_type=lx.inference.GeminiVertexLanguageModel, + temperature=0.1, + use_schema_constraints=True, + ) + + print(f"Extracted {len(result.extractions)} entities:") + for extraction in result.extractions: + print(f" - {extraction.extraction_class}: {extraction.extraction_text}") + if extraction.attributes: + print(f" Attributes: {extraction.attributes}") + + except ValueError as e: + print(f"Configuration error: {e}") + print("\nTo use this example:") + print("1. Replace 'your-project-id' with your actual GCP project ID") + print("2. Ensure you have Vertex AI enabled and proper authentication") + print("3. Make sure you have the required permissions for Vertex AI") + + print("\n=== API Key Example (for comparison) ===") + print("Using standard GeminiLanguageModel with API key") + + try: + # Example using API key (traditional approach) + result = lx.extract( + text_or_documents=text, + prompt_description=prompt, + examples=examples, + model_id="gemini-2.5-flash", + api_key="your-api-key", # Replace with your API key or set LANGEXTRACT_API_KEY + language_model_type=lx.inference.GeminiLanguageModel, + temperature=0.1, + use_schema_constraints=True, + ) + + print(f"Extracted {len(result.extractions)} entities:") + for extraction in result.extractions: + print(f" - {extraction.extraction_class}: {extraction.extraction_text}") + if extraction.attributes: + print(f" Attributes: {extraction.attributes}") + + except ValueError as e: + print(f"Configuration error: {e}") + print("\nTo use this example:") + print("1. Replace 'your-api-key' with your actual Gemini API key") + print("2. Or set the LANGEXTRACT_API_KEY environment variable") + + print("\n=== Advanced Vertex AI Features ===") + print("Demonstrating thinking budget and safety settings") + + # Example with advanced Vertex AI features + try: + # Advanced configuration with thinking budget and safety settings + advanced_params = { + "safety_settings": [ + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, + ] + } + + result = lx.extract( + text_or_documents=text, + prompt_description=prompt, + examples=examples, + model_id="gemini-2.5-flash", + project="your-project-id", # Replace with your GCP project ID + location="global", + thinking_budget=1000, # Allow more reasoning steps + language_model_type=lx.inference.GeminiVertexLanguageModel, + temperature=0.1, + language_model_params=advanced_params, + use_schema_constraints=True, + ) + + print(f"Advanced extraction found {len(result.extractions)} entities") + + except ValueError as e: + print(f"Advanced configuration error: {e}") + +if __name__ == "__main__": + main() diff --git a/exceptions.py b/exceptions.py new file mode 100644 index 00000000..0199da56 --- /dev/null +++ b/exceptions.py @@ -0,0 +1,30 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base exceptions for LangExtract. + +This module defines the base exception class that all LangExtract exceptions +inherit from. Individual modules define their own specific exceptions. +""" + +__all__ = ["LangExtractError"] + + +class LangExtractError(Exception): + """Base exception for all LangExtract errors. + + All exceptions raised by LangExtract should inherit from this class. + This allows users to catch all LangExtract-specific errors with a single + except clause. + """ diff --git a/kokoro/presubmit.cfg b/kokoro/presubmit.cfg index 6d821424..746c6d14 100644 --- a/kokoro/presubmit.cfg +++ b/kokoro/presubmit.cfg @@ -28,4 +28,4 @@ container_properties { xunit_test_results { target_name: "pytest_results" result_xml_path: "git/repo/pytest_results/test.xml" -} \ No newline at end of file +} diff --git a/kokoro/test.sh b/kokoro/test.sh index ba75ace2..87134817 100644 --- a/kokoro/test.sh +++ b/kokoro/test.sh @@ -103,4 +103,4 @@ deactivate echo "=========================================" echo "Kokoro test script for langextract finished successfully." -echo "=========================================" \ No newline at end of file +echo "=========================================" diff --git a/langextract/__init__.py b/langextract/__init__.py index 817e04f2..20f80816 100644 --- a/langextract/__init__.py +++ b/langextract/__init__.py @@ -18,13 +18,14 @@ from collections.abc import Iterable, Sequence import os -from typing import Any, Type, TypeVar, cast +from typing import Any, cast, Type, TypeVar import warnings import dotenv from langextract import annotation from langextract import data +from langextract import exceptions from langextract import inference from langextract import io from langextract import prompting @@ -32,6 +33,19 @@ from langextract import schema from langextract import visualization +__all__ = [ + "extract", + "visualize", + "annotation", + "data", + "exceptions", + "inference", + "io", + "prompting", + "resolver", + "schema", + "visualization", +] LanguageModelT = TypeVar("LanguageModelT", bound=inference.BaseLanguageModel) @@ -48,6 +62,9 @@ def extract( examples: Sequence[data.ExampleData] | None = None, model_id: str = "gemini-2.5-flash", api_key: str | None = None, + project: str | None = None, + location: str = "global", + thinking_budget: int = 0, language_model_type: Type[LanguageModelT] = inference.GeminiLanguageModel, format_type: data.FormatType = data.FormatType.JSON, max_char_buffer: int = 1000, @@ -76,15 +93,20 @@ def extract( of Document objects. prompt_description: Instructions for what to extract from the text. examples: List of ExampleData objects to guide the extraction. - api_key: API key for Gemini or other LLM services (can also use - environment variable LANGEXTRACT_API_KEY). Cost considerations: Most - APIs charge by token volume. Smaller max_char_buffer values increase the - number of API calls, while extraction_passes > 1 reprocesses tokens - multiple times. Note that max_workers improves processing speed without - additional token costs. Refer to your API provider's pricing details and - monitor usage with small test runs to estimate costs. model_id: The model ID to use for extraction. + api_key: API key for Gemini or other LLM services (can also use + environment variable LANGEXTRACT_API_KEY). Used for standard Gemini API + access. Mutually exclusive with project parameter. + project: Google Cloud project ID for Vertex AI access. Used with + GeminiVertexLanguageModel. Mutually exclusive with api_key parameter. + location: Google Cloud location/region for Vertex AI. Defaults to "global". + Only used when project is specified. + thinking_budget: Thinking budget for reasoning models (0 = no thinking). + Higher values allow more reasoning steps. Only supported by Vertex AI + models. Defaults to 0. language_model_type: The type of language model to use for inference. + Use GeminiLanguageModel for API key access or GeminiVertexLanguageModel + for Vertex AI access. format_type: The format type for the output (JSON or YAML). max_char_buffer: Max number of characters for inference. temperature: The sampling temperature for generation. Higher values (e.g., @@ -133,7 +155,8 @@ def extract( Raises: ValueError: If examples is None or empty. - ValueError: If no API key is provided or found in environment variables. + ValueError: If no API key or project is provided for cloud-hosted models. + ValueError: If both api_key and project are provided (mutually exclusive). requests.RequestException: If URL download fails. """ if not examples: @@ -166,37 +189,67 @@ def extract( ) prompt_template.examples.extend(examples) + # Validate authentication parameters + if api_key and project: + raise ValueError( + "Both api_key and project parameters are provided. These are mutually " + "exclusive. Use api_key for standard Gemini API access or project for " + "Vertex AI access." + ) + # Generate schema constraints if enabled model_schema = None schema_constraint = None # TODO: Unify schema generation. - if ( - use_schema_constraints - and language_model_type == inference.GeminiLanguageModel + if use_schema_constraints and language_model_type in ( + inference.GeminiLanguageModel, + inference.GeminiVertexLanguageModel, ): model_schema = schema.GeminiSchema.from_examples(prompt_template.examples) - if not api_key: - api_key = os.environ.get("LANGEXTRACT_API_KEY") + # Handle authentication for different model types + if language_model_type == inference.GeminiVertexLanguageModel: + # Vertex AI authentication + if not project: + raise ValueError( + "Project ID must be provided for Vertex AI models via the project " + "parameter" + ) + + base_lm_kwargs: dict[str, Any] = { + "project": project, + "location": location, + "model_id": model_id, + "gemini_schema": model_schema, + "format_type": format_type, + "temperature": temperature, + "thinking_budget": thinking_budget, + "constraint": schema_constraint, + "max_workers": max_workers, + } + else: + # Standard API key authentication + if not api_key: + api_key = os.environ.get("LANGEXTRACT_API_KEY") - # Currently only Gemini is supported + # Currently only Gemini is supported for API key access if not api_key and language_model_type == inference.GeminiLanguageModel: raise ValueError( "API key must be provided for cloud-hosted models via the api_key" " parameter or the LANGEXTRACT_API_KEY environment variable" ) - base_lm_kwargs: dict[str, Any] = { - "api_key": api_key, - "model_id": model_id, - "gemini_schema": model_schema, - "format_type": format_type, - "temperature": temperature, - "model_url": model_url, - "constraint": schema_constraint, - "max_workers": max_workers, - } + base_lm_kwargs: dict[str, Any] = { + "api_key": api_key, + "model_id": model_id, + "gemini_schema": model_schema, + "format_type": format_type, + "temperature": temperature, + "model_url": model_url, + "constraint": schema_constraint, + "max_workers": max_workers, + } # Merge user-provided params which have precedence over defaults. base_lm_kwargs.update(language_model_params or {}) diff --git a/langextract/annotation.py b/langextract/annotation.py index fe3b5a54..a370be9e 100644 --- a/langextract/annotation.py +++ b/langextract/annotation.py @@ -31,6 +31,7 @@ from langextract import chunking from langextract import data +from langextract import exceptions from langextract import inference from langextract import progress from langextract import prompting @@ -39,7 +40,7 @@ ATTRIBUTE_SUFFIX = "_attributes" -class DocumentRepeatError(Exception): +class DocumentRepeatError(exceptions.LangExtractError): """Exception raised when identical document ids are present.""" diff --git a/langextract/chunking.py b/langextract/chunking.py index 3625d7a1..2663ed85 100644 --- a/langextract/chunking.py +++ b/langextract/chunking.py @@ -28,10 +28,11 @@ import more_itertools from langextract import data +from langextract import exceptions from langextract import tokenizer -class TokenUtilError(Exception): +class TokenUtilError(exceptions.LangExtractError): """Error raised when token_util returns unexpected values.""" diff --git a/langextract/exceptions.py b/langextract/exceptions.py new file mode 100644 index 00000000..b3103ab7 --- /dev/null +++ b/langextract/exceptions.py @@ -0,0 +1,26 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base exceptions for LangExtract.""" + +__all__ = ["LangExtractError"] + + +class LangExtractError(Exception): + """Base exception for all LangExtract errors. + + All exceptions raised by LangExtract should inherit from this class. + This allows users to catch all LangExtract-specific errors with a single + except clause. + """ diff --git a/langextract/inference.py b/langextract/inference.py index 822661fd..690660e4 100644 --- a/langextract/inference.py +++ b/langextract/inference.py @@ -24,17 +24,16 @@ from typing import Any from google import genai -import langfun as lf +from google.genai import types +import openai import requests from typing_extensions import override import yaml - - from langextract import data +from langextract import exceptions from langextract import schema - _OLLAMA_DEFAULT_MODEL_URL = 'http://localhost:11434' @@ -52,7 +51,7 @@ def __str__(self) -> str: return f'Score: {self.score:.2f}\nOutput:\n{formatted_lines}' -class InferenceOutputError(Exception): +class InferenceOutputError(exceptions.LangExtractError): """Exception raised when no scored outputs are available from the language model.""" def __init__(self, message: str): @@ -99,49 +98,6 @@ class InferenceType(enum.Enum): MULTIPROCESS = 'multiprocess' -# TODO: Add support for llm options. -@dataclasses.dataclass(init=False) -class LangFunLanguageModel(BaseLanguageModel): - """Language model inference class using LangFun language class. - - See https://github.com/google/langfun for more details on LangFun. - """ - - _lm: lf.core.language_model.LanguageModel # underlying LangFun model - _constraint: schema.Constraint = dataclasses.field( - default_factory=schema.Constraint, repr=False, compare=False - ) - _extra_kwargs: dict[str, Any] = dataclasses.field( - default_factory=dict, repr=False, compare=False - ) - - def __init__( - self, - language_model: lf.core.language_model.LanguageModel, - constraint: schema.Constraint = schema.Constraint(), - **kwargs, - ) -> None: - self._lm = language_model - self._constraint = constraint - - # Preserve any unused kwargs for debugging / future use - self._extra_kwargs = kwargs or {} - super().__init__(constraint=constraint) - - @override - def infer( - self, batch_prompts: Sequence[str], **kwargs - ) -> Iterator[Sequence[ScoredOutput]]: - responses = self._lm.sample(prompts=batch_prompts) - for a_response in responses: - for sample in a_response.samples: - yield [ - ScoredOutput( - score=sample.response.score, output=sample.response.text - ) - ] - - @dataclasses.dataclass(init=False) class OllamaLanguageModel(BaseLanguageModel): """Language model inference class using Ollama based host.""" @@ -158,13 +114,13 @@ class OllamaLanguageModel(BaseLanguageModel): def __init__( self, - model: str, + model_id: str, model_url: str = _OLLAMA_DEFAULT_MODEL_URL, structured_output_format: str = 'json', constraint: schema.Constraint = schema.Constraint(), **kwargs, ) -> None: - self._model = model + self._model = model_id self._model_url = model_url self._structured_output_format = structured_output_format self._constraint = constraint @@ -429,7 +385,344 @@ def infer( yield [result] def parse_output(self, output: str) -> Any: - """Parses Gemini output as JSON or YAML.""" + """Parses Gemini output as JSON or YAML. + + Note: This expects raw JSON/YAML without code fences. + Code fence extraction is handled by resolver.py. + """ + try: + if self.format_type == data.FormatType.JSON: + return json.loads(output) + else: + return yaml.safe_load(output) + except Exception as e: + raise ValueError( + f'Failed to parse output as {self.format_type.name}: {str(e)}' + ) from e + + +@dataclasses.dataclass(init=False) +class GeminiVertexLanguageModel(BaseLanguageModel): + """Language model inference using Google's Gemini Vertex AI with structured output.""" + + model_id: str = 'gemini-2.5-flash' + project: str | None = None + location: str = 'global' + gemini_schema: schema.GeminiSchema | None = None + format_type: data.FormatType = data.FormatType.JSON + temperature: float = 0.0 + thinking_budget: int = 0 + max_workers: int = 10 + _extra_kwargs: dict[str, Any] = dataclasses.field( + default_factory=dict, repr=False, compare=False + ) + + def __init__( + self, + model_id: str = 'gemini-2.5-flash', + project: str | None = None, + location: str = 'global', + gemini_schema: schema.GeminiSchema | None = None, + format_type: data.FormatType = data.FormatType.JSON, + temperature: float = 0.0, + thinking_budget: int = 0, + max_workers: int = 10, + **kwargs, + ) -> None: + """Initialize the Gemini Vertex AI language model. + + Args: + model_id: The Gemini model ID to use. + project: Google Cloud project ID for Vertex AI. + location: Google Cloud location/region for Vertex AI. + gemini_schema: Optional schema for structured output. + format_type: Output format (JSON or YAML). + temperature: Sampling temperature. + thinking_budget: Thinking budget for reasoning (0 = no thinking). + max_workers: Maximum number of parallel API calls. + **kwargs: Ignored extra parameters so callers can pass a superset of + arguments shared across back-ends without raising ``TypeError``. + """ + self.model_id = model_id + self.project = project + self.location = location + self.gemini_schema = gemini_schema + self.format_type = format_type + self.temperature = temperature + self.thinking_budget = thinking_budget + self.max_workers = max_workers + self._extra_kwargs = kwargs or {} + + if not self.project: + raise ValueError('Project ID not provided for Vertex AI.') + + self._client = genai.Client( + vertexai=True, + project=self.project, + location=self.location, + ) + + super().__init__( + constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE) + ) + + def _process_single_prompt(self, prompt: str, config: dict) -> ScoredOutput: + """Process a single prompt and return a ScoredOutput.""" + try: + if self.gemini_schema: + response_schema = self.gemini_schema.schema_dict + mime_type = ( + 'application/json' + if self.format_type == data.FormatType.JSON + else 'application/yaml' + ) + config['response_mime_type'] = mime_type + config['response_schema'] = response_schema + + # Add thinking config if thinking_budget is specified + if self.thinking_budget > 0: + config['thinking_config'] = types.ThinkingConfig( + thinking_budget=self.thinking_budget + ) + + response = self._client.models.generate_content( + model=self.model_id, contents=prompt, config=config + ) + + return ScoredOutput(score=1.0, output=response.text) + + except Exception as e: + raise InferenceOutputError(f'Gemini Vertex AI error: {str(e)}') from e + + def infer( + self, batch_prompts: Sequence[str], **kwargs + ) -> Iterator[Sequence[ScoredOutput]]: + """Runs inference on a list of prompts via Gemini Vertex AI. + + Args: + batch_prompts: A list of string prompts. + **kwargs: Additional generation params (temperature, top_p, top_k, etc.) + + Yields: + Lists of ScoredOutputs. + """ + config = { + 'temperature': kwargs.get('temperature', self.temperature), + } + if 'max_output_tokens' in kwargs: + config['max_output_tokens'] = kwargs['max_output_tokens'] + if 'top_p' in kwargs: + config['top_p'] = kwargs['top_p'] + if 'top_k' in kwargs: + config['top_k'] = kwargs['top_k'] + if 'seed' in kwargs: + config['seed'] = kwargs['seed'] + + # Add safety settings if provided + if 'safety_settings' in kwargs: + config['safety_settings'] = kwargs['safety_settings'] + + # Use parallel processing for batches larger than 1 + if len(batch_prompts) > 1 and self.max_workers > 1: + with concurrent.futures.ThreadPoolExecutor( + max_workers=min(self.max_workers, len(batch_prompts)) + ) as executor: + future_to_index = { + executor.submit( + self._process_single_prompt, prompt, config.copy() + ): i + for i, prompt in enumerate(batch_prompts) + } + + results: list[ScoredOutput | None] = [None] * len(batch_prompts) + for future in concurrent.futures.as_completed(future_to_index): + index = future_to_index[future] + try: + results[index] = future.result() + except Exception as e: + raise InferenceOutputError( + f'Parallel inference error: {str(e)}' + ) from e + + for result in results: + if result is None: + raise InferenceOutputError('Failed to process one or more prompts') + yield [result] + else: + # Sequential processing for single prompt or worker + for prompt in batch_prompts: + result = self._process_single_prompt(prompt, config.copy()) + yield [result] + + def parse_output(self, output: str) -> Any: + """Parses Gemini output as JSON or YAML. + + Note: This expects raw JSON/YAML without code fences. + Code fence extraction is handled by resolver.py. + """ + try: + if self.format_type == data.FormatType.JSON: + return json.loads(output) + else: + return yaml.safe_load(output) + except Exception as e: + raise ValueError( + f'Failed to parse output as {self.format_type.name}: {str(e)}' + ) from e + + +@dataclasses.dataclass(init=False) +class OpenAILanguageModel(BaseLanguageModel): + """Language model inference using OpenAI's API with structured output.""" + + model_id: str = 'gpt-4o-mini' + api_key: str | None = None + organization: str | None = None + format_type: data.FormatType = data.FormatType.JSON + temperature: float = 0.0 + max_workers: int = 10 + _client: openai.OpenAI | None = dataclasses.field( + default=None, repr=False, compare=False + ) + _extra_kwargs: dict[str, Any] = dataclasses.field( + default_factory=dict, repr=False, compare=False + ) + + def __init__( + self, + model_id: str = 'gpt-4o-mini', + api_key: str | None = None, + organization: str | None = None, + format_type: data.FormatType = data.FormatType.JSON, + temperature: float = 0.0, + max_workers: int = 10, + **kwargs, + ) -> None: + """Initialize the OpenAI language model. + + Args: + model_id: The OpenAI model ID to use (e.g., 'gpt-4o-mini', 'gpt-4o'). + api_key: API key for OpenAI service. + organization: Optional OpenAI organization ID. + format_type: Output format (JSON or YAML). + temperature: Sampling temperature. + max_workers: Maximum number of parallel API calls. + **kwargs: Ignored extra parameters so callers can pass a superset of + arguments shared across back-ends without raising ``TypeError``. + """ + self.model_id = model_id + self.api_key = api_key + self.organization = organization + self.format_type = format_type + self.temperature = temperature + self.max_workers = max_workers + self._extra_kwargs = kwargs or {} + + if not self.api_key: + raise ValueError('API key not provided.') + + # Initialize the OpenAI client + self._client = openai.OpenAI( + api_key=self.api_key, organization=self.organization + ) + + super().__init__( + constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE) + ) + + def _process_single_prompt(self, prompt: str, config: dict) -> ScoredOutput: + """Process a single prompt and return a ScoredOutput.""" + try: + # Prepare the system message for structured output + system_message = '' + if self.format_type == data.FormatType.JSON: + system_message = ( + 'You are a helpful assistant that responds in JSON format.' + ) + elif self.format_type == data.FormatType.YAML: + system_message = ( + 'You are a helpful assistant that responds in YAML format.' + ) + + # Create the chat completion using the v1.x client API + response = self._client.chat.completions.create( + model=self.model_id, + messages=[ + {'role': 'system', 'content': system_message}, + {'role': 'user', 'content': prompt}, + ], + temperature=config.get('temperature', self.temperature), + max_tokens=config.get('max_output_tokens'), + top_p=config.get('top_p'), + n=1, + ) + + # Extract the response text using the v1.x response format + output_text = response.choices[0].message.content + + return ScoredOutput(score=1.0, output=output_text) + + except Exception as e: + raise InferenceOutputError(f'OpenAI API error: {str(e)}') from e + + def infer( + self, batch_prompts: Sequence[str], **kwargs + ) -> Iterator[Sequence[ScoredOutput]]: + """Runs inference on a list of prompts via OpenAI's API. + + Args: + batch_prompts: A list of string prompts. + **kwargs: Additional generation params (temperature, top_p, etc.) + + Yields: + Lists of ScoredOutputs. + """ + config = { + 'temperature': kwargs.get('temperature', self.temperature), + } + if 'max_output_tokens' in kwargs: + config['max_output_tokens'] = kwargs['max_output_tokens'] + if 'top_p' in kwargs: + config['top_p'] = kwargs['top_p'] + + # Use parallel processing for batches larger than 1 + if len(batch_prompts) > 1 and self.max_workers > 1: + with concurrent.futures.ThreadPoolExecutor( + max_workers=min(self.max_workers, len(batch_prompts)) + ) as executor: + future_to_index = { + executor.submit( + self._process_single_prompt, prompt, config.copy() + ): i + for i, prompt in enumerate(batch_prompts) + } + + results: list[ScoredOutput | None] = [None] * len(batch_prompts) + for future in concurrent.futures.as_completed(future_to_index): + index = future_to_index[future] + try: + results[index] = future.result() + except Exception as e: + raise InferenceOutputError( + f'Parallel inference error: {str(e)}' + ) from e + + for result in results: + if result is None: + raise InferenceOutputError('Failed to process one or more prompts') + yield [result] + else: + # Sequential processing for single prompt or worker + for prompt in batch_prompts: + result = self._process_single_prompt(prompt, config.copy()) + yield [result] + + def parse_output(self, output: str) -> Any: + """Parses OpenAI output as JSON or YAML. + + Note: This expects raw JSON/YAML without code fences. + Code fence extraction is handled by resolver.py. + """ try: if self.format_type == data.FormatType.JSON: return json.loads(output) diff --git a/langextract/io.py b/langextract/io.py index 7f94a193..59dead7a 100644 --- a/langextract/io.py +++ b/langextract/io.py @@ -18,23 +18,21 @@ import dataclasses import json import os +import pathlib from typing import Any, Iterator import pandas as pd import requests -import os -import pathlib -import os -import pathlib from langextract import data from langextract import data_lib +from langextract import exceptions from langextract import progress DEFAULT_TIMEOUT_SECONDS = 30 -class InvalidDatasetError(Exception): +class InvalidDatasetError(exceptions.LangExtractError): """Error raised when Dataset is empty or invalid.""" @@ -83,7 +81,7 @@ def load(self, delimiter: str = ',') -> Iterator[data.Document]: def save_annotated_documents( annotated_documents: Iterator[data.AnnotatedDocument], - output_dir: pathlib.Path | None = None, + output_dir: pathlib.Path | str | None = None, output_name: str = 'data.jsonl', show_progress: bool = True, ) -> None: @@ -92,7 +90,7 @@ def save_annotated_documents( Args: annotated_documents: Iterator over AnnotatedDocument objects to save. output_dir: The directory to which the JSONL file should be written. - Defaults to 'test_output/' if None. + Can be a Path object or a string. Defaults to 'test_output/' if None. output_name: File name for the JSONL file. show_progress: Whether to show a progress bar during saving. @@ -102,6 +100,8 @@ def save_annotated_documents( """ if output_dir is None: output_dir = pathlib.Path('test_output') + else: + output_dir = pathlib.Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) diff --git a/langextract/progress.py b/langextract/progress.py index a79b9126..41c4f3b8 100644 --- a/langextract/progress.py +++ b/langextract/progress.py @@ -16,6 +16,7 @@ from typing import Any import urllib.parse + import tqdm # ANSI color codes for terminal output diff --git a/langextract/prompting.py b/langextract/prompting.py index 5d6623b1..4484273b 100644 --- a/langextract/prompting.py +++ b/langextract/prompting.py @@ -16,17 +16,18 @@ import dataclasses import json +import os +import pathlib import pydantic import yaml -import os -import pathlib from langextract import data +from langextract import exceptions from langextract import schema -class PromptBuilderError(Exception): +class PromptBuilderError(exceptions.LangExtractError): """Failure to build prompt.""" diff --git a/langextract/resolver.py b/langextract/resolver.py index e9085f16..c6496b82 100644 --- a/langextract/resolver.py +++ b/langextract/resolver.py @@ -31,6 +31,7 @@ import yaml from langextract import data +from langextract import exceptions from langextract import schema from langextract import tokenizer @@ -151,7 +152,7 @@ def align( ExtractionValueType = str | int | float | dict | list | None -class ResolverParsingError(Exception): +class ResolverParsingError(exceptions.LangExtractError): """Error raised when content cannot be parsed as the given format.""" diff --git a/langextract/schema.py b/langextract/schema.py index 2c02baac..dd553bdc 100644 --- a/langextract/schema.py +++ b/langextract/schema.py @@ -22,7 +22,6 @@ import enum from typing import Any - from langextract import data diff --git a/langextract/tokenizer.py b/langextract/tokenizer.py index f4036f36..5028fb0f 100644 --- a/langextract/tokenizer.py +++ b/langextract/tokenizer.py @@ -30,8 +30,10 @@ from absl import logging +from langextract import exceptions -class BaseTokenizerError(Exception): + +class BaseTokenizerError(exceptions.LangExtractError): """Base class for all tokenizer-related errors.""" diff --git a/langextract/visualization.py b/langextract/visualization.py index 513cfa58..b382961a 100644 --- a/langextract/visualization.py +++ b/langextract/visualization.py @@ -28,10 +28,10 @@ import html import itertools import json -import textwrap - import os import pathlib +import textwrap + from langextract import data as _data from langextract import io as _io @@ -119,9 +119,7 @@ padding: 8px 10px; margin-top: 8px; font-size: 13px; } .lx-current-highlight { - text-decoration: underline; - text-decoration-color: #ff4444; - text-decoration-thickness: 3px; + border-bottom: 4px solid #ff4444; font-weight: bold; animation: lx-pulse 1s ease-in-out; } @@ -130,9 +128,9 @@ 50% { text-decoration-color: #ff0000; } 100% { text-decoration-color: #ff4444; } } - .lx-legend { - font-size: 12px; margin-bottom: 8px; - padding-bottom: 8px; border-bottom: 1px solid #e0e0e0; + .lx-legend { + font-size: 12px; margin-bottom: 8px; + padding-bottom: 8px; border-bottom: 1px solid #e0e0e0; } .lx-label { display: inline-block; @@ -456,12 +454,12 @@ def _extraction_sort_key(extraction):
-
- Entity 1/{len(extractions)} | + Entity 1/{len(extractions)} | Pos {pos_info_str}
diff --git a/pyproject.toml b/pyproject.toml index a2dfd19c..7be5c5f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ build-backend = "setuptools.build_meta" [project] name = "langextract" -version = "0.1.0" +version = "1.0.4" description = "LangExtract: A library for extracting structured data from language models" readme = "README.md" requires-python = ">=3.10" @@ -32,27 +32,32 @@ dependencies = [ "async_timeout>=4.0.0", "exceptiongroup>=1.1.0", "google-genai>=0.1.0", - "langfun>=0.1.0", "ml-collections>=0.1.0", "more-itertools>=8.0.0", "numpy>=1.20.0", - "openai>=0.27.0", + "openai>=1.50.0", "pandas>=1.3.0", "pydantic>=1.8.0", "python-dotenv>=0.19.0", - "python-magic>=0.4.27", + "PyYAML>=6.0", "requests>=2.25.0", "tqdm>=4.64.0", "typing-extensions>=4.0.0" ] +[project.urls] +"Homepage" = "https://github.com/google/langextract" +"Repository" = "https://github.com/google/langextract" +"Documentation" = "https://github.com/google/langextract/blob/main/README.md" +"Bug Tracker" = "https://github.com/google/langextract/issues" + [project.optional-dependencies] dev = [ - "black>=23.7.0", - "pylint>=2.17.5", - "pytest>=7.4.0", + "pyink~=24.3.0", + "isort>=5.13.0", + "pylint>=3.0.0", "pytype>=2024.10.11", - "tox>=4.0.0", + "tox>=4.0.0" ] test = [ "pytest>=7.4.0", @@ -72,6 +77,28 @@ include-package-data = false "*.svg", ] -[tool.pytest] +[tool.pytest.ini_options] testpaths = ["tests"] -python_files = "*_test.py" \ No newline at end of file +python_files = "*_test.py" +python_classes = "Test*" +python_functions = "test_*" +# Show extra test summary info +addopts = "-ra" +markers = [ + "live_api: marks tests as requiring live API access", +] + +[tool.pyink] +# Configuration for Google's style guide +line-length = 80 +unstable = true +pyink-indentation = 2 +pyink-use-majority-quotes = true + +[tool.isort] +# Configuration for Google's style guide +profile = "google" +line_length = 80 +force_sort_within_sections = true +# Allow multiple imports on one line for these modules +single_line_exclusions = ["typing", "typing_extensions", "collections.abc"] diff --git a/test_vertex_integration.py b/test_vertex_integration.py new file mode 100644 index 00000000..b8fb6d72 --- /dev/null +++ b/test_vertex_integration.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +""" +Simple test to verify Vertex AI integration works correctly. +""" + +import sys +import os + +# Add the langextract directory to the path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '.')) + +def test_imports(): + """Test that all imports work correctly.""" + print("Testing imports...") + + try: + import langextract as lx + print("โœ“ langextract imported successfully") + except ImportError as e: + print(f"โœ— Failed to import langextract: {e}") + return False + + try: + from langextract.inference import GeminiVertexLanguageModel + print("โœ“ GeminiVertexLanguageModel imported successfully") + except ImportError as e: + print(f"โœ— Failed to import GeminiVertexLanguageModel: {e}") + return False + + return True + +def test_vertex_model_creation(): + """Test that the Vertex AI model can be created with proper parameters.""" + print("\nTesting Vertex AI model creation...") + + try: + from langextract.inference import GeminiVertexLanguageModel + + # Test with minimal required parameters + model = GeminiVertexLanguageModel( + project="test-project", + location="global", + model_id="gemini-2.5-flash" + ) + print("โœ“ GeminiVertexLanguageModel created successfully with minimal params") + + # Test with all parameters + model_full = GeminiVertexLanguageModel( + project="test-project", + location="us-central1", + model_id="gemini-2.5-flash", + temperature=0.5, + thinking_budget=100, + max_workers=5 + ) + print("โœ“ GeminiVertexLanguageModel created successfully with full params") + + # Verify attributes are set correctly + assert model_full.project == "test-project" + assert model_full.location == "us-central1" + assert model_full.thinking_budget == 100 + assert model_full.temperature == 0.5 + print("โœ“ Model attributes set correctly") + + return True + + except Exception as e: + print(f"โœ— Failed to create GeminiVertexLanguageModel: {e}") + return False + +def test_vertex_model_validation(): + """Test that the Vertex AI model validates parameters correctly.""" + print("\nTesting parameter validation...") + + try: + from langextract.inference import GeminiVertexLanguageModel + + # Test missing project parameter + try: + model = GeminiVertexLanguageModel(project=None) + print("โœ— Should have failed with missing project") + return False + except ValueError as e: + if "Project ID not provided" in str(e): + print("โœ“ Correctly validates missing project parameter") + else: + print(f"โœ— Wrong error message: {e}") + return False + + return True + + except Exception as e: + print(f"โœ— Unexpected error during validation test: {e}") + return False + +def test_extract_function_parameters(): + """Test that the extract function accepts new parameters.""" + print("\nTesting extract function parameters...") + + try: + import langextract as lx + + # Create minimal example data + examples = [ + lx.data.ExampleData( + text="Test text", + extractions=[ + lx.data.Extraction( + extraction_class="test", + extraction_text="test" + ) + ] + ) + ] + + # Test that the function accepts new parameters without error + # (We won't actually call it since we don't have valid credentials) + try: + # This should fail with authentication error, not parameter error + result = lx.extract( + text_or_documents="Test text", + prompt_description="Test prompt", + examples=examples, + project="test-project", + location="global", + thinking_budget=0, + language_model_type=lx.inference.GeminiVertexLanguageModel + ) + except ValueError as e: + # We expect this to fail due to authentication, not parameter issues + if "Project ID" in str(e) or "authentication" in str(e).lower(): + print("โœ“ Extract function accepts new parameters correctly") + return True + else: + print(f"โœ— Unexpected parameter error: {e}") + return False + except Exception as e: + # Any other error suggests the parameters were accepted + print("โœ“ Extract function accepts new parameters correctly") + return True + + return True + + except Exception as e: + print(f"โœ— Failed to test extract function: {e}") + return False + +def main(): + """Run all tests.""" + print("=== Vertex AI Integration Test ===\n") + + tests = [ + test_imports, + test_vertex_model_creation, + test_vertex_model_validation, + test_extract_function_parameters + ] + + passed = 0 + total = len(tests) + + for test in tests: + if test(): + passed += 1 + print() # Add spacing between tests + + print(f"=== Results: {passed}/{total} tests passed ===") + + if passed == total: + print("๐ŸŽ‰ All tests passed! Vertex AI integration is working correctly.") + return 0 + else: + print("โŒ Some tests failed. Please check the implementation.") + return 1 + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/.pylintrc b/tests/.pylintrc new file mode 100644 index 00000000..4b06ddd5 --- /dev/null +++ b/tests/.pylintrc @@ -0,0 +1,52 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Test-specific Pylint configuration +# Inherits from parent ../.pylintrc and adds test-specific relaxations + +[MASTER] +# Python will merge with parent; no need to repeat plugins. + +[MESSAGES CONTROL] +# Additional disables for test code only +disable= + # --- Test-specific relaxations --- + duplicate-code, # Test fixtures often have similar patterns + too-many-lines, # Large test files are common + missing-module-docstring, # Tests don't need module docs + missing-class-docstring, # Test classes are self-explanatory + missing-function-docstring, # Test method names describe intent + line-too-long, # Golden strings and test data + invalid-name, # setUp, tearDown, maxDiff, etc. + protected-access, # Tests often access private members + use-dict-literal, # Parametrized tests benefit from dict() + bad-indentation, # pyink 2-space style conflicts with pylint + unused-argument, # Mock callbacks often have unused args + import-error, # Test dependencies may not be installed + unused-import, # Some imports are for test fixtures + too-many-positional-arguments # Test methods can have many args + +[DESIGN] +# Relax complexity limits for tests +max-args = 10 # Fixtures often take many params +max-locals = 25 # Complex test setups +max-statements = 75 # Detailed test scenarios +max-branches = 15 # Multiple test conditions + +[BASIC] +# Allow common test naming patterns +good-names=i,j,k,ex,Run,_,id,ok,fd,fp,maxDiff,setUp,tearDown + +# Include test-specific naming patterns +method-rgx=[a-z_][a-z0-9_]{2,50}$|test[A-Z_][a-zA-Z0-9]*$|assert[A-Z][a-zA-Z0-9]*$ diff --git a/tests/annotation_test.py b/tests/annotation_test.py index bfa87c09..a5540e4e 100644 --- a/tests/annotation_test.py +++ b/tests/annotation_test.py @@ -20,6 +20,7 @@ from absl.testing import absltest from absl.testing import parameterized + from langextract import annotation from langextract import data from langextract import inference @@ -34,7 +35,7 @@ class AnnotatorTest(absltest.TestCase): def setUp(self): super().setUp() self.mock_language_model = self.enter_context( - mock.patch.object(inference, "LangFunLanguageModel", autospec=True) + mock.patch.object(inference, "GeminiLanguageModel", autospec=True) ) self.annotator = annotation.Annotator( language_model=self.mock_language_model, @@ -687,7 +688,7 @@ def test_annotate_documents( batch_length: int = 1, ): mock_language_model = self.enter_context( - mock.patch.object(inference, "LangFunLanguageModel", autospec=True) + mock.patch.object(inference, "GeminiLanguageModel", autospec=True) ) # Define a side effect function so return length based on batch length. @@ -760,7 +761,7 @@ def test_annotate_documents_exceptions( batch_length: int = 1, ): mock_language_model = self.enter_context( - mock.patch.object(inference, "LangFunLanguageModel", autospec=True) + mock.patch.object(inference, "GeminiLanguageModel", autospec=True) ) mock_language_model.infer.return_value = [ [ @@ -797,7 +798,7 @@ class AnnotatorMultiPassTest(absltest.TestCase): def setUp(self): super().setUp() self.mock_language_model = self.enter_context( - mock.patch.object(inference, "LangFunLanguageModel", autospec=True) + mock.patch.object(inference, "GeminiLanguageModel", autospec=True) ) self.annotator = annotation.Annotator( language_model=self.mock_language_model, diff --git a/tests/chunking_test.py b/tests/chunking_test.py index ad4f17b5..f28866a8 100644 --- a/tests/chunking_test.py +++ b/tests/chunking_test.py @@ -14,11 +14,12 @@ import textwrap +from absl.testing import absltest +from absl.testing import parameterized + from langextract import chunking from langextract import data from langextract import tokenizer -from absl.testing import absltest -from absl.testing import parameterized class SentenceIterTest(absltest.TestCase): @@ -368,7 +369,9 @@ def test_string_output(self): )""") document = data.Document(text=text, document_id="test_doc_123") tokenized_text = tokenizer.tokenize(text) - chunk_iter = chunking.ChunkIterator(tokenized_text, max_char_buffer=7, document=document) + chunk_iter = chunking.ChunkIterator( + tokenized_text, max_char_buffer=7, document=document + ) text_chunk = next(chunk_iter) self.assertEqual(str(text_chunk), expected) diff --git a/tests/data_lib_test.py b/tests/data_lib_test.py index 0eed51cc..e1cbdeb0 100644 --- a/tests/data_lib_test.py +++ b/tests/data_lib_test.py @@ -14,13 +14,13 @@ import json +from absl.testing import absltest +from absl.testing import parameterized import numpy as np from langextract import data from langextract import data_lib from langextract import tokenizer -from absl.testing import absltest -from absl.testing import parameterized class DataLibToDictParameterizedTest(parameterized.TestCase): diff --git a/tests/inference_test.py b/tests/inference_test.py index d9cf6b57..88b84d42 100644 --- a/tests/inference_test.py +++ b/tests/inference_test.py @@ -13,56 +13,11 @@ # limitations under the License. from unittest import mock -import langfun as lf -from absl.testing import absltest -from langextract import inference +from absl.testing import absltest -class TestLangFunLanguageModel(absltest.TestCase): - @mock.patch.object( - inference.lf.core.language_model, "LanguageModel", autospec=True - ) - def test_langfun_infer(self, mock_lf_model): - mock_client_instance = mock_lf_model.return_value - metadata = { - "score": -0.004259720362824737, - "logprobs": None, - "is_cached": False, - } - source = lf.UserMessage( - text="What's heart in Italian?.", - sender="User", - metadata={"formatted_text": "What's heart in Italian?."}, - tags=["lm-input"], - ) - sample = lf.LMSample( - response=lf.AIMessage( - text="Cuore", - sender="AI", - metadata=metadata, - source=source, - tags=["lm-response"], - ), - score=-0.004259720362824737, - ) - actual_response = lf.LMSamplingResult( - samples=[sample], - ) - - # Mock the sample response. - mock_client_instance.sample.return_value = [actual_response] - model = inference.LangFunLanguageModel(language_model=mock_client_instance) - - batch_prompts = ["What's heart in Italian?"] - - expected_results = [ - [inference.ScoredOutput(score=-0.004259720362824737, output="Cuore")] - ] - - results = list(model.infer(batch_prompts)) - - mock_client_instance.sample.assert_called_once_with(prompts=batch_prompts) - self.assertEqual(results, expected_results) +from langextract import data +from langextract import inference class TestOllamaLanguageModel(absltest.TestCase): @@ -118,7 +73,7 @@ def test_ollama_infer(self, mock_ollama_query): } mock_ollama_query.return_value = gemma_response model = inference.OllamaLanguageModel( - model="gemma2:latest", + model_id="gemma2:latest", model_url="http://localhost:11434", structured_output_format="json", ) @@ -139,5 +94,120 @@ def test_ollama_infer(self, mock_ollama_query): self.assertEqual(results, expected_results) +class TestOpenAILanguageModel(absltest.TestCase): + + @mock.patch("openai.OpenAI") + def test_openai_infer(self, mock_openai_class): + # Mock the OpenAI client and chat completion response + mock_client = mock.Mock() + mock_openai_class.return_value = mock_client + + # Mock response structure for v1.x API + mock_response = mock.Mock() + mock_response.choices = [ + mock.Mock(message=mock.Mock(content='{"name": "John", "age": 30}')) + ] + mock_client.chat.completions.create.return_value = mock_response + + # Create model instance + model = inference.OpenAILanguageModel( + model_id="gpt-4o-mini", api_key="test-api-key", temperature=0.5 + ) + + # Test inference + batch_prompts = ["Extract name and age from: John is 30 years old"] + results = list(model.infer(batch_prompts)) + + # Verify API was called correctly + mock_client.chat.completions.create.assert_called_once_with( + model="gpt-4o-mini", + messages=[ + { + "role": "system", + "content": ( + "You are a helpful assistant that responds in JSON format." + ), + }, + { + "role": "user", + "content": "Extract name and age from: John is 30 years old", + }, + ], + temperature=0.5, + max_tokens=None, + top_p=None, + n=1, + ) + + # Check results + expected_results = [[ + inference.ScoredOutput(score=1.0, output='{"name": "John", "age": 30}') + ]] + self.assertEqual(results, expected_results) + + def test_openai_parse_output_json(self): + model = inference.OpenAILanguageModel( + api_key="test-key", format_type=data.FormatType.JSON + ) + + # Test valid JSON parsing + output = '{"key": "value", "number": 42}' + parsed = model.parse_output(output) + self.assertEqual(parsed, {"key": "value", "number": 42}) + + # Test invalid JSON + with self.assertRaises(ValueError) as context: + model.parse_output("invalid json") + self.assertIn("Failed to parse output as JSON", str(context.exception)) + + def test_openai_parse_output_yaml(self): + model = inference.OpenAILanguageModel( + api_key="test-key", format_type=data.FormatType.YAML + ) + + # Test valid YAML parsing + output = "key: value\nnumber: 42" + parsed = model.parse_output(output) + self.assertEqual(parsed, {"key": "value", "number": 42}) + + # Test invalid YAML + with self.assertRaises(ValueError) as context: + model.parse_output("invalid: yaml: bad") + self.assertIn("Failed to parse output as YAML", str(context.exception)) + + def test_openai_no_api_key_raises_error(self): + with self.assertRaises(ValueError) as context: + inference.OpenAILanguageModel(api_key=None) + self.assertEqual(str(context.exception), "API key not provided.") + + @mock.patch("openai.OpenAI") + def test_openai_temperature_zero(self, mock_openai_class): + # Test that temperature=0.0 is properly passed through + mock_client = mock.Mock() + mock_openai_class.return_value = mock_client + + mock_response = mock.Mock() + mock_response.choices = [ + mock.Mock(message=mock.Mock(content='{"result": "test"}')) + ] + mock_client.chat.completions.create.return_value = mock_response + + model = inference.OpenAILanguageModel( + api_key="test-key", temperature=0.0 # Testing zero temperature + ) + + list(model.infer(["test prompt"])) + + # Verify temperature=0.0 was passed to the API + mock_client.chat.completions.create.assert_called_with( + model="gpt-4o-mini", + messages=mock.ANY, + temperature=0.0, + max_tokens=None, + top_p=None, + n=1, + ) + + if __name__ == "__main__": absltest.main() diff --git a/tests/init_test.py b/tests/init_test.py index b68371f7..d79a07f4 100644 --- a/tests/init_test.py +++ b/tests/init_test.py @@ -18,11 +18,12 @@ from unittest import mock from absl.testing import absltest -import langextract as lx + from langextract import data from langextract import inference from langextract import prompting from langextract import schema +import langextract as lx class InitTest(absltest.TestCase): @@ -142,5 +143,6 @@ def test_lang_extract_as_lx_extract( self.assertDataclassEqual(expected_result, actual_result) + if __name__ == "__main__": absltest.main() diff --git a/tests/prompting_test.py b/tests/prompting_test.py index 93712121..5449139b 100644 --- a/tests/prompting_test.py +++ b/tests/prompting_test.py @@ -16,6 +16,7 @@ from absl.testing import absltest from absl.testing import parameterized + from langextract import data from langextract import prompting from langextract import schema diff --git a/tests/resolver_test.py b/tests/resolver_test.py index 61d2a5e6..b96270ee 100644 --- a/tests/resolver_test.py +++ b/tests/resolver_test.py @@ -17,6 +17,7 @@ from absl.testing import absltest from absl.testing import parameterized + from langextract import chunking from langextract import data from langextract import resolver as resolver_lib diff --git a/tests/schema_test.py b/tests/schema_test.py index 4664da08..d4b067b5 100644 --- a/tests/schema_test.py +++ b/tests/schema_test.py @@ -16,11 +16,9 @@ import textwrap from unittest import mock - - - from absl.testing import absltest from absl.testing import parameterized + from langextract import data from langextract import schema diff --git a/tests/test_live_api.py b/tests/test_live_api.py new file mode 100644 index 00000000..8d9801eb --- /dev/null +++ b/tests/test_live_api.py @@ -0,0 +1,585 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Live API integration tests that require real API keys. + +These tests are skipped if API keys are not available in the environment. +They should run in CI after all other tests pass. +""" + +from functools import wraps +import os +import re +import textwrap +import time +import unittest + +from dotenv import load_dotenv +import pytest + +import langextract as lx +from langextract.inference import OpenAILanguageModel + +load_dotenv() + +DEFAULT_GEMINI_MODEL = "gemini-2.5-flash" +DEFAULT_OPENAI_MODEL = "gpt-4o" + +GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get( + "LANGEXTRACT_API_KEY" +) +OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") + +skip_if_no_gemini = pytest.mark.skipif( + not GEMINI_API_KEY, + reason=( + "Gemini API key not available (set GEMINI_API_KEY or" + " LANGEXTRACT_API_KEY)" + ), +) +skip_if_no_openai = pytest.mark.skipif( + not OPENAI_API_KEY, + reason="OpenAI API key not available (set OPENAI_API_KEY)", +) + +live_api = pytest.mark.live_api + +GEMINI_MODEL_PARAMS = { + "temperature": 0.0, + "top_p": 0.0, + "max_output_tokens": 256, +} + +OPENAI_MODEL_PARAMS = { + "temperature": 0.0, +} + + +INITIAL_RETRY_DELAY = 1.0 +MAX_RETRY_DELAY = 8.0 + + +def retry_on_transient_errors(max_retries=3, backoff_factor=2.0): + """Decorator to retry tests on transient API errors with exponential backoff. + + Args: + max_retries: Maximum number of retry attempts + backoff_factor: Multiplier for exponential backoff (e.g., 2.0 = 1s, 2s, 4s) + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + last_exception = None + delay = INITIAL_RETRY_DELAY + + for attempt in range(max_retries + 1): + try: + return func(*args, **kwargs) + except ( + lx.exceptions.LangExtractError, + ConnectionError, + TimeoutError, + OSError, + RuntimeError, + ) as e: + last_exception = e + error_str = str(e).lower() + error_type = type(e).__name__ + + transient_errors = [ + "503", + "service unavailable", + "temporarily unavailable", + "rate limit", + "429", + "too many requests", + "connection reset", + "timeout", + "deadline exceeded", + ] + + is_transient = any( + err in error_str for err in transient_errors + ) or error_type in ["ServiceUnavailable", "RateLimitError", "Timeout"] + + if is_transient and attempt < max_retries: + print( + f"\nTransient error ({error_type}) on attempt" + f" {attempt + 1}/{max_retries + 1}: {e}" + ) + print(f"Retrying in {delay} seconds...") + time.sleep(delay) + delay = min(delay * backoff_factor, MAX_RETRY_DELAY) + else: + raise + + raise last_exception + + return wrapper + + return decorator + + +@pytest.fixture(autouse=True) +def add_delay_between_tests(): + """Add a small delay between tests to avoid rate limiting.""" + yield + time.sleep(0.5) + + +def get_basic_medication_examples(): + """Get example data for basic medication extraction.""" + return [ + lx.data.ExampleData( + text="Patient was given 250 mg IV Cefazolin TID for one week.", + extractions=[ + lx.data.Extraction( + extraction_class="dosage", extraction_text="250 mg" + ), + lx.data.Extraction( + extraction_class="route", extraction_text="IV" + ), + lx.data.Extraction( + extraction_class="medication", extraction_text="Cefazolin" + ), + lx.data.Extraction( + extraction_class="frequency", + extraction_text="TID", # TID = three times a day + ), + lx.data.Extraction( + extraction_class="duration", extraction_text="for one week" + ), + ], + ) + ] + + +def get_relationship_examples(): + """Get example data for medication relationship extraction.""" + return [ + lx.data.ExampleData( + text=( + "Patient takes Aspirin 100mg daily for heart health and" + " Simvastatin 20mg at bedtime." + ), + extractions=[ + # First medication group + lx.data.Extraction( + extraction_class="medication", + extraction_text="Aspirin", + attributes={"medication_group": "Aspirin"}, + ), + lx.data.Extraction( + extraction_class="dosage", + extraction_text="100mg", + attributes={"medication_group": "Aspirin"}, + ), + lx.data.Extraction( + extraction_class="frequency", + extraction_text="daily", + attributes={"medication_group": "Aspirin"}, + ), + lx.data.Extraction( + extraction_class="condition", + extraction_text="heart health", + attributes={"medication_group": "Aspirin"}, + ), + # Second medication group + lx.data.Extraction( + extraction_class="medication", + extraction_text="Simvastatin", + attributes={"medication_group": "Simvastatin"}, + ), + lx.data.Extraction( + extraction_class="dosage", + extraction_text="20mg", + attributes={"medication_group": "Simvastatin"}, + ), + lx.data.Extraction( + extraction_class="frequency", + extraction_text="at bedtime", + attributes={"medication_group": "Simvastatin"}, + ), + ], + ) + ] + + +def extract_by_class(result, extraction_class): + """Helper to extract entities by class. + + Returns a set of extraction texts for the given class. + """ + return { + e.extraction_text + for e in result.extractions + if e.extraction_class == extraction_class + } + + +def assert_extractions_contain(test_case, result, expected_classes): + """Assert that result contains all expected extraction classes. + + Uses unittest assertions for richer error messages. + """ + actual_classes = {e.extraction_class for e in result.extractions} + missing_classes = expected_classes - actual_classes + test_case.assertFalse( + missing_classes, + f"Missing expected classes: {missing_classes}. Found extractions:" + f" {[f'{e.extraction_class}:{e.extraction_text}' for e in result.extractions]}", + ) + + +def assert_valid_char_intervals(test_case, result): + """Assert that all extractions have valid char intervals and alignment status.""" + for extraction in result.extractions: + test_case.assertIsNotNone( + extraction.char_interval, + f"Missing char_interval for extraction: {extraction.extraction_text}", + ) + test_case.assertIsNotNone( + extraction.alignment_status, + "Missing alignment_status for extraction:" + f" {extraction.extraction_text}", + ) + if hasattr(result, "text") and result.text: + text_length = len(result.text) + test_case.assertGreaterEqual( + extraction.char_interval.start_pos, + 0, + f"Invalid start_pos for extraction: {extraction.extraction_text}", + ) + test_case.assertLessEqual( + extraction.char_interval.end_pos, + text_length, + f"Invalid end_pos for extraction: {extraction.extraction_text}", + ) + + +class TestLiveAPIGemini(unittest.TestCase): + """Tests using real Gemini API.""" + + @skip_if_no_gemini + @live_api + @retry_on_transient_errors(max_retries=2) + def test_medication_extraction(self): + """Test medication extraction with entities in order.""" + prompt = textwrap.dedent("""\ + Extract medication information including medication name, dosage, route, frequency, + and duration in the order they appear in the text.""") + + examples = get_basic_medication_examples() + input_text = "Patient took 400 mg PO Ibuprofen q4h for two days." + + result = lx.extract( + text_or_documents=input_text, + prompt_description=prompt, + examples=examples, + model_id=DEFAULT_GEMINI_MODEL, + api_key=GEMINI_API_KEY, + language_model_params=GEMINI_MODEL_PARAMS, + ) + + assert result is not None + assert hasattr(result, "extractions") + assert len(result.extractions) > 0 + + expected_classes = { + "dosage", + "route", + "medication", + "frequency", + "duration", + } + assert_extractions_contain(self, result, expected_classes) + assert_valid_char_intervals(self, result) + + # Using regex for precise matching to avoid false positives + medication_texts = extract_by_class(result, "medication") + self.assertTrue( + any( + re.search(r"\bIbuprofen\b", text, re.IGNORECASE) + for text in medication_texts + ), + f"No Ibuprofen found in: {medication_texts}", + ) + + dosage_texts = extract_by_class(result, "dosage") + self.assertTrue( + any( + re.search(r"\b400\s*mg\b", text, re.IGNORECASE) + for text in dosage_texts + ), + f"No 400mg dosage found in: {dosage_texts}", + ) + + route_texts = extract_by_class(result, "route") + self.assertTrue( + any( + re.search(r"\b(PO|oral)\b", text, re.IGNORECASE) + for text in route_texts + ), + f"No PO/oral route found in: {route_texts}", + ) + + @skip_if_no_gemini + @live_api + @retry_on_transient_errors(max_retries=2) + @pytest.mark.xfail( + reason=( + "Known tokenizer issue with non-Latin characters - see GitHub" + " issue #13" + ), + strict=True, + ) + def test_multilingual_medication_extraction(self): + """Test medication extraction with Japanese text.""" + text = ( # "The patient takes 10 mg of medication daily." + "ๆ‚ฃ่€…ใฏๆฏŽๆ—ฅ10mgใฎ่–ฌใ‚’ๆœ็”จใ—ใพใ™ใ€‚" + ) + + prompt = "Extract medication information including dosage and frequency." + + examples = [ + lx.data.ExampleData( + text="The patient takes 20mg of aspirin twice daily.", + extractions=[ + lx.data.Extraction( + extraction_class="medication", + extraction_text="aspirin", + attributes={"dosage": "20mg", "frequency": "twice daily"}, + ), + ], + ) + ] + + result = lx.extract( + text_or_documents=text, + prompt_description=prompt, + examples=examples, + model_id=DEFAULT_GEMINI_MODEL, + api_key=GEMINI_API_KEY, + language_model_params=GEMINI_MODEL_PARAMS, + ) + + assert result is not None + assert hasattr(result, "extractions") + assert len(result.extractions) > 0 + + medication_extractions = [ + e for e in result.extractions if e.extraction_class == "medication" + ] + assert ( + len(medication_extractions) > 0 + ), "No medication entities found in Japanese text" + assert_valid_char_intervals(self, result) + + @skip_if_no_gemini + @live_api + @retry_on_transient_errors(max_retries=2) + def test_medication_relationship_extraction(self): + """Test relationship extraction for medications with Gemini.""" + input_text = """ + The patient was prescribed Lisinopril and Metformin last month. + He takes the Lisinopril 10mg daily for hypertension, but often misses + his Metformin 500mg dose which should be taken twice daily for diabetes. + """ + + prompt = textwrap.dedent(""" + Extract medications with their details, using attributes to group related information: + + 1. Extract entities in the order they appear in the text + 2. Each entity must have a 'medication_group' attribute linking it to its medication + 3. All details about a medication should share the same medication_group value + """) + + examples = get_relationship_examples() + + result = lx.extract( + text_or_documents=input_text, + prompt_description=prompt, + examples=examples, + model_id=DEFAULT_GEMINI_MODEL, + api_key=GEMINI_API_KEY, + language_model_params=GEMINI_MODEL_PARAMS, + ) + + assert result is not None + assert len(result.extractions) > 0 + assert_valid_char_intervals(self, result) + + medication_groups = {} + for extraction in result.extractions: + assert ( + extraction.attributes is not None + ), f"Missing attributes for {extraction.extraction_text}" + assert ( + "medication_group" in extraction.attributes + ), f"Missing medication_group for {extraction.extraction_text}" + + group_name = extraction.attributes["medication_group"] + medication_groups.setdefault(group_name, []).append(extraction) + + assert ( + len(medication_groups) >= 2 + ), f"Expected at least 2 medications, found {len(medication_groups)}" + + # Allow flexible matching for dosage field (could be "dosage" or "dose") + for med_name, extractions in medication_groups.items(): + extraction_classes = {e.extraction_class for e in extractions} + # At minimum, each group should have the medication itself + assert ( + "medication" in extraction_classes + ), f"{med_name} group missing medication entity" + # Dosage is expected but might be formatted differently + assert any( + c in extraction_classes for c in ["dosage", "dose"] + ), f"{med_name} group missing dosage" + + +class TestLiveAPIOpenAI(unittest.TestCase): + """Tests using real OpenAI API.""" + + @skip_if_no_openai + @live_api + @retry_on_transient_errors(max_retries=2) + def test_medication_extraction(self): + """Test medication extraction with OpenAI models.""" + prompt = textwrap.dedent("""\ + Extract medication information including medication name, dosage, route, frequency, + and duration in the order they appear in the text.""") + + examples = get_basic_medication_examples() + input_text = "Patient took 400 mg PO Ibuprofen q4h for two days." + + result = lx.extract( + text_or_documents=input_text, + prompt_description=prompt, + examples=examples, + language_model_type=OpenAILanguageModel, + model_id=DEFAULT_OPENAI_MODEL, + api_key=OPENAI_API_KEY, + fence_output=True, + use_schema_constraints=False, + language_model_params=OPENAI_MODEL_PARAMS, + ) + + assert result is not None + assert hasattr(result, "extractions") + assert len(result.extractions) > 0 + + expected_classes = { + "dosage", + "route", + "medication", + "frequency", + "duration", + } + assert_extractions_contain(self, result, expected_classes) + assert_valid_char_intervals(self, result) + + # Using regex for precise matching to avoid false positives + medication_texts = extract_by_class(result, "medication") + self.assertTrue( + any( + re.search(r"\bIbuprofen\b", text, re.IGNORECASE) + for text in medication_texts + ), + f"No Ibuprofen found in: {medication_texts}", + ) + + dosage_texts = extract_by_class(result, "dosage") + self.assertTrue( + any( + re.search(r"\b400\s*mg\b", text, re.IGNORECASE) + for text in dosage_texts + ), + f"No 400mg dosage found in: {dosage_texts}", + ) + + route_texts = extract_by_class(result, "route") + self.assertTrue( + any( + re.search(r"\b(PO|oral)\b", text, re.IGNORECASE) + for text in route_texts + ), + f"No PO/oral route found in: {route_texts}", + ) + + @skip_if_no_openai + @live_api + @retry_on_transient_errors(max_retries=2) + def test_medication_relationship_extraction(self): + """Test relationship extraction for medications with OpenAI.""" + input_text = """ + The patient was prescribed Lisinopril and Metformin last month. + He takes the Lisinopril 10mg daily for hypertension, but often misses + his Metformin 500mg dose which should be taken twice daily for diabetes. + """ + + prompt = textwrap.dedent(""" + Extract medications with their details, using attributes to group related information: + + 1. Extract entities in the order they appear in the text + 2. Each entity must have a 'medication_group' attribute linking it to its medication + 3. All details about a medication should share the same medication_group value + """) + + examples = get_relationship_examples() + + result = lx.extract( + text_or_documents=input_text, + prompt_description=prompt, + examples=examples, + language_model_type=OpenAILanguageModel, + model_id=DEFAULT_OPENAI_MODEL, + api_key=OPENAI_API_KEY, + fence_output=True, + use_schema_constraints=False, + language_model_params=OPENAI_MODEL_PARAMS, + ) + + assert result is not None + assert len(result.extractions) > 0 + assert_valid_char_intervals(self, result) + + medication_groups = {} + for extraction in result.extractions: + assert ( + extraction.attributes is not None + ), f"Missing attributes for {extraction.extraction_text}" + assert ( + "medication_group" in extraction.attributes + ), f"Missing medication_group for {extraction.extraction_text}" + + group_name = extraction.attributes["medication_group"] + medication_groups.setdefault(group_name, []).append(extraction) + + assert ( + len(medication_groups) >= 2 + ), f"Expected at least 2 medications, found {len(medication_groups)}" + + # Allow flexible matching for dosage field (could be "dosage" or "dose") + for med_name, extractions in medication_groups.items(): + extraction_classes = {e.extraction_class for e in extractions} + # At minimum, each group should have the medication itself + assert ( + "medication" in extraction_classes + ), f"{med_name} group missing medication entity" + # Dosage is expected but might be formatted differently + assert any( + c in extraction_classes for c in ["dosage", "dose"] + ), f"{med_name} group missing dosage" diff --git a/tests/test_ollama_integration.py b/tests/test_ollama_integration.py new file mode 100644 index 00000000..5ab4397d --- /dev/null +++ b/tests/test_ollama_integration.py @@ -0,0 +1,80 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for Ollama functionality.""" +import socket + +import pytest +import requests + +import langextract as lx + + +def _ollama_available(): + """Check if Ollama is running on localhost:11434.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + result = sock.connect_ex(("localhost", 11434)) + return result == 0 + + +@pytest.mark.skipif(not _ollama_available(), reason="Ollama not running") +def test_ollama_extraction(): + """Test extraction using Ollama when available.""" + input_text = "Isaac Asimov was a prolific science fiction writer." + prompt = "Extract the author's full name and their primary literary genre." + + examples = [ + lx.data.ExampleData( + text=( + "J.R.R. Tolkien was an English writer, best known for" + " high-fantasy." + ), + extractions=[ + lx.data.Extraction( + extraction_class="author_details", + extraction_text="J.R.R. Tolkien was an English writer...", + attributes={ + "name": "J.R.R. Tolkien", + "genre": "high-fantasy", + }, + ) + ], + ) + ] + + model_id = "gemma2:2b" + + try: + result = lx.extract( + text_or_documents=input_text, + prompt_description=prompt, + examples=examples, + language_model_type=lx.inference.OllamaLanguageModel, + model_id=model_id, + model_url="http://localhost:11434", + temperature=0.3, + fence_output=False, + use_schema_constraints=False, + ) + + assert len(result.extractions) > 0 + extraction = result.extractions[0] + assert extraction.extraction_class == "author_details" + if extraction.attributes: + assert "asimov" in extraction.attributes.get("name", "").lower() + + except ValueError as e: + if "Can't find Ollama" in str(e): + pytest.skip(f"Ollama model {model_id} not available") + raise diff --git a/tests/tokenizer_test.py b/tests/tokenizer_test.py index 9d296978..021f802a 100644 --- a/tests/tokenizer_test.py +++ b/tests/tokenizer_test.py @@ -14,10 +14,11 @@ import textwrap -from langextract import tokenizer from absl.testing import absltest from absl.testing import parameterized +from langextract import tokenizer + class TokenizerTest(parameterized.TestCase): diff --git a/tests/visualization_test.py b/tests/visualization_test.py index 0cb7fbe2..647107f9 100644 --- a/tests/visualization_test.py +++ b/tests/visualization_test.py @@ -17,6 +17,7 @@ from unittest import mock from absl.testing import absltest + from langextract import data as lx_data from langextract import visualization diff --git a/tox.ini b/tox.ini index e8988af7..7abd98a0 100644 --- a/tox.ini +++ b/tox.ini @@ -13,7 +13,7 @@ # limitations under the License. [tox] -envlist = py310, py311 +envlist = py310, py311, py312, format, lint-src, lint-tests skip_missing_interpreters = True [testenv] @@ -22,5 +22,41 @@ setenv = deps = .[dev,test] commands = - pylint --rcfile=.pylintrc --score n langextract tests - pytest -q \ No newline at end of file + pytest -ra -m "not live_api" + +[testenv:format] +skip_install = true +deps = + isort>=5.13.2 + pyink~=24.3.0 +commands = + isort langextract tests --check-only --diff + pyink langextract tests --check --diff --config pyproject.toml + +[testenv:lint-src] +deps = + pylint>=3.0.0 +commands = + pylint --rcfile=.pylintrc langextract + +[testenv:lint-tests] +deps = + pylint>=3.0.0 +commands = + pylint --rcfile=tests/.pylintrc tests + +[testenv:live-api] +basepython = python3.11 +passenv = + GEMINI_API_KEY + LANGEXTRACT_API_KEY + OPENAI_API_KEY +deps = {[testenv]deps} +commands = + pytest tests/test_live_api.py -v -m live_api --maxfail=1 + +[testenv:ollama-integration] +basepython = python3.11 +deps = {[testenv]deps} +commands = + pytest tests/test_ollama_integration.py -v --tb=short