Skip to content

Commit 7ccb806

Browse files
Profiler Teamcopybara-github
authored andcommitted
Replaces the use of RemapInstructionIds with HloModule::CreateFromProto without preserving unique id in xprof.
PiperOrigin-RevId: 824583193
1 parent aa8ee31 commit 7ccb806

File tree

3 files changed

+128
-29
lines changed

3 files changed

+128
-29
lines changed

WORKSPACE

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,39 @@ http_archive(
3030
url = "https://github.com/bazelbuild/rules_java/releases/download/8.7.0/rules_java-8.7.0.tar.gz",
3131
)
3232

33+
# Toolchains for ML projects
34+
# Details: https://github.com/google-ml-infra/rules_ml_toolchain
35+
http_archive(
36+
name = "rules_ml_toolchain",
37+
sha256 = "2a5591ec7543c8b37aead3cb681eb2b93c9616ce94abdf3aedcf391b372d4007",
38+
strip_prefix = "rules_ml_toolchain-b2b08356ac30353c49587b0e8dfe65aabb35e78d",
39+
urls = [
40+
"https://github.com/google-ml-infra/rules_ml_toolchain/archive/b2b08356ac30353c49587b0e8dfe65aabb35e78d.tar.gz",
41+
],
42+
)
43+
44+
load(
45+
"@rules_ml_toolchain//cc/deps:cc_toolchain_deps.bzl",
46+
"cc_toolchain_deps",
47+
)
48+
49+
cc_toolchain_deps()
50+
51+
register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64")
52+
53+
register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64_cuda")
54+
55+
load("@rules_ml_toolchain//gpu/sycl:sycl_configure.bzl", "sycl_configure")
56+
load("@rules_ml_toolchain//gpu/sycl:sycl_init_repository.bzl", "sycl_init_repository")
57+
3358
http_archive(
3459
name = "xla",
3560
patch_args = ["-p1"],
3661
patches = ["//third_party:xla.patch"],
37-
sha256 = "a106290c8a1f522d57feed0be31496c571c2a50545cc92a1cdb32aef2309270b",
38-
strip_prefix = "xla-845061f0e1162559fabf5dc6555b85a31bd96cb9",
62+
sha256 = "82160211319100b8c1d55e016c426b2999ccb9c3091f699ac55a2d536c784630",
63+
strip_prefix = "xla-68b8314049f2a7256aea628a7e3377a00278345a",
3964
urls = [
40-
"https://github.com/openxla/xla/archive/845061f0e1162559fabf5dc6555b85a31bd96cb9.zip",
65+
"https://github.com/openxla/xla/archive/68b8314049f2a7256aea628a7e3377a00278345a.zip",
4166
],
4267
)
4368

@@ -105,28 +130,6 @@ load("@xla//:workspace0.bzl", "xla_workspace0")
105130

106131
xla_workspace0()
107132

108-
# Toolchains for ML projects
109-
# Details: https://github.com/google-ml-infra/rules_ml_toolchain
110-
http_archive(
111-
name = "rules_ml_toolchain",
112-
sha256 = "d1a64a54b1688446619364dac25ff5bcef65c6ffb6984f82128986f5f66129f6",
113-
strip_prefix = "rules_ml_toolchain-b42dc53b80d7f4da1e12abca7503a264e96de98e",
114-
urls = [
115-
"https://github.com/google-ml-infra/rules_ml_toolchain/archive/b42dc53b80d7f4da1e12abca7503a264e96de98e.tar.gz",
116-
],
117-
)
118-
119-
load(
120-
"@rules_ml_toolchain//cc/deps:cc_toolchain_deps.bzl",
121-
"cc_toolchain_deps",
122-
)
123-
124-
cc_toolchain_deps()
125-
126-
register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64")
127-
128-
register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64_cuda")
129-
130133
load(
131134
"@xla//third_party/py:python_wheel.bzl",
132135
"python_wheel_version_suffix_repository",

xprof/utils/hlo_proto_to_module.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ absl::StatusOr<std::unique_ptr<xla::HloModule>> ConvertHloProtoToModule(
3636
const xla::HloModuleProto& module_proto = hlo_proto.hlo_module();
3737
TF_ASSIGN_OR_RETURN(auto config, xla::HloModule::CreateModuleConfigFromProto(
3838
module_proto, xla::DebugOptions()));
39-
TF_ASSIGN_OR_RETURN(xla::HloModuleProto remapped_module_proto,
40-
xla::HloModule::RemapInstructionIds(module_proto));
4139
TF_ASSIGN_OR_RETURN(auto module, xla::HloModule::CreateFromProto(
42-
remapped_module_proto, config));
40+
module_proto, config,
41+
/*buffer_assignment_proto=*/nullptr,
42+
/*preserve_instruction_ids=*/false));
4343
return module;
4444
}
4545

