Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions torchspec/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,3 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from torchspec.inference.factory import (
create_inference_engines,
prepare_inference_engines,
)

__all__ = [
"create_inference_engines",
"prepare_inference_engines",
]
14 changes: 0 additions & 14 deletions torchspec/inference/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,3 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from torchspec.inference.engine.base import InferenceEngine
from torchspec.inference.engine.hf_engine import HFEngine
from torchspec.inference.engine.hf_runner import HFRunner
from torchspec.inference.engine.sgl_engine import SglEngine
from torchspec.inference.engine.vllm_engine import VllmEngine

__all__ = [
"InferenceEngine",
"HFEngine",
"HFRunner",
"SglEngine",
"VllmEngine",
]
9 changes: 6 additions & 3 deletions torchspec/inference/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
import ray
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from torchspec.inference.engine.hf_engine import HFEngine
from torchspec.inference.engine.sgl_engine import SglEngine
from torchspec.inference.engine.vllm_engine import VllmEngine
from torchspec.utils.env import get_torchspec_env_vars
from torchspec.utils.logging import logger

Expand Down Expand Up @@ -131,6 +128,8 @@ def _prepare_hf_engines(args, pg, mooncake_config=None, engine_group: int = 0) -

logger.info(f"Initializing {num_engines} HF engines ({num_gpus_per_engine} GPU(s) each)")

from torchspec.inference.engine.hf_engine import HFEngine

HFRayActor = ray.remote(HFEngine)
return _create_and_init_actors(
args,
Expand Down Expand Up @@ -188,6 +187,8 @@ def _prepare_sgl_engines(
)

pg_obj, reordered_bundle_indices, reordered_gpu_ids = pg
from torchspec.inference.engine.sgl_engine import SglEngine

SglRayActor = ray.remote(SglEngine)
env_vars = get_torchspec_env_vars()

Expand Down Expand Up @@ -328,6 +329,8 @@ def _prepare_vllm_engines(
)

pg_obj, reordered_bundle_indices, reordered_gpu_ids = pg
from torchspec.inference.engine.vllm_engine import VllmEngine

VllmRayActor = ray.remote(VllmEngine)
env_vars = get_torchspec_env_vars()

Expand Down
2 changes: 1 addition & 1 deletion torchspec/train_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
run_training_loop,
setup_async_training_with_engines,
)
from torchspec.inference import prepare_inference_engines
from torchspec.inference.factory import prepare_inference_engines
from torchspec.ray.placement_group import (
allocate_train_group,
create_placement_groups,
Expand Down
Loading