-
Notifications
You must be signed in to change notification settings - Fork 54
LLMBlock concurrency #157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LLMBlock concurrency #157
Changes from all commits
0db3b25
3a9a177
1ce4edc
0866168
470ffbb
68303f2
a27a1b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,15 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# Standard | ||
from concurrent.futures import ThreadPoolExecutor | ||
from dataclasses import dataclass | ||
from importlib import resources | ||
from typing import Optional | ||
from typing import Iterable, Optional | ||
import math | ||
import os.path | ||
|
||
# Third Party | ||
from datasets import Dataset | ||
from datasets import Dataset, concatenate_datasets | ||
from openai import OpenAI | ||
import yaml | ||
|
||
# Local | ||
|
@@ -22,16 +26,47 @@ class EmptyDatasetError(Exception): | |
|
||
|
||
# This is part of the public API. | ||
class PipelineContext: | ||
def __init__( | ||
self, client, model_family, model_id, num_instructions_to_generate | ||
) -> None: | ||
self.client = client | ||
self.model_family = model_family | ||
self.model_id = model_id | ||
self.num_instructions_to_generate = num_instructions_to_generate | ||
# FIXME: base this on the available number of CPUs | ||
self.num_procs = 8 | ||
@dataclass | ||
class PipelineContext: # pylint: disable=too-many-instance-attributes | ||
""" | ||
A PipelineContext holds the common attributes needed between blocks in a | ||
pipeline | ||
|
||
client: The OpenAI client handle. | ||
model_id: The ID of the teacher model to be used for client calls. | ||
model_family: The family identifier for the model being updated. | ||
num_instructions_to_generate: The total number of instructions the user | ||
wants to generate during this run. | ||
batch_size: The size of the dataset batches for parallel generation. Set to | ||
0 to disable batching. | ||
batch_num_workers: The number of worker threads/processes to maintain in the | ||
central executor pool. | ||
dataset_num_procs: The number of processes to use when performing parallel | ||
map operations on individual datasets. | ||
""" | ||
|
||
# The default batch size of 8 has been determined as a good default for | ||
# standard instructlab workloads when running with vllm batching. | ||
DEFAULT_BATCH_SIZE = 8 | ||
|
||
# The default number of processes to use when performing parallel operations | ||
# on individual datasets | ||
DEFAULT_DATASET_NUM_PROCS = 8 | ||
|
||
client: OpenAI | ||
model_family: str | ||
model_id: str | ||
num_instructions_to_generate: int | ||
dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS | ||
batch_size: int = DEFAULT_BATCH_SIZE | ||
batch_num_workers: Optional[int] = None | ||
|
||
@property | ||
def batching_enabled(self) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another example of that confusion ... I want to move There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is definitely confusing and requires realizing the client/server relationship between this SDG code and the server. |
||
"""Batching is enabled IFF the batch size is specified and the number of | ||
workers is not set explicitly to 1 | ||
""" | ||
return self.batch_size > 0 and self.batch_num_workers != 1 | ||
|
||
|
||
# This is part of the public API. | ||
|
@@ -63,7 +98,12 @@ def exception_message(self) -> str: | |
|
||
# This is part of the public API. | ||
class Pipeline: | ||
def __init__(self, ctx, config_path, chained_blocks: list) -> None: | ||
def __init__( | ||
self, | ||
ctx: PipelineContext, | ||
config_path: str, | ||
chained_blocks: list[dict], | ||
) -> None: | ||
""" | ||
Initialize the Pipeline class with a configuration dictionary. | ||
config_dict: the run config py or yaml loaded into a dictionary | ||
|
@@ -81,20 +121,40 @@ def from_file(cls, ctx, pipeline_yaml): | |
pipeline_yaml = os.path.join(resources.files(__package__), pipeline_yaml) | ||
return cls(ctx, pipeline_yaml, _parse_pipeline_config_file(pipeline_yaml)) | ||
|
||
def _drop_duplicates(self, dataset, cols): | ||
""" | ||
Drop duplicates from the dataset based on the columns provided. | ||
""" | ||
df = dataset.to_pandas() | ||
df = df.drop_duplicates(subset=cols).reset_index(drop=True) | ||
ds = Dataset.from_pandas(df) | ||
return ds | ||
|
||
def generate(self, dataset) -> Dataset: | ||
""" | ||
Generate the dataset by running the pipeline steps. | ||
dataset: the input dataset | ||
""" | ||
# If not batching, simply delegate to _generate_single | ||
if not self.ctx.batching_enabled: | ||
logger.info("Running pipeline single-threaded") | ||
return self._generate_single(dataset) | ||
|
||
# Otherwise, split the dataset into batches and run each batch as a | ||
# future in the thread pool | ||
logger.info( | ||
"Running pipeline with multi-threaded batching. Using %s workers for batches of size %s", | ||
self.ctx.batch_num_workers, | ||
self.ctx.batch_size, | ||
) | ||
input_splits = self._split_dataset(dataset) | ||
with ThreadPoolExecutor(max_workers=self.ctx.batch_num_workers) as executor: | ||
futures = [ | ||
executor.submit(self._generate_single, input_split) | ||
for input_split in input_splits | ||
] | ||
|
||
# Collect the results of each batch as they finish. This needs to | ||
# wait for them all, so the order of waiting doesn't matter | ||
output_splits = [future.result() for future in futures] | ||
|
||
return concatenate_datasets(output_splits) | ||
|
||
## Implementation Details ## | ||
|
||
def _generate_single(self, dataset) -> Dataset: | ||
"""Generate a single dataset by running the pipeline steps.""" | ||
for block_prop in self.chained_blocks: | ||
# Initialize arguments for error handling to None | ||
block, block_name, block_type = None, None, None | ||
|
@@ -134,6 +194,39 @@ def generate(self, dataset) -> Dataset: | |
|
||
return dataset | ||
|
||
def _drop_duplicates(self, dataset, cols): | ||
""" | ||
Drop duplicates from the dataset based on the columns provided. | ||
""" | ||
df = dataset.to_pandas() | ||
df = df.drop_duplicates(subset=cols).reset_index(drop=True) | ||
ds = Dataset.from_pandas(df) | ||
return ds | ||
|
||
def _split_dataset(self, dataset: Dataset) -> list[Dataset]: | ||
"""Split the dataset into smaller batches.""" | ||
assert ( | ||
self.ctx.batch_size is not None | ||
), "Programming Error: Should not call _split_dataset if batching disabled" | ||
total_size = len(dataset) | ||
num_batches = math.ceil(total_size / self.ctx.batch_size) | ||
batches = [ | ||
dataset.select(self._get_batch_indices(i, total_size)) | ||
for i in range(num_batches) | ||
] | ||
return batches | ||
|
||
def _get_batch_indices(self, batch_index: int, total_size: int) -> Iterable[int]: | ||
assert ( | ||
self.ctx.batch_size is not None | ||
), "Programming Error: Should not call _get_batch_indices if batching disabled" | ||
return range( | ||
# Start index offset by the batch size | ||
batch_index * self.ctx.batch_size, | ||
# End index is the next batch offset or the end of the dataset | ||
min((batch_index + 1) * self.ctx.batch_size, total_size), | ||
) | ||
|
||
|
||
_block_types = { | ||
"CombineColumnsBlock": utilblocks.CombineColumnsBlock, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
""" | ||
Common fixtures and testing utilities | ||
""" | ||
|
||
# Standard | ||
from unittest import mock | ||
|
||
# Third Party | ||
from datasets import Dataset | ||
import pytest | ||
|
||
# First Party | ||
from instructlab.sdg.pipeline import PipelineContext | ||
|
||
|
||
def get_ctx(**kwargs) -> PipelineContext: | ||
kwargs.setdefault("client", mock.MagicMock()) | ||
kwargs.setdefault("model_family", "test") | ||
kwargs.setdefault("model_id", "test-model") | ||
kwargs.setdefault("num_instructions_to_generate", 10) | ||
kwargs.setdefault("dataset_num_procs", 1) | ||
return PipelineContext(**kwargs) | ||
|
||
|
||
def get_single_threaded_ctx(**kwargs) -> PipelineContext: | ||
kwargs["batch_size"] = 0 | ||
return get_ctx(**kwargs) | ||
|
||
|
||
def get_threaded_ctx(**kwargs) -> PipelineContext: | ||
kwargs["batch_size"] = 6 | ||
kwargs["batch_num_workers"] = 2 | ||
return get_ctx(**kwargs) | ||
|
||
|
||
@pytest.fixture | ||
def single_threaded_ctx() -> PipelineContext: | ||
return get_single_threaded_ctx() | ||
|
||
|
||
@pytest.fixture | ||
def threaded_ctx() -> PipelineContext: | ||
return get_threaded_ctx() | ||
|
||
|
||
@pytest.fixture | ||
def sample_dataset(): | ||
return Dataset.from_list([{"foo": i} for i in range(10)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since no concurrency is used unless batch size is set ... I feel like we need a sensible default here. We certainly shouldn't require users to set it to get sensible behavior
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, good point. This could get tricky to estimate. I think the primary elements of a heuristic would be expected memory overhead per batch element (likely low for IO-bound work), and expected concurrency limits on the client-side for IO-bound work. I'm not sure we would have any reasonable way of knowing these though. I have much less familiarity with the "real" workload, though. Is there some kind of fixed number that we think would make sense?
The other concern I have about enabling concurrency by default is the possibility for the sharding of the dataset to actually change the results. I think most blocks will treat rows of the
Dataset
as independent entities and process them as such, but I could certainly imagine a case where the logic of a block does not treat the rows as independent. In this case, enabling concurrency might degrade the overall results if the individual block executions have less full context to work with.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our users certainly aren't going to be able to do a better job of judging these questions (whether sharding is safe, what batch size to us)
If we can't come up with a sane default, it should be in the pipeline config somehow
@shivchander @xukai92 @aakankshaduggal @npalaska we need your help making a call on an appropriate default batch size or whether it needs to be configured per-pipeline or per-block
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From @shivchander
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Further from @shivchander - it can be tricky to debug with this concurrency enabled, so would be nice to have an easy way to disable this for debugging (as a runtime parameter, not a pipeline config). It might make sense for
--num-cpus=1
to disable this?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, so it sounds like we want:
batch_size=8
num_cpus=None
(<- The python default withnum_workers=None
is here)num_cpus=1
, we fully disable usingThreadPoolExecutor
SynchronousExecutor
that offers the same API. I wrote one in a different project that we can probably borrow (here). The alternative is to maintain separate code paths for the synchronous version which can be pretty error prone.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My sense is that this will inevitably lead to users experiencing sub-par performance, complaining (or just walking way), and us telling them "oh, you need this go-faster flag"
I don't think we should be taking that route out of a concern that it might be a little slower with small datasets
Unless I'm misunderstanding something, I think we should enable the batching by default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a fair point - im with you on making the distributed setting as a default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, just a note: The combination of batch size and concurrency also depends on the underlying accelerator being used and how the accelerators can handle the batches. (The total number of batches the server can handle depends on the compute profile and the memory available on the accelerator).
The general consensus is to have smaller batches per thread and maximize the number of threads to a point where we are not adding a lot of requests in the backend vLLM queue. As the requests start to queue up on the backend vLLM server the time to the first token will shoot up and some requests will result in timeout errors. However, having a small batch means only a few requests get penalized in case of an error.
For example, using a batch of 8 and num_workers 32 (256 total concurrent requests) on 4H100 yielded the most optimal performance for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the input! This makes sense and absolutely lines up with my assumptions about why a heuristic for this would be tough to nail down. Since we do have some strong ideas about the default environment where this would run, I think
batch_size = 8
andnum_cpus = None
(thereby using the python default ofmin(32, os.cpu_count() + 4)
) is probably a safe place to start.