Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: reduce import time - draft #8831

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions haystack/telemetry/_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,31 +84,9 @@ def collect_system_specs() -> Dict[str, Any]:
"os.machine": platform.machine(),
"python.version": platform.python_version(),
"hardware.cpus": os.cpu_count(),
"hardware.gpus": 0,
"libraries.transformers": False,
"libraries.torch": False,
"libraries.cuda": False,
"libraries.pytest": sys.modules["pytest"].__version__ if "pytest" in sys.modules.keys() else False,
"libraries.ipython": sys.modules["ipython"].__version__ if "ipython" in sys.modules.keys() else False,
"libraries.colab": sys.modules["google.colab"].__version__ if "google.colab" in sys.modules.keys() else False,
}

# Try to find out transformer's version
try:
import transformers

specs["libraries.transformers"] = transformers.__version__
except ImportError:
pass

# Try to find out torch's version and info on potential GPU(s)
try:
import torch

specs["libraries.torch"] = torch.__version__
if torch.cuda.is_available():
specs["libraries.cuda"] = torch.version.cuda
specs["libraries.gpus"] = torch.cuda.device_count()
except ImportError:
pass
return specs
46 changes: 17 additions & 29 deletions haystack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,21 @@
#
# SPDX-License-Identifier: Apache-2.0

from .auth import Secret, deserialize_secrets_inplace
from .callable_serialization import deserialize_callable, serialize_callable
from .device import ComponentDevice, Device, DeviceMap, DeviceType
from .docstore_deserialization import deserialize_document_store_in_init_params_inplace
from .expit import expit
from .filters import document_matches_filter, raise_on_invalid_filter_syntax
from .jinja2_extensions import Jinja2TimeExtension
from .jupyter import is_in_jupyter
from .requests_utils import request_with_retry
from .type_serialization import deserialize_type, serialize_type
import sys

__all__ = [
"Secret",
"deserialize_secrets_inplace",
"ComponentDevice",
"Device",
"DeviceMap",
"DeviceType",
"expit",
"document_matches_filter",
"raise_on_invalid_filter_syntax",
"is_in_jupyter",
"request_with_retry",
"serialize_callable",
"deserialize_callable",
"serialize_type",
"deserialize_type",
"deserialize_document_store_in_init_params_inplace",
"Jinja2TimeExtension",
]
from lazy_imports import LazyImporter

_import_structure = {
".utils.device": ["ComponentDevice", "Device", "DeviceMap", "DeviceType"],
".utils.auth": ["Secret", "deserialize_secrets_inplace"],
".utils.callable_serialization": ["deserialize_callable", "serialize_callable"],
".utils.docstore_deserialization": ["deserialize_document_store_in_init_params_inplace"],
".utils.expit": ["expit"],
".utils.filters": ["document_matches_filter", "raise_on_invalid_filter_syntax"],
".utils.type_serialization": ["deserialize_type", "serialize_type"],
".utils.jinja2_extensions": ["Jinja2TimeExtension"],
".utils.jupyter": ["is_in_jupyter"],
".utils.requests_utils": ["request_with_retry"],
}

sys.modules[__name__] = LazyImporter(__name__, globals()["__file__"], _import_structure)
58 changes: 58 additions & 0 deletions measure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import re
import statistics
import subprocess


def measure_import_times(n_runs=30, module="haystack"):
"""
Measure the time it takes to import a module.
"""
user_times = []
sys_times = []

print(f"Running {n_runs} measurements...")

for i in range(n_runs):
# Run the import command and capture output
result = subprocess.run(["time", "python", "-c", f"import {module}"], capture_output=True, text=True)

# Check both stdout and stderr
time_output = result.stderr

# Extract times using regex - matches patterns like "3.21user 0.17system"
time_pattern = r"([\d.]+)user\s+([\d.]+)system"
match = re.search(time_pattern, time_output)

if match:
user_time = float(match.group(1))
sys_time = float(match.group(2))

user_times.append(user_time)
sys_times.append(sys_time)

# print(user_times)

if (i + 1) % 10 == 0:
print(f"Completed {i + 1} runs...")

# Calculate statistics
avg_user = statistics.mean(user_times)
avg_sys = statistics.mean(sys_times)
avg_total = avg_user + avg_sys

# Calculate standard deviations
std_user = statistics.stdev(user_times)
std_sys = statistics.stdev(sys_times)

print("\nResults:")
print(f"Average user time: {avg_user:.3f}s ± {std_user:.3f}s")
print(f"Average sys time: {avg_sys:.3f}s ± {std_sys:.3f}s")
print(f"Average total (user + sys): {avg_total:.3f}s")


if __name__ == "__main__":
measure_import_times()
64 changes: 64 additions & 0 deletions torch_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import importlib.abc
import importlib.util
import sys
import types
from pathlib import Path


class ImportTracker(importlib.abc.MetaPathFinder):
def find_spec(self, fullname, path, target=None):
"""
If the module name contains "torch", print the full name and the stack trace.
"""
if "torch" in fullname:
print(f"\nAttempting to import: {fullname}")
import traceback

for frame in traceback.extract_stack()[:-1]: # Exclude this frame
if "haystack" in frame.filename:
print(f" In Haystack file: {frame.filename}:{frame.lineno}")
print(f" {frame.line}")


# Install the import tracker
sys.meta_path.insert(0, ImportTracker())

# Record modules before import
print("Recording initial modules...")
modules_before = set(sys.modules.keys())

# Import haystack
print("Importing haystack...")
import haystack

# Find new modules after import
print("Analyzing new modules...")
modules_after = set(sys.modules.keys())
new_modules = modules_after - modules_before

# Filter for haystack modules that imported torch
haystack_importers = {}

for name in new_modules:
if name.startswith("haystack"):
module = sys.modules[name]
# Check if this module uses torch
module_dict = getattr(module, "__dict__", {})
for value in module_dict.values():
if isinstance(value, types.ModuleType) and "torch" in value.__name__:
if name not in haystack_importers:
haystack_importers[name] = set()
haystack_importers[name].add(value.__name__)

if haystack_importers:
print("\nFound haystack modules that imported torch:")
for module_name, torch_modules in sorted(haystack_importers.items()):
print(f"\n{module_name}:")
for torch_module in sorted(torch_modules):
print(f" - {torch_module}")
else:
print("\nNo haystack modules imported torch")
Loading