11#include " xprof/convert/streaming_trace_viewer_processor.h"
22
3+ #include < cmath>
34#include < cstdint>
45#include < memory>
56#include < string>
@@ -46,6 +47,10 @@ struct TraceViewOption {
4647 uint64_t resolution = 0 ;
4748 double start_time_ms = 0.0 ;
4849 double end_time_ms = 0.0 ;
50+ std::string event_name = " " ;
51+ std::string search_prefix = " " ;
52+ double duration_ms = 0.0 ;
53+ uint64_t unique_id = 0 ;
4954};
5055
5156absl::StatusOr<TraceViewOption> GetTraceViewOption (const ToolOptions& options) {
@@ -56,10 +61,21 @@ absl::StatusOr<TraceViewOption> GetTraceViewOption(const ToolOptions& options) {
5661 GetParamWithDefault<std::string>(options, " end_time_ms" , " 0.0" );
5762 auto resolution_opt =
5863 GetParamWithDefault<std::string>(options, " resolution" , " 0" );
64+ trace_options.event_name =
65+ GetParamWithDefault<std::string>(options, " event_name" , " " );
66+ trace_options.search_prefix =
67+ GetParamWithDefault<std::string>(options, " search_prefix" , " " );
68+ auto duration_ms_opt =
69+ GetParamWithDefault<std::string>(options, " duration_ms" , " 0.0" );
70+ auto unique_id_opt =
71+ GetParamWithDefault<std::string>(options, " unique_id" , " 0" );
72+
5973
6074 if (!absl::SimpleAtoi (resolution_opt, &trace_options.resolution ) ||
6175 !absl::SimpleAtod (start_time_ms_opt, &trace_options.start_time_ms ) ||
62- !absl::SimpleAtod (end_time_ms_opt, &trace_options.end_time_ms )) {
76+ !absl::SimpleAtod (end_time_ms_opt, &trace_options.end_time_ms ) ||
77+ !absl::SimpleAtoi (unique_id_opt, &trace_options.unique_id ) ||
78+ !absl::SimpleAtod (duration_ms_opt, &trace_options.duration_ms )) {
6379 return tsl::errors::InvalidArgument (" wrong arguments" );
6480 }
6581 return trace_options;
@@ -84,36 +100,81 @@ absl::Status StreamingTraceViewerProcessor::ProcessSession(
84100 /* derived_timeline=*/ true );
85101
86102 std::string host_name = session_snapshot.GetHostname (i);
87- auto sstable_path = session_snapshot.GetFilePath (tool_name, host_name);
88- if (!sstable_path) {
103+ auto trace_events_sstable_path = session_snapshot.MakeHostDataFilePath (
104+ tensorflow::profiler::StoredDataType::TRACE_LEVELDB, host_name);
105+ auto trace_events_metadata_sstable_path =
106+ session_snapshot.MakeHostDataFilePath (
107+ tensorflow::profiler::StoredDataType::TRACE_EVENTS_METADATA_LEVELDB,
108+ host_name);
109+ auto trace_events_prefix_trie_sstable_path =
110+ session_snapshot.MakeHostDataFilePath (
111+ tensorflow::profiler::StoredDataType::
112+ TRACE_EVENTS_PREFIX_TRIE_LEVELDB,
113+ host_name);
114+ if (!trace_events_sstable_path || !trace_events_metadata_sstable_path ||
115+ !trace_events_prefix_trie_sstable_path) {
89116 return tsl::errors::Unimplemented (
90117 " streaming trace viewer hasn't been supported in Cloud AI" );
91118 }
92- if (!tsl::Env::Default ()->FileExists (*sstable_path ).ok ()) {
119+ if (!tsl::Env::Default ()->FileExists (*trace_events_sstable_path ).ok ()) {
93120 ProcessMegascaleDcn (xspace);
94121 TraceEventsContainer trace_container;
95122 ConvertXSpaceToTraceEventsContainer (host_name, *xspace,
96123 &trace_container);
97- std::unique_ptr<tsl::WritableFile> file;
98- TF_RETURN_IF_ERROR (
99- tsl::Env::Default ()->NewWritableFile (*sstable_path, &file));
100- TF_RETURN_IF_ERROR (trace_container.StoreAsLevelDbTable (std::move (file)));
124+ std::unique_ptr<tsl::WritableFile> trace_events_file;
125+ TF_RETURN_IF_ERROR (tsl::Env::Default ()->NewWritableFile (
126+ *trace_events_sstable_path, &trace_events_file));
127+ std::unique_ptr<tsl::WritableFile> trace_events_metadata_file;
128+ TF_RETURN_IF_ERROR (tsl::Env::Default ()->NewWritableFile (
129+ *trace_events_metadata_sstable_path, &trace_events_metadata_file));
130+ std::unique_ptr<tsl::WritableFile> trace_events_prefix_trie_file;
131+ TF_RETURN_IF_ERROR (tsl::Env::Default ()->NewWritableFile (
132+ *trace_events_prefix_trie_sstable_path,
133+ &trace_events_prefix_trie_file));
134+ TF_RETURN_IF_ERROR (trace_container.StoreAsLevelDbTables (
135+ std::move (trace_events_file),
136+ std::move (trace_events_metadata_file),
137+ std::move (trace_events_prefix_trie_file)
138+ ));
101139 }
102140
103- auto visibility_filter = std::make_unique<TraceVisibilityFilter>(
104- tsl::profiler::MilliSpan (trace_option.start_time_ms ,
105- trace_option.end_time_ms ),
106- trace_option.resolution , profiler_trace_options);
107- TraceEventsContainer trace_container;
108- // Trace smaller than threshold will be disabled from streaming.
109- constexpr int64_t kDisableStreamingThreshold = 500000 ;
110- auto trace_events_filter =
111- CreateTraceEventsFilterFromTraceOptions (profiler_trace_options);
112141 TraceEventsLevelDbFilePaths file_paths;
113- file_paths.trace_events_file_path = *sstable_path;
114- TF_RETURN_IF_ERROR (trace_container.LoadFromLevelDbTable (
115- file_paths, std::move (trace_events_filter),
116- std::move (visibility_filter), kDisableStreamingThreshold ));
142+ file_paths.trace_events_file_path = *trace_events_sstable_path;
143+ file_paths.trace_events_metadata_file_path =
144+ *trace_events_metadata_sstable_path;
145+ file_paths.trace_events_prefix_trie_file_path =
146+ *trace_events_prefix_trie_sstable_path;
147+
148+ TraceEventsContainer trace_container;
149+ if (!trace_option.event_name .empty ()) {
150+ TF_RETURN_IF_ERROR (trace_container.ReadFullEventFromLevelDbTable (
151+ *trace_events_metadata_sstable_path, *trace_events_sstable_path,
152+ trace_option.event_name ,
153+ static_cast <uint64_t >(std::round (trace_option.start_time_ms * 1E9 )),
154+ static_cast <uint64_t >(std::round (trace_option.duration_ms * 1E9 )),
155+ trace_option.unique_id ));
156+ } else if (!trace_option.search_prefix .empty ()) { // Search Events Request
157+ if (tsl::Env::Default ()
158+ ->FileExists (*trace_events_prefix_trie_sstable_path).ok ()) {
159+ auto trace_events_filter =
160+ CreateTraceEventsFilterFromTraceOptions (profiler_trace_options);
161+ TF_RETURN_IF_ERROR (trace_container.SearchInLevelDbTable (
162+ file_paths,
163+ trace_option.search_prefix , std::move (trace_events_filter)));
164+ }
165+ } else {
166+ auto visibility_filter = std::make_unique<TraceVisibilityFilter>(
167+ tsl::profiler::MilliSpan (trace_option.start_time_ms ,
168+ trace_option.end_time_ms ),
169+ trace_option.resolution , profiler_trace_options);
170+ // Trace smaller than threshold will be disabled from streaming.
171+ constexpr int64_t kDisableStreamingThreshold = 500000 ;
172+ auto trace_events_filter =
173+ CreateTraceEventsFilterFromTraceOptions (profiler_trace_options);
174+ TF_RETURN_IF_ERROR (trace_container.LoadFromLevelDbTable (
175+ file_paths, std::move (trace_events_filter),
176+ std::move (visibility_filter), kDisableStreamingThreshold ));
177+ }
117178 merged_trace_container.Merge (std::move (trace_container), host_id);
118179 }
119180
0 commit comments