diff --git a/frontend/app/common/interfaces/navigation_event.ts b/frontend/app/common/interfaces/navigation_event.ts index e7df483ac..2cd2b79e1 100644 --- a/frontend/app/common/interfaces/navigation_event.ts +++ b/frontend/app/common/interfaces/navigation_event.ts @@ -5,6 +5,8 @@ export declare interface NavigationEvent { run?: string; tag?: string; host?: string; + // Added to support multi-host functionality for trace_viewer. + hosts?: string[]; // Graph Viewer crosslink params opName?: string; moduleName?: string; diff --git a/frontend/app/components/sidenav/sidenav.ng.html b/frontend/app/components/sidenav/sidenav.ng.html index 53761e457..81613e06c 100644 --- a/frontend/app/components/sidenav/sidenav.ng.html +++ b/frontend/app/components/sidenav/sidenav.ng.html @@ -32,10 +32,17 @@
- Hosts ({{hosts.length}}) + Hosts ({{ isMultiHostsEnabled ? selectedHostsInternal.length : hosts.length }})
- + + + + {{host}} + + + + {{host}} diff --git a/frontend/app/components/sidenav/sidenav.ts b/frontend/app/components/sidenav/sidenav.ts index e4e759e54..db1cf9e48 100644 --- a/frontend/app/components/sidenav/sidenav.ts +++ b/frontend/app/components/sidenav/sidenav.ts @@ -32,8 +32,10 @@ export class SideNav implements OnInit, OnDestroy { selectedRunInternal = ''; selectedTagInternal = ''; selectedHostInternal = ''; + selectedHostsInternal: string[] = []; selectedModuleInternal = ''; navigationParams: {[key: string]: string|boolean} = {}; + multiHostEnabledTools: string[] = ['trace_viewer', 'trace_viewer@']; hideCaptureProfileButton = false; @@ -65,6 +67,11 @@ export class SideNav implements OnInit, OnDestroy { return HLO_TOOLS.includes(this.selectedTag); } + get isMultiHostsEnabled() { + const tag = this.selectedTag || ''; + return this.multiHostEnabledTools.includes(tag); + } + // Getter for valid run given url router or user selection. get selectedRun() { return this.runs.find(validRun => validRun === this.selectedRunInternal) || @@ -90,6 +97,10 @@ export class SideNav implements OnInit, OnDestroy { this.moduleList[0] || ''; } + get selectedHosts() { + return this.selectedHostsInternal; + } + // https://github.com/angular/angular/issues/11023#issuecomment-752228784 mergeRouteParams(): Map { const params = new Map(); @@ -119,20 +130,21 @@ export class SideNav implements OnInit, OnDestroy { const run = params.get('run') || ''; const tag = params.get('tool') || params.get('tag') || ''; const host = params.get('host') || ''; + const hostsParam = params.get('hosts'); const opName = params.get('node_name') || params.get('opName') || ''; const moduleName = params.get('module_name') || ''; this.navigationParams['firstLoad'] = true; if (opName) { this.navigationParams['opName'] = opName; } - if (this.selectedRunInternal === run && this.selectedTagInternal === tag && - this.selectedHostInternal === host) { - return; - } this.selectedRunInternal = run; this.selectedTagInternal = tag; - this.selectedHostInternal = host; this.selectedModuleInternal = moduleName; + + if (hostsParam) { + this.selectedHostsInternal = hostsParam.split(','); + } + this.selectedHostInternal = host; this.update(); } @@ -153,9 +165,13 @@ export class SideNav implements OnInit, OnDestroy { const navigationEvent: NavigationEvent = { run: this.selectedRun, tag: this.selectedTag, - host: this.selectedHost, ...this.navigationParams, }; + if (this.isMultiHostsEnabled) { + navigationEvent.hosts = this.selectedHosts; + } else { + navigationEvent.host = this.selectedHost; + } if (this.is_hlo_tool) { navigationEvent.moduleName = this.selectedModule; } @@ -255,6 +271,8 @@ export class SideNav implements OnInit, OnDestroy { // Keep them under the same update function as initial step of the separation. async updateHosts() { this.hosts = await this.getHostsForSelectedTag(); + this.selectedHostsInternal = [this.hosts[0]]; + this.selectedHostInternal = this.hosts[0]; if (this.is_hlo_tool) { this.moduleList = await this.getModuleListForSelectedTag(); } @@ -262,8 +280,15 @@ export class SideNav implements OnInit, OnDestroy { this.afterUpdateHost(); } - onHostSelectionChange(host: string) { - this.selectedHostInternal = host; + onHostSelectionChange(selection: string) { + this.selectedHostInternal = selection; + this.selectedHostsInternal = []; + this.navigateTools(); + } + + onHostsSelectionChange(selection: string[]) { + this.selectedHostsInternal = selection; + this.selectedHostInternal = ''; // Ensure single-host is empty this.navigateTools(); } @@ -276,26 +301,65 @@ export class SideNav implements OnInit, OnDestroy { this.navigateTools(); } + // Helper function to serialize query parameters + private serializeQueryParams( + params: {[key: string]: string|string[]|boolean|undefined}): string { + const searchParams = new URLSearchParams(); + for (const key in params) { + if (params.hasOwnProperty(key)) { + const value = params[key]; + // Only include non-null/non-undefined values + if (value !== undefined && value !== null) { + if (Array.isArray(value)) { + // Arrays are handled as comma-separated strings (like 'hosts') + searchParams.set(key, value.join(',')); + } else if (typeof value === 'boolean') { + // Only set boolean flags if they are explicitly true + if (value === true) { + searchParams.set(key, 'true'); + } + } else { + searchParams.set(key, String(value)); + } + } + } + } + const queryString = searchParams.toString(); + return queryString ? `?${queryString}` : ''; + } + updateUrlHistory() { - // TODO(xprof): change to camel case when constructing url - const toolQueryParams = Object.keys(this.navigationParams) - .map(key => { - return `${key}=${this.navigationParams[key]}`; - }) - .join('&'); - const toolQueryParamsString = - toolQueryParams.length ? `&${toolQueryParams}` : ''; - const moduleNameQuery = - this.is_hlo_tool ? `&module_name=${this.selectedModule}` : ''; - const url = `${window.parent.location.origin}?tool=${ - this.selectedTag}&host=${this.selectedHost}&run=${this.selectedRun}${ - toolQueryParamsString}${moduleNameQuery}#profile`; + const navigationEvent = this.getNavigationEvent(); + const queryParams: {[key: string]: string|string[]|boolean| + undefined} = {...navigationEvent}; + + if (this.isMultiHostsEnabled) { + // For Trace Viewer, ensure 'hosts' is a comma-separated string in the URL + if (queryParams['hosts'] && Array.isArray(queryParams['hosts'])) { + queryParams['hosts'] = (queryParams['hosts'] as string[]).join(','); + } + delete queryParams['host']; // Remove single host param + } else { + // For other tools, ensure 'host' is used + delete queryParams['hosts']; // Remove multi-host param + } + + // Get current path to avoid changing the base URL + const pathname = window.parent.location.pathname; + + // Use the custom serialization helper + const queryString = this.serializeQueryParams(queryParams); + const url = pathname + queryString; + window.parent.history.pushState({}, '', url); } navigateTools() { const navigationEvent = this.getNavigationEvent(); this.communicationService.onNavigateReady(navigationEvent); + + // This router.navigate call remains, as it's responsible for Angular + // routing this.router.navigate( [ this.selectedTag || 'empty', diff --git a/frontend/app/components/trace_viewer/trace_viewer.ts b/frontend/app/components/trace_viewer/trace_viewer.ts index 9eb5f0ebf..4ade3c9c4 100644 --- a/frontend/app/components/trace_viewer/trace_viewer.ts +++ b/frontend/app/components/trace_viewer/trace_viewer.ts @@ -1,5 +1,4 @@ import {PlatformLocation} from '@angular/common'; -import {HttpParams} from '@angular/common/http'; import {Component, inject, Injector, OnDestroy} from '@angular/core'; import {ActivatedRoute} from '@angular/router'; import {API_PREFIX, DATA_API, PLUGIN_NAME} from 'org_xprof/frontend/app/common/constants/constants'; @@ -38,11 +37,19 @@ export class TraceViewer implements OnDestroy { update(event: NavigationEvent) { const isStreaming = (event.tag === 'trace_viewer@'); - const params = new HttpParams() - .set('run', event.run!) - .set('tag', event.tag!) - .set('host', event.host!); - const traceDataUrl = this.pathPrefix + DATA_API + '?' + params.toString(); + const run = event.run || ''; + const tag = event.tag || ''; + + let queryString = `run=${run}&tag=${tag}`; + + if (event.hosts && typeof event.hosts === 'string') { + // Since event.hosts is a comma-separated string, we can use it directly. + queryString += `&hosts=${event.hosts}`; + } else if (event.host) { + queryString += `&host=${event.host}`; + } + + const traceDataUrl = `${this.pathPrefix}${DATA_API}?${queryString}`; this.url = this.pathPrefix + API_PREFIX + PLUGIN_NAME + '/trace_viewer_index.html' + '?is_streaming=' + isStreaming.toString() + '&is_oss=true' + diff --git a/plugin/xprof/convert/raw_to_tool_data.py b/plugin/xprof/convert/raw_to_tool_data.py index 04b4f5b19..4e42e2348 100644 --- a/plugin/xprof/convert/raw_to_tool_data.py +++ b/plugin/xprof/convert/raw_to_tool_data.py @@ -41,29 +41,6 @@ def process_raw_trace(raw_trace): return ''.join(trace_events_json.TraceEventsJsonStream(trace)) -def xspace_to_tools_data_from_byte_string(xspace_byte_list, filenames, tool, - params): - """Helper function for getting an XSpace tool from a bytes string. - - Args: - xspace_byte_list: A list of byte strings read from a XSpace proto file. - filenames: Names of the read files. - tool: A string of tool name. - params: user input parameters. - - Returns: - Returns a string of tool data. - """ -# pylint:disable=dangerous-default-value - def xspace_wrapper_func(xspace_arg, tool_arg, params={}): - return _pywrap_profiler_plugin.xspace_to_tools_data_from_byte_string( - xspace_arg, filenames, tool_arg, params) -# pylint:enable=dangerous-default-value - - return xspace_to_tool_data(xspace_byte_list, tool, params, - xspace_wrapper_func) - - def xspace_to_tool_names(xspace_paths): """Converts XSpace to all the available tool names. @@ -73,8 +50,10 @@ def xspace_to_tool_names(xspace_paths): Returns: Returns a list of tool names. """ + # xspace_to_tools_data expects all_hosts as the second argument, passing an + # empty list. raw_data, success = _pywrap_profiler_plugin.xspace_to_tools_data( - xspace_paths, 'tool_names') + xspace_paths, [], 'tool_names', {}) if success: return [tool for tool in raw_data.decode().split(',')] return [] @@ -82,6 +61,7 @@ def xspace_to_tool_names(xspace_paths): def xspace_to_tool_data( xspace_paths, + all_hosts, tool, params, xspace_wrapper_func=_pywrap_profiler_plugin.xspace_to_tools_data): @@ -89,6 +69,7 @@ def xspace_to_tool_data( Args: xspace_paths: A list of XSpace paths. + all_hosts: A list of all hosts in the session. tool: A string of tool name. params: user input parameters. xspace_wrapper_func: A callable that takes a list of strings and a tool and @@ -112,27 +93,31 @@ def xspace_to_tool_data( if tool == 'trace_viewer': # Trace viewer handles one host at a time. assert len(xspace_paths) == 1 - raw_data, success = xspace_wrapper_func(xspace_paths, tool, options) + raw_data, success = xspace_wrapper_func( + xspace_paths, all_hosts, tool, options) if success: data = process_raw_trace(raw_data) elif tool == 'trace_viewer@': - # Streaming trace viewer handles one host at a time. - assert len(xspace_paths) == 1 options = params.get('trace_viewer_options', {}) options['use_saved_result'] = params.get('use_saved_result', True) - raw_data, success = xspace_wrapper_func(xspace_paths, tool, options) + options['hosts'] = all_hosts + raw_data, success = xspace_wrapper_func( + xspace_paths, all_hosts, tool, options) if success: data = raw_data elif tool == 'overview_page': - json_data, success = xspace_wrapper_func(xspace_paths, tool, options) + json_data, success = xspace_wrapper_func( + xspace_paths, all_hosts, tool, options) if success: data = json_data elif tool == 'input_pipeline_analyzer': - json_data, success = xspace_wrapper_func(xspace_paths, tool, options) + json_data, success = xspace_wrapper_func( + xspace_paths, all_hosts, tool, options) if success: data = json_data elif tool == 'framework_op_stats': - json_data, success = xspace_wrapper_func(xspace_paths, tool, options) + json_data, success = xspace_wrapper_func( + xspace_paths, all_hosts, tool, options) if success: if tqx == 'out:csv': data = csv_writer.json_to_csv(json_data) @@ -143,7 +128,7 @@ def xspace_to_tool_data( # TODO(b/419013992): Remove this tool completely as it has been deprecated legacy_tool = 'tensorflow_stats' json_data, success = xspace_wrapper_func( - xspace_paths, legacy_tool, options + xspace_paths, all_hosts, legacy_tool, options ) if success: if tqx == 'out:csv': @@ -151,7 +136,8 @@ def xspace_to_tool_data( else: data = json_data elif tool == 'kernel_stats': - json_data, success = xspace_wrapper_func(xspace_paths, tool, options) + json_data, success = xspace_wrapper_func( + xspace_paths, all_hosts, tool, options) if success: if tqx == 'out:csv': data = csv_writer.json_to_csv(json_data) @@ -160,29 +146,35 @@ def xspace_to_tool_data( elif tool == 'memory_profile': # Memory profile handles one host at a time. assert len(xspace_paths) == 1 - raw_data, success = xspace_wrapper_func(xspace_paths, tool, options) + raw_data, success = xspace_wrapper_func( + xspace_paths, all_hosts, tool, options) if success: data = raw_data elif tool == 'pod_viewer': - raw_data, success = xspace_wrapper_func(xspace_paths, tool, options) + raw_data, success = xspace_wrapper_func( + xspace_paths, all_hosts, tool, options) if success: data = raw_data elif tool == 'op_profile': options['group_by'] = params.get('group_by', 'program') - raw_data, success = xspace_wrapper_func(xspace_paths, tool, options) + raw_data, success = xspace_wrapper_func( + xspace_paths, all_hosts, tool, options) if success: data = raw_data elif tool == 'hlo_op_profile': options['group_by'] = params.get('group_by', 'program') - raw_data, success = xspace_wrapper_func(xspace_paths, tool, options) + raw_data, success = xspace_wrapper_func( + xspace_paths, all_hosts, tool, options) if success: data = raw_data elif tool == 'hlo_stats': - json_data, success = xspace_wrapper_func(xspace_paths, tool, options) + json_data, success = xspace_wrapper_func( + xspace_paths, all_hosts, tool, options) if success: data = json_data elif tool == 'roofline_model': - json_data, success = xspace_wrapper_func(xspace_paths, tool, options) + json_data, success = xspace_wrapper_func( + xspace_paths, all_hosts, tool, options) if success: data = json_data elif tool == 'graph_viewer': @@ -190,7 +182,8 @@ def xspace_to_tool_data( graph_html_type = 'graph' options = params.get('graph_viewer_options', {}) options['use_saved_result'] = params.get('use_saved_result', True) - raw_data, success = xspace_wrapper_func(xspace_paths, tool, options) + raw_data, success = xspace_wrapper_func( + xspace_paths, all_hosts, tool, options) if success: data = raw_data content_type = 'text/plain' @@ -214,18 +207,21 @@ def xspace_to_tool_data( 'view_memory_allocation_timeline': view_memory_allocation_timeline, 'memory_space': params.get('memory_space', ''), } - raw_data, success = xspace_wrapper_func(xspace_paths, tool, options) + raw_data, success = xspace_wrapper_func( + xspace_paths, all_hosts, tool, options) if success: data = raw_data if view_memory_allocation_timeline: content_type = 'text/html' elif tool == 'megascale_stats': options = {'host_name': params.get('host')} - json_data, success = xspace_wrapper_func(xspace_paths, tool, options) + json_data, success = xspace_wrapper_func( + xspace_paths, all_hosts, tool, options) if success: data = json_data elif tool == 'inference_profile': - json_data, success = xspace_wrapper_func(xspace_paths, tool, options) + json_data, success = xspace_wrapper_func( + xspace_paths, all_hosts, tool, options) if success: data = json_data else: diff --git a/plugin/xprof/convert/raw_to_tool_data_test.py b/plugin/xprof/convert/raw_to_tool_data_test.py index a0a3df88d..95e6e1ca3 100644 --- a/plugin/xprof/convert/raw_to_tool_data_test.py +++ b/plugin/xprof/convert/raw_to_tool_data_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Tests for the raw_to_tool_data module.""" import tensorflow as tf @@ -27,7 +26,11 @@ def test_using_old_tool_format_maps_to_new_format(self): xspace_paths=["/path/to/xspace"], tool="trace_viewer@^", params={}, - xspace_wrapper_func=lambda paths, tool, options: (tool.encode(), True), + all_hosts=[], + xspace_wrapper_func=lambda paths, hosts, tool, options: ( + tool.encode(), + True, + ), ) self.assertEqual(data, b"trace_viewer@") @@ -38,7 +41,11 @@ def test_using_new_tool_format_does_not_map_to_old_format(self): xspace_paths=["/path/to/xspace"], tool="trace_viewer@", params={}, - xspace_wrapper_func=lambda paths, tool, options: (tool.encode(), True), + all_hosts=[], + xspace_wrapper_func=lambda paths, hosts, tool, options: ( + tool.encode(), + True, + ), ) self.assertEqual(data, b"trace_viewer@") diff --git a/plugin/xprof/profile_plugin.py b/plugin/xprof/profile_plugin.py index b949ea5a8..0c5a4a899 100644 --- a/plugin/xprof/profile_plugin.py +++ b/plugin/xprof/profile_plugin.py @@ -380,23 +380,6 @@ def filenames_to_hosts(filenames: list[str], tool: str) -> list[str]: return sorted(hosts) -def validate_xplane_asset_paths(asset_paths: List[str]) -> None: - """Validates that all xplane asset paths that are provided are valid files. - - Args: - asset_paths: A list of asset paths. - - Raises: - FileNotFoundError: If any of the xplane asset paths do not exist. - """ - for asset_path in asset_paths: - if ( - str(asset_path).endswith(TOOLS['xplane']) - and not epath.Path(asset_path).exists() - ): - raise FileNotFoundError(f'Invalid asset path: {asset_path}') - - def _get_bool_arg( args: Mapping[str, Any], arg_name: str, default: bool ) -> bool: @@ -511,6 +494,10 @@ def is_active(self) -> bool: self._is_active = any(self.generate_runs()) return self._is_active + def _does_tool_support_multi_hosts_processing(self, tool: str) -> bool: + """Returns true if the tool supports multi-hosts processing.""" + return tool == 'trace_viewer@' or tool == 'trace_viewer' + def get_plugin_apps( self, ) -> dict[str, Callable[[wrappers.Request], wrappers.Response]]: @@ -718,6 +705,85 @@ def hlo_module_list_route( module_names_str = self.hlo_module_list_impl(request) return respond(module_names_str, 'text/plain') + def _get_valid_hosts( + self, run_dir: str, run: str, tool: str, hosts_param: str, host: str + ) -> tuple[List[str], List[epath.Path], List[str]]: + """Retrieves and validates the hosts and asset paths for a run and tool. + + Args: + run_dir: The run directory. + run: The frontend run name. + tool: The requested tool. + hosts_param: Comma-separated list of selected hosts. + host: The single host parameter. + + Returns: + A tuple containing (selected_hosts, asset_paths). + + Raises: + FileNotFoundError: If a required xplane file for the specified host(s) + is not found. + IOError: If there is an error reading asset directories. + """ + asset_paths = [] + selected_hosts = [] + all_xplane_files = {} # Map host to path + + # Find all available xplane files for the run and map them by host. + file_pattern = make_filename('*', 'xplane') + try: + path = epath.Path(run_dir) + for xplane_path in path.glob(file_pattern): + host_name, _ = _parse_filename(xplane_path.name) + if host_name: + print('host_name: %s', host_name) + all_xplane_files[host_name] = xplane_path + except OSError as e: + print('Error') + logger.warning('Cannot read asset directory: %s, OpError %s', run_dir, e) + raise IOError( + 'Cannot read asset directory: %s, OpError %s' % (run_dir, e) + ) from e + + if hosts_param and self._does_tool_support_multi_hosts_processing(tool): + selected_hosts = hosts_param.split(',') + for selected_host in selected_hosts: + if selected_host in all_xplane_files: + asset_paths.append(all_xplane_files[selected_host]) + else: + raise FileNotFoundError( + 'No xplane file found for host: %s in run: %s' + % (selected_host, run) + ) + logger.info('Inside trace_viewer@, asset_paths: %s') + elif host == ALL_HOSTS: + asset_paths = list(all_xplane_files.values()) + selected_hosts = list(all_xplane_files.keys()) + elif host and host in all_xplane_files: + selected_hosts = [host] + asset_paths = [all_xplane_files[host]] + elif host: + logger.warning('No xplane file found for host: %s in run: %s', host, run) + if host not in XPLANE_TOOLS_ALL_HOSTS_ONLY: + raise FileNotFoundError( + 'No xplane file found for host: %s in run: %s' % (host, run) + ) + + if not asset_paths: + logger.warning( + 'No matching asset paths found for run %s, tool %s, host(s) %s / %s', + run, + tool, + hosts_param, + host, + ) + if not host and tool not in XPLANE_TOOLS_ALL_HOSTS_ONLY: + raise FileNotFoundError( + 'Host must be specified for tool %s in run %s' % (tool, run) + ) + + return selected_hosts, asset_paths, list(all_xplane_files.keys()) + def data_impl( self, request: wrappers.Request ) -> tuple[Optional[str], str, Optional[str]]: @@ -729,9 +795,17 @@ def data_impl( Returns: A string that can be served to the frontend tool or None if tool, run or host is invalid. + + Raises: + FileNotFoundError: If a required xplane file for the specified host(s) + is not found. + IOError: If there is an error reading asset directories. + AttributeError: If there is an error during xplane to tool data conversion + ValueError: If xplane conversion fails due to invalid data. """ run = request.args.get('run') tool = request.args.get('tag') + hosts_param = request.args.get('hosts') host = request.args.get('host') module_name = request.args.get('module_name') tqx = request.args.get('tqx') @@ -795,28 +869,19 @@ def data_impl( options['search_prefix'] = request.args.get('search_prefix') params['trace_viewer_options'] = options - asset_path = os.path.join(run_dir, make_filename(host, tool)) - _, content_encoding = None, None if use_xplane(tool): - if host == ALL_HOSTS: - file_pattern = make_filename('*', 'xplane') - try: - path = epath.Path(run_dir) - asset_paths = list(path.glob(file_pattern)) - except OSError as e: - logger.warning('Cannot read asset directory: %s, OpError %s', run_dir, - e) - raise IOError( - 'Cannot read asset directory: %s, OpError %s' % (run_dir, e) - ) from e - else: - asset_paths = [asset_path] + selected_hosts, asset_paths, all_hosts = self._get_valid_hosts( + run_dir, run, tool, hosts_param, host + ) + if not asset_paths: + return None, content_type, None + params['hosts'] = selected_hosts try: - validate_xplane_asset_paths(asset_paths) data, content_type = convert.xspace_to_tool_data( - asset_paths, tool, params) + asset_paths, all_hosts, tool, params + ) except AttributeError as e: logger.warning('Error generating analysis results due to %s', e) raise AttributeError( diff --git a/plugin/xprof/profile_plugin_test.py b/plugin/xprof/profile_plugin_test.py index 0f44db974..30dba19e8 100644 --- a/plugin/xprof/profile_plugin_test.py +++ b/plugin/xprof/profile_plugin_test.py @@ -330,7 +330,9 @@ def testData(self): with self.assertRaises(FileNotFoundError): self.plugin.data_impl( utils.make_data_request( - utils.DataRequestOptions(run='a', tool='trace_viewer', host='') + utils.DataRequestOptions( + run='a/foo', tool='trace_viewer', host='' + ) ) ) @@ -445,6 +447,7 @@ def testDataImplTraceViewerOptions(self, mock_xspace_to_tool_data): 'start_time_ms': '100', 'end_time_ms': '200', }, + 'hosts': ['host1'], } _, _, _ = self.plugin.data_impl( @@ -462,8 +465,11 @@ def testDataImplTraceViewerOptions(self, mock_xspace_to_tool_data): ) mock_xspace_to_tool_data.assert_called_once_with( - [expected_asset_path], 'trace_viewer@', expected_params + [mock.ANY], ['host0', 'host1'], 'trace_viewer@', expected_params ) + actual_path_list = mock_xspace_to_tool_data.call_args[0][0] + self.assertLen(actual_path_list, 1) + self.assertEqual(str(actual_path_list[0]), expected_asset_path) def testActive(self): @@ -535,8 +541,10 @@ def test_generate_runs_from_path_params_with_run_path(self): # run3 is a file, not a directory, and should be ignored. with open(os.path.join(run_path, 'run3'), 'w') as f: f.write('dummy file') + with open(os.path.join(run2_path, 'host2.xplane.pb'), 'w') as f: + f.write('dummy xplane data for run2') runs = list(self.plugin._generate_runs_from_path_params(run_path=run_path)) - self.assertListEqual(['run1'], runs) + self.assertListEqual(['run1', 'run2'], sorted(runs)) self.assertEqual(run_path, self.plugin.logdir) def test_runs_impl_with_session(self): diff --git a/xprof/convert/BUILD b/xprof/convert/BUILD index 87b0426a1..3dc218a04 100644 --- a/xprof/convert/BUILD +++ b/xprof/convert/BUILD @@ -190,6 +190,7 @@ cc_library( ":repository", ":tool_options", ":xplane_to_trace_container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -1592,12 +1593,14 @@ cc_library( srcs = ["xplane_to_trace_container.cc"], hdrs = ["xplane_to_trace_container.h"], deps = [ + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", "@org_xprof//plugin/xprof/protobuf:trace_events_proto_cc", "@org_xprof//plugin/xprof/protobuf:trace_events_raw_proto_cc", "@org_xprof//xprof/convert/trace_viewer:trace_event_arguments_builder", "@org_xprof//xprof/convert/trace_viewer:trace_events", "@org_xprof//xprof/convert/trace_viewer:trace_events_util", + "@org_xprof//xprof/convert/trace_viewer:trace_utils", "@tsl//tsl/profiler/protobuf:xplane_proto_cc", "@xla//xla/tsl/profiler/utils:tf_xplane_visitor", "@xla//xla/tsl/profiler/utils:timespan", diff --git a/xprof/convert/repository.cc b/xprof/convert/repository.cc index f98a76749..e6887e9e0 100644 --- a/xprof/convert/repository.cc +++ b/xprof/convert/repository.cc @@ -58,7 +58,8 @@ static auto* kHostDataSuffixes = absl::StatusOr SessionSnapshot::Create( std::vector xspace_paths, - std::optional>> xspaces) { + std::optional>> xspaces, + std::optional> all_hosts) { if (xspace_paths.empty()) { return absl::InvalidArgumentError("Can not find XSpace path."); } @@ -85,7 +86,26 @@ absl::StatusOr SessionSnapshot::Create( } } - return SessionSnapshot(std::move(xspace_paths), std::move(xspaces)); + return SessionSnapshot(std::move(xspace_paths), std::move(xspaces), + std::move(all_hosts)); +} + +SessionSnapshot::SessionSnapshot( + std::vector xspace_paths, + std::optional>> xspaces, + std::optional> all_hosts) + : xspace_paths_(std::move(xspace_paths)), + all_hosts_(std::move(all_hosts)), + // If the snapshot was initialized by xspaces, the file path and run dir + // is a path tensorflow can't read from or write to so any file IO + // encapsulated in this class will be disabled in this mode. + has_accessible_run_dir_(!xspaces.has_value()), + xspaces_(std::move(xspaces)) { + session_run_dir_ = tsl::io::Dirname(xspace_paths_.at(0)); + for (size_t i = 0; i < xspace_paths_.size(); ++i) { + std::string host_name = GetHostname(i); + hostname_map_[host_name] = i; + } } absl::StatusOr SessionSnapshot::GetXSpace(size_t index, @@ -126,6 +146,10 @@ std::string SessionSnapshot::GetHostname(size_t index) const { return GetHostnameByPath(xspace_paths_.at(index)); } +std::optional> SessionSnapshot::GetAllHosts() const { + return all_hosts_; +} + std::optional SessionSnapshot::GetFilePath( absl::string_view toolname, absl::string_view hostname) const { if (!has_accessible_run_dir_) return std::nullopt; diff --git a/xprof/convert/repository.h b/xprof/convert/repository.h index 07b649378..46d7d40f4 100644 --- a/xprof/convert/repository.h +++ b/xprof/convert/repository.h @@ -8,7 +8,7 @@ You may obtain a copy of the License at Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +WITHOUTHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ @@ -59,7 +59,8 @@ class SessionSnapshot { // profiler plugin. static absl::StatusOr Create( std::vector xspace_paths, - std::optional>> xspaces); + std::optional>> xspaces, + std::optional> all_hosts = std::nullopt); // Returns the number of XSpaces in the profile session. size_t XSpaceSize() const { return xspace_paths_.size(); } @@ -76,6 +77,9 @@ class SessionSnapshot { // Gets host name. std::string GetHostname(size_t index) const; + // Gets all host names. + std::optional> GetAllHosts() const; + // Gets the run directory of the profile session. absl::string_view GetSessionRunDir() const { return session_run_dir_; } @@ -142,22 +146,15 @@ class SessionSnapshot { private: SessionSnapshot(std::vector xspace_paths, - std::optional>> xspaces) - : xspace_paths_(std::move(xspace_paths)), - // If the snapshot was initialized by xspaces, the file path and run dir - // is a path tensorflow can't read from or write to so any file IO - // encapsulated in this class will be disabled in this mode. - has_accessible_run_dir_(!xspaces.has_value()), - xspaces_(std::move(xspaces)) { - session_run_dir_ = tsl::io::Dirname(xspace_paths_.at(0)); - for (size_t i = 0; i < xspace_paths_.size(); ++i) { - std::string host_name = GetHostname(i); - hostname_map_[host_name] = i; - } - } + std::optional>> xspaces, + std::optional> all_hosts); // File paths to XSpace protos. std::vector xspace_paths_; + + // All hosts in the session. + std::optional> all_hosts_; + // The run directory of the profile session. absl::string_view session_run_dir_; diff --git a/xprof/convert/streaming_trace_viewer_processor.cc b/xprof/convert/streaming_trace_viewer_processor.cc index 954885cf7..c872d3431 100644 --- a/xprof/convert/streaming_trace_viewer_processor.cc +++ b/xprof/convert/streaming_trace_viewer_processor.cc @@ -5,6 +5,7 @@ #include #include +#include "absl/container/flat_hash_map.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/numbers.h" @@ -67,65 +68,78 @@ absl::StatusOr GetTraceViewOption(const ToolOptions& options) { absl::Status StreamingTraceViewerProcessor::ProcessSession( const SessionSnapshot& session_snapshot, const ToolOptions& options) { - if (session_snapshot.XSpaceSize() != 1) { - return tsl::errors::InvalidArgument( - "Trace events tool expects only 1 XSpace path but gets ", - session_snapshot.XSpaceSize()); - } - - google::protobuf::Arena arena; - TF_ASSIGN_OR_RETURN(XSpace * xspace, session_snapshot.GetXSpace(0, &arena)); - PreprocessSingleHostXSpace(xspace, /*step_grouping=*/true, - /*derived_timeline=*/true); - + TraceEventsContainer merged_trace_container; std::string tool_name = "trace_viewer@"; - std::string trace_viewer_json; - std::string host_name = session_snapshot.GetHostname(0); - auto sstable_path = session_snapshot.GetFilePath(tool_name, host_name); - if (!sstable_path) { - return tsl::errors::Unimplemented( - "streaming trace viewer hasn't been supported in Cloud AI"); - } - if (!tsl::Env::Default()->FileExists(*sstable_path).ok()) { - ProcessMegascaleDcn(xspace); - TraceEventsContainer trace_container; - ConvertXSpaceToTraceEventsContainer(host_name, *xspace, &trace_container); - std::unique_ptr file; - TF_RETURN_IF_ERROR( - tsl::Env::Default()->NewWritableFile(*sstable_path, &file)); - TF_RETURN_IF_ERROR(trace_container.StoreAsLevelDbTable(std::move(file))); - } TF_ASSIGN_OR_RETURN(TraceViewOption trace_option, GetTraceViewOption(options)); tensorflow::profiler::TraceOptions profiler_trace_options = TraceOptionsFromToolOptions(options); - auto visibility_filter = std::make_unique( - tsl::profiler::MilliSpan(trace_option.start_time_ms, - trace_option.end_time_ms), - trace_option.resolution, profiler_trace_options); - TraceEventsContainer trace_container; - // Trace smaller than threshold will be disabled from streaming. - constexpr int64_t kDisableStreamingThreshold = 500000; - auto trace_events_filter = - CreateTraceEventsFilterFromTraceOptions(profiler_trace_options); - TraceEventsLevelDbFilePaths file_paths; - file_paths.trace_events_file_path = *sstable_path; - TF_RETURN_IF_ERROR(trace_container.LoadFromLevelDbTable( - file_paths, std::move(trace_events_filter), std::move(visibility_filter), - kDisableStreamingThreshold)); + + absl::flat_hash_map host_to_id_map; + if (auto all_hosts = session_snapshot.GetAllHosts()) { + for (int i = 0; i < all_hosts->size(); ++i) { + host_to_id_map[(*all_hosts)[i]] = i; + } + } + + // TODO(b/452217676) : Optimize this to process hosts in parallel. + for (int i = 0; i < session_snapshot.XSpaceSize(); ++i) { + int host_id = host_to_id_map[session_snapshot.GetHostname(i)]; + LOG(INFO) << "Processing host: " << session_snapshot.GetHostname(i) + << " with host_id: " << host_id; + google::protobuf::Arena arena; + TF_ASSIGN_OR_RETURN(XSpace * xspace, session_snapshot.GetXSpace(i, &arena)); + PreprocessSingleHostXSpace(xspace, /*step_grouping=*/true, + /*derived_timeline=*/true); + + std::string host_name = session_snapshot.GetHostname(i); + auto sstable_path = session_snapshot.GetFilePath(tool_name, host_name); + if (!sstable_path) { + return tsl::errors::Unimplemented( + "streaming trace viewer hasn't been supported in Cloud AI"); + } + if (!tsl::Env::Default()->FileExists(*sstable_path).ok()) { + ProcessMegascaleDcn(xspace); + TraceEventsContainer trace_container; + ConvertXSpaceToTraceEventsContainer(host_name, host_id, *xspace, + &trace_container); + std::unique_ptr file; + TF_RETURN_IF_ERROR( + tsl::Env::Default()->NewWritableFile(*sstable_path, &file)); + TF_RETURN_IF_ERROR(trace_container.StoreAsLevelDbTable(std::move(file))); + } + + auto visibility_filter = std::make_unique( + tsl::profiler::MilliSpan(trace_option.start_time_ms, + trace_option.end_time_ms), + trace_option.resolution, profiler_trace_options); + TraceEventsContainer trace_container; + // Trace smaller than threshold will be disabled from streaming. + constexpr int64_t kDisableStreamingThreshold = 500000; + auto trace_events_filter = + CreateTraceEventsFilterFromTraceOptions(profiler_trace_options); + TraceEventsLevelDbFilePaths file_paths; + file_paths.trace_events_file_path = *sstable_path; + TF_RETURN_IF_ERROR(trace_container.LoadFromLevelDbTable( + file_paths, std::move(trace_events_filter), + std::move(visibility_filter), kDisableStreamingThreshold)); + merged_trace_container.Merge(std::move(trace_container)); + } + + std::string trace_viewer_json; JsonTraceOptions json_trace_options; tensorflow::profiler::TraceDeviceType device_type = tensorflow::profiler::TraceDeviceType::kUnknownDevice; - if (IsTpuTrace(trace_container.trace())) { + if (IsTpuTrace(merged_trace_container.trace())) { device_type = TraceDeviceType::kTpu; } json_trace_options.details = TraceOptionsToDetails(device_type, profiler_trace_options); IOBufferAdapter adapter(&trace_viewer_json); TraceEventsToJson( - json_trace_options, trace_container, &adapter); + json_trace_options, merged_trace_container, &adapter); SetOutput(trace_viewer_json, "application/json"); return absl::OkStatus(); diff --git a/xprof/convert/trace_viewer/BUILD b/xprof/convert/trace_viewer/BUILD index d0020296e..54ca46f03 100644 --- a/xprof/convert/trace_viewer/BUILD +++ b/xprof/convert/trace_viewer/BUILD @@ -133,6 +133,7 @@ cc_library( ":trace_events_filter_interface", ":trace_events_util", ":trace_viewer_visibility", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:endian", "@com_google_absl//absl/container:flat_hash_map", diff --git a/xprof/convert/trace_viewer/trace_events.cc b/xprof/convert/trace_viewer/trace_events.cc index fa5bce32d..650ac914c 100644 --- a/xprof/convert/trace_viewer/trace_events.cc +++ b/xprof/convert/trace_viewer/trace_events.cc @@ -19,13 +19,17 @@ limitations under the License. #include #include #include +#include #include #include #include #include +#include #include +#include "absl/algorithm/container.h" #include "absl/base/internal/endian.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -77,6 +81,15 @@ void MaybeAddEventUniqueId(std::vector& events) { } } +// Appends all events from src into dst. +inline void AppendEvents(TraceEventTrack&& src, TraceEventTrack* dst) { + if (dst->empty()) { + *dst = std::move(src); + } else { + absl::c_move(src, std::back_inserter(*dst)); + } +} + } // namespace TraceEvent::EventType GetTraceEventType(const TraceEvent& event) { @@ -293,5 +306,51 @@ void PurgeIrrelevantEntriesInTraceNameTable( trace.mutable_name_table()->swap(new_name_table); } +template +void TraceEventsContainerBase::MergeTrace( + const Trace& other_trace) { + trace_.mutable_tasks()->insert(other_trace.tasks().begin(), + other_trace.tasks().end()); + trace_.mutable_name_table()->insert(other_trace.name_table().begin(), + other_trace.name_table().end()); + if (other_trace.has_min_timestamp_ps() && + other_trace.has_max_timestamp_ps()) { + ExpandTraceSpan(TraceSpan(other_trace), &trace_); + } + trace_.set_num_events(trace_.num_events() + other_trace.num_events()); +} + +template +void TraceEventsContainerBase::Merge( + TraceEventsContainerBase&& other) { + if (this == &other) return; + if (other.NumEvents() == 0 && other.trace().devices().empty()) return; + + auto& this_device_map = *trace_.mutable_devices(); + for (const auto& [other_id, other_device] : other.trace().devices()) { + this_device_map.insert({other_id, other_device}); + } + + other.ForAllMutableTracks([this](uint32_t other_device_id, + ResourceValue resource_id_or_counter_name, + TraceEventTrack* track) { + DeviceEvents& device = this->events_by_device_[other_device_id]; + if (uint64_t* resource_id = + std::get_if(&resource_id_or_counter_name)) { + AppendEvents(std::move(*track), &device.events_by_resource[*resource_id]); + } else if (absl::string_view* counter_name = std::get_if( + &resource_id_or_counter_name)) { + AppendEvents(std::move(*track), + &device.counter_events_by_name[*counter_name]); + } + }); + + MergeTrace(other.trace()); + arenas_.insert(other.arenas_.begin(), other.arenas_.end()); +} + +// Explicit instantiations for the common case. +template class TraceEventsContainerBase; + } // namespace profiler } // namespace tensorflow diff --git a/xprof/convert/trace_viewer/trace_events.h b/xprof/convert/trace_viewer/trace_events.h index 99a68d3b9..d42589559 100644 --- a/xprof/convert/trace_viewer/trace_events.h +++ b/xprof/convert/trace_viewer/trace_events.h @@ -729,6 +729,8 @@ class TraceEventsContainerBase { TraceEventsContainerBase(const TraceEventsContainerBase&) = delete; TraceEventsContainerBase& operator=(const TraceEventsContainerBase&) = delete; + void Merge(TraceEventsContainerBase&& other); + // Creates a TraceEvent prefilled with the given values. void AddCompleteEvent(absl::string_view name, uint64_t resource_id, uint32_t device_id, tsl::profiler::Timespan timespan, @@ -1075,6 +1077,9 @@ class TraceEventsContainerBase { return copy; } + // Helper function to merge top-level trace metadata. + void MergeTrace(const Trace& other_trace); + // Adds an event from arenas_ to events_by_device_. void AddArenaEvent(TraceEvent* event) { ExpandTraceSpan(EventSpan(*event), &trace_); diff --git a/xprof/convert/trace_viewer/trace_utils.h b/xprof/convert/trace_viewer/trace_utils.h index 783fd8420..eaa9b0d3d 100644 --- a/xprof/convert/trace_viewer/trace_utils.h +++ b/xprof/convert/trace_viewer/trace_utils.h @@ -37,6 +37,8 @@ inline bool MaybeTpuNonCoreDeviceName(absl::string_view device_name) { IsTpuIciRouterDeviceName(device_name)); } +static constexpr int kMaxDevicesPerHost = 1000; + } // namespace profiler } // namespace tensorflow diff --git a/xprof/convert/xplane_to_tools_data.cc b/xprof/convert/xplane_to_tools_data.cc index 4ecce1108..986fa6a70 100644 --- a/xprof/convert/xplane_to_tools_data.cc +++ b/xprof/convert/xplane_to_tools_data.cc @@ -177,7 +177,10 @@ absl::StatusOr ConvertXSpaceToTraceEvents( if (!tsl::Env::Default()->FileExists(*trace_events_sstable_path).ok()) { ProcessMegascaleDcn(xspace); TraceEventsContainer trace_container; - ConvertXSpaceToTraceEventsContainer(host_name, *xspace, &trace_container); + // No-op method which will be deprecated in the future, thus added + // /*host_id=*/1 as a placeholder for now. + ConvertXSpaceToTraceEventsContainer(host_name, 1, *xspace, + &trace_container); std::unique_ptr trace_events_file; TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile( *trace_events_sstable_path, &trace_events_file)); diff --git a/xprof/convert/xplane_to_trace_container.cc b/xprof/convert/xplane_to_trace_container.cc index aba9c186c..29919bd1f 100644 --- a/xprof/convert/xplane_to_trace_container.cc +++ b/xprof/convert/xplane_to_trace_container.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/base/optimization.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/tsl/profiler/utils/tf_xplane_visitor.h" @@ -32,6 +33,7 @@ limitations under the License. #include "xla/tsl/profiler/utils/xplane_visitor.h" #include "xprof/convert/trace_viewer/trace_event_arguments_builder.h" #include "xprof/convert/trace_viewer/trace_events_util.h" +#include "xprof/convert/trace_viewer/trace_utils.h" #include "plugin/xprof/protobuf/trace_events.pb.h" #include "plugin/xprof/protobuf/trace_events_raw.pb.h" @@ -219,7 +221,7 @@ void ConvertXPlaneToTraceEventsContainer(uint64_t device_id, } // namespace void ConvertXSpaceToTraceEventsContainer(absl::string_view hostname, - const XSpace& space, + int host_id, const XSpace& space, TraceEventsContainer* container) { const XPlane* host_plane = FindPlaneWithName(space, tsl::profiler::kHostThreadsPlaneName); @@ -236,9 +238,13 @@ void ConvertXSpaceToTraceEventsContainer(absl::string_view hostname, } for (const XPlane* device_plane : device_planes) { - ConvertXPlaneToTraceEventsContainer( - tsl::profiler::kFirstDeviceId + device_plane->id(), hostname, - *device_plane, container); + uint32_t device_pid = tsl::profiler::kFirstDeviceId + device_plane->id(); + if (ABSL_PREDICT_FALSE(device_pid > tsl::profiler::kLastDeviceId)) { + device_pid = tsl::profiler::kFirstDeviceId; + } + uint32_t final_device_pid = (host_id)*kMaxDevicesPerHost + device_pid; + ConvertXPlaneToTraceEventsContainer(final_device_pid, hostname, + *device_plane, container); } for (const XPlane* custom_plane : FindPlanesWithPrefix(space, tsl::profiler::kCustomPlanePrefix)) { diff --git a/xprof/convert/xplane_to_trace_container.h b/xprof/convert/xplane_to_trace_container.h index c0f0272b3..01af9ad94 100644 --- a/xprof/convert/xplane_to_trace_container.h +++ b/xprof/convert/xplane_to_trace_container.h @@ -28,6 +28,7 @@ using TraceEventsContainer = TraceEventsContainerBase; // Converts XEvents within the XSpace into trace_viewer events container. void ConvertXSpaceToTraceEventsContainer(absl::string_view hostname, + int host_id, const XSpace& xspace, TraceEventsContainer* container); diff --git a/xprof/convert/xplane_to_trace_container_test.cc b/xprof/convert/xplane_to_trace_container_test.cc index 4e0a5cf3e..cd9f97618 100644 --- a/xprof/convert/xplane_to_trace_container_test.cc +++ b/xprof/convert/xplane_to_trace_container_test.cc @@ -96,7 +96,7 @@ TEST(XPlaneToTraceContainerTest, CounterLine) { tsl::profiler::UniToNano(1), tsl::profiler::UniToNano(500)), &xspace)); TraceEventsContainer container; - ConvertXSpaceToTraceEventsContainer("localhost", xspace, &container); + ConvertXSpaceToTraceEventsContainer("localhost", 0, xspace, &container); absl::flat_hash_map> counter_offset_to_values; container.ForAllEvents([&counter_offset_to_values](const TraceEvent& event) { @@ -142,7 +142,7 @@ TEST(XPlaneToTraceContainerTest, AsyncLine) { tsl::profiler::kXlaAsyncOpLineName, kAsyncOpEventName), &xspace)); TraceEventsContainer container; - ConvertXSpaceToTraceEventsContainer("localhost", xspace, &container); + ConvertXSpaceToTraceEventsContainer("localhost", 0, xspace, &container); bool async_event_found = false; container.ForAllEvents( [&async_event_found, &kAsyncOpEventName](const TraceEvent& event) { diff --git a/xprof/pywrap/_pywrap_profiler_plugin.pyi b/xprof/pywrap/_pywrap_profiler_plugin.pyi index 6be031344..84118493e 100644 --- a/xprof/pywrap/_pywrap_profiler_plugin.pyi +++ b/xprof/pywrap/_pywrap_profiler_plugin.pyi @@ -15,8 +15,8 @@ def monitor(arg0: str, arg1: int, arg2: int, arg3: bool) -> str: ... def trace(arg0: str, arg1: str, arg2: str, arg3: bool, arg4: int, arg5: int, arg6: dict) -> None: ... -def xspace_to_tools_data(arg0: list, arg1: str, arg2: dict = ...) -> tuple: ... -def xspace_to_tools_data_from_byte_string(arg0: list, arg1: list, arg2: str, arg3: dict) -> tuple: ... +def xspace_to_tools_data(arg0: list, arg1: list, arg2: str, arg3: dict = ...) -> tuple: ... +def xspace_to_tools_data_from_byte_string(arg0: list, arg1: list, arg2: list, arg3: str, arg4: dict) -> tuple: ... def start_grpc_server(port: int) -> None: ... def initialize_stubs(worker_service_addresses: str) -> None: ... diff --git a/xprof/pywrap/profiler_plugin_impl.cc b/xprof/pywrap/profiler_plugin_impl.cc index ccf5c2a43..145ade5b3 100644 --- a/xprof/pywrap/profiler_plugin_impl.cc +++ b/xprof/pywrap/profiler_plugin_impl.cc @@ -124,16 +124,18 @@ void StartGrpcServer(int port) { } absl::StatusOr> XSpaceToToolsData( - std::vector xspace_paths, const std::string& tool_name, + std::vector xspace_paths, + std::vector all_hosts, const std::string& tool_name, const ToolOptions& tool_options) { auto status_or_session_snapshot = SessionSnapshot::Create( - std::move(xspace_paths), /*xspaces=*/std::nullopt); + std::move(xspace_paths), /*xspaces=*/std::nullopt, all_hosts); return SessionSnapshotToToolsData(status_or_session_snapshot, tool_name, tool_options); } absl::StatusOr> XSpaceToToolsDataFromByteString( std::vector xspace_strings, + std::vector all_hosts, std::vector xspace_paths, const std::string& tool_name, const ToolOptions& tool_options) { std::vector> xspaces; @@ -154,7 +156,8 @@ absl::StatusOr> XSpaceToToolsDataFromByteString( } auto status_or_session_snapshot = - SessionSnapshot::Create(std::move(xspace_paths), std::move(xspaces)); + SessionSnapshot::Create(std::move(xspace_paths), std::move(xspaces), + all_hosts); return SessionSnapshotToToolsData(status_or_session_snapshot, tool_name, tool_options); } diff --git a/xprof/pywrap/profiler_plugin_impl.h b/xprof/pywrap/profiler_plugin_impl.h index 03b13b23a..9c7562b13 100644 --- a/xprof/pywrap/profiler_plugin_impl.h +++ b/xprof/pywrap/profiler_plugin_impl.h @@ -36,11 +36,13 @@ absl::Status Monitor(const char* service_addr, int duration_ms, tsl::string* result); absl::StatusOr> XSpaceToToolsData( - std::vector xspace_paths, const std::string& tool_name, + std::vector xspace_paths, + std::vector all_hosts, const std::string& tool_name, const tensorflow::profiler::ToolOptions& tool_options); absl::StatusOr> XSpaceToToolsDataFromByteString( std::vector xspace_strings, + std::vector all_hosts, std::vector xspace_paths, const std::string& tool_name, const tensorflow::profiler::ToolOptions& tool_options); diff --git a/xprof/pywrap/pywrap_profiler_plugin.cc b/xprof/pywrap/pywrap_profiler_plugin.cc index 7e4dd69f1..612c39de4 100644 --- a/xprof/pywrap/pywrap_profiler_plugin.cc +++ b/xprof/pywrap/pywrap_profiler_plugin.cc @@ -88,22 +88,27 @@ PYBIND11_MODULE(_pywrap_profiler_plugin, m) { m.def( "xspace_to_tools_data", - [](const py::list& xspace_path_list, const py::str& py_tool_name, - const py::dict options = py::dict()) { + [](const py::list& xspace_path_list, const py::list& all_hosts_list, + const py::str& py_tool_name, const py::dict options = py::dict()) { std::vector xspace_paths; xspace_paths.reserve(xspace_path_list.size()); for (py::handle obj : xspace_path_list) { std::string xspace_path = std::string(py::cast(obj)); xspace_paths.push_back(xspace_path); } + std::vector all_hosts; + all_hosts.reserve(all_hosts_list.size()); + for (py::handle obj : all_hosts_list) { + all_hosts.push_back(std::string(py::cast(obj))); + } std::string tool_name = std::string(py_tool_name); ToolOptions tool_options = ToolOptionsFromPythonDict(options); absl::StatusOr> result; { py::gil_scoped_release release; - result = xprof::pywrap::XSpaceToToolsData(xspace_paths, tool_name, - tool_options); + result = xprof::pywrap::XSpaceToToolsData(xspace_paths, all_hosts, + tool_name, tool_options); } if (!result.ok()) { @@ -112,18 +117,25 @@ PYBIND11_MODULE(_pywrap_profiler_plugin, m) { return py::make_tuple(py::bytes(result->first), py::bool_(result->second)); }, - py::arg(), py::arg(), py::arg() = py::dict()); + py::arg(), py::arg(), py::arg(), py::arg() = py::dict()); m.def( "xspace_to_tools_data_from_byte_string", - [](const py::list& xspace_string_list, const py::list& filenames_list, - const py::str& py_tool_name, const py::dict options = py::dict()) { + [](const py::list& xspace_string_list, const py::list& all_hosts_list, + const py::list& filenames_list, const py::str& py_tool_name, + const py::dict options = py::dict()) { std::vector xspace_strings; xspace_strings.reserve(xspace_string_list.size()); for (py::handle obj : xspace_string_list) { xspace_strings.push_back(std::string(py::cast(obj))); } + std::vector all_hosts; + all_hosts.reserve(all_hosts_list.size()); + for (py::handle obj : all_hosts_list) { + all_hosts.push_back(std::string(py::cast(obj))); + } + std::vector xspace_paths; xspace_paths.reserve(filenames_list.size()); for (py::handle obj : filenames_list) { @@ -138,7 +150,7 @@ PYBIND11_MODULE(_pywrap_profiler_plugin, m) { { py::gil_scoped_release release; result = xprof::pywrap::XSpaceToToolsDataFromByteString( - xspace_strings, xspace_paths, tool_name, tool_options); + xspace_strings, all_hosts, xspace_paths, tool_name, tool_options); } if (!result.ok()) { @@ -147,7 +159,7 @@ PYBIND11_MODULE(_pywrap_profiler_plugin, m) { return py::make_tuple(py::bytes(result->first), py::bool_(result->second)); }, - py::arg(), py::arg(), py::arg(), py::arg() = py::dict()); + py::arg(), py::arg(), py::arg(), py::arg(), py::arg() = py::dict()); m.def("start_grpc_server", [](int port) { py::gil_scoped_release release;