xprof/utils/hlo_proto_to_module_test.cc

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ TEST(HloProtoToModuleTest, FixNonConsecutiveInstructionIds) {
5151
}
5252
}
5353
id: 4294967303
54-
operand_ids: 6
54+
operand_ids: 1
5555
}
5656
id: 1
5757
root_id: 4294967303
@@ -88,6 +88,102 @@ TEST(HloProtoToModuleTest, FixNonConsecutiveInstructionIds) {
8888
ElementsAre(Property(&xla::HloInstruction::local_id, 0),
8989
Property(&xla::HloInstruction::local_id, 1),
9090
Property(&xla::HloInstruction::local_id, 2)));
91+
// Check correct operand translation
92+
EXPECT_EQ(module->entry_computation()->parameter_instruction(0)->name(),
93+
"arg0.1");
94+
EXPECT_EQ(module->entry_computation()->parameter_instruction(0)->local_id(),
95+
0);
96+
EXPECT_THAT(
97+
module->entry_computation()->root_instruction()->operands(),
98+
ElementsAre(module->entry_computation()->parameter_instruction(0)));
99+
}
100+
101+
TEST(HloProtoToModuleTest, FixNonConsecutiveInstructionIdsForModule) {
102+
xla::HloProto hlo_proto;
103+
ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(
104+
R"pb(
105+
hlo_module {
106+
name: "some_module"
107+
entry_computation_name: "some_module"
108+
computations {
109+
name: "some_module"
110+
instructions {
111+
name: "arg0.1"
112+
opcode: "parameter"
113+
shape {
114+
element_type: S32
115+
layout { tail_padding_alignment_in_elements: 1 }
116+
}
117+
id: 4294967297
118+
}
119+
instructions {
120+
name: "arg1.1"
121+
opcode: "parameter"
122+
shape {
123+
element_type: S32
124+
layout { tail_padding_alignment_in_elements: 1 }
125+
}
126+
parameter_number: 1
127+
id: 4294967298
128+
}
129+
instructions {
130+
name: "XLA_Retvals.1"
131+
opcode: "tuple"
132+
shape {
133+
element_type: TUPLE
134+
tuple_shapes {
135+
element_type: S32
136+
layout { tail_padding_alignment_in_elements: 1 }
137+
}
138+
}
139+
id: 4294967303
140+
operand_ids: 1
141+
}
142+
id: 1
143+
root_id: 4294967303
144+
}
145+
host_program_shape {
146+
parameters {
147+
element_type: S32
148+
layout { tail_padding_alignment_in_elements: 1 }
149+
}
150+
parameters {
151+
element_type: S32
152+
layout { tail_padding_alignment_in_elements: 1 }
153+
}
154+
result {
155+
element_type: TUPLE
156+
tuple_shapes {
157+
element_type: S32
158+
layout { tail_padding_alignment_in_elements: 1 }
159+
}
160+
}
161+
parameter_names: "arg0"
162+
parameter_names: "arg1"
163+
}
164+
id: 1
165+
entry_computation_id: 1
166+
}
167+
)pb",
168+
&hlo_proto));
169+
170+
171+
ASSERT_OK_AND_ASSIGN(auto module,
172+
ConvertHloProtoToModule(hlo_proto));
173+
EXPECT_EQ(module->entry_computation()->instruction_count(), 3);
174+
// Check that ids are consecutive
175+
EXPECT_THAT(module->entry_computation()->instructions(),
176+
ElementsAre(Property(&xla::HloInstruction::local_id, 0),
177+
Property(&xla::HloInstruction::local_id, 1),
178+
Property(&xla::HloInstruction::local_id, 2)));
179+
// Check correct operand translation
180+
EXPECT_EQ(module->entry_computation()->parameter_instruction(0)->name(),
181+
"arg0.1");
182+
EXPECT_EQ(module->entry_computation()->parameter_instruction(0)->local_id(),
183+
0);
184+
EXPECT_THAT(
185+
module->entry_computation()->root_instruction()->operands(),
186+
ElementsAre(module->entry_computation()->parameter_instruction(0)));
91187
}
92188

93189
} // namespace

0 commit comments

Comments
 (0)