@@ -576,7 +576,14 @@ def forward(self,
576
576
) -> ALL_NET_OUTPUT :
577
577
578
578
if isinstance (past_targets , dict ):
579
- past_targets , past_features , future_features , past_observed_targets = self ._unwrap_past_targets (past_targets )
579
+ (
580
+ past_targets ,
581
+ past_features ,
582
+ future_features ,
583
+ past_observed_targets ,
584
+ future_targets ,
585
+ decoder_observed_values
586
+ ) = self ._unwrap_past_targets (past_targets )
580
587
581
588
x_past , x_future , x_static , loc , scale , static_context_initial_hidden , _ = self .pre_processing (
582
589
past_targets = past_targets ,
@@ -610,13 +617,12 @@ def forward(self,
610
617
def _unwrap_past_targets (
611
618
self ,
612
619
past_targets : dict
613
- ) -> Tuple [
614
- torch .Tensor ,
615
- Optional [torch .Tensor ],
616
- Optional [torch .Tensor ],
617
- Optional [torch .Tensor ],
618
- Optional [torch .BoolTensor ],
619
- Optional [torch .Tensor ]]:
620
+ ) -> Tuple [torch .Tensor ,
621
+ Optional [torch .Tensor ],
622
+ Optional [torch .Tensor ],
623
+ Optional [torch .Tensor ],
624
+ Optional [torch .BoolTensor ],
625
+ Optional [torch .Tensor ]]:
620
626
"""
621
627
Time series forecasting network requires multiple inputs for the forward pass which is different to how pytorch
622
628
networks usually work. SWA's update_bn in line #452 of trainer choice, does not unwrap the dictionary of the
@@ -637,7 +643,14 @@ def _unwrap_past_targets(
637
643
future_features = past_targets_copy .pop ('future_features' , None )
638
644
past_observed_targets = past_targets_copy .pop ('past_observed_targets' , None )
639
645
decoder_observed_values = past_targets_copy .pop ('decoder_observed_values' , None )
640
- return past_targets ,past_features ,future_features ,past_observed_targets
646
+ return (
647
+ past_targets ,
648
+ past_features ,
649
+ future_features ,
650
+ past_observed_targets ,
651
+ future_targets ,
652
+ decoder_observed_values
653
+ )
641
654
642
655
def pred_from_net_output (self , net_output : ALL_NET_OUTPUT ) -> torch .Tensor :
643
656
if self .output_type == 'regression' :
@@ -730,9 +743,16 @@ def forward(self,
730
743
future_features : Optional [torch .Tensor ] = None ,
731
744
past_observed_targets : Optional [torch .BoolTensor ] = None ,
732
745
decoder_observed_values : Optional [torch .Tensor ] = None , ) -> ALL_NET_OUTPUT :
733
-
746
+
734
747
if isinstance (past_targets , dict ):
735
- past_targets , past_features , future_features , past_observed_targets = self ._unwrap_past_targets (past_targets )
748
+ (
749
+ past_targets ,
750
+ past_features ,
751
+ future_features ,
752
+ past_observed_targets ,
753
+ future_targets ,
754
+ decoder_observed_values
755
+ ) = self ._unwrap_past_targets (past_targets )
736
756
737
757
x_past , _ , x_static , loc , scale , static_context_initial_hidden , past_targets = self .pre_processing (
738
758
past_targets = past_targets ,
@@ -1025,7 +1045,14 @@ def forward(self,
1025
1045
decoder_observed_values : Optional [torch .Tensor ] = None , ) -> ALL_NET_OUTPUT :
1026
1046
1027
1047
if isinstance (past_targets , dict ):
1028
- past_targets , past_features , future_features , past_observed_targets = self ._unwrap_past_targets (past_targets )
1048
+ (
1049
+ past_targets ,
1050
+ past_features ,
1051
+ future_features ,
1052
+ past_observed_targets ,
1053
+ future_targets ,
1054
+ decoder_observed_values
1055
+ ) = self ._unwrap_past_targets (past_targets )
1029
1056
1030
1057
encode_length = min (self .window_size , past_targets .shape [1 ])
1031
1058
@@ -1295,7 +1322,14 @@ def forward(self, # type: ignore[override]
1295
1322
Tuple [torch .Tensor , torch .Tensor ]]:
1296
1323
1297
1324
if isinstance (past_targets , dict ):
1298
- past_targets , past_features , future_features , past_observed_targets = self ._unwrap_past_targets (past_targets )
1325
+ (
1326
+ past_targets ,
1327
+ past_features ,
1328
+ future_features ,
1329
+ past_observed_targets ,
1330
+ future_targets ,
1331
+ decoder_observed_values
1332
+ ) = self ._unwrap_past_targets (past_targets )
1299
1333
1300
1334
# Unlike other networks, NBEATS network is required to predict both past and future targets.
1301
1335
# Thereby, we return two tensors for backcast and forecast
0 commit comments