Skip to content

Implement PyTorch-like functionalities using JAX #329

Implement PyTorch-like functionalities using JAX

Implement PyTorch-like functionalities using JAX #329

Workflow file for this run

name: Python package
on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-latest
timeout-minutes: 40
strategy:
matrix:
python-version: [3.9, '3.10', '3.11', '3.12']
steps:
- uses: actions/checkout@v4
- name: Set PYTHONPATH
run: |
echo "::group::Set PYTHONPATH"
set -x
echo "PYTHONPATH=${{ github.workspace }}/src" >> $GITHUB_ENV
echo "::endgroup::"
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Print Python and pip versions
run: |
python --version
pip --version
- name: Cache pip dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Print requirements files
run: |
echo "::group::Print requirements files"
set -x
echo "Contents of requirements.txt:"
cat requirements.txt
echo "Contents of test_requirements.txt:"
cat test_requirements.txt
echo "::endgroup::"
- name: Install dependencies
run: |
echo "::group::Install dependencies"
set -x
python -m pip install --upgrade pip
pip install flake8 pytest transformers
pip install torch --index-url https://download.pytorch.org/whl/cpu
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
if [ -f test_requirements.txt ]; then pip install -r test_requirements.txt; fi
pip install jax jaxlib
pip install optax
pip install dm-haiku
pip install langchain-community
echo "PYTHONPATH=$PYTHONPATH:$(pwd)" >> $GITHUB_ENV
echo "::endgroup::"
timeout-minutes: 30
- name: Install NextGenJAX package
run: |
echo "::group::Install NextGenJAX package"
set -x
pip install -e .
echo "NextGenJAX installation completed"
echo "::endgroup::"
- name: Debug information
run: |
echo "::group::Debug information"
python -c "import sys; print(sys.path)"
pip list
pip show nextgenjax
echo "::endgroup::"
- name: Test nextgenjax import
run: |
echo "::group::Test nextgenjax import"
set -x
python -c "import nextgenjax; print(f'nextgenjax imported successfully from {nextgenjax.__file__}')"
echo "::endgroup::"
- name: Lint with flake8
run: |
echo "::group::Lint with flake8"
set -x
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
echo "::endgroup::"
- name: Run unit tests
env:
JAX_TRACEBACK_FILTERING: 'off'
run: |
echo "::group::Run unit tests"
set -x
echo "Current working directory: $(pwd)"
echo "PYTHONPATH: $PYTHONPATH"
echo "Python sys.path:"
python -c "import sys; print('\n'.join(sys.path))"
echo "Attempting to import nextgenjax:"
python -c "import nextgenjax; print(f'nextgenjax imported successfully from {nextgenjax.__file__}')" || echo "Failed to import nextgenjax"
pytest tests -v
echo "::endgroup::"
timeout-minutes: 10
# Integration tests step removed as the 'tests/integration' directory does not exist
- name: Run other tests
run: |
echo "::group::Run other tests"
set -x
pytest tests/ -v --ignore=tests/unit --ignore=tests/integration
echo "::endgroup::"
timeout-minutes: 10
cuda-tests:
runs-on: ubuntu-latest
needs: build
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: 3.9
timeout-minutes: 5
- name: Install CUDA dependencies
run: |
# Add commands to install CUDA dependencies
timeout-minutes: 15
- name: Run CUDA-specific tests
run: |
# Add commands to run CUDA-specific tests
timeout-minutes: 20