Implement PyTorch-like functionalities using JAX #332
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: Python package | |
on: [push, pull_request] | |
jobs: | |
build: | |
runs-on: ubuntu-latest | |
timeout-minutes: 40 | |
strategy: | |
matrix: | |
python-version: [3.9, '3.10', '3.11', '3.12'] | |
steps: | |
- uses: actions/checkout@v4 | |
- name: Set PYTHONPATH | |
run: | | |
echo "::group::Set PYTHONPATH" | |
set -x | |
echo "PYTHONPATH=${{ github.workspace }}/src" >> $GITHUB_ENV | |
echo "::endgroup::" | |
- name: Set up Python ${{ matrix.python-version }} | |
uses: actions/setup-python@v5 | |
with: | |
python-version: ${{ matrix.python-version }} | |
- name: Print Python and pip versions | |
run: | | |
python --version | |
pip --version | |
- name: Cache pip dependencies | |
uses: actions/cache@v4 | |
with: | |
path: ~/.cache/pip | |
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} | |
restore-keys: | | |
${{ runner.os }}-pip- | |
- name: Print requirements files | |
run: | | |
echo "::group::Print requirements files" | |
set -x | |
echo "Contents of requirements.txt:" | |
cat requirements.txt | |
echo "Contents of test_requirements.txt:" | |
cat test_requirements.txt | |
echo "::endgroup::" | |
- name: Install dependencies | |
run: | | |
echo "::group::Install dependencies" | |
set -x | |
python -m pip install --upgrade pip | |
pip install flake8 pytest transformers | |
pip install torch --index-url https://download.pytorch.org/whl/cpu | |
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi | |
if [ -f test_requirements.txt ]; then pip install -r test_requirements.txt; fi | |
pip install jax jaxlib | |
pip install optax | |
pip install dm-haiku | |
pip install langchain-community | |
echo "PYTHONPATH=$PYTHONPATH:$(pwd)" >> $GITHUB_ENV | |
echo "::endgroup::" | |
timeout-minutes: 30 | |
- name: Install NextGenJAX package | |
run: | | |
echo "::group::Install NextGenJAX package" | |
set -x | |
pip install -e . | |
echo "NextGenJAX installation completed" | |
echo "::endgroup::" | |
- name: Debug information | |
run: | | |
echo "::group::Debug information" | |
python -c "import sys; print(sys.path)" | |
pip list | |
pip show nextgenjax | |
echo "::endgroup::" | |
- name: Test nextgenjax import | |
run: | | |
echo "::group::Test nextgenjax import" | |
set -x | |
python -c "import nextgenjax; print(f'nextgenjax imported successfully from {nextgenjax.__file__}')" | |
echo "::endgroup::" | |
- name: Lint with flake8 | |
run: | | |
echo "::group::Lint with flake8" | |
set -x | |
# stop the build if there are Python syntax errors or undefined names | |
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics | |
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide | |
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics | |
echo "::endgroup::" | |
- name: Run unit tests | |
env: | |
JAX_TRACEBACK_FILTERING: 'off' | |
run: | | |
echo "::group::Run unit tests" | |
set -x | |
echo "Current working directory: $(pwd)" | |
echo "PYTHONPATH: $PYTHONPATH" | |
echo "Python sys.path:" | |
python -c "import sys; print('\n'.join(sys.path))" | |
echo "Attempting to import nextgenjax:" | |
python -c "import nextgenjax; print(f'nextgenjax imported successfully from {nextgenjax.__file__}')" || echo "Failed to import nextgenjax" | |
pytest tests -v | |
echo "::endgroup::" | |
timeout-minutes: 10 | |
# Integration tests step removed as the 'tests/integration' directory does not exist | |
- name: Run other tests | |
run: | | |
echo "::group::Run other tests" | |
set -x | |
pytest tests/ -v --ignore=tests/unit --ignore=tests/integration | |
echo "::endgroup::" | |
timeout-minutes: 10 | |
cuda-tests: | |
runs-on: ubuntu-latest | |
needs: build | |
steps: | |
- uses: actions/checkout@v4 | |
- name: Set up Python | |
uses: actions/setup-python@v5 | |
with: | |
python-version: 3.9 | |
timeout-minutes: 5 | |
- name: Install CUDA dependencies | |
run: | | |
# Add commands to install CUDA dependencies | |
timeout-minutes: 15 | |
- name: Run CUDA-specific tests | |
run: | | |
# Add commands to run CUDA-specific tests | |
timeout-minutes: 20 |