diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml index c86f87fc65..77172ff11f 100644 --- a/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml @@ -112,7 +112,6 @@ task: valid: [2015-01-01, 2016-12-31] test: [2017-01-01, 2020-08-01] seq_len: 60 - horizon: 2 input_size: num_states: *num_states batch_size: 1024 diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml index 75f18f3ee6..66815de711 100644 --- a/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml @@ -106,7 +106,6 @@ task: valid: [2015-01-01, 2016-12-31] test: [2017-01-01, 2020-08-01] seq_len: 60 - horizon: 2 input_size: num_states: *num_states batch_size: 1024 diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml index 9ab5b904ba..a4a2486a58 100644 --- a/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml @@ -106,7 +106,6 @@ task: valid: [2015-01-01, 2016-12-31] test: [2017-01-01, 2020-08-01] seq_len: 60 - horizon: 2 input_size: 6 num_states: *num_states batch_size: 1024 diff --git a/qlib/contrib/data/dataset.py b/qlib/contrib/data/dataset.py index 9ce522cc06..ccc667ae91 100644 --- a/qlib/contrib/data/dataset.py +++ b/qlib/contrib/data/dataset.py @@ -6,6 +6,8 @@ import warnings import numpy as np import pandas as pd +from qlib.utils.data import guess_horizon +from qlib.utils import init_instance_by_config from qlib.data.dataset import DatasetH @@ -130,6 +132,12 @@ def __init__( input_size=None, **kwargs, ): + if horizon == 0: + # Try to guess horizon + if isinstance(handler, (dict, str)): + handler = init_instance_by_config(handler) + label = handler.data_loader.fields["label"][0][0] + horizon = guess_horizon(label) assert num_states == 0 or horizon > 0, "please specify `horizon` to avoid data leakage" assert memory_mode in ["sample", "daily"], "unsupported memory mode" diff --git a/qlib/utils/data.py b/qlib/utils/data.py index 6c62f75583..7b196c50ec 100644 --- a/qlib/utils/data.py +++ b/qlib/utils/data.py @@ -3,10 +3,12 @@ """ This module covers some utility functions that operate on data or basic object """ +import re from copy import deepcopy from typing import List, Union -import pandas as pd + import numpy as np +import pandas as pd def robust_zscore(x: pd.Series, zscore=False): @@ -103,3 +105,19 @@ def update_config(base_config: dict, ext_config: Union[dict, List[dict]]): # one of then are not dict. Then replace base_config[key] = ec[key] return base_config + + +def guess_horizon(label): + """ + Try to guess the horizon by parsing label + """ + regex = r"Ref\(\s*\$[a-zA-Z]+,\s*-(\d+)\)" + horizon_list = [int(x) for x in re.findall(regex, label)] + + if len(horizon_list) == 0: + return 0 + max_horizon = max(horizon_list) + # Unlikely the label doesn't use future information + if max_horizon < 2: + return 0 + return max_horizon + 1 diff --git a/tests/misc/test_utils.py b/tests/misc/test_utils.py index 2be792faf7..53108a75e3 100644 --- a/tests/misc/test_utils.py +++ b/tests/misc/test_utils.py @@ -9,6 +9,7 @@ from qlib.log import TimeInspector from qlib.constant import REG_CN, REG_US, REG_TW from qlib.utils.time import cal_sam_minute as cal_sam_minute_new, get_min_cal, CN_TIME, US_TIME, TW_TIME +from qlib.utils.data import guess_horizon REG_MAP = {REG_CN: CN_TIME, REG_US: US_TIME, REG_TW: TW_TIME} @@ -112,5 +113,24 @@ def gen_args(cal: List): cal_sam_minute_new(*args, region=region) +class DataUtils(TestCase): + @classmethod + def setUpClass(cls): + init() + + def test_guess_horizon(self): + label = "Ref($close, -2) / Ref($close, -1) - 1" + result = guess_horizon(label) + assert result == 3 + + label = "Ref($close, -5) / Ref($close, -1) - 1" + result = guess_horizon(label) + assert result == 6 + + label = "Ref($close, -1) / Ref($close, -1) - 1" + result = guess_horizon(label) + assert result == 0 + + if __name__ == "__main__": unittest.main()