Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 87 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

![alt text](img/Block_diagram_upd.png?raw=true)

A GPU-accelerated JAX-based implementation of [PROTAX](https://pubmed.ncbi.nlm.nih.gov/27296980/). Contains all code and experiments for PROTAX-GPU
A GPU-accelerated JAX-based implementation of [PROTAX](https://pubmed.ncbi.nlm.nih.gov/27296980/). Contains all code and experiments for [PROTAX-GPU](https://royalsocietypublishing.org/doi/10.1098/rstb.2023.0124)

To reproduce the BOLD 7.8M dataset experiments, PROTAX-GPU requires a NVIDIA GPU with at least 8GB VRAM and CUDA compute capability 6.0 or later. This corresponds to GPUs in the NVIDIA Pascal, NVIDIA Volta™, NVIDIA Turing™, NVIDIA Ampere architecture, and NVIDIA Hopper™ architecture families.

Expand Down Expand Up @@ -132,12 +132,12 @@ Once you have a trained model, you can use the classify_file function to classif

Run the sequence classification script:
```
python scripts/process_seqs.py [PATH_TO_QUERY_SEQUENCES] [PATH_TO_MODEL] [PATH_TO_TAXONOMY]
python scripts/process_seqs.py [PATH_TO_QUERY_SEQUENCES] [PATH_TO_MODEL] [PATH_TO_TAXONOMY] [PATH_TO_TAXONOMY_MAPPING]
```
Example:

```
python scripts/process_seqs.py data/refs.aln models/params/model.npz models/ref_db/taxonomy37k.npz
python scripts/process_seqs.py data/refs.aln models/params/model.npz models/ref_db/taxonomy37k.npz data/tax_mapping.priors
```
<!-- python scripts/process_seqs.py FinPROTAX/FinPROTAX/modelCOIfull/refs.aln models/params/model.npz models/ref_db/taxonomy37k.npz -->

Expand All @@ -146,10 +146,93 @@ Arguments:
- `PATH_TO_QUERY_SEQUENCES`: File containing the sequences to classify (e.g., FASTA or alignment file)(Can use refs.aln from [FinPROTAX](https://github.com/psomervuo/FinPROTAX/tree/main) for experiment)
- `PATH_TO_MODEL`: Path to the model. (Base Model is available in `models/params/model.npz`)
- `PATH_TO_TAXONOMY`: Path to the taxonomy .npz file. (taxonomy file is available in `models/ref_db/taxonomy37k.npz`)

- `PATH_TO_TAXONOMY_MAPPING`: Path for taxonomy mapping (node to label mapping) (Can use taxonomy.prior from [FinPROTAX](https://github.com/psomervuo/FinPROTAX/tree/main) for experiment)

Results are saved to `pyprotax_results.csv`

<details>
<summary> <b>More details regarding input file formats </b></summary>

#### Structure of Query File (`PATH_TO_QUERY_SEQUENCES`)

The query file (`qdir`) contains data in the following structure:

1. **Header Line (Taxonomic Metadata)**
- Starts with `>` followed by a unique identifier for the query sequence.
- Contains the full taxonomic lineage associated with the query, separated by commas.
- Example:
```
>COLFA029-10 Insecta,Coleoptera,Ptiliidae,Acrotrichinae,Acrotrichini,Acrotrichis
```

2. **Sequence Line**
- Contains the DNA sequence corresponding to the query.
- The sequence can include standard nucleotide codes (e.g., A, T, C, G).

#### Structure of `PATH_TO_MODEL` (`.npz` file)

The `par_dir` file is a compressed `.npz` archive that contains the trained parameters of the PROTAX model. The file should include the following arrays:

1. **`beta`**
- Shape: `(M, R)`
- Description: Coefficients for the regression model, where `M` is the number of features and `R` is the number of ranks in the taxonomy.

2. **`scalings`**
- Shape: `(R, 4)`
- Description: Scaling parameters for the regression model. Each row corresponds to a rank in the taxonomy and contains four values:
- Mean scaling (columns 0 and 2).
- Variance scaling (columns 1 and 3).

3. **`node_layer`**
- Shape: `(N,)`
- Description: Indicates the layer (rank) in the taxonomy tree to which each node belongs.

#### Structure of the `PATH_TO_TAXONOMY` `.npz` File

The `tax_dir` file is a serialized representation of the taxonomy and sequence data required by the PROTAX-GPU model. Below is the description of the required structure for the `.npz` file:

1. **`refs`**: A 2D array of shape `(R, L)`, where `R` is the number of reference sequences and `L` is the sequence length. Each row represents a reference sequence.

2. **`ok_pos`**: A binary 2D array of shape `(R, L)`, indicating valid positions (non-missing data) for each reference sequence.

3. **`priors`**: A 1D array of length `N`, where `N` is the number of nodes in the taxonomy. It specifies prior probabilities for each node.

4. **`segments`**: A 1D array of length `N`, containing segment identifiers for each node in the taxonomy.

5. **`paths`**: A 2D array of shape `(N, D)`, where `D` is the maximum depth of the taxonomy. Each row represents the path from the root node to a specific taxon.

6. **`node_state`**: A 2D array of shape `(N, S)`, where `S` is the state size (typically 2). Contains state information for each node in the taxonomy.

7. **`ref_rows` and `ref_cols`**: Two 1D arrays defining the row and column indices for mapping reference sequences to nodes in the taxonomy. These are used to construct a sparse binary matrix (`node2seq`).

8. **`node_layer`**: A 1D array of length `N`, defining the taxonomic layer (or rank) for each node.



#### Taxonomy Mapping File (`PATH_TO_TAXONOMY_MAPPING`) Format

The `tax_map` file defines the mapping between node numbers and their corresponding taxonomy labels in the taxonomy tree. Each line in the file represents a single node and its associated data.

The `tax_map` file should be a tab-separated text file , where each line contains information for one node. It must adhere to the following format:

| Column Name | Description |
|----------------|---------------------------------------------------------------------------|
| **Node Number** | A unique integer ID representing a node in the taxonomy tree. |
| **Other Fields**| Possible Additional metadata or numeric attributes (optional). |
| **Taxonomy Label** | A string representing the taxonomy label for the node (e.g., "Insecta"). |


</details>


## Output
The `pyprotax_results.csv` file contains the taxonomic labels and their associated probabilities for each query, organized based on the hierarchical traversal of the taxonomy. The structure of the output aligns with the traversal paths of the taxonomy from the root to the leaves, ensuring consistency in representation across runs. Here’s a detailed explanation of the output:

- **Columns 1–7**: These columns represent the taxonomic labels at each level of the hierarchy. Each row corresponds to a single query, and these columns specify the path from the root node to the assigned taxon. For instance, a row might represent the path `[Insecta, Coleoptera, Ptiliidae, Acrotrichinae Acrotrichini Acrotrichis Acrotrichis_rugulosa]`, reflecting the taxonomic lineage of a species.

- **Columns 8–14**: These columns contain the probabilities associated with each taxonomic label at every level of the hierarchy for the given query. These probabilities are computed as the product of branch probabilities along the path, as determined by the traversal of the taxonomy.


## Training
Run the script from the command line: You need to specify the paths to your training data and target data using the `--train_dir` and `--targ_dir` arguments, respectively.
```
Expand Down
6 changes: 3 additions & 3 deletions scripts/process_seqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
protax_args = sys.argv
if len(protax_args) < 4:
print(
"Usage: python3 classify.py [PATH_TO_TAXONOMY_FILE] [PATH_TO_PARAMETERS] [PATH_TO_QUERY_SEQUENCES]"
"Usage: python3 classify.py [PATH_TO_TAXONOMY_FILE] [PATH_TO_PARAMETERS] [PATH_TO_QUERY_SEQUENCES] [PATH_TO_TAXONOMY_MAPPING]"
)

query_dir, model_dir, tax_dir = protax_args[1:4]
classify_file(query_dir, model_dir, tax_dir)
query_dir, model_dir, tax_dir, tax_map = protax_args[1:5]
classify_file(query_dir, model_dir, tax_dir, tax_map)

# testing
# query_dir = r"/home/roy/Documents/PROTAX-dsets/30k_small/refs.aln"
Expand Down
40 changes: 34 additions & 6 deletions src/protax/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,31 @@ def read_names(tdir):

return names

def getMapping(tax_map):
"""
Generate a mapping from node numbers to taxonomy labels dynamically
identifying the label column as the first column with non-numeric values.
"""
taxonomy_dict = {}

with open(tax_map, "r") as file:
for line in file:
parts = line.strip().split("\t") # Split line into parts using tab as the delimiter

# Determine the label column dynamically
for i, part in enumerate(parts):
if not part.isdigit(): # Check if the column contains non-numeric data
label_column = i
break

node_number = int(parts[0]) # Assume the first column is always the node number
label = parts[label_column] # Dynamically select the label column
taxonomy_dict[node_number] = label

def classify_file(qdir, par_dir, tax_dir, verbose=False):
return taxonomy_dict


def classify_file(qdir, par_dir, tax_dir, tax_map, verbose=False):
"""
Process a batch of queries given a model and taxonomy directory
"""
Expand All @@ -56,7 +79,7 @@ def classify_file(qdir, par_dir, tax_dir, verbose=False):

tot_time = 0
res = []

node2label = getMapping(tax_map)
while True:
curr = f.readline().strip('\n')
curr = curr.replace('|', '\t').split('\t')
Expand All @@ -75,10 +98,13 @@ def classify_file(qdir, par_dir, tax_dir, verbose=False):


# TODO argmax at leaf level?
probabilities = np.array(jnp.max(probs, axis=0))
classified_layer = jnp.argmax(probs, axis=0)
res.append(classified_layer)


classified_tax_labels = [node2label[int(i)] for i in classified_layer]
out = []
out.extend(list(classified_tax_labels[-1].split(',')))
out.extend(probabilities[1:])
res.append(out)
if verbose:
pass
# sel_name = names[classified_layer.at[-1].get()]
Expand All @@ -87,6 +113,8 @@ def classify_file(qdir, par_dir, tax_dir, verbose=False):

# saving results
df = pd.DataFrame(np.array(res))
df.columns = ['tax_label1', 'tax_label2', 'tax_label3', 'tax_label4', 'tax_label5', 'tax_label6', 'tax_label7',
'prob_level1', 'prob_level2', 'prob_level3', 'prob_level4', 'prob_level5', 'prob_level6', 'prob_level7']
df.to_csv("pyprotax_results.csv")
print(f"finished in {tot_time}s")

Expand Down Expand Up @@ -146,5 +174,5 @@ def compute_perplexity(qdir, model_dir, tax_dir, verbose=False):
# testing for now

query_dir = r"FinPROTAX/FinPROTAX/modelCOIfull/refs.aln"
classify_file(query_dir, "models/params/model.npz", "models/ref_db/taxonomy37k.npz")
classify_file(query_dir, "models/params/model.npz", "models/ref_db/taxonomy37k.npz", "FinPROTAX/FinPROTAX/modelCOIfull/taxonomy.priors")
compute_perplexity(query_dir, "models/params/model.npz", "models/ref_db/taxonomy37k.npz")