Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
Sea-Snell committed May 19, 2023
1 parent f1736f0 commit ccf8c78
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 13 deletions.
17 changes: 6 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# JaxSeq2
# JaxSeq

## Overview

Expand All @@ -13,8 +13,8 @@ Thanks to Jax's [pjit](https://jax.readthedocs.io/en/latest/jax.experimental.pji
### **1. pull from github**

``` bash
git clone https://github.com/Sea-Snell/JaxSeq2.git
cd JaxSeq2
git clone https://github.com/Sea-Snell/JAXSeq.git
cd JAXSeq
```

### **2. install dependencies**
Expand All @@ -24,15 +24,15 @@ Install with conda (cpu, tpu, or gpu).
**install with conda (cpu):**
``` shell
conda env create -f environment.yml
conda activate JaxSeq2
conda activate JaxSeq
python -m pip install --upgrade pip
python -m pip install -e .
```

**install with conda (gpu):**
``` shell
conda env create -f environment.yml
conda activate JaxSeq2
conda activate JaxSeq
python -m pip install --upgrade pip
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
python -m pip install -e .
Expand All @@ -41,7 +41,7 @@ python -m pip install -e .
**install with conda (tpu):**
``` shell
conda env create -f environment.yml
conda activate JaxSeq2
conda activate JaxSeq
python -m pip install --upgrade pip
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
python -m pip install -e .
Expand Down Expand Up @@ -72,8 +72,3 @@ To further support TPU workflows the example scripts provide functionality for u
* [GPT-J Repo](https://github.com/kingoflolz/mesh-transformer-jax) [uses xmap instead of pjit]
* [Alpa](https://github.com/alpa-projects/alpa)
* [Jaxformer](https://github.com/salesforce/jaxformer)

## TODO
* [ ] Improve serving function, use redis pub-sub instead.
* [ ] Add OPT?
* [ ] Add GPT-Neo?
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: JaxSeq2
name: JaxSeq
channels:
- defaults
- conda-forge
Expand Down
2 changes: 1 addition & 1 deletion tpu_vm_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ bash ~/Miniconda3-py310_23.1.0-1-Linux-x86_64.sh -b
source ~/miniconda3/bin/activate
conda init bash
conda env create -f environment.yml
conda activate JaxSeq2
conda activate JaxSeq
python -m pip install --upgrade pip && python -m pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# clean up
Expand Down

0 comments on commit ccf8c78

Please sign in to comment.