Skip to content

Commit 33a10e4

Browse files
muditgokhale2copybara-github
authored andcommitted
Add all_hosts information to the session_snapshot and move the device collision logic for trace_viewer to CreateTraceEventsContainer.
PiperOrigin-RevId: 822654197
1 parent b2d480d commit 33a10e4

18 files changed

+162
-97
lines changed

plugin/xprof/convert/raw_to_tool_data.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,13 @@ def process_raw_trace(raw_trace):
4141
return ''.join(trace_events_json.TraceEventsJsonStream(trace))
4242

4343

44-
def xspace_to_tools_data_from_byte_string(xspace_byte_list, filenames, tool,
45-
params):
44+
def xspace_to_tools_data_from_byte_string(xspace_byte_list, all_hosts,
45+
filenames, tool, params):
4646
"""Helper function for getting an XSpace tool from a bytes string.
4747
4848
Args:
4949
xspace_byte_list: A list of byte strings read from a XSpace proto file.
50+
all_hosts: A list of all hosts in the session.
5051
filenames: Names of the read files.
5152
tool: A string of tool name.
5253
params: user input parameters.
@@ -57,7 +58,7 @@ def xspace_to_tools_data_from_byte_string(xspace_byte_list, filenames, tool,
5758
# pylint:disable=dangerous-default-value
5859
def xspace_wrapper_func(xspace_arg, tool_arg, params={}):
5960
return _pywrap_profiler_plugin.xspace_to_tools_data_from_byte_string(
60-
xspace_arg, filenames, tool_arg, params)
61+
xspace_arg, all_hosts, filenames, tool_arg, params)
6162
# pylint:enable=dangerous-default-value
6263

6364
return xspace_to_tool_data(xspace_byte_list, tool, params,
@@ -73,22 +74,26 @@ def xspace_to_tool_names(xspace_paths):
7374
Returns:
7475
Returns a list of tool names.
7576
"""
77+
# xspace_to_tools_data expects all_hosts as the second argument, passing an
78+
# empty list.
7679
raw_data, success = _pywrap_profiler_plugin.xspace_to_tools_data(
77-
xspace_paths, 'tool_names')
80+
xspace_paths, [], 'tool_names', {})
7881
if success:
7982
return [tool for tool in raw_data.decode().split(',')]
8083
return []
8184

8285

