Skip to content

Commit e54450d

Browse files
Profiler Teamcopybara-github
authored andcommitted
Refactored HloProtoMap::contains into ContainsOptimizedModule and ContainsOriginalModule to clearly distinguish between optimized and original HLO module lookups.
PiperOrigin-RevId: 825412534
1 parent 067ede9 commit e54450d

File tree

3 files changed

+17
-20
lines changed

3 files changed

+17
-20
lines changed

xprof/utils/hlo_proto_map.cc

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -112,19 +112,13 @@ void HloProtoMap::AddHloProtosFromXSpace(const XSpace& space) {
112112
}
113113
}
114114

115-
std::vector<absl::string_view> HloProtoMap::GetModuleList() const {
115+
std::vector<absl::string_view> HloProtoMap::GetModuleList(
116+
bool is_original /*= false*/) const {
117+
const auto& hlo_protos_map =
118+
is_original ? original_hlo_protos_by_name_ : hlo_protos_by_name_;
116119
std::vector<absl::string_view> module_list;
117-
module_list.reserve(hlo_protos_by_name_.size());
118-
for (const auto& [name, hlo_proto] : hlo_protos_by_name_) {
119-
module_list.push_back(name);
120-
}
121-
return module_list;
122-
}
123-
124-
std::vector<absl::string_view> HloProtoMap::GetOriginalModuleList() const {
125-
std::vector<absl::string_view> module_list;
126-
module_list.reserve(original_hlo_protos_by_name_.size());
127-
for (const auto& [name, hlo_proto] : original_hlo_protos_by_name_) {
120+
module_list.reserve(hlo_protos_map.size());
121+
for (const auto& [name, hlo_proto] : hlo_protos_map) {
128122
module_list.push_back(name);
129123
}
130124
return module_list;

xprof/utils/hlo_proto_map.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,22 @@ class HloProtoMap {
4949
auto begin() const { return hlo_protos_by_program_id_.begin(); }
5050
auto end() const { return hlo_protos_by_program_id_.end(); }
5151

52-
bool contains(absl::string_view name) const {
52+
bool ContainsOptimizedModule(absl::string_view name) const {
5353
return hlo_protos_by_name_.contains(name);
5454
}
5555

56+
bool ContainsOriginalModule(absl::string_view name) const {
57+
return original_hlo_protos_by_name_.contains(name);
58+
}
59+
5660
bool contains(uint64_t program_id) const {
5761
return hlo_protos_by_program_id_.contains(program_id);
5862
}
5963

60-
// Returns a list of module names (not sorted).
61-
std::vector<absl::string_view> GetModuleList() const;
62-
63-
// Returns a list of unoptimized/original module names (not sorted).
64-
std::vector<absl::string_view> GetOriginalModuleList() const;
64+
// Returns a list of HLO module names. If `is_original` is true, returns the
65+
// names of the original/unoptimized modules; otherwise, returns the names of
66+
// the optimized modules. The order of the modules is not guaranteed.
67+
std::vector<absl::string_view> GetModuleList(bool is_original = false) const;
6568

6669
// Returns a list of module names sorted alphabetically.
6770
std::vector<absl::string_view> GetSortedModuleList() const;

xprof/utils/hlo_proto_map_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ using ::testing::status::StatusIs;
1919

2020
TEST(HloProtoMapTest, GetOriginalModuleList) {
2121
HloProtoMap hlo_proto_map;
22-
EXPECT_THAT(hlo_proto_map.GetOriginalModuleList(), IsEmpty());
22+
EXPECT_THAT(hlo_proto_map.GetModuleList(true), IsEmpty());
2323

2424
auto hlo_proto_1 = std::make_unique<xla::HloProto>();
2525
hlo_proto_1->mutable_hlo_module()->set_name("module1");
@@ -29,7 +29,7 @@ TEST(HloProtoMapTest, GetOriginalModuleList) {
2929
hlo_proto_2->mutable_hlo_module()->set_name("module2");
3030
hlo_proto_map.AddOriginalHloProto(2, std::move(hlo_proto_2));
3131

32-
EXPECT_THAT(hlo_proto_map.GetOriginalModuleList(),
32+
EXPECT_THAT(hlo_proto_map.GetModuleList(true),
3333
UnorderedElementsAre("module1(1)", "module2(2)"));
3434
}
3535

0 commit comments

Comments
 (0)