File tree Expand file tree Collapse file tree 3 files changed +17
-20
lines changed Expand file tree Collapse file tree 3 files changed +17
-20
lines changed Original file line number Diff line number Diff line change @@ -112,19 +112,13 @@ void HloProtoMap::AddHloProtosFromXSpace(const XSpace& space) {
112112 }
113113}
114114
115- std::vector<absl::string_view> HloProtoMap::GetModuleList () const {
115+ std::vector<absl::string_view> HloProtoMap::GetModuleList (
116+ bool is_original /* = false*/ ) const {
117+ const auto & hlo_protos_map =
118+ is_original ? original_hlo_protos_by_name_ : hlo_protos_by_name_;
116119 std::vector<absl::string_view> module_list;
117- module_list.reserve (hlo_protos_by_name_.size ());
118- for (const auto & [name, hlo_proto] : hlo_protos_by_name_) {
119- module_list.push_back (name);
120- }
121- return module_list;
122- }
123-
124- std::vector<absl::string_view> HloProtoMap::GetOriginalModuleList () const {
125- std::vector<absl::string_view> module_list;
126- module_list.reserve (original_hlo_protos_by_name_.size ());
127- for (const auto & [name, hlo_proto] : original_hlo_protos_by_name_) {
120+ module_list.reserve (hlo_protos_map.size ());
121+ for (const auto & [name, hlo_proto] : hlo_protos_map) {
128122 module_list.push_back (name);
129123 }
130124 return module_list;
Original file line number Diff line number Diff line change @@ -49,19 +49,22 @@ class HloProtoMap {
4949 auto begin () const { return hlo_protos_by_program_id_.begin (); }
5050 auto end () const { return hlo_protos_by_program_id_.end (); }
5151
52- bool contains (absl::string_view name) const {
52+ bool ContainsOptimizedModule (absl::string_view name) const {
5353 return hlo_protos_by_name_.contains (name);
5454 }
5555
56+ bool ContainsOriginalModule (absl::string_view name) const {
57+ return original_hlo_protos_by_name_.contains (name);
58+ }
59+
5660 bool contains (uint64_t program_id) const {
5761 return hlo_protos_by_program_id_.contains (program_id);
5862 }
5963
60- // Returns a list of module names (not sorted).
61- std::vector<absl::string_view> GetModuleList () const ;
62-
63- // Returns a list of unoptimized/original module names (not sorted).
64- std::vector<absl::string_view> GetOriginalModuleList () const ;
64+ // Returns a list of HLO module names. If `is_original` is true, returns the
65+ // names of the original/unoptimized modules; otherwise, returns the names of
66+ // the optimized modules. The order of the modules is not guaranteed.
67+ std::vector<absl::string_view> GetModuleList (bool is_original = false ) const ;
6568
6669 // Returns a list of module names sorted alphabetically.
6770 std::vector<absl::string_view> GetSortedModuleList () const ;
Original file line number Diff line number Diff line change @@ -19,7 +19,7 @@ using ::testing::status::StatusIs;
1919
2020TEST (HloProtoMapTest, GetOriginalModuleList) {
2121 HloProtoMap hlo_proto_map;
22- EXPECT_THAT (hlo_proto_map.GetOriginalModuleList ( ), IsEmpty ());
22+ EXPECT_THAT (hlo_proto_map.GetModuleList ( true ), IsEmpty ());
2323
2424 auto hlo_proto_1 = std::make_unique<xla::HloProto>();
2525 hlo_proto_1->mutable_hlo_module ()->set_name (" module1" );
@@ -29,7 +29,7 @@ TEST(HloProtoMapTest, GetOriginalModuleList) {
2929 hlo_proto_2->mutable_hlo_module ()->set_name (" module2" );
3030 hlo_proto_map.AddOriginalHloProto (2 , std::move (hlo_proto_2));
3131
32- EXPECT_THAT (hlo_proto_map.GetOriginalModuleList ( ),
32+ EXPECT_THAT (hlo_proto_map.GetModuleList ( true ),
3333 UnorderedElementsAre (" module1(1)" , " module2(2)" ));
3434}
3535
You can’t perform that action at this time.
0 commit comments