Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 82 additions & 21 deletions xprof/convert/streaming_trace_viewer_processor.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "xprof/convert/streaming_trace_viewer_processor.h"

#include <cmath>
#include <cstdint>
#include <memory>
#include <string>
Expand Down Expand Up @@ -46,6 +47,10 @@ struct TraceViewOption {
uint64_t resolution = 0;
double start_time_ms = 0.0;
double end_time_ms = 0.0;
std::string event_name = "";
std::string search_prefix = "";
double duration_ms = 0.0;
uint64_t unique_id = 0;
};

absl::StatusOr<TraceViewOption> GetTraceViewOption(const ToolOptions& options) {
Expand All @@ -56,10 +61,21 @@ absl::StatusOr<TraceViewOption> GetTraceViewOption(const ToolOptions& options) {
GetParamWithDefault<std::string>(options, "end_time_ms", "0.0");
auto resolution_opt =
GetParamWithDefault<std::string>(options, "resolution", "0");
trace_options.event_name =
GetParamWithDefault<std::string>(options, "event_name", "");
trace_options.search_prefix =
GetParamWithDefault<std::string>(options, "search_prefix", "");
auto duration_ms_opt =
GetParamWithDefault<std::string>(options, "duration_ms", "0.0");
auto unique_id_opt =
GetParamWithDefault<std::string>(options, "unique_id", "0");


if (!absl::SimpleAtoi(resolution_opt, &trace_options.resolution) ||
!absl::SimpleAtod(start_time_ms_opt, &trace_options.start_time_ms) ||
!absl::SimpleAtod(end_time_ms_opt, &trace_options.end_time_ms)) {
!absl::SimpleAtod(end_time_ms_opt, &trace_options.end_time_ms) ||
!absl::SimpleAtoi(unique_id_opt, &trace_options.unique_id) ||
!absl::SimpleAtod(duration_ms_opt, &trace_options.duration_ms)) {
return tsl::errors::InvalidArgument("wrong arguments");
}
return trace_options;
Expand All @@ -84,36 +100,81 @@ absl::Status StreamingTraceViewerProcessor::ProcessSession(
/*derived_timeline=*/true);

std::string host_name = session_snapshot.GetHostname(i);
auto sstable_path = session_snapshot.GetFilePath(tool_name, host_name);
if (!sstable_path) {
auto trace_events_sstable_path = session_snapshot.MakeHostDataFilePath(
tensorflow::profiler::StoredDataType::TRACE_LEVELDB, host_name);
auto trace_events_metadata_sstable_path =
session_snapshot.MakeHostDataFilePath(
tensorflow::profiler::StoredDataType::TRACE_EVENTS_METADATA_LEVELDB,
host_name);
auto trace_events_prefix_trie_sstable_path =
session_snapshot.MakeHostDataFilePath(
tensorflow::profiler::StoredDataType::
TRACE_EVENTS_PREFIX_TRIE_LEVELDB,
host_name);
if (!trace_events_sstable_path || !trace_events_metadata_sstable_path ||
!trace_events_prefix_trie_sstable_path) {
return tsl::errors::Unimplemented(
"streaming trace viewer hasn't been supported in Cloud AI");
}
if (!tsl::Env::Default()->FileExists(*sstable_path).ok()) {
if (!tsl::Env::Default()->FileExists(*trace_events_sstable_path).ok()) {
ProcessMegascaleDcn(xspace);
TraceEventsContainer trace_container;
ConvertXSpaceToTraceEventsContainer(host_name, *xspace,
&trace_container);
std::unique_ptr<tsl::WritableFile> file;
TF_RETURN_IF_ERROR(
tsl::Env::Default()->NewWritableFile(*sstable_path, &file));
TF_RETURN_IF_ERROR(trace_container.StoreAsLevelDbTable(std::move(file)));
std::unique_ptr<tsl::WritableFile> trace_events_file;
TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(
*trace_events_sstable_path, &trace_events_file));
std::unique_ptr<tsl::WritableFile> trace_events_metadata_file;
TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(
*trace_events_metadata_sstable_path, &trace_events_metadata_file));
std::unique_ptr<tsl::WritableFile> trace_events_prefix_trie_file;
TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(
*trace_events_prefix_trie_sstable_path,
&trace_events_prefix_trie_file));
TF_RETURN_IF_ERROR(trace_container.StoreAsLevelDbTables(
std::move(trace_events_file),
std::move(trace_events_metadata_file),
std::move(trace_events_prefix_trie_file)
));
}

auto visibility_filter = std::make_unique<TraceVisibilityFilter>(
tsl::profiler::MilliSpan(trace_option.start_time_ms,
trace_option.end_time_ms),
trace_option.resolution, profiler_trace_options);
TraceEventsContainer trace_container;
// Trace smaller than threshold will be disabled from streaming.
constexpr int64_t kDisableStreamingThreshold = 500000;
auto trace_events_filter =
CreateTraceEventsFilterFromTraceOptions(profiler_trace_options);
TraceEventsLevelDbFilePaths file_paths;
file_paths.trace_events_file_path = *sstable_path;
TF_RETURN_IF_ERROR(trace_container.LoadFromLevelDbTable(
file_paths, std::move(trace_events_filter),
std::move(visibility_filter), kDisableStreamingThreshold));
file_paths.trace_events_file_path = *trace_events_sstable_path;
file_paths.trace_events_metadata_file_path =
*trace_events_metadata_sstable_path;
file_paths.trace_events_prefix_trie_file_path =
*trace_events_prefix_trie_sstable_path;

TraceEventsContainer trace_container;
if (!trace_option.event_name.empty()) {
TF_RETURN_IF_ERROR(trace_container.ReadFullEventFromLevelDbTable(
*trace_events_metadata_sstable_path, *trace_events_sstable_path,
trace_option.event_name,
static_cast<uint64_t>(std::round(trace_option.start_time_ms * 1E9)),
static_cast<uint64_t>(std::round(trace_option.duration_ms * 1E9)),
trace_option.unique_id));
} else if (!trace_option.search_prefix.empty()) { // Search Events Request
if (tsl::Env::Default()
->FileExists(*trace_events_prefix_trie_sstable_path).ok()) {
auto trace_events_filter =
CreateTraceEventsFilterFromTraceOptions(profiler_trace_options);
TF_RETURN_IF_ERROR(trace_container.SearchInLevelDbTable(
file_paths,
trace_option.search_prefix, std::move(trace_events_filter)));
}
} else {
auto visibility_filter = std::make_unique<TraceVisibilityFilter>(
tsl::profiler::MilliSpan(trace_option.start_time_ms,
trace_option.end_time_ms),
trace_option.resolution, profiler_trace_options);
// Trace smaller than threshold will be disabled from streaming.
constexpr int64_t kDisableStreamingThreshold = 500000;
auto trace_events_filter =
CreateTraceEventsFilterFromTraceOptions(profiler_trace_options);
TF_RETURN_IF_ERROR(trace_container.LoadFromLevelDbTable(
file_paths, std::move(trace_events_filter),
std::move(visibility_filter), kDisableStreamingThreshold));
}
merged_trace_container.Merge(std::move(trace_container), host_id);
}

Expand Down