2525
2626import numpy as np
2727import pandas as pd
28+ from skbase .utils .dependencies import _check_soft_dependencies
2829
2930from .expected_returns import returns_from_prices
3031
@@ -298,11 +299,14 @@ def min_cov_determinant(
298299 warnings .warn ("data is not in a dataframe" , RuntimeWarning )
299300 prices = pd .DataFrame (prices )
300301
301- # Extra dependency
302- try :
303- import sklearn .covariance
304- except (ModuleNotFoundError , ImportError ):
305- raise ImportError ("Please install scikit-learn via pip or poetry" )
302+ if not _check_soft_dependencies (["scikit-learn" ], severity = "none" ):
303+ raise ImportError (
304+ "scikit-learn is required to use min_cov_determinant. "
305+ "Please ensure that scikit-learn is installed in your environment,"
306+ " e.g via pip install scikit-learn"
307+ )
308+
309+ from sklearn .covariance import fast_mcd
306310
307311 assets = prices .columns
308312
@@ -312,7 +316,7 @@ def min_cov_determinant(
312316 X = returns_from_prices (prices , log_returns )
313317 # X = np.nan_to_num(X.values)
314318 X = X .dropna ().values
315- raw_cov_array = sklearn . covariance . fast_mcd (X , random_state = random_state )[1 ]
319+ raw_cov_array = fast_mcd (X , random_state = random_state )[1 ]
316320 cov = pd .DataFrame (raw_cov_array , index = assets , columns = assets ) * frequency
317321 return fix_nonpositive_semidefinite (cov , kwargs .get ("fix_method" , "spectral" ))
318322
@@ -379,13 +383,16 @@ def __init__(self, prices, returns_data=False, frequency=252, log_returns=False)
379383 :param log_returns: whether to compute using log returns
380384 :type log_returns: bool, defaults to False
381385 """
382- # Optional import
383- try :
384- from sklearn import covariance
386+ if not _check_soft_dependencies (["scikit-learn" ], severity = "none" ):
387+ raise ImportError (
388+ "scikit-learn is required to use CovarianceShrinkage. "
389+ "Please ensure that scikit-learn is installed in your environment,"
390+ " e.g via pip install scikit-learn"
391+ )
392+
393+ from sklearn import covariance
385394
386- self .covariance = covariance
387- except (ModuleNotFoundError , ImportError ): # pragma: no cover
388- raise ImportError ("Please install scikit-learn via pip or poetry" )
395+ self .covariance = covariance
389396
390397 if not isinstance (prices , pd .DataFrame ):
391398 warnings .warn ("data is not in a dataframe" , RuntimeWarning )
0 commit comments