2424from autoPyTorch .api .tabular_classification import TabularClassificationTask
2525
2626
27- if __name__ == '__main__' :
27+ ############################################################################
28+ # Data Loading
29+ # ============
30+ X , y = sklearn .datasets .fetch_openml (data_id = 40981 , return_X_y = True , as_frame = True )
31+ X_train , X_test , y_train , y_test = sklearn .model_selection .train_test_split (
32+ X ,
33+ y ,
34+ random_state = 42 ,
35+ )
2836
29- ############################################################################
30- # Data Loading
31- # ============
32- X , y = sklearn .datasets .fetch_openml (data_id = 40981 , return_X_y = True , as_frame = True )
33- X_train , X_test , y_train , y_test = sklearn .model_selection .train_test_split (
34- X ,
35- y ,
36- random_state = 42 ,
37- )
37+ ############################################################################
38+ # Build and fit a classifier
39+ # ==========================
40+ api = TabularClassificationTask (
41+ seed = 42 ,
42+ )
3843
39- ############################################################################
40- # Build and fit a classifier
41- # ==========================
42- api = TabularClassificationTask (
43- ensemble_size = 0 ,
44- seed = 42 ,
45- )
44+ ############################################################################
45+ # Search for the best neural network
46+ # ==================================
47+ api .search (
48+ X_train = X_train ,
49+ y_train = y_train ,
50+ X_test = X_test .copy (),
51+ y_test = y_test .copy (),
52+ optimize_metric = 'accuracy' ,
53+ total_walltime_limit = 100 ,
54+ func_eval_time_limit_secs = 50 ,
55+ ensemble_size = 0 ,
56+ )
4657
47- ############################################################################
48- # Search for the best neural network
49- # ==================================
50- api .search (
51- X_train = X_train ,
52- y_train = y_train ,
53- X_test = X_test .copy (),
54- y_test = y_test .copy (),
55- optimize_metric = 'accuracy' ,
56- total_walltime_limit = 100 ,
57- func_eval_time_limit_secs = 50
58- )
58+ ############################################################################
59+ # Print the final performance of the incumbent neural network
60+ # ===========================================================
61+ print (api .run_history , api .trajectory )
62+ y_pred = api .predict (X_test )
63+ score = api .score (y_pred , y_test )
64+ print (score )
5965
60- ############################################################################
61- # Print the final performance of the incumbent neural network
62- # ===========================================================
63- print (api .run_history , api .trajectory )
64- y_pred = api .predict (X_test )
65- score = api .score (y_pred , y_test )
66- print (score )
66+ ############################################################################
67+ # Fit an ensemble with the neural networks fitted during the search
68+ # =================================================================
6769
68- ############################################################################
69- # Fit an ensemble with the neural networks fitted during the search
70- # =================================================================
70+ api .fit_ensemble (ensemble_size = 5 ,
71+ # Set the enable_traditional_pipeline=True
72+ # to also include traditional models
73+ # in the ensemble
74+ enable_traditional_pipeline = False )
75+ # Print the final ensemble built by AutoPyTorch
76+ y_pred = api .predict (X_test )
77+ score = api .score (y_pred , y_test )
78+ print (score )
79+ print (api .show_models ())
7180
72- api .fit_ensemble (ensemble_size = 5 ,
73- # Set the enable_traditional_pipeline=True
74- # to also include traditional models
75- # in the ensemble
76- enable_traditional_pipeline = False )
77- # Print the final ensemble built by AutoPyTorch
78- y_pred = api .predict (X_test )
79- score = api .score (y_pred , y_test )
80- print (score )
81- print (api .show_models ())
82- api ._cleanup ()
81+ # Print statistics from search
82+ print (api .sprint_statistics ())
0 commit comments