Skip to content

Commit

Permalink
Refine train and index building
Browse files Browse the repository at this point in the history
  • Loading branch information
sdake authored and rstarmer committed May 28, 2024
1 parent 41c94f7 commit cb2446e
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 47 deletions.
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,48 @@ to improve the quality and capability of a foundational LLM.

**Deploy.** 👉 Local-first and cloud native, you choose where to operate, on our cloud, or yours!

# Learn Workflow

To learn new information, create several data structures with Mousion.

## Make Tokens

Produce a tokenized dataset, a map of sequence index to token index, a map of token index to sequence index
and finally embeddings.

```console
todo in followup PR
```

## Retrieve Neighbors

```console
python retrieve_neighbours.py \
--query-embeddings $HOME/datasets/arxiv/00_arxiv_train_embeddings.npy \
--query-chunk2seq $HOME/datasets/arxiv/00_arxiv_train_chunks_to_sequence.npy \
--index $HOME/datasets/realnews/arxiv.index \
--index-spec $HOME/train-arxiv.json \
--num-neighbours 4 \
--neighbors-output $HOME/datasets/arxiv/00_train_arxiv_train_neighbors.npy
```

## Build Index

Build a faiss index representing the dataset. To build the faiss index, run the command below:

```console
mkdir -p $HOME/datasets/arxiv
python '../mouseion/train_and_build_index.py' \
--specs 'mousion/datacollection_specstrain-arxiv.json' \
--index-type 'IVF16384,PQ32' \
--output '$HOME/datasets/arxiv/arxiv.index' \
--use-gpus \
--batch-size 65536
```

![mousion_train_index gif](assets/gifs/mousion_train_index.gif)


# Build your own AKI (augmented knowledge index)

- Create a `DataLoader()` to read the authoritative documents in your document collection.
Expand Down
Binary file added assets/gifs/mousion_build_index.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
182 changes: 182 additions & 0 deletions mouseion/datacollection_specs/train-arxiv.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
[
{
"chunks": "/home/sdake/datasets/arxiv/00_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/00_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/00_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/00_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/01_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/01_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/01_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/01_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/02_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/02_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/02_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/02_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/03_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/03_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/03_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/03_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/04_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/04_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/04_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/04_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/05_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/05_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/05_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/05_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/06_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/06_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/06_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/06_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/08_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/08_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/08_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/08_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/09_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/09_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/09_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/09_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/10_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/10_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/10_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/10_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/11_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/11_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/11_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/11_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/12_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/12_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/12_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/12_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/13_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/13_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/13_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/13_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/14_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/14_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/14_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/14_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/15_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/15_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/15_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/15_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/16_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/16_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/16_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/16_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/17_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/17_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/17_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/17_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/18_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/18_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/18_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/18_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/20_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/20_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/20_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/20_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/21_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/21_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/21_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/21_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/22_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/22_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/22_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/22_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/23_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/23_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/23_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/23_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/24_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/24_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/24_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/24_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/25_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/25_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/25_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/25_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/26_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/26_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/26_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/26_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/27_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/27_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/27_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/27_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/28_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/28_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/28_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/28_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/29_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/29_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/29_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/29_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/30_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/30_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/30_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/30_arxiv_train_embedding.npy"
},
{
"chunks": "/home/sdake/datasets/arxiv/31_arxiv_train_tokenized.npy",
"chunk2seq": "/home/sdake/datasets/arxiv/31_arxiv_train_sequence_to_chunks.npy",
"seq2chunk": "/home/sdake/datasets/arxiv/31_arxiv_train_chunks_to_sequence.npy",
"embeddings": "/home/sdake/datasets/arxiv/31_arxiv_train_embedding.npy"
}
]
6 changes: 3 additions & 3 deletions mouseion/mouseion/indexbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ class GPUIndexBuilder(IndexBuilder):
"An index builder which uses GPU acceleration."

def __init__(self, *args):
super(self).__init__(*args)
super(GPUIndexBuilder, self).__init__(*args)

co = faiss.GpuMultipleClonerOptions()
co.useFloat16 = True
co.usePrecomputed = True
co.shard = True
co.resources = [limitedGPUResource(2**20)
for _ in range(faiss.get_num_cpus())]
co.resources = [limitedGPUResource(2**32)
for _ in range(faiss.get_num_gpus())]

