|
9 | 9 | import logging
|
10 | 10 | import os
|
11 | 11 | import types
|
12 |
| -from functools import wraps |
13 | 12 | from pathlib import Path
|
14 | 13 |
|
15 | 14 | import torch
|
| 15 | +from torchaudio._internal.module_utils import eval_env |
16 | 16 |
|
17 | 17 | _LG = logging.getLogger(__name__)
|
18 | 18 | _LIB_DIR = Path(__file__).parent.parent / "lib"
|
@@ -62,16 +62,49 @@ def _load_lib(lib: str) -> bool:
|
62 | 62 | return True
|
63 | 63 |
|
64 | 64 |
|
65 |
| -def _init_sox(): |
| 65 | +def _import_sox_ext(): |
| 66 | + if os.name == "nt": |
| 67 | + raise RuntimeError("sox extension is not supported on Windows") |
| 68 | + if not eval_env("TORCHAUDIO_USE_SOX", True): |
| 69 | + raise RuntimeError("sox extension is disabled. (TORCHAUDIO_USE_SOX=0)") |
| 70 | + |
| 71 | + ext = "torchaudio.lib._torchaudio_sox" |
| 72 | + |
| 73 | + if not importlib.util.find_spec(ext): |
| 74 | + raise RuntimeError( |
| 75 | + # fmt: off |
| 76 | + "TorchAudio is not built with sox extension. " |
| 77 | + "Please build TorchAudio with libsox support. (BUILD_SOX=1)" |
| 78 | + # fmt: on |
| 79 | + ) |
| 80 | + |
66 | 81 | _load_lib("libtorchaudio_sox")
|
67 |
| - import torchaudio.lib._torchaudio_sox # noqa |
| 82 | + return importlib.import_module(ext) |
68 | 83 |
|
69 |
| - torchaudio.lib._torchaudio_sox.set_verbosity(0) |
| 84 | + |
| 85 | +def _init_sox(): |
| 86 | + ext = _import_sox_ext() |
| 87 | + ext.set_verbosity(0) |
70 | 88 |
|
71 | 89 | import atexit
|
72 | 90 |
|
73 |
| - torch.ops.torchaudio.sox_effects_initialize_sox_effects() |
74 |
| - atexit.register(torch.ops.torchaudio.sox_effects_shutdown_sox_effects) |
| 91 | + torch.ops.torchaudio_sox.initialize_sox_effects() |
| 92 | + atexit.register(torch.ops.torchaudio_sox.shutdown_sox_effects) |
| 93 | + |
| 94 | + # Bundle functions registered with TORCH_LIBRARY into extension |
| 95 | + # so that they can also be accessed in the same (lazy) manner |
| 96 | + # from the extension. |
| 97 | + keys = [ |
| 98 | + "get_info", |
| 99 | + "load_audio_file", |
| 100 | + "save_audio_file", |
| 101 | + "apply_effects_tensor", |
| 102 | + "apply_effects_file", |
| 103 | + ] |
| 104 | + for key in keys: |
| 105 | + setattr(ext, key, getattr(torch.ops.torchaudio_sox, key)) |
| 106 | + |
| 107 | + return ext |
75 | 108 |
|
76 | 109 |
|
77 | 110 | _FFMPEG_VERS = ["6", "5", "4", ""]
|
@@ -197,22 +230,3 @@ def _check_cuda_version():
|
197 | 230 | "Please install the TorchAudio version that matches your PyTorch version."
|
198 | 231 | )
|
199 | 232 | return version
|
200 |
| - |
201 |
| - |
202 |
| -def _fail_since_no_sox(func): |
203 |
| - @wraps(func) |
204 |
| - def wrapped(*_args, **_kwargs): |
205 |
| - try: |
206 |
| - # Note: |
207 |
| - # We run _init_sox again just to show users the stacktrace. |
208 |
| - # _init_sox would not succeed here. |
209 |
| - _init_sox() |
210 |
| - except Exception as err: |
211 |
| - raise RuntimeError( |
212 |
| - f"{func.__name__} requires sox extension which is not available. " |
213 |
| - "Please refer to the stacktrace above for how to resolve this." |
214 |
| - ) from err |
215 |
| - # This should not happen in normal execution, but just in case. |
216 |
| - return func(*_args, **_kwargs) |
217 |
| - |
218 |
| - return wrapped |
|
0 commit comments