@@ -574,6 +574,10 @@ def forward(self,
574
574
past_observed_targets : Optional [torch .BoolTensor ] = None ,
575
575
decoder_observed_values : Optional [torch .Tensor ] = None ,
576
576
) -> ALL_NET_OUTPUT :
577
+
578
+ if isinstance (past_targets , dict ):
579
+ past_targets , past_features , future_features , past_observed_targets = self ._unwrap_past_targets (past_targets )
580
+
577
581
x_past , x_future , x_static , loc , scale , static_context_initial_hidden , _ = self .pre_processing (
578
582
past_targets = past_targets ,
579
583
past_observed_targets = past_observed_targets ,
@@ -603,6 +607,38 @@ def forward(self,
603
607
604
608
return self .rescale_output (output , loc , scale , self .device )
605
609
610
+ def _unwrap_past_targets (
611
+ self ,
612
+ past_targets : dict
613
+ ) -> Tuple [
614
+ torch .Tensor ,
615
+ Optional [torch .Tensor ],
616
+ Optional [torch .Tensor ],
617
+ Optional [torch .Tensor ],
618
+ Optional [torch .BoolTensor ],
619
+ Optional [torch .Tensor ]]:
620
+ """
621
+ Time series forecasting network requires multiple inputs for the forward pass which is different to how pytorch
622
+ networks usually work. SWA's update_bn in line #452 of trainer choice, does not unwrap the dictionary of the
623
+ input when running the forward pass. So we need to check for that here.
624
+
625
+ Args:
626
+ past_targets (dict):
627
+ Input mistakenly passed to past_targets variable
628
+
629
+ Returns:
630
+ _type_: _description_
631
+ """
632
+
633
+ past_targets_copy = past_targets .copy ()
634
+ past_targets = past_targets_copy .pop ('past_targets' )
635
+ future_targets = past_targets_copy .pop ('future_targets' , None )
636
+ past_features = past_targets_copy .pop ('past_features' , None )
637
+ future_features = past_targets_copy .pop ('future_features' , None )
638
+ past_observed_targets = past_targets_copy .pop ('past_observed_targets' , None )
639
+ decoder_observed_values = past_targets_copy .pop ('decoder_observed_values' , None )
640
+ return past_targets ,past_features ,future_features ,past_observed_targets
641
+
606
642
def pred_from_net_output (self , net_output : ALL_NET_OUTPUT ) -> torch .Tensor :
607
643
if self .output_type == 'regression' :
608
644
return net_output
@@ -694,6 +730,10 @@ def forward(self,
694
730
future_features : Optional [torch .Tensor ] = None ,
695
731
past_observed_targets : Optional [torch .BoolTensor ] = None ,
696
732
decoder_observed_values : Optional [torch .Tensor ] = None , ) -> ALL_NET_OUTPUT :
733
+
734
+ if isinstance (past_targets , dict ):
735
+ past_targets , past_features , future_features , past_observed_targets = self ._unwrap_past_targets (past_targets )
736
+
697
737
x_past , _ , x_static , loc , scale , static_context_initial_hidden , past_targets = self .pre_processing (
698
738
past_targets = past_targets ,
699
739
past_observed_targets = past_observed_targets ,
@@ -983,6 +1023,10 @@ def forward(self,
983
1023
future_features : Optional [torch .Tensor ] = None ,
984
1024
past_observed_targets : Optional [torch .BoolTensor ] = None ,
985
1025
decoder_observed_values : Optional [torch .Tensor ] = None , ) -> ALL_NET_OUTPUT :
1026
+
1027
+ if isinstance (past_targets , dict ):
1028
+ past_targets , past_features , future_features , past_observed_targets = self ._unwrap_past_targets (past_targets )
1029
+
986
1030
encode_length = min (self .window_size , past_targets .shape [1 ])
987
1031
988
1032
if past_observed_targets is None :
@@ -1250,6 +1294,9 @@ def forward(self, # type: ignore[override]
1250
1294
decoder_observed_values : Optional [torch .Tensor ] = None , ) -> Union [torch .Tensor ,
1251
1295
Tuple [torch .Tensor , torch .Tensor ]]:
1252
1296
1297
+ if isinstance (past_targets , dict ):
1298
+ past_targets , past_features , future_features , past_observed_targets = self ._unwrap_past_targets (past_targets )
1299
+
1253
1300
# Unlike other networks, NBEATS network is required to predict both past and future targets.
1254
1301
# Thereby, we return two tensors for backcast and forecast
1255
1302
if past_observed_targets is None :
0 commit comments