8386
def xspace_to_tool_data(
8487
xspace_paths,
88+
all_hosts,
8589
tool,
8690
params,
8791
xspace_wrapper_func=_pywrap_profiler_plugin.xspace_to_tools_data):
8892
"""Converts XSpace to tool data string.
8993
9094
Args:
9195
xspace_paths: A list of XSpace paths.
96+
all_hosts: A list of all hosts in the session.
9297
tool: A string of tool name.
9398
params: user input parameters.
9499
xspace_wrapper_func: A callable that takes a list of strings and a tool and
@@ -112,26 +117,31 @@ def xspace_to_tool_data(
112117
if tool == 'trace_viewer':
113118
# Trace viewer handles one host at a time.
114119
assert len(xspace_paths) == 1
115-
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
120+
raw_data, success = xspace_wrapper_func(
121+
xspace_paths, all_hosts, tool, options)
116122
if success:
117123
data = process_raw_trace(raw_data)
118124
elif tool == 'trace_viewer@':
119125
options = params.get('trace_viewer_options', {})
120126
options['use_saved_result'] = params.get('use_saved_result', True)
121-
options['hosts'] = params.get('hosts', [])
122-
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
127+
options['hosts'] = all_hosts
128+
raw_data, success = xspace_wrapper_func(
129+
xspace_paths, all_hosts, tool, options)
123130
if success:
124131
data = raw_data
125132
elif tool == 'overview_page':
126-
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
133+
json_data, success = xspace_wrapper_func(
134+
xspace_paths, all_hosts, tool, options)
127135
if success:
128136
data = json_data
129137
elif tool == 'input_pipeline_analyzer':
130-
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
138+
json_data, success = xspace_wrapper_func(
139+
xspace_paths, all_hosts, tool, options)
131140
if success:
132141
data = json_data
133142
elif tool == 'framework_op_stats':
134-
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
143+
json_data, success = xspace_wrapper_func(
144+
xspace_paths, all_hosts, tool, options)
135145
if success:
136146
if tqx == 'out:csv':
137147
data = csv_writer.json_to_csv(json_data)
@@ -142,15 +152,16 @@ def xspace_to_tool_data(
142152
# TODO(b/419013992): Remove this tool completely as it has been deprecated
143153
legacy_tool = 'tensorflow_stats'
144154
json_data, success = xspace_wrapper_func(
145-
xspace_paths, legacy_tool, options
155+
xspace_paths, all_hosts, legacy_tool, options
146156
)
147157
if success:
148158
if tqx == 'out:csv':
149159
data = csv_writer.json_to_csv(json_data)
150160
else:
151161
data = json_data
152162
elif tool == 'kernel_stats':
153-
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
163+
json_data, success = xspace_wrapper_func(
164+
xspace_paths, all_hosts, tool, options)
154165
if success:
155166
if tqx == 'out:csv':
156167
data = csv_writer.json_to_csv(json_data)
@@ -159,37 +170,44 @@ def xspace_to_tool_data(
159170
elif tool == 'memory_profile':
160171
# Memory profile handles one host at a time.
161172
assert len(xspace_paths) == 1
162-
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
173+
raw_data, success = xspace_wrapper_func(
174+
xspace_paths, all_hosts, tool, options)
163175
if success:
164176
data = raw_data
165177
elif tool == 'pod_viewer':
166-
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
178+
raw_data, success = xspace_wrapper_func(
179+
xspace_paths, all_hosts, tool, options)
167180
if success:
168181
data = raw_data
169182
elif tool == 'op_profile':
170183
options['group_by'] = params.get('group_by', 'program')
171-
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
184+
raw_data, success = xspace_wrapper_func(
185+
xspace_paths, all_hosts, tool, options)
172186
if success:
173187
data = raw_data
174188
elif tool == 'hlo_op_profile':
175189
options['group_by'] = params.get('group_by', 'program')
176-
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
190+
raw_data, success = xspace_wrapper_func(
191+
xspace_paths, all_hosts, tool, options)
177192
if success:
178193
data = raw_data
179194
elif tool == 'hlo_stats':
180-
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
195+
json_data, success = xspace_wrapper_func(
196+
xspace_paths, all_hosts, tool, options)
181197
if success:
182198
data = json_data
183199
elif tool == 'roofline_model':
184-
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
200+
json_data, success = xspace_wrapper_func(
201+
xspace_paths, all_hosts, tool, options)
185202
if success:
186203
data = json_data
187204
elif tool == 'graph_viewer':
188205
download_hlo_types = ['pb', 'pbtxt', 'json', 'short_txt', 'long_txt']
189206
graph_html_type = 'graph'
190207
options = params.get('graph_viewer_options', {})
191208
options['use_saved_result'] = params.get('use_saved_result', True)
192-
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
209+
raw_data, success = xspace_wrapper_func(
210+
xspace_paths, all_hosts, tool, options)
193211
if success:
194212
data = raw_data
195213
content_type = 'text/plain'
@@ -213,18 +231,21 @@ def xspace_to_tool_data(
213231
'view_memory_allocation_timeline': view_memory_allocation_timeline,
214232
'memory_space': params.get('memory_space', ''),
215233
}
216-
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
234+
raw_data, success = xspace_wrapper_func(
235+
xspace_paths, all_hosts, tool, options)
217236
if success:
218237
data = raw_data
219238
if view_memory_allocation_timeline:
220239
content_type = 'text/html'
221240
elif tool == 'megascale_stats':
222241
options = {'host_name': params.get('host')}
223-
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
242+
json_data, success = xspace_wrapper_func(
243+
xspace_paths, all_hosts, tool, options)
224244
if success:
225245
data = json_data
226246
elif tool == 'inference_profile':
227-
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
247+
json_data, success = xspace_wrapper_func(
248+
xspace_paths, all_hosts, tool, options)
228249
if success:
229250
data = json_data
230251
else:

plugin/xprof/convert/raw_to_tool_data_test.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ def test_using_old_tool_format_maps_to_new_format(self):
2727
xspace_paths=["/path/to/xspace"],
2828
tool="trace_viewer@^",
2929
params={},
30-
xspace_wrapper_func=lambda paths, tool, options: (tool.encode(), True),
30+
all_hosts=[],
31+
xspace_wrapper_func=lambda paths, hosts, tool, options: (
32+
tool.encode(),
33+
True,
34+
),
3135
)
3236

3337
self.assertEqual(data, b"trace_viewer@")
@@ -38,7 +42,11 @@ def test_using_new_tool_format_does_not_map_to_old_format(self):
3842
xspace_paths=["/path/to/xspace"],
3943
tool="trace_viewer@",
4044
params={},
41-
xspace_wrapper_func=lambda paths, tool, options: (tool.encode(), True),
45+
all_hosts=[],
46+
xspace_wrapper_func=lambda paths, hosts, tool, options: (
47+
tool.encode(),
48+
True,
49+
),
4250
)
4351

4452
self.assertEqual(data, b"trace_viewer@")

plugin/xprof/profile_plugin.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ def hlo_module_list_route(
707707

708708
def _get_valid_hosts(
709709
self, run_dir: str, run: str, tool: str, hosts_param: str, host: str
710-
) -> tuple[List[str], List[epath.Path]]:
710+
) -> tuple[List[str], List[epath.Path], List[str]]:
711711
"""Retrieves and validates the hosts and asset paths for a run and tool.
712712
713713
Args:
@@ -718,7 +718,7 @@ def _get_valid_hosts(
718718
host: The single host parameter.
719719
720720
Returns:
721-
A tuple containing (selected_hosts, asset_paths).
721+
A tuple containing (selected_hosts, asset_paths, all_hosts).
722722
723723
Raises:
724724
FileNotFoundError: If a required xplane file for the specified host(s)
@@ -781,7 +781,9 @@ def _get_valid_hosts(
781781
'Host must be specified for tool %s in run %s' % (tool, run)
782782
)
783783

784-
return selected_hosts, asset_paths
784+
all_hosts = list(all_xplane_files.keys())
785+
786+
return selected_hosts, asset_paths, all_hosts
785787

786788
def data_impl(
787789
self, request: wrappers.Request
@@ -870,7 +872,7 @@ def data_impl(
870872

871873
_, content_encoding = None, None
872874
if use_xplane(tool):
873-
selected_hosts, asset_paths = self._get_valid_hosts(
875+
selected_hosts, asset_paths, all_hosts = self._get_valid_hosts(
874876
run_dir, run, tool, hosts_param, host
875877
)
876878
if not asset_paths:
@@ -879,7 +881,7 @@ def data_impl(
879881
params['hosts'] = selected_hosts
880882
try:
881883
data, content_type = convert.xspace_to_tool_data(
882-
asset_paths, tool, params)
884+
asset_paths, all_hosts, tool, params)
883885
except AttributeError as e:
884886
logger.warning('Error generating analysis results due to %s', e)
885887
raise AttributeError(

plugin/xprof/profile_plugin_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def testDataImplTraceViewerOptions(self, mock_xspace_to_tool_data):
465465
)
466466

467467
mock_xspace_to_tool_data.assert_called_once_with(
468-
[mock.ANY], 'trace_viewer@', expected_params
468+
[mock.ANY], ['host0', 'host1'], 'trace_viewer@', expected_params
469469
)
470470
args, _ = mock_xspace_to_tool_data.call_args
471471
actual_path_list = args[0]

xprof/convert/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ cc_library(
190190
":repository",
191191
":tool_options",
192192
":xplane_to_trace_container",
193+
"@com_google_absl//absl/container:flat_hash_map",
193194
"@com_google_absl//absl/log",
194195
"@com_google_absl//absl/status",
195196
"@com_google_absl//absl/status:statusor",

xprof/convert/repository.cc

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ static auto* kHostDataSuffixes =
5858

5959
absl::StatusOr<SessionSnapshot> SessionSnapshot::Create(
6060
std::vector<std::string> xspace_paths,
61-
std::optional<std::vector<std::unique_ptr<XSpace>>> xspaces) {
61+
std::optional<std::vector<std::unique_ptr<XSpace>>> xspaces,
62+
std::optional<std::vector<std::string>> all_hosts) {
6263
if (xspace_paths.empty()) {
6364
return absl::InvalidArgumentError("Can not find XSpace path.");
6465
}
@@ -85,7 +86,26 @@ absl::StatusOr<SessionSnapshot> SessionSnapshot::Create(
8586
}
8687
}
8788

88-
return SessionSnapshot(std::move(xspace_paths), std::move(xspaces));
89+
return SessionSnapshot(std::move(xspace_paths), std::move(xspaces),
90+
std::move(all_hosts));
91+
}
92+
93+
SessionSnapshot::SessionSnapshot(
94+
std::vector<std::string> xspace_paths,
95+
std::optional<std::vector<std::unique_ptr<XSpace>>> xspaces,
96+
std::optional<std::vector<std::string>> all_hosts)
97+
: xspace_paths_(std::move(xspace_paths)),
98+
all_hosts_(std::move(all_hosts)),
99+
// If the snapshot was initialized by xspaces, the file path and run dir
100+
// is a path tensorflow can't read from or write to so any file IO
101+
// encapsulated in this class will be disabled in this mode.
102+
has_accessible_run_dir_(!xspaces.has_value()),
103+
xspaces_(std::move(xspaces)) {
104+
session_run_dir_ = tsl::io::Dirname(xspace_paths_.at(0));
105+
for (size_t i = 0; i < xspace_paths_.size(); ++i) {
106+
std::string host_name = GetHostname(i);
107+
hostname_map_[host_name] = i;
108+
}
89109
}
90110

91111
absl::StatusOr<XSpace*> SessionSnapshot::GetXSpace(size_t index,
@@ -126,6 +146,10 @@ std::string SessionSnapshot::GetHostname(size_t index) const {
126146
return GetHostnameByPath(xspace_paths_.at(index));
127147
}
128148

149+
std::optional<std::vector<std::string>> SessionSnapshot::GetAllHosts() const {
150+
return all_hosts_;
151+
}
152+
129153
std::optional<std::string> SessionSnapshot::GetFilePath(
130154
absl::string_view toolname, absl::string_view hostname) const {
131155
if (!has_accessible_run_dir_) return std::nullopt;

xprof/convert/repository.h

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ class SessionSnapshot {
5959
// profiler plugin.
6060
static absl::StatusOr<SessionSnapshot> Create(
6161
std::vector<std::string> xspace_paths,
62-
std::optional<std::vector<std::unique_ptr<XSpace>>> xspaces);
62+
std::optional<std::vector<std::unique_ptr<XSpace>>> xspaces,
63+
std::optional<std::vector<std::string>> all_hosts = std::nullopt);
6364

6465
// Returns the number of XSpaces in the profile session.
6566
size_t XSpaceSize() const { return xspace_paths_.size(); }
@@ -76,6 +77,9 @@ class SessionSnapshot {
7677
// Gets host name.
7778
std::string GetHostname(size_t index) const;
7879

80+
// Gets all host names.
81+
std::optional<std::vector<std::string>> GetAllHosts() const;
82+
7983
// Gets the run directory of the profile session.
8084
absl::string_view GetSessionRunDir() const { return session_run_dir_; }
8185

@@ -142,22 +146,15 @@ class SessionSnapshot {
142146

143147
private:
144148
SessionSnapshot(std::vector<std::string> xspace_paths,
145-
std::optional<std::vector<std::unique_ptr<XSpace>>> xspaces)
146-
: xspace_paths_(std::move(xspace_paths)),
147-
// If the snapshot was initialized by xspaces, the file path and run dir
148-
// is a path tensorflow can't read from or write to so any file IO
149-
// encapsulated in this class will be disabled in this mode.
150-
has_accessible_run_dir_(!xspaces.has_value()),
151-
xspaces_(std::move(xspaces)) {
152-
session_run_dir_ = tsl::io::Dirname(xspace_paths_.at(0));
153-
for (size_t i = 0; i < xspace_paths_.size(); ++i) {
154-
std::string host_name = GetHostname(i);
155-
hostname_map_[host_name] = i;
156-
}
157-
}
149+
std::optional<std::vector<std::unique_ptr<XSpace>>> xspaces,
150+
std::optional<std::vector<std::string>> all_hosts);
158151

159152
// File paths to XSpace protos.
160153
std::vector<std::string> xspace_paths_;
154+
155+
// All hosts in the session.
156+
std::optional<std::vector<std::string>> all_hosts_;
157+
161158
// The run directory of the profile session.
162159
absl::string_view session_run_dir_;
163160

0 commit comments

Comments
 (0)