self.index = faiss.index_cpu_to_all_gpus(self.index, co)

Expand Down
94 changes: 50 additions & 44 deletions mouseion/mouseion/train_and_build_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,68 +4,74 @@
import tqdm
import pathlib

from .indexbuilder import GPUIndexBuilder, IndexBuilder
from multiprocessing import Queue
from indexbuilder import GPUIndexBuilder, IndexBuilder
from tasks import Task, TaskProcessor, ProgressType

def main(args):

def function_train_index(size: int, progress_queue: Queue, **kwargs) -> list:
"""
Determine if an output directory exists.
"""
if not pathlib.Path(args.output).is_dir():
raise ValueError(f"The output path is invalid {args.output}")

"""
Combines numpy embeddings from multiple npy files into a FAISS index.
Combines numpy embeddings from multiple numpy files into a FAISS index.
"""
try:
spec = json.load(args.spec.open("r"))
except Exception as e:
print(f"Error reading spec file: {e}")
return

total_embeddings = sum([numpy.load(shard["embeddings"], mmap_mode="r").shape[0] for shard in spec["shards"]])
is_trained = False

try:
first_shard = numpy.load(spec["shards"][0]["embeddings"], mmap_mode="r")
except Exception as e:
print(f"Error loading shard: {e}")
return

dim = first_shard.shape[1]
print(f'dimensionality is {dim}')
kwarg_output = kwargs['output']
kwargs_index_type = kwargs['index_type']
kwargs_batch_size = kwargs['batch_size']
kwargs_specs = kwargs['specs']

builder_cls = GPUIndexBuilder if args.use_gpus else IndexBuilder
builder = builder_cls(dim, args.index_type)
segments = list()
segment = numpy.load(specs[0]['embeddings'])

with tqdm.tqdm(total=total_embeddings, desc="Embeddings") as progress:
for shard_info in spec["shards"]:
try:
shard_embeddings = numpy.load(shard_info["embeddings"], mmap_mode="r")
except Exception as e:
print(f"Error loading shard: {e}")
continue
dimension = segment.shape[1]
split_size = segment.shape[0]

shard_size = shard_embeddings.shape[0]
builder_cls = GPUIndexBuilder
builder = builder_cls(dimension, kwargs_index_type)
builder.train(segment)

if not is_trained:
builder.train(shard_embeddings)
is_trained = True
shard_total_size=0
for i, spec in enumerate(specs):
segment = numpy.load(spec['embeddings'])
for j in range(0, split_size, kwargs_batch_size):
batch = segment[j:j+kwargs_batch_size]
builder.add(batch)
progress_queue.put(kwargs_batch_size)
shard_total_size += kwargs_batch_size

for i in range(0, shard_size, args.batch_size):
end = i + args.batch_size
batch = shard_embeddings[i:end]
builder.add(batch)
progress.update(batch.shape[0])

builder.write_path(args.output)
builder.write_path(kwarg_output)
return []

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--spec", type=pathlib.Path, required=True)
parser.add_argument("--specs", type=pathlib.Path, required=True)
parser.add_argument("--index-type", type=str, required=True)
parser.add_argument("--use-gpus", action="store_true")
parser.add_argument("--output", type=str, required=True)
parser.add_argument("--batch-size", type=int, default=32768)
args = parser.parse_args()

main(args)
specs = json.load(args.specs.open("r"))
total_embeddings = sum([numpy.load(shard["embeddings"], mmap_mode="r").shape[0] for shard in specs])


tasks_train_index = [
Task(
id=0,
description="Train index",
size=total_embeddings,
progress_type=ProgressType.ITERATIONS_PER_SECOND,
function=function_train_index,
index_type=args.index_type,
use_gpus=args.use_gpus,
specs=specs,
output=args.output,
batch_size=args.batch_size,
)
]

processor = TaskProcessor(command="RETRO Index Trainer", max_workers=50)
processor.add_tasks(tasks_train_index)
processor.process_tasks()

0 comments on commit cb2446e

Please sign in to comment.