diff --git a/frontend/app/common/interfaces/navigation_event.ts b/frontend/app/common/interfaces/navigation_event.ts
index e7df483ac..2cd2b79e1 100644
--- a/frontend/app/common/interfaces/navigation_event.ts
+++ b/frontend/app/common/interfaces/navigation_event.ts
@@ -5,6 +5,8 @@ export declare interface NavigationEvent {
run?: string;
tag?: string;
host?: string;
+ // Added to support multi-host functionality for trace_viewer.
+ hosts?: string[];
// Graph Viewer crosslink params
opName?: string;
moduleName?: string;
diff --git a/frontend/app/components/sidenav/sidenav.ng.html b/frontend/app/components/sidenav/sidenav.ng.html
index 53761e457..81613e06c 100644
--- a/frontend/app/components/sidenav/sidenav.ng.html
+++ b/frontend/app/components/sidenav/sidenav.ng.html
@@ -32,10 +32,17 @@
- Hosts ({{hosts.length}})
+ Hosts ({{ isMultiHostsEnabled ? selectedHostsInternal.length : hosts.length }})
-
+
+
+
+ {{host}}
+
+
+
+
{{host}}
diff --git a/frontend/app/components/sidenav/sidenav.ts b/frontend/app/components/sidenav/sidenav.ts
index e4e759e54..db1cf9e48 100644
--- a/frontend/app/components/sidenav/sidenav.ts
+++ b/frontend/app/components/sidenav/sidenav.ts
@@ -32,8 +32,10 @@ export class SideNav implements OnInit, OnDestroy {
selectedRunInternal = '';
selectedTagInternal = '';
selectedHostInternal = '';
+ selectedHostsInternal: string[] = [];
selectedModuleInternal = '';
navigationParams: {[key: string]: string|boolean} = {};
+ multiHostEnabledTools: string[] = ['trace_viewer', 'trace_viewer@'];
hideCaptureProfileButton = false;
@@ -65,6 +67,11 @@ export class SideNav implements OnInit, OnDestroy {
return HLO_TOOLS.includes(this.selectedTag);
}
+ get isMultiHostsEnabled() {
+ const tag = this.selectedTag || '';
+ return this.multiHostEnabledTools.includes(tag);
+ }
+
// Getter for valid run given url router or user selection.
get selectedRun() {
return this.runs.find(validRun => validRun === this.selectedRunInternal) ||
@@ -90,6 +97,10 @@ export class SideNav implements OnInit, OnDestroy {
this.moduleList[0] || '';
}
+ get selectedHosts() {
+ return this.selectedHostsInternal;
+ }
+
// https://github.com/angular/angular/issues/11023#issuecomment-752228784
mergeRouteParams(): Map {
const params = new Map();
@@ -119,20 +130,21 @@ export class SideNav implements OnInit, OnDestroy {
const run = params.get('run') || '';
const tag = params.get('tool') || params.get('tag') || '';
const host = params.get('host') || '';
+ const hostsParam = params.get('hosts');
const opName = params.get('node_name') || params.get('opName') || '';
const moduleName = params.get('module_name') || '';
this.navigationParams['firstLoad'] = true;
if (opName) {
this.navigationParams['opName'] = opName;
}
- if (this.selectedRunInternal === run && this.selectedTagInternal === tag &&
- this.selectedHostInternal === host) {
- return;
- }
this.selectedRunInternal = run;
this.selectedTagInternal = tag;
- this.selectedHostInternal = host;
this.selectedModuleInternal = moduleName;
+
+ if (hostsParam) {
+ this.selectedHostsInternal = hostsParam.split(',');
+ }
+ this.selectedHostInternal = host;
this.update();
}
@@ -153,9 +165,13 @@ export class SideNav implements OnInit, OnDestroy {
const navigationEvent: NavigationEvent = {
run: this.selectedRun,
tag: this.selectedTag,
- host: this.selectedHost,
...this.navigationParams,
};
+ if (this.isMultiHostsEnabled) {
+ navigationEvent.hosts = this.selectedHosts;
+ } else {
+ navigationEvent.host = this.selectedHost;
+ }
if (this.is_hlo_tool) {
navigationEvent.moduleName = this.selectedModule;
}
@@ -255,6 +271,8 @@ export class SideNav implements OnInit, OnDestroy {
// Keep them under the same update function as initial step of the separation.
async updateHosts() {
this.hosts = await this.getHostsForSelectedTag();
+ this.selectedHostsInternal = [this.hosts[0]];
+ this.selectedHostInternal = this.hosts[0];
if (this.is_hlo_tool) {
this.moduleList = await this.getModuleListForSelectedTag();
}
@@ -262,8 +280,15 @@ export class SideNav implements OnInit, OnDestroy {
this.afterUpdateHost();
}
- onHostSelectionChange(host: string) {
- this.selectedHostInternal = host;
+ onHostSelectionChange(selection: string) {
+ this.selectedHostInternal = selection;
+ this.selectedHostsInternal = [];
+ this.navigateTools();
+ }
+
+ onHostsSelectionChange(selection: string[]) {
+ this.selectedHostsInternal = selection;
+ this.selectedHostInternal = ''; // Ensure single-host is empty
this.navigateTools();
}
@@ -276,26 +301,65 @@ export class SideNav implements OnInit, OnDestroy {
this.navigateTools();
}
+ // Helper function to serialize query parameters
+ private serializeQueryParams(
+ params: {[key: string]: string|string[]|boolean|undefined}): string {
+ const searchParams = new URLSearchParams();
+ for (const key in params) {
+ if (params.hasOwnProperty(key)) {
+ const value = params[key];
+ // Only include non-null/non-undefined values
+ if (value !== undefined && value !== null) {
+ if (Array.isArray(value)) {
+ // Arrays are handled as comma-separated strings (like 'hosts')
+ searchParams.set(key, value.join(','));
+ } else if (typeof value === 'boolean') {
+ // Only set boolean flags if they are explicitly true
+ if (value === true) {
+ searchParams.set(key, 'true');
+ }
+ } else {
+ searchParams.set(key, String(value));
+ }
+ }
+ }
+ }
+ const queryString = searchParams.toString();
+ return queryString ? `?${queryString}` : '';
+ }
+
updateUrlHistory() {
- // TODO(xprof): change to camel case when constructing url
- const toolQueryParams = Object.keys(this.navigationParams)
- .map(key => {
- return `${key}=${this.navigationParams[key]}`;
- })
- .join('&');
- const toolQueryParamsString =
- toolQueryParams.length ? `&${toolQueryParams}` : '';
- const moduleNameQuery =
- this.is_hlo_tool ? `&module_name=${this.selectedModule}` : '';
- const url = `${window.parent.location.origin}?tool=${
- this.selectedTag}&host=${this.selectedHost}&run=${this.selectedRun}${
- toolQueryParamsString}${moduleNameQuery}#profile`;
+ const navigationEvent = this.getNavigationEvent();
+ const queryParams: {[key: string]: string|string[]|boolean|
+ undefined} = {...navigationEvent};
+
+ if (this.isMultiHostsEnabled) {
+ // For Trace Viewer, ensure 'hosts' is a comma-separated string in the URL
+ if (queryParams['hosts'] && Array.isArray(queryParams['hosts'])) {
+ queryParams['hosts'] = (queryParams['hosts'] as string[]).join(',');
+ }
+ delete queryParams['host']; // Remove single host param
+ } else {
+ // For other tools, ensure 'host' is used
+ delete queryParams['hosts']; // Remove multi-host param
+ }
+
+ // Get current path to avoid changing the base URL
+ const pathname = window.parent.location.pathname;
+
+ // Use the custom serialization helper
+ const queryString = this.serializeQueryParams(queryParams);
+ const url = pathname + queryString;
+
window.parent.history.pushState({}, '', url);
}
navigateTools() {
const navigationEvent = this.getNavigationEvent();
this.communicationService.onNavigateReady(navigationEvent);
+
+ // This router.navigate call remains, as it's responsible for Angular
+ // routing
this.router.navigate(
[
this.selectedTag || 'empty',
diff --git a/frontend/app/components/trace_viewer/trace_viewer.ts b/frontend/app/components/trace_viewer/trace_viewer.ts
index 9eb5f0ebf..4ade3c9c4 100644
--- a/frontend/app/components/trace_viewer/trace_viewer.ts
+++ b/frontend/app/components/trace_viewer/trace_viewer.ts
@@ -1,5 +1,4 @@
import {PlatformLocation} from '@angular/common';
-import {HttpParams} from '@angular/common/http';
import {Component, inject, Injector, OnDestroy} from '@angular/core';
import {ActivatedRoute} from '@angular/router';
import {API_PREFIX, DATA_API, PLUGIN_NAME} from 'org_xprof/frontend/app/common/constants/constants';
@@ -38,11 +37,19 @@ export class TraceViewer implements OnDestroy {
update(event: NavigationEvent) {
const isStreaming = (event.tag === 'trace_viewer@');
- const params = new HttpParams()
- .set('run', event.run!)
- .set('tag', event.tag!)
- .set('host', event.host!);
- const traceDataUrl = this.pathPrefix + DATA_API + '?' + params.toString();
+ const run = event.run || '';
+ const tag = event.tag || '';
+
+ let queryString = `run=${run}&tag=${tag}`;
+
+ if (event.hosts && typeof event.hosts === 'string') {
+ // Since event.hosts is a comma-separated string, we can use it directly.
+ queryString += `&hosts=${event.hosts}`;
+ } else if (event.host) {
+ queryString += `&host=${event.host}`;
+ }
+
+ const traceDataUrl = `${this.pathPrefix}${DATA_API}?${queryString}`;
this.url = this.pathPrefix + API_PREFIX + PLUGIN_NAME +
'/trace_viewer_index.html' +
'?is_streaming=' + isStreaming.toString() + '&is_oss=true' +
diff --git a/plugin/xprof/convert/raw_to_tool_data.py b/plugin/xprof/convert/raw_to_tool_data.py
index 04b4f5b19..4e42e2348 100644
--- a/plugin/xprof/convert/raw_to_tool_data.py
+++ b/plugin/xprof/convert/raw_to_tool_data.py
@@ -41,29 +41,6 @@ def process_raw_trace(raw_trace):
return ''.join(trace_events_json.TraceEventsJsonStream(trace))
-def xspace_to_tools_data_from_byte_string(xspace_byte_list, filenames, tool,
- params):
- """Helper function for getting an XSpace tool from a bytes string.
-
- Args:
- xspace_byte_list: A list of byte strings read from a XSpace proto file.
- filenames: Names of the read files.
- tool: A string of tool name.
- params: user input parameters.
-
- Returns:
- Returns a string of tool data.
- """
-# pylint:disable=dangerous-default-value
- def xspace_wrapper_func(xspace_arg, tool_arg, params={}):
- return _pywrap_profiler_plugin.xspace_to_tools_data_from_byte_string(
- xspace_arg, filenames, tool_arg, params)
-# pylint:enable=dangerous-default-value
-
- return xspace_to_tool_data(xspace_byte_list, tool, params,
- xspace_wrapper_func)
-
-
def xspace_to_tool_names(xspace_paths):
"""Converts XSpace to all the available tool names.
@@ -73,8 +50,10 @@ def xspace_to_tool_names(xspace_paths):
Returns:
Returns a list of tool names.
"""
+ # xspace_to_tools_data expects all_hosts as the second argument, passing an
+ # empty list.
raw_data, success = _pywrap_profiler_plugin.xspace_to_tools_data(
- xspace_paths, 'tool_names')
+ xspace_paths, [], 'tool_names', {})
if success:
return [tool for tool in raw_data.decode().split(',')]
return []
@@ -82,6 +61,7 @@ def xspace_to_tool_names(xspace_paths):
def xspace_to_tool_data(
xspace_paths,
+ all_hosts,
tool,
params,
xspace_wrapper_func=_pywrap_profiler_plugin.xspace_to_tools_data):
@@ -89,6 +69,7 @@ def xspace_to_tool_data(
Args:
xspace_paths: A list of XSpace paths.
+ all_hosts: A list of all hosts in the session.
tool: A string of tool name.
params: user input parameters.
xspace_wrapper_func: A callable that takes a list of strings and a tool and
@@ -112,27 +93,31 @@ def xspace_to_tool_data(
if tool == 'trace_viewer':
# Trace viewer handles one host at a time.
assert len(xspace_paths) == 1
- raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
+ raw_data, success = xspace_wrapper_func(
+ xspace_paths, all_hosts, tool, options)
if success:
data = process_raw_trace(raw_data)
elif tool == 'trace_viewer@':
- # Streaming trace viewer handles one host at a time.
- assert len(xspace_paths) == 1
options = params.get('trace_viewer_options', {})
options['use_saved_result'] = params.get('use_saved_result', True)
- raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
+ options['hosts'] = all_hosts
+ raw_data, success = xspace_wrapper_func(
+ xspace_paths, all_hosts, tool, options)
if success:
data = raw_data
elif tool == 'overview_page':
- json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
+ json_data, success = xspace_wrapper_func(
+ xspace_paths, all_hosts, tool, options)
if success:
data = json_data
elif tool == 'input_pipeline_analyzer':
- json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
+ json_data, success = xspace_wrapper_func(
+ xspace_paths, all_hosts, tool, options)
if success:
data = json_data
elif tool == 'framework_op_stats':
- json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
+ json_data, success = xspace_wrapper_func(
+ xspace_paths, all_hosts, tool, options)
if success:
if tqx == 'out:csv':
data = csv_writer.json_to_csv(json_data)
@@ -143,7 +128,7 @@ def xspace_to_tool_data(
# TODO(b/419013992): Remove this tool completely as it has been deprecated
legacy_tool = 'tensorflow_stats'
json_data, success = xspace_wrapper_func(
- xspace_paths, legacy_tool, options
+ xspace_paths, all_hosts, legacy_tool, options
)
if success:
if tqx == 'out:csv':
@@ -151,7 +136,8 @@ def xspace_to_tool_data(
else:
data = json_data
elif tool == 'kernel_stats':
- json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
+ json_data, success = xspace_wrapper_func(
+ xspace_paths, all_hosts, tool, options)
if success:
if tqx == 'out:csv':
data = csv_writer.json_to_csv(json_data)
@@ -160,29 +146,35 @@ def xspace_to_tool_data(
elif tool == 'memory_profile':
# Memory profile handles one host at a time.
assert len(xspace_paths) == 1
- raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
+ raw_data, success = xspace_wrapper_func(
+ xspace_paths, all_hosts, tool, options)
if success:
data = raw_data
elif tool == 'pod_viewer':
- raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
+ raw_data, success = xspace_wrapper_func(
+ xspace_paths, all_hosts, tool, options)
if success:
data = raw_data
elif tool == 'op_profile':
options['group_by'] = params.get('group_by', 'program')
- raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
+ raw_data, success = xspace_wrapper_func(
+ xspace_paths, all_hosts, tool, options)
if success:
data = raw_data
elif tool == 'hlo_op_profile':
options['group_by'] = params.get('group_by', 'program')
- raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
+ raw_data, success = xspace_wrapper_func(
+ xspace_paths, all_hosts, tool, options)
if success:
data = raw_data
elif tool == 'hlo_stats':
- json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
+ json_data, success = xspace_wrapper_func(
+ xspace_paths, all_hosts, tool, options)
if success:
data = json_data
elif tool == 'roofline_model':
- json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
+ json_data, success = xspace_wrapper_func(
+ xspace_paths, all_hosts, tool, options)
if success:
data = json_data
elif tool == 'graph_viewer':
@@ -190,7 +182,8 @@ def xspace_to_tool_data(
graph_html_type = 'graph'
options = params.get('graph_viewer_options', {})
options['use_saved_result'] = params.get('use_saved_result', True)
- raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
+ raw_data, success = xspace_wrapper_func(
+ xspace_paths, all_hosts, tool, options)
if success:
data = raw_data
content_type = 'text/plain'
@@ -214,18 +207,21 @@ def xspace_to_tool_data(
'view_memory_allocation_timeline': view_memory_allocation_timeline,
'memory_space': params.get('memory_space', ''),
}
- raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
+ raw_data, success = xspace_wrapper_func(
+ xspace_paths, all_hosts, tool, options)
if success:
data = raw_data
if view_memory_allocation_timeline:
content_type = 'text/html'
elif tool == 'megascale_stats':
options = {'host_name': params.get('host')}
- json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
+ json_data, success = xspace_wrapper_func(
+ xspace_paths, all_hosts, tool, options)
if success:
data = json_data
elif tool == 'inference_profile':
- json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
+ json_data, success = xspace_wrapper_func(
+ xspace_paths, all_hosts, tool, options)
if success:
data = json_data
else:
diff --git a/plugin/xprof/convert/raw_to_tool_data_test.py b/plugin/xprof/convert/raw_to_tool_data_test.py
index a0a3df88d..95e6e1ca3 100644
--- a/plugin/xprof/convert/raw_to_tool_data_test.py
+++ b/plugin/xprof/convert/raw_to_tool_data_test.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Tests for the raw_to_tool_data module."""
import tensorflow as tf
@@ -27,7 +26,11 @@ def test_using_old_tool_format_maps_to_new_format(self):
xspace_paths=["/path/to/xspace"],
tool="trace_viewer@^",
params={},
- xspace_wrapper_func=lambda paths, tool, options: (tool.encode(), True),
+ all_hosts=[],
+ xspace_wrapper_func=lambda paths, hosts, tool, options: (
+ tool.encode(),
+ True,
+ ),
)
self.assertEqual(data, b"trace_viewer@")
@@ -38,7 +41,11 @@ def test_using_new_tool_format_does_not_map_to_old_format(self):
xspace_paths=["/path/to/xspace"],
tool="trace_viewer@",
params={},
- xspace_wrapper_func=lambda paths, tool, options: (tool.encode(), True),
+ all_hosts=[],
+ xspace_wrapper_func=lambda paths, hosts, tool, options: (
+ tool.encode(),
+ True,
+ ),
)
self.assertEqual(data, b"trace_viewer@")
diff --git a/plugin/xprof/profile_plugin.py b/plugin/xprof/profile_plugin.py
index b949ea5a8..0c5a4a899 100644
--- a/plugin/xprof/profile_plugin.py
+++ b/plugin/xprof/profile_plugin.py
@@ -380,23 +380,6 @@ def filenames_to_hosts(filenames: list[str], tool: str) -> list[str]:
return sorted(hosts)
-def validate_xplane_asset_paths(asset_paths: List[str]) -> None:
- """Validates that all xplane asset paths that are provided are valid files.
-
- Args:
- asset_paths: A list of asset paths.
-
- Raises:
- FileNotFoundError: If any of the xplane asset paths do not exist.
- """
- for asset_path in asset_paths:
- if (
- str(asset_path).endswith(TOOLS['xplane'])
- and not epath.Path(asset_path).exists()
- ):
- raise FileNotFoundError(f'Invalid asset path: {asset_path}')
-
-
def _get_bool_arg(
args: Mapping[str, Any], arg_name: str, default: bool
) -> bool:
@@ -511,6 +494,10 @@ def is_active(self) -> bool:
self._is_active = any(self.generate_runs())
return self._is_active
+ def _does_tool_support_multi_hosts_processing(self, tool: str) -> bool:
+ """Returns true if the tool supports multi-hosts processing."""
+ return tool == 'trace_viewer@' or tool == 'trace_viewer'
+
def get_plugin_apps(
self,
) -> dict[str, Callable[[wrappers.Request], wrappers.Response]]:
@@ -718,6 +705,85 @@ def hlo_module_list_route(
module_names_str = self.hlo_module_list_impl(request)
return respond(module_names_str, 'text/plain')
+ def _get_valid_hosts(
+ self, run_dir: str, run: str, tool: str, hosts_param: str, host: str
+ ) -> tuple[List[str], List[epath.Path], List[str]]:
+ """Retrieves and validates the hosts and asset paths for a run and tool.
+
+ Args:
+ run_dir: The run directory.
+ run: The frontend run name.
+ tool: The requested tool.
+ hosts_param: Comma-separated list of selected hosts.
+ host: The single host parameter.
+
+ Returns:
+ A tuple containing (selected_hosts, asset_paths).
+
+ Raises:
+ FileNotFoundError: If a required xplane file for the specified host(s)
+ is not found.
+ IOError: If there is an error reading asset directories.
+ """
+ asset_paths = []
+ selected_hosts = []
+ all_xplane_files = {} # Map host to path
+
+ # Find all available xplane files for the run and map them by host.
+ file_pattern = make_filename('*', 'xplane')
+ try:
+ path = epath.Path(run_dir)
+ for xplane_path in path.glob(file_pattern):
+ host_name, _ = _parse_filename(xplane_path.name)
+ if host_name:
+ print('host_name: %s', host_name)
+ all_xplane_files[host_name] = xplane_path
+ except OSError as e:
+ print('Error')
+ logger.warning('Cannot read asset directory: %s, OpError %s', run_dir, e)
+ raise IOError(
+ 'Cannot read asset directory: %s, OpError %s' % (run_dir, e)
+ ) from e
+
+ if hosts_param and self._does_tool_support_multi_hosts_processing(tool):
+ selected_hosts = hosts_param.split(',')
+ for selected_host in selected_hosts:
+ if selected_host in all_xplane_files:
+ asset_paths.append(all_xplane_files[selected_host])
+ else:
+ raise FileNotFoundError(
+ 'No xplane file found for host: %s in run: %s'
+ % (selected_host, run)
+ )
+ logger.info('Inside trace_viewer@, asset_paths: %s')
+ elif host == ALL_HOSTS:
+ asset_paths = list(all_xplane_files.values())
+ selected_hosts = list(all_xplane_files.keys())
+ elif host and host in all_xplane_files:
+ selected_hosts = [host]
+ asset_paths = [all_xplane_files[host]]
+ elif host:
+ logger.warning('No xplane file found for host: %s in run: %s', host, run)
+ if host not in XPLANE_TOOLS_ALL_HOSTS_ONLY:
+ raise FileNotFoundError(
+ 'No xplane file found for host: %s in run: %s' % (host, run)
+ )
+
+ if not asset_paths:
+ logger.warning(
+ 'No matching asset paths found for run %s, tool %s, host(s) %s / %s',
+ run,
+ tool,
+ hosts_param,
+ host,
+ )
+ if not host and tool not in XPLANE_TOOLS_ALL_HOSTS_ONLY:
+ raise FileNotFoundError(
+ 'Host must be specified for tool %s in run %s' % (tool, run)
+ )
+
+ return selected_hosts, asset_paths, list(all_xplane_files.keys())
+
def data_impl(
self, request: wrappers.Request
) -> tuple[Optional[str], str, Optional[str]]:
@@ -729,9 +795,17 @@ def data_impl(
Returns:
A string that can be served to the frontend tool or None if tool,
run or host is invalid.
+
+ Raises:
+ FileNotFoundError: If a required xplane file for the specified host(s)
+ is not found.
+ IOError: If there is an error reading asset directories.
+ AttributeError: If there is an error during xplane to tool data conversion
+ ValueError: If xplane conversion fails due to invalid data.
"""
run = request.args.get('run')
tool = request.args.get('tag')
+ hosts_param = request.args.get('hosts')
host = request.args.get('host')
module_name = request.args.get('module_name')
tqx = request.args.get('tqx')
@@ -795,28 +869,19 @@ def data_impl(
options['search_prefix'] = request.args.get('search_prefix')
params['trace_viewer_options'] = options
- asset_path = os.path.join(run_dir, make_filename(host, tool))
-
_, content_encoding = None, None
if use_xplane(tool):
- if host == ALL_HOSTS:
- file_pattern = make_filename('*', 'xplane')
- try:
- path = epath.Path(run_dir)
- asset_paths = list(path.glob(file_pattern))
- except OSError as e:
- logger.warning('Cannot read asset directory: %s, OpError %s', run_dir,
- e)
- raise IOError(
- 'Cannot read asset directory: %s, OpError %s' % (run_dir, e)
- ) from e
- else:
- asset_paths = [asset_path]
+ selected_hosts, asset_paths, all_hosts = self._get_valid_hosts(
+ run_dir, run, tool, hosts_param, host
+ )
+ if not asset_paths:
+ return None, content_type, None
+ params['hosts'] = selected_hosts
try:
- validate_xplane_asset_paths(asset_paths)
data, content_type = convert.xspace_to_tool_data(
- asset_paths, tool, params)
+ asset_paths, all_hosts, tool, params
+ )
except AttributeError as e:
logger.warning('Error generating analysis results due to %s', e)
raise AttributeError(
diff --git a/plugin/xprof/profile_plugin_test.py b/plugin/xprof/profile_plugin_test.py
index 0f44db974..30dba19e8 100644
--- a/plugin/xprof/profile_plugin_test.py
+++ b/plugin/xprof/profile_plugin_test.py
@@ -330,7 +330,9 @@ def testData(self):
with self.assertRaises(FileNotFoundError):
self.plugin.data_impl(
utils.make_data_request(
- utils.DataRequestOptions(run='a', tool='trace_viewer', host='')
+ utils.DataRequestOptions(
+ run='a/foo', tool='trace_viewer', host=''
+ )
)
)
@@ -445,6 +447,7 @@ def testDataImplTraceViewerOptions(self, mock_xspace_to_tool_data):
'start_time_ms': '100',
'end_time_ms': '200',
},
+ 'hosts': ['host1'],
}
_, _, _ = self.plugin.data_impl(
@@ -462,8 +465,11 @@ def testDataImplTraceViewerOptions(self, mock_xspace_to_tool_data):
)
mock_xspace_to_tool_data.assert_called_once_with(
- [expected_asset_path], 'trace_viewer@', expected_params
+ [mock.ANY], ['host0', 'host1'], 'trace_viewer@', expected_params
)
+ actual_path_list = mock_xspace_to_tool_data.call_args[0][0]
+ self.assertLen(actual_path_list, 1)
+ self.assertEqual(str(actual_path_list[0]), expected_asset_path)
def testActive(self):
@@ -535,8 +541,10 @@ def test_generate_runs_from_path_params_with_run_path(self):
# run3 is a file, not a directory, and should be ignored.
with open(os.path.join(run_path, 'run3'), 'w') as f:
f.write('dummy file')
+ with open(os.path.join(run2_path, 'host2.xplane.pb'), 'w') as f:
+ f.write('dummy xplane data for run2')
runs = list(self.plugin._generate_runs_from_path_params(run_path=run_path))
- self.assertListEqual(['run1'], runs)
+ self.assertListEqual(['run1', 'run2'], sorted(runs))
self.assertEqual(run_path, self.plugin.logdir)
def test_runs_impl_with_session(self):
diff --git a/xprof/convert/BUILD b/xprof/convert/BUILD
index 87b0426a1..3dc218a04 100644
--- a/xprof/convert/BUILD
+++ b/xprof/convert/BUILD
@@ -190,6 +190,7 @@ cc_library(
":repository",
":tool_options",
":xplane_to_trace_container",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
@@ -1592,12 +1593,14 @@ cc_library(
srcs = ["xplane_to_trace_container.cc"],
hdrs = ["xplane_to_trace_container.h"],
deps = [
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
"@org_xprof//plugin/xprof/protobuf:trace_events_proto_cc",
"@org_xprof//plugin/xprof/protobuf:trace_events_raw_proto_cc",
"@org_xprof//xprof/convert/trace_viewer:trace_event_arguments_builder",
"@org_xprof//xprof/convert/trace_viewer:trace_events",
"@org_xprof//xprof/convert/trace_viewer:trace_events_util",
+ "@org_xprof//xprof/convert/trace_viewer:trace_utils",
"@tsl//tsl/profiler/protobuf:xplane_proto_cc",
"@xla//xla/tsl/profiler/utils:tf_xplane_visitor",
"@xla//xla/tsl/profiler/utils:timespan",
diff --git a/xprof/convert/repository.cc b/xprof/convert/repository.cc
index f98a76749..e6887e9e0 100644
--- a/xprof/convert/repository.cc
+++ b/xprof/convert/repository.cc
@@ -58,7 +58,8 @@ static auto* kHostDataSuffixes =
absl::StatusOr SessionSnapshot::Create(
std::vector xspace_paths,
- std::optional>> xspaces) {
+ std::optional>> xspaces,
+ std::optional> all_hosts) {
if (xspace_paths.empty()) {
return absl::InvalidArgumentError("Can not find XSpace path.");
}
@@ -85,7 +86,26 @@ absl::StatusOr SessionSnapshot::Create(
}
}
- return SessionSnapshot(std::move(xspace_paths), std::move(xspaces));
+ return SessionSnapshot(std::move(xspace_paths), std::move(xspaces),
+ std::move(all_hosts));
+}
+
+SessionSnapshot::SessionSnapshot(
+ std::vector xspace_paths,
+ std::optional>> xspaces,
+ std::optional> all_hosts)
+ : xspace_paths_(std::move(xspace_paths)),
+ all_hosts_(std::move(all_hosts)),
+ // If the snapshot was initialized by xspaces, the file path and run dir
+ // is a path tensorflow can't read from or write to so any file IO
+ // encapsulated in this class will be disabled in this mode.
+ has_accessible_run_dir_(!xspaces.has_value()),
+ xspaces_(std::move(xspaces)) {
+ session_run_dir_ = tsl::io::Dirname(xspace_paths_.at(0));
+ for (size_t i = 0; i < xspace_paths_.size(); ++i) {
+ std::string host_name = GetHostname(i);
+ hostname_map_[host_name] = i;
+ }
}
absl::StatusOr SessionSnapshot::GetXSpace(size_t index,
@@ -126,6 +146,10 @@ std::string SessionSnapshot::GetHostname(size_t index) const {
return GetHostnameByPath(xspace_paths_.at(index));
}
+std::optional> SessionSnapshot::GetAllHosts() const {
+ return all_hosts_;
+}
+
std::optional SessionSnapshot::GetFilePath(
absl::string_view toolname, absl::string_view hostname) const {
if (!has_accessible_run_dir_) return std::nullopt;
diff --git a/xprof/convert/repository.h b/xprof/convert/repository.h
index 07b649378..46d7d40f4 100644
--- a/xprof/convert/repository.h
+++ b/xprof/convert/repository.h
@@ -8,7 +8,7 @@ You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+WITHOUTHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
@@ -59,7 +59,8 @@ class SessionSnapshot {
// profiler plugin.
static absl::StatusOr Create(
std::vector xspace_paths,
- std::optional>> xspaces);
+ std::optional>> xspaces,
+ std::optional> all_hosts = std::nullopt);
// Returns the number of XSpaces in the profile session.
size_t XSpaceSize() const { return xspace_paths_.size(); }
@@ -76,6 +77,9 @@ class SessionSnapshot {
// Gets host name.
std::string GetHostname(size_t index) const;
+ // Gets all host names.
+ std::optional> GetAllHosts() const;
+
// Gets the run directory of the profile session.
absl::string_view GetSessionRunDir() const { return session_run_dir_; }
@@ -142,22 +146,15 @@ class SessionSnapshot {
private:
SessionSnapshot(std::vector xspace_paths,
- std::optional>> xspaces)
- : xspace_paths_(std::move(xspace_paths)),
- // If the snapshot was initialized by xspaces, the file path and run dir
- // is a path tensorflow can't read from or write to so any file IO
- // encapsulated in this class will be disabled in this mode.
- has_accessible_run_dir_(!xspaces.has_value()),
- xspaces_(std::move(xspaces)) {
- session_run_dir_ = tsl::io::Dirname(xspace_paths_.at(0));
- for (size_t i = 0; i < xspace_paths_.size(); ++i) {
- std::string host_name = GetHostname(i);
- hostname_map_[host_name] = i;
- }
- }
+ std::optional>> xspaces,
+ std::optional> all_hosts);
// File paths to XSpace protos.
std::vector xspace_paths_;
+
+ // All hosts in the session.
+ std::optional> all_hosts_;
+
// The run directory of the profile session.
absl::string_view session_run_dir_;
diff --git a/xprof/convert/streaming_trace_viewer_processor.cc b/xprof/convert/streaming_trace_viewer_processor.cc
index 954885cf7..c872d3431 100644
--- a/xprof/convert/streaming_trace_viewer_processor.cc
+++ b/xprof/convert/streaming_trace_viewer_processor.cc
@@ -5,6 +5,7 @@
#include
#include
+#include "absl/container/flat_hash_map.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/numbers.h"
@@ -67,65 +68,78 @@ absl::StatusOr GetTraceViewOption(const ToolOptions& options) {
absl::Status StreamingTraceViewerProcessor::ProcessSession(
const SessionSnapshot& session_snapshot, const ToolOptions& options) {
- if (session_snapshot.XSpaceSize() != 1) {
- return tsl::errors::InvalidArgument(
- "Trace events tool expects only 1 XSpace path but gets ",
- session_snapshot.XSpaceSize());
- }
-
- google::protobuf::Arena arena;
- TF_ASSIGN_OR_RETURN(XSpace * xspace, session_snapshot.GetXSpace(0, &arena));
- PreprocessSingleHostXSpace(xspace, /*step_grouping=*/true,
- /*derived_timeline=*/true);
-
+ TraceEventsContainer merged_trace_container;
std::string tool_name = "trace_viewer@";
- std::string trace_viewer_json;
- std::string host_name = session_snapshot.GetHostname(0);
- auto sstable_path = session_snapshot.GetFilePath(tool_name, host_name);
- if (!sstable_path) {
- return tsl::errors::Unimplemented(
- "streaming trace viewer hasn't been supported in Cloud AI");
- }
- if (!tsl::Env::Default()->FileExists(*sstable_path).ok()) {
- ProcessMegascaleDcn(xspace);
- TraceEventsContainer trace_container;
- ConvertXSpaceToTraceEventsContainer(host_name, *xspace, &trace_container);
- std::unique_ptr file;
- TF_RETURN_IF_ERROR(
- tsl::Env::Default()->NewWritableFile(*sstable_path, &file));
- TF_RETURN_IF_ERROR(trace_container.StoreAsLevelDbTable(std::move(file)));
- }
TF_ASSIGN_OR_RETURN(TraceViewOption trace_option,
GetTraceViewOption(options));
tensorflow::profiler::TraceOptions profiler_trace_options =
TraceOptionsFromToolOptions(options);
- auto visibility_filter = std::make_unique(
- tsl::profiler::MilliSpan(trace_option.start_time_ms,
- trace_option.end_time_ms),
- trace_option.resolution, profiler_trace_options);
- TraceEventsContainer trace_container;
- // Trace smaller than threshold will be disabled from streaming.
- constexpr int64_t kDisableStreamingThreshold = 500000;
- auto trace_events_filter =
- CreateTraceEventsFilterFromTraceOptions(profiler_trace_options);
- TraceEventsLevelDbFilePaths file_paths;
- file_paths.trace_events_file_path = *sstable_path;
- TF_RETURN_IF_ERROR(trace_container.LoadFromLevelDbTable(
- file_paths, std::move(trace_events_filter), std::move(visibility_filter),
- kDisableStreamingThreshold));
+
+ absl::flat_hash_map host_to_id_map;
+ if (auto all_hosts = session_snapshot.GetAllHosts()) {
+ for (int i = 0; i < all_hosts->size(); ++i) {
+ host_to_id_map[(*all_hosts)[i]] = i;
+ }
+ }
+
+ // TODO(b/452217676) : Optimize this to process hosts in parallel.
+ for (int i = 0; i < session_snapshot.XSpaceSize(); ++i) {
+ int host_id = host_to_id_map[session_snapshot.GetHostname(i)];
+ LOG(INFO) << "Processing host: " << session_snapshot.GetHostname(i)
+ << " with host_id: " << host_id;
+ google::protobuf::Arena arena;
+ TF_ASSIGN_OR_RETURN(XSpace * xspace, session_snapshot.GetXSpace(i, &arena));
+ PreprocessSingleHostXSpace(xspace, /*step_grouping=*/true,
+ /*derived_timeline=*/true);
+
+ std::string host_name = session_snapshot.GetHostname(i);
+ auto sstable_path = session_snapshot.GetFilePath(tool_name, host_name);
+ if (!sstable_path) {
+ return tsl::errors::Unimplemented(
+ "streaming trace viewer hasn't been supported in Cloud AI");
+ }
+ if (!tsl::Env::Default()->FileExists(*sstable_path).ok()) {
+ ProcessMegascaleDcn(xspace);
+ TraceEventsContainer trace_container;
+ ConvertXSpaceToTraceEventsContainer(host_name, host_id, *xspace,
+ &trace_container);
+ std::unique_ptr file;
+ TF_RETURN_IF_ERROR(
+ tsl::Env::Default()->NewWritableFile(*sstable_path, &file));
+ TF_RETURN_IF_ERROR(trace_container.StoreAsLevelDbTable(std::move(file)));
+ }
+
+ auto visibility_filter = std::make_unique(
+ tsl::profiler::MilliSpan(trace_option.start_time_ms,
+ trace_option.end_time_ms),
+ trace_option.resolution, profiler_trace_options);
+ TraceEventsContainer trace_container;
+ // Trace smaller than threshold will be disabled from streaming.
+ constexpr int64_t kDisableStreamingThreshold = 500000;
+ auto trace_events_filter =
+ CreateTraceEventsFilterFromTraceOptions(profiler_trace_options);
+ TraceEventsLevelDbFilePaths file_paths;
+ file_paths.trace_events_file_path = *sstable_path;
+ TF_RETURN_IF_ERROR(trace_container.LoadFromLevelDbTable(
+ file_paths, std::move(trace_events_filter),
+ std::move(visibility_filter), kDisableStreamingThreshold));
+ merged_trace_container.Merge(std::move(trace_container));
+ }
+
+ std::string trace_viewer_json;
JsonTraceOptions json_trace_options;
tensorflow::profiler::TraceDeviceType device_type =
tensorflow::profiler::TraceDeviceType::kUnknownDevice;
- if (IsTpuTrace(trace_container.trace())) {
+ if (IsTpuTrace(merged_trace_container.trace())) {
device_type = TraceDeviceType::kTpu;
}
json_trace_options.details =
TraceOptionsToDetails(device_type, profiler_trace_options);
IOBufferAdapter adapter(&trace_viewer_json);
TraceEventsToJson(
- json_trace_options, trace_container, &adapter);
+ json_trace_options, merged_trace_container, &adapter);
SetOutput(trace_viewer_json, "application/json");
return absl::OkStatus();
diff --git a/xprof/convert/trace_viewer/BUILD b/xprof/convert/trace_viewer/BUILD
index d0020296e..54ca46f03 100644
--- a/xprof/convert/trace_viewer/BUILD
+++ b/xprof/convert/trace_viewer/BUILD
@@ -133,6 +133,7 @@ cc_library(
":trace_events_filter_interface",
":trace_events_util",
":trace_viewer_visibility",
+ "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:endian",
"@com_google_absl//absl/container:flat_hash_map",
diff --git a/xprof/convert/trace_viewer/trace_events.cc b/xprof/convert/trace_viewer/trace_events.cc
index fa5bce32d..650ac914c 100644
--- a/xprof/convert/trace_viewer/trace_events.cc
+++ b/xprof/convert/trace_viewer/trace_events.cc
@@ -19,13 +19,17 @@ limitations under the License.
#include
#include
#include
+#include
#include
#include
#include
#include
+#include
#include
+#include "absl/algorithm/container.h"
#include "absl/base/internal/endian.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
@@ -77,6 +81,15 @@ void MaybeAddEventUniqueId(std::vector& events) {
}
}
+// Appends all events from src into dst.
+inline void AppendEvents(TraceEventTrack&& src, TraceEventTrack* dst) {
+ if (dst->empty()) {
+ *dst = std::move(src);
+ } else {
+ absl::c_move(src, std::back_inserter(*dst));
+ }
+}
+
} // namespace
TraceEvent::EventType GetTraceEventType(const TraceEvent& event) {
@@ -293,5 +306,51 @@ void PurgeIrrelevantEntriesInTraceNameTable(
trace.mutable_name_table()->swap(new_name_table);
}
+template
+void TraceEventsContainerBase::MergeTrace(
+ const Trace& other_trace) {
+ trace_.mutable_tasks()->insert(other_trace.tasks().begin(),
+ other_trace.tasks().end());
+ trace_.mutable_name_table()->insert(other_trace.name_table().begin(),
+ other_trace.name_table().end());
+ if (other_trace.has_min_timestamp_ps() &&
+ other_trace.has_max_timestamp_ps()) {
+ ExpandTraceSpan(TraceSpan(other_trace), &trace_);
+ }
+ trace_.set_num_events(trace_.num_events() + other_trace.num_events());
+}
+
+template
+void TraceEventsContainerBase::Merge(
+ TraceEventsContainerBase&& other) {
+ if (this == &other) return;
+ if (other.NumEvents() == 0 && other.trace().devices().empty()) return;
+
+ auto& this_device_map = *trace_.mutable_devices();
+ for (const auto& [other_id, other_device] : other.trace().devices()) {
+ this_device_map.insert({other_id, other_device});
+ }
+
+ other.ForAllMutableTracks([this](uint32_t other_device_id,
+ ResourceValue resource_id_or_counter_name,
+ TraceEventTrack* track) {
+ DeviceEvents& device = this->events_by_device_[other_device_id];
+ if (uint64_t* resource_id =
+ std::get_if(&resource_id_or_counter_name)) {
+ AppendEvents(std::move(*track), &device.events_by_resource[*resource_id]);
+ } else if (absl::string_view* counter_name = std::get_if(
+ &resource_id_or_counter_name)) {
+ AppendEvents(std::move(*track),
+ &device.counter_events_by_name[*counter_name]);
+ }
+ });
+
+ MergeTrace(other.trace());
+ arenas_.insert(other.arenas_.begin(), other.arenas_.end());
+}
+
+// Explicit instantiations for the common case.
+template class TraceEventsContainerBase;
+
} // namespace profiler
} // namespace tensorflow
diff --git a/xprof/convert/trace_viewer/trace_events.h b/xprof/convert/trace_viewer/trace_events.h
index 99a68d3b9..d42589559 100644
--- a/xprof/convert/trace_viewer/trace_events.h
+++ b/xprof/convert/trace_viewer/trace_events.h
@@ -729,6 +729,8 @@ class TraceEventsContainerBase {
TraceEventsContainerBase(const TraceEventsContainerBase&) = delete;
TraceEventsContainerBase& operator=(const TraceEventsContainerBase&) = delete;
+ void Merge(TraceEventsContainerBase&& other);
+
// Creates a TraceEvent prefilled with the given values.
void AddCompleteEvent(absl::string_view name, uint64_t resource_id,
uint32_t device_id, tsl::profiler::Timespan timespan,
@@ -1075,6 +1077,9 @@ class TraceEventsContainerBase {
return copy;
}
+ // Helper function to merge top-level trace metadata.
+ void MergeTrace(const Trace& other_trace);
+
// Adds an event from arenas_ to events_by_device_.
void AddArenaEvent(TraceEvent* event) {
ExpandTraceSpan(EventSpan(*event), &trace_);
diff --git a/xprof/convert/trace_viewer/trace_utils.h b/xprof/convert/trace_viewer/trace_utils.h
index 783fd8420..eaa9b0d3d 100644
--- a/xprof/convert/trace_viewer/trace_utils.h
+++ b/xprof/convert/trace_viewer/trace_utils.h
@@ -37,6 +37,8 @@ inline bool MaybeTpuNonCoreDeviceName(absl::string_view device_name) {
IsTpuIciRouterDeviceName(device_name));
}
+static constexpr int kMaxDevicesPerHost = 1000;
+
} // namespace profiler
} // namespace tensorflow
diff --git a/xprof/convert/xplane_to_tools_data.cc b/xprof/convert/xplane_to_tools_data.cc
index 4ecce1108..986fa6a70 100644
--- a/xprof/convert/xplane_to_tools_data.cc
+++ b/xprof/convert/xplane_to_tools_data.cc
@@ -177,7 +177,10 @@ absl::StatusOr ConvertXSpaceToTraceEvents(
if (!tsl::Env::Default()->FileExists(*trace_events_sstable_path).ok()) {
ProcessMegascaleDcn(xspace);
TraceEventsContainer trace_container;
- ConvertXSpaceToTraceEventsContainer(host_name, *xspace, &trace_container);
+ // No-op method which will be deprecated in the future, thus added
+ // /*host_id=*/1 as a placeholder for now.
+ ConvertXSpaceToTraceEventsContainer(host_name, 1, *xspace,
+ &trace_container);
std::unique_ptr trace_events_file;
TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(
*trace_events_sstable_path, &trace_events_file));
diff --git a/xprof/convert/xplane_to_trace_container.cc b/xprof/convert/xplane_to_trace_container.cc
index aba9c186c..29919bd1f 100644
--- a/xprof/convert/xplane_to_trace_container.cc
+++ b/xprof/convert/xplane_to_trace_container.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include
#include
+#include "absl/base/optimization.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "xla/tsl/profiler/utils/tf_xplane_visitor.h"
@@ -32,6 +33,7 @@ limitations under the License.
#include "xla/tsl/profiler/utils/xplane_visitor.h"
#include "xprof/convert/trace_viewer/trace_event_arguments_builder.h"
#include "xprof/convert/trace_viewer/trace_events_util.h"
+#include "xprof/convert/trace_viewer/trace_utils.h"
#include "plugin/xprof/protobuf/trace_events.pb.h"
#include "plugin/xprof/protobuf/trace_events_raw.pb.h"
@@ -219,7 +221,7 @@ void ConvertXPlaneToTraceEventsContainer(uint64_t device_id,
} // namespace
void ConvertXSpaceToTraceEventsContainer(absl::string_view hostname,
- const XSpace& space,
+ int host_id, const XSpace& space,
TraceEventsContainer* container) {
const XPlane* host_plane =
FindPlaneWithName(space, tsl::profiler::kHostThreadsPlaneName);
@@ -236,9 +238,13 @@ void ConvertXSpaceToTraceEventsContainer(absl::string_view hostname,
}
for (const XPlane* device_plane : device_planes) {
- ConvertXPlaneToTraceEventsContainer(
- tsl::profiler::kFirstDeviceId + device_plane->id(), hostname,
- *device_plane, container);
+ uint32_t device_pid = tsl::profiler::kFirstDeviceId + device_plane->id();
+ if (ABSL_PREDICT_FALSE(device_pid > tsl::profiler::kLastDeviceId)) {
+ device_pid = tsl::profiler::kFirstDeviceId;
+ }
+ uint32_t final_device_pid = (host_id)*kMaxDevicesPerHost + device_pid;
+ ConvertXPlaneToTraceEventsContainer(final_device_pid, hostname,
+ *device_plane, container);
}
for (const XPlane* custom_plane :
FindPlanesWithPrefix(space, tsl::profiler::kCustomPlanePrefix)) {
diff --git a/xprof/convert/xplane_to_trace_container.h b/xprof/convert/xplane_to_trace_container.h
index c0f0272b3..01af9ad94 100644
--- a/xprof/convert/xplane_to_trace_container.h
+++ b/xprof/convert/xplane_to_trace_container.h
@@ -28,6 +28,7 @@ using TraceEventsContainer = TraceEventsContainerBase;
// Converts XEvents within the XSpace into trace_viewer events container.
void ConvertXSpaceToTraceEventsContainer(absl::string_view hostname,
+ int host_id,
const XSpace& xspace,
TraceEventsContainer* container);
diff --git a/xprof/convert/xplane_to_trace_container_test.cc b/xprof/convert/xplane_to_trace_container_test.cc
index 4e0a5cf3e..cd9f97618 100644
--- a/xprof/convert/xplane_to_trace_container_test.cc
+++ b/xprof/convert/xplane_to_trace_container_test.cc
@@ -96,7 +96,7 @@ TEST(XPlaneToTraceContainerTest, CounterLine) {
tsl::profiler::UniToNano(1), tsl::profiler::UniToNano(500)),
&xspace));
TraceEventsContainer container;
- ConvertXSpaceToTraceEventsContainer("localhost", xspace, &container);
+ ConvertXSpaceToTraceEventsContainer("localhost", 0, xspace, &container);
absl::flat_hash_map>
counter_offset_to_values;
container.ForAllEvents([&counter_offset_to_values](const TraceEvent& event) {
@@ -142,7 +142,7 @@ TEST(XPlaneToTraceContainerTest, AsyncLine) {
tsl::profiler::kXlaAsyncOpLineName, kAsyncOpEventName),
&xspace));
TraceEventsContainer container;
- ConvertXSpaceToTraceEventsContainer("localhost", xspace, &container);
+ ConvertXSpaceToTraceEventsContainer("localhost", 0, xspace, &container);
bool async_event_found = false;
container.ForAllEvents(
[&async_event_found, &kAsyncOpEventName](const TraceEvent& event) {
diff --git a/xprof/pywrap/_pywrap_profiler_plugin.pyi b/xprof/pywrap/_pywrap_profiler_plugin.pyi
index 6be031344..84118493e 100644
--- a/xprof/pywrap/_pywrap_profiler_plugin.pyi
+++ b/xprof/pywrap/_pywrap_profiler_plugin.pyi
@@ -15,8 +15,8 @@
def monitor(arg0: str, arg1: int, arg2: int, arg3: bool) -> str: ...
def trace(arg0: str, arg1: str, arg2: str, arg3: bool, arg4: int, arg5: int, arg6: dict) -> None: ...
-def xspace_to_tools_data(arg0: list, arg1: str, arg2: dict = ...) -> tuple: ...
-def xspace_to_tools_data_from_byte_string(arg0: list, arg1: list, arg2: str, arg3: dict) -> tuple: ...
+def xspace_to_tools_data(arg0: list, arg1: list, arg2: str, arg3: dict = ...) -> tuple: ...
+def xspace_to_tools_data_from_byte_string(arg0: list, arg1: list, arg2: list, arg3: str, arg4: dict) -> tuple: ...
def start_grpc_server(port: int) -> None: ...
def initialize_stubs(worker_service_addresses: str) -> None: ...
diff --git a/xprof/pywrap/profiler_plugin_impl.cc b/xprof/pywrap/profiler_plugin_impl.cc
index ccf5c2a43..145ade5b3 100644
--- a/xprof/pywrap/profiler_plugin_impl.cc
+++ b/xprof/pywrap/profiler_plugin_impl.cc
@@ -124,16 +124,18 @@ void StartGrpcServer(int port) {
}
absl::StatusOr> XSpaceToToolsData(
- std::vector xspace_paths, const std::string& tool_name,
+ std::vector xspace_paths,
+ std::vector all_hosts, const std::string& tool_name,
const ToolOptions& tool_options) {
auto status_or_session_snapshot = SessionSnapshot::Create(
- std::move(xspace_paths), /*xspaces=*/std::nullopt);
+ std::move(xspace_paths), /*xspaces=*/std::nullopt, all_hosts);
return SessionSnapshotToToolsData(status_or_session_snapshot, tool_name,
tool_options);
}
absl::StatusOr> XSpaceToToolsDataFromByteString(
std::vector xspace_strings,
+ std::vector all_hosts,
std::vector xspace_paths, const std::string& tool_name,
const ToolOptions& tool_options) {
std::vector> xspaces;
@@ -154,7 +156,8 @@ absl::StatusOr> XSpaceToToolsDataFromByteString(
}
auto status_or_session_snapshot =
- SessionSnapshot::Create(std::move(xspace_paths), std::move(xspaces));
+ SessionSnapshot::Create(std::move(xspace_paths), std::move(xspaces),
+ all_hosts);
return SessionSnapshotToToolsData(status_or_session_snapshot, tool_name,
tool_options);
}
diff --git a/xprof/pywrap/profiler_plugin_impl.h b/xprof/pywrap/profiler_plugin_impl.h
index 03b13b23a..9c7562b13 100644
--- a/xprof/pywrap/profiler_plugin_impl.h
+++ b/xprof/pywrap/profiler_plugin_impl.h
@@ -36,11 +36,13 @@ absl::Status Monitor(const char* service_addr, int duration_ms,
tsl::string* result);
absl::StatusOr> XSpaceToToolsData(
- std::vector xspace_paths, const std::string& tool_name,
+ std::vector xspace_paths,
+ std::vector all_hosts, const std::string& tool_name,
const tensorflow::profiler::ToolOptions& tool_options);
absl::StatusOr> XSpaceToToolsDataFromByteString(
std::vector xspace_strings,
+ std::vector all_hosts,
std::vector xspace_paths, const std::string& tool_name,
const tensorflow::profiler::ToolOptions& tool_options);
diff --git a/xprof/pywrap/pywrap_profiler_plugin.cc b/xprof/pywrap/pywrap_profiler_plugin.cc
index 7e4dd69f1..612c39de4 100644
--- a/xprof/pywrap/pywrap_profiler_plugin.cc
+++ b/xprof/pywrap/pywrap_profiler_plugin.cc
@@ -88,22 +88,27 @@ PYBIND11_MODULE(_pywrap_profiler_plugin, m) {
m.def(
"xspace_to_tools_data",
- [](const py::list& xspace_path_list, const py::str& py_tool_name,
- const py::dict options = py::dict()) {
+ [](const py::list& xspace_path_list, const py::list& all_hosts_list,
+ const py::str& py_tool_name, const py::dict options = py::dict()) {
std::vector xspace_paths;
xspace_paths.reserve(xspace_path_list.size());
for (py::handle obj : xspace_path_list) {
std::string xspace_path = std::string(py::cast(obj));
xspace_paths.push_back(xspace_path);
}
+ std::vector all_hosts;
+ all_hosts.reserve(all_hosts_list.size());
+ for (py::handle obj : all_hosts_list) {
+ all_hosts.push_back(std::string(py::cast(obj)));
+ }
std::string tool_name = std::string(py_tool_name);
ToolOptions tool_options =
ToolOptionsFromPythonDict(options);
absl::StatusOr> result;
{
py::gil_scoped_release release;
- result = xprof::pywrap::XSpaceToToolsData(xspace_paths, tool_name,
- tool_options);
+ result = xprof::pywrap::XSpaceToToolsData(xspace_paths, all_hosts,
+ tool_name, tool_options);
}
if (!result.ok()) {
@@ -112,18 +117,25 @@ PYBIND11_MODULE(_pywrap_profiler_plugin, m) {
return py::make_tuple(py::bytes(result->first),
py::bool_(result->second));
},
- py::arg(), py::arg(), py::arg() = py::dict());
+ py::arg(), py::arg(), py::arg(), py::arg() = py::dict());
m.def(
"xspace_to_tools_data_from_byte_string",
- [](const py::list& xspace_string_list, const py::list& filenames_list,
- const py::str& py_tool_name, const py::dict options = py::dict()) {
+ [](const py::list& xspace_string_list, const py::list& all_hosts_list,
+ const py::list& filenames_list, const py::str& py_tool_name,
+ const py::dict options = py::dict()) {
std::vector xspace_strings;
xspace_strings.reserve(xspace_string_list.size());
for (py::handle obj : xspace_string_list) {
xspace_strings.push_back(std::string(py::cast(obj)));
}
+ std::vector all_hosts;
+ all_hosts.reserve(all_hosts_list.size());
+ for (py::handle obj : all_hosts_list) {
+ all_hosts.push_back(std::string(py::cast(obj)));
+ }
+
std::vector xspace_paths;
xspace_paths.reserve(filenames_list.size());
for (py::handle obj : filenames_list) {
@@ -138,7 +150,7 @@ PYBIND11_MODULE(_pywrap_profiler_plugin, m) {
{
py::gil_scoped_release release;
result = xprof::pywrap::XSpaceToToolsDataFromByteString(
- xspace_strings, xspace_paths, tool_name, tool_options);
+ xspace_strings, all_hosts, xspace_paths, tool_name, tool_options);
}
if (!result.ok()) {
@@ -147,7 +159,7 @@ PYBIND11_MODULE(_pywrap_profiler_plugin, m) {
return py::make_tuple(py::bytes(result->first),
py::bool_(result->second));
},
- py::arg(), py::arg(), py::arg(), py::arg() = py::dict());
+ py::arg(), py::arg(), py::arg(), py::arg(), py::arg() = py::dict());
m.def("start_grpc_server", [](int port) {
py::gil_scoped_release release;