17
17
use std:: iter:: zip;
18
18
19
19
use crate :: {
20
- alloc:: { get2_mut , FlatAlloc } ,
20
+ alloc:: FlatAlloc ,
21
21
prelude:: { LatencyCountInferenceVarID , LatencyCountInferenceVarIDMarker } ,
22
22
} ;
23
23
@@ -705,61 +705,60 @@ fn solve_port_latencies(
705
705
) -> Result < Vec < Vec < SpecifiedLatency > > , LatencyCountingError > {
706
706
let mut bad_ports: Vec < ( usize , i64 , i64 ) > = Vec :: new ( ) ;
707
707
708
- let mut port_groups: Vec < Vec < SpecifiedLatency > > = ports
709
- . inputs ( )
710
- . iter ( )
711
- . map ( |input_port| {
712
- let mut result: Vec < SpecifiedLatency > = Vec :: new ( ) ;
713
- let start_node = SpecifiedLatency {
714
- node : * input_port,
715
- latency : 0 ,
716
- } ;
717
- result. push ( start_node) ;
708
+ let port_groups_iter = ports. inputs ( ) . iter ( ) . map ( |input_port| {
709
+ let mut result: Vec < SpecifiedLatency > = Vec :: new ( ) ;
710
+ let start_node = SpecifiedLatency {
711
+ node : * input_port,
712
+ latency : 0 ,
713
+ } ;
714
+ result. push ( start_node) ;
718
715
719
- let working_latencies =
720
- LatencyNode :: make_solution_and_count_latencies ( fanouts, & [ start_node] ) ;
716
+ let working_latencies =
717
+ LatencyNode :: make_solution_and_count_latencies ( fanouts, & [ start_node] ) ;
721
718
722
- for output in ports. outputs ( ) {
723
- if let Some ( latency) = working_latencies[ * output] . get_maybe ( ) {
724
- result. push ( SpecifiedLatency {
725
- node : * output,
726
- latency,
727
- } ) ;
728
- }
719
+ for output in ports. outputs ( ) {
720
+ if let Some ( latency) = working_latencies[ * output] . get_maybe ( ) {
721
+ result. push ( SpecifiedLatency {
722
+ node : * output,
723
+ latency,
724
+ } ) ;
729
725
}
726
+ }
730
727
731
- debug_assert ! ( !SpecifiedLatency :: has_duplicates( & result) ) ;
728
+ debug_assert ! ( !SpecifiedLatency :: has_duplicates( & result) ) ;
732
729
733
- result
734
- } )
735
- . collect ( ) ;
730
+ result
731
+ } ) ;
736
732
737
- merge_where_possible ( & mut port_groups, |merge_to, merge_from| {
738
- debug_assert ! ( !SpecifiedLatency :: has_duplicates( merge_to) ) ;
739
- debug_assert ! ( !SpecifiedLatency :: has_duplicates( merge_from) ) ;
733
+ let mut port_groups =
734
+ merge_iter_into_disjoint_groups ( port_groups_iter, |merge_to, merge_from| {
735
+ debug_assert ! ( !SpecifiedLatency :: has_duplicates( merge_to) ) ;
736
+ debug_assert ! ( !SpecifiedLatency :: has_duplicates( merge_from) ) ;
740
737
741
- let Some ( offset) = merge_to. iter ( ) . find_map ( |to| {
742
- SpecifiedLatency :: get_latency ( merge_from, to. node )
743
- . map ( |from_latency| to. latency - from_latency)
744
- } ) else {
745
- return false ;
746
- } ;
738
+ let Some ( offset) = merge_to. iter ( ) . find_map ( |to| {
739
+ SpecifiedLatency :: get_latency ( merge_from, to. node )
740
+ . map ( |from_latency| to. latency - from_latency)
741
+ } ) else {
742
+ return false ;
743
+ } ;
747
744
748
- for from_node in merge_from {
749
- from_node. latency += offset;
745
+ for from_node in merge_from {
746
+ from_node. latency += offset;
750
747
751
- if let Some ( to_node_latency) = SpecifiedLatency :: get_latency ( merge_to, from_node. node ) {
752
- if to_node_latency != from_node. latency {
753
- bad_ports. push ( ( from_node. node , to_node_latency, from_node. latency ) ) ;
748
+ if let Some ( to_node_latency) =
749
+ SpecifiedLatency :: get_latency ( merge_to, from_node. node )
750
+ {
751
+ if to_node_latency != from_node. latency {
752
+ bad_ports. push ( ( from_node. node , to_node_latency, from_node. latency ) ) ;
753
+ }
754
+ } else {
755
+ merge_to. push ( * from_node) ;
754
756
}
755
- } else {
756
- merge_to. push ( * from_node) ;
757
757
}
758
- }
759
758
760
- debug_assert ! ( !SpecifiedLatency :: has_duplicates( merge_to) ) ;
761
- true
762
- } ) ;
759
+ debug_assert ! ( !SpecifiedLatency :: has_duplicates( merge_to) ) ;
760
+ true
761
+ } ) ;
763
762
764
763
for output_port in ports. outputs ( ) {
765
764
if !port_groups
@@ -818,13 +817,53 @@ pub fn solve_latencies(
818
817
} ] ) ;
819
818
}
820
819
821
- let mut partial_solutions: Vec < PartialLatencyCountingSolution > = partial_solutions
822
- . iter ( )
823
- . map ( |seeds| PartialLatencyCountingSolution {
824
- latencies : LatencyNode :: make_solution_forwards_then_backwards ( & fanins, & fanouts, seeds) ,
825
- conflicting_nodes : Vec :: new ( ) ,
826
- } )
827
- . collect ( ) ;
820
+ let partial_solutions_iter =
821
+ partial_solutions
822
+ . iter ( )
823
+ . map ( |seeds| PartialLatencyCountingSolution {
824
+ latencies : LatencyNode :: make_solution_forwards_then_backwards (
825
+ & fanins, & fanouts, seeds,
826
+ ) ,
827
+ conflicting_nodes : Vec :: new ( ) ,
828
+ } ) ;
829
+
830
+ let mut partial_solutions =
831
+ merge_iter_into_disjoint_groups ( partial_solutions_iter, |merge_to, merge_from| {
832
+ // Find a node both share
833
+ let Some ( joining_node) = merge_to
834
+ . latencies
835
+ . iter ( )
836
+ . zip ( merge_from. latencies . iter ( ) )
837
+ . position ( |( a, b) | a. get_maybe ( ) . is_some ( ) && b. get_maybe ( ) . is_some ( ) )
838
+ else {
839
+ return false ;
840
+ } ;
841
+
842
+ // Offset the vector we're merging to bring it in line with the target
843
+ merge_from. offset_to_pin_node_to ( SpecifiedLatency {
844
+ node : joining_node,
845
+ latency : merge_to. latencies [ joining_node] . get_maybe ( ) . unwrap ( ) ,
846
+ } ) ;
847
+ merge_to
848
+ . conflicting_nodes
849
+ . append ( & mut merge_from. conflicting_nodes ) ;
850
+
851
+ for ( node, ( to, from) ) in
852
+ zip ( merge_to. latencies . iter_mut ( ) , merge_from. latencies . iter ( ) ) . enumerate ( )
853
+ {
854
+ match ( to. get_maybe ( ) , from. get_maybe ( ) ) {
855
+ ( _, None ) => { } // Do nothing
856
+ ( None , Some ( from) ) => to. abs_lat = from,
857
+ ( Some ( to) , Some ( from) ) => {
858
+ if to != from {
859
+ merge_to. conflicting_nodes . push ( ( node, from) ) ;
860
+ }
861
+ }
862
+ }
863
+ }
864
+
865
+ true
866
+ } ) ;
828
867
829
868
// Polish solution: if there were no specified latencies, then we make the latency of the first port '0
830
869
// This is to shift the whole solution to one canonical absolute latency. Prefer:
@@ -849,43 +888,6 @@ pub fn solve_latencies(
849
888
partial_solutions[ 0 ] . offset_to_pin_node_to ( reference_node) ;
850
889
}
851
890
852
- merge_where_possible ( & mut partial_solutions, |merge_to, merge_from| {
853
- // Find a node both share
854
- let Some ( joining_node) = merge_to
855
- . latencies
856
- . iter ( )
857
- . zip ( merge_from. latencies . iter ( ) )
858
- . position ( |( a, b) | a. get_maybe ( ) . is_some ( ) && b. get_maybe ( ) . is_some ( ) )
859
- else {
860
- return false ;
861
- } ;
862
-
863
- // Offset the vector we're merging to bring it in line with the target
864
- merge_from. offset_to_pin_node_to ( SpecifiedLatency {
865
- node : joining_node,
866
- latency : merge_to. latencies [ joining_node] . get_maybe ( ) . unwrap ( ) ,
867
- } ) ;
868
- merge_to
869
- . conflicting_nodes
870
- . append ( & mut merge_from. conflicting_nodes ) ;
871
-
872
- for ( node, ( to, from) ) in
873
- zip ( merge_to. latencies . iter_mut ( ) , merge_from. latencies . iter ( ) ) . enumerate ( )
874
- {
875
- match ( to. get_maybe ( ) , from. get_maybe ( ) ) {
876
- ( _, None ) => { } // Do nothing
877
- ( None , Some ( from) ) => to. abs_lat = from,
878
- ( Some ( to) , Some ( from) ) => {
879
- if to != from {
880
- merge_to. conflicting_nodes . push ( ( node, from) ) ;
881
- }
882
- }
883
- }
884
- }
885
-
886
- true
887
- } ) ;
888
-
889
891
let mut solution_iter = partial_solutions. into_iter ( ) ;
890
892
891
893
let first_solution = solution_iter. next ( ) . unwrap ( ) ;
@@ -914,21 +916,19 @@ pub fn solve_latencies(
914
916
}
915
917
}
916
918
917
- /// merge should return true if the second argument was merged into the first argument.
918
- fn merge_where_possible < T > ( parts : & mut Vec < T > , mut merge : impl FnMut ( & mut T , & mut T ) -> bool ) {
919
- let mut merge_to_idx = 0 ;
920
- while merge_to_idx < parts. len ( ) {
921
- let mut merge_from_idx = merge_to_idx + 1 ;
922
- while merge_from_idx < parts. len ( ) {
923
- let ( merge_to, merge_from) = get2_mut ( parts, merge_to_idx, merge_from_idx) . unwrap ( ) ;
924
- if merge ( merge_to, merge_from) {
925
- parts. swap_remove ( merge_from_idx) ;
926
- } else {
927
- merge_from_idx += 1 ;
928
- }
929
- }
930
- merge_to_idx += 1 ;
919
+ /// [try_merge] should return true if the second argument was merged into the first argument.
920
+ fn merge_iter_into_disjoint_groups < T > (
921
+ iter : impl Iterator < Item = T > ,
922
+ mut try_merge : impl FnMut ( & mut T , & mut T ) -> bool ,
923
+ ) -> Vec < T > {
924
+ let mut result = Vec :: new ( ) ;
925
+
926
+ for mut new_node in iter {
927
+ result. retain_mut ( |existing_elem| !try_merge ( & mut new_node, existing_elem) ) ;
928
+ result. push ( new_node) ;
931
929
}
930
+
931
+ result
932
932
}
933
933
934
934
/// A candidate for latency inference. Passed to [try_infer_value_for] as a list of possibilities.
0 commit comments