From 36306b5406b276891e78ae1238b823b88ac88991 Mon Sep 17 00:00:00 2001 From: Profiler Team Date: Thu, 2 Oct 2025 22:39:36 -0700 Subject: [PATCH] Support For Roofline Analysis Of Pallas Kernels PiperOrigin-RevId: 814542103 --- xprof/convert/BUILD | 1 + xprof/convert/xplane_to_op_stats.cc | 110 +++++++++ xprof/utils/BUILD | 24 +- xprof/utils/hlo_module_map.cc | 343 ++++++++++++++++++++++++++++ xprof/utils/hlo_module_map.h | 13 ++ 5 files changed, 488 insertions(+), 3 deletions(-) diff --git a/xprof/convert/BUILD b/xprof/convert/BUILD index 87b0426a1..601f5545a 100644 --- a/xprof/convert/BUILD +++ b/xprof/convert/BUILD @@ -884,6 +884,7 @@ cc_library( ":xplane_to_tf_functions", ":xprof_thread_pool_executor", "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", diff --git a/xprof/convert/xplane_to_op_stats.cc b/xprof/convert/xplane_to_op_stats.cc index d644a4135..680087e34 100644 --- a/xprof/convert/xplane_to_op_stats.cc +++ b/xprof/convert/xplane_to_op_stats.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include #include +#include #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_set.h" @@ -66,6 +67,7 @@ limitations under the License. #include "xprof/utils/kernel_stats_utils.h" #include "xprof/utils/op_utils.h" #include "xprof/utils/xprof_gpu_cost_analysis_types.h" +#include "absl/container/flat_hash_map.h" namespace tensorflow { namespace profiler { @@ -77,6 +79,7 @@ using ::tsl::profiler::kGpuPlanePrefix; using ::tsl::profiler::kTpuPlanePrefix; using tsl::profiler::Timespan; using ::tsl::profiler::XPlaneBuilder; +using tsl::profiler::kXlaOpLineName; std::string Hostname(const XSpace& space) { if (space.hostnames().empty()) return "localhost"; @@ -380,6 +383,113 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space, .has_value()) { op_metrics_db = ConvertTpuDeviceTraceXPlaneToOpMetricsDb(*device_plane); + XPlaneVisitor visitorSecond = + tsl::profiler::CreateTfXPlaneVisitor(device_plane); + std::queue custom_call_blocks; + visitorSecond.ForEachLine([&](const XLineVisitor& line) { + if (line.Name() == "XLA TraceMe") { + line.ForEachEvent([&](const XEventVisitor& event) { + if (absl::StartsWith(event.Name(), "__block_")){ + custom_call_blocks.push(event); + } + }); + } + }); + absl::flat_hash_map> + custom_call_to_block_count; + + XPlaneVisitor xlaEvents = + tsl::profiler::CreateTfXPlaneVisitor(device_plane); + xlaEvents.ForEachLine([&](const XLineVisitor& line) { + if (line.Name() == kXlaOpLineName) { + line.ForEachEvent([&](const XEventVisitor& event) { + tsl::profiler::Timespan custom_call_timespan = + GetDeviceEventTimespan(event); + bool custom_call = false; + event.Metadata().ForEachStat([&] + (const XStatVisitor& stat) { + if (stat.Type().has_value()) { + switch (static_cast(*stat.Type())) { + case StatType::kHloCategory: + custom_call = + (stat.StrOrRefValue() == "custom-call"); + break; + default: + break; + } + } + }); + if (custom_call){ + while (!custom_call_blocks.empty()){ + tsl::profiler::Timespan ccall_blck_timespan = + GetDeviceEventTimespan(custom_call_blocks.front()); + if ((custom_call_timespan.begin_ps() <= + ccall_blck_timespan.begin_ps()) && + (ccall_blck_timespan.end_ps() <= + custom_call_timespan.end_ps()) + ){ + custom_call_to_block_count[event.DisplayName()] + [std::string(custom_call_blocks.front(). + Name())] += 1; + custom_call_blocks.pop(); + }else{ + break; + } + } + } + }); + } + }); + for (OpMetrics& op_metrics : + *op_metrics_db.mutable_metrics_db()) { + const HloInstructionWrapper* instr_wrapper = + GetHloInstruction(hlo_module_map, + op_metrics.hlo_module_id(), op_metrics.name()); + if (instr_wrapper != nullptr) { + if (instr_wrapper->Category() == "custom-call"){ + uint64 total_flops = 0; + uint64 total_bytes_accessed = 0; + bool has_block_costs = + custom_call_to_block_count.contains(op_metrics.name()); + if (has_block_costs){ + for (auto&[block_name, occurrence] : + custom_call_to_block_count[op_metrics.name()]){ + auto block_cost_pair = instr_wrapper-> + GetCustomCallBlockCosts(block_name); + if (block_cost_pair.has_value()){ + OpMetrics* child_metric = + op_metrics.mutable_children()->add_metrics_db(); + child_metric->set_name(block_name); + child_metric->set_occurrences(occurrence); + child_metric->set_flops( + block_cost_pair.value().first); + child_metric->set_model_flops( + block_cost_pair.value().first); + child_metric->set_bytes_accessed( + block_cost_pair.value().second); + total_flops += + (occurrence*block_cost_pair.value().first); + total_bytes_accessed += + (occurrence*block_cost_pair.value().second); + }else{ + LOG(WARNING) << "No Costs Found for : " << block_name; + } + } + if (instr_wrapper->FusedChildren().empty()){ + LOG(INFO) << "Custom - Call Name: " + << op_metrics.name() << " Total Flops: " + << total_flops + << " Total Bytes Accessed: " << + total_bytes_accessed; + op_metrics.set_flops(total_flops); + op_metrics.set_bytes_accessed(total_bytes_accessed); + op_metrics.set_model_flops(total_flops); + } + } + } + } + } UpdateOpMetricsDbFromHloModuleMap(op_metrics_db, hlo_module_map); } } diff --git a/xprof/utils/BUILD b/xprof/utils/BUILD index f3faba97c..29bb6aaf9 100644 --- a/xprof/utils/BUILD +++ b/xprof/utils/BUILD @@ -398,23 +398,41 @@ cc_library( srcs = ["hlo_module_map.cc"], hdrs = ["hlo_module_map.h"], deps = [ + ":backend_configs_proto_cc", ":hlo_cost_analysis_wrapper", ":hlo_module_utils", ":hlo_proto_map", ":hlo_proto_to_module", ":performance_info_wrapper", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@tsl//tsl/platform:path", - "@tsl//tsl/profiler/lib:traceme_encode", - "@tsl//tsl/profiler/protobuf:xplane_proto_cc", + "//third_party/llvm/llvm-project/mlir:ArithDialect", + "//third_party/llvm/llvm-project/mlir:DataLayoutInterfaces", + "//third_party/llvm/llvm-project/mlir:FuncDialect", + "//third_party/llvm/llvm-project/mlir:IR", + "//third_party/llvm/llvm-project/mlir:LLVMDialect", + "//third_party/llvm/llvm-project/mlir:MathDialect", + "//third_party/llvm/llvm-project/mlir:MemRefDialect", + "//third_party/llvm/llvm-project/mlir:Pass", + "//third_party/llvm/llvm-project/mlir:SCFDialect", + "//third_party/llvm/llvm-project/mlir:Support", + "//third_party/llvm/llvm-project/mlir:VectorDialect", + "//third_party/protobuf/json", + "//third_party/protobuf/util:json_util", + # "//third_party/py/jax/jaxlib/mosaic:tpu_dialect", "@xla//xla/hlo/ir:hlo", + "@xla//xla/mlir_hlo", + "@xla//xla/pjrt:mlir_to_hlo", "@xla//xla/service:hlo_cost_analysis", "@xla//xla/service:hlo_proto_cc", "@xla//xla/tsl/profiler/convert:xla_op_utils", + "@tsl//tsl/platform:path", + "@tsl//tsl/profiler/lib:traceme_encode", + "@tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) diff --git a/xprof/utils/hlo_module_map.cc b/xprof/utils/hlo_module_map.cc index 26e58f4b1..3ae322682 100644 --- a/xprof/utils/hlo_module_map.cc +++ b/xprof/utils/hlo_module_map.cc @@ -20,29 +20,369 @@ limitations under the License. #include #include #include +#include +#include +#include "absl/base/no_destructor.h" #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" +#include "google/protobuf/json/json.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/tsl/profiler/convert/xla_op_utils.h" #include "tsl/platform/path.h" #include "tsl/profiler/lib/traceme_encode.h" #include "xprof/utils/hlo_cost_analysis_wrapper.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" #include "xprof/utils/hlo_module_utils.h" #include "xprof/utils/hlo_proto_map.h" #include "xprof/utils/hlo_proto_to_module.h" #include "xprof/utils/performance_info_wrapper.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinOps.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OwningOpRef.h" + +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Value.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinAttributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinTypeInterfaces.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Func/IR/FuncOps.h" +#include "xprof/utils/backend_configs.pb.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassManager.h" +// #include "third_party/py/jax/jaxlib/mosaic/dialect/tpu/transforms/serde.h" +// #include "third_party/py/jax/jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/IR/SCF.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Math/IR/Math.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Arith/IR/Arith.h" +#include "google/protobuf/util/json_util.h" + + +namespace CustomCallCostEstimator { + +// Represents the computational cost of an operation. +struct OperationCost { + uint64_t flops = 0; + uint64_t bytes_consumed = 0; +}; + +// Base class for operation cost estimators. +class OperationCostEstimator { + public: + virtual ~OperationCostEstimator() = default; + virtual OperationCost Estimate(mlir::Operation* op) const = 0; +}; + + // Estimator for element-wise operations. +class ElementWiseOpEstimator : public OperationCostEstimator { + public: + OperationCost Estimate(mlir::Operation* op) const override { + OperationCost cost; + if (op->getNumResults() > 0) { + mlir::Value result = op->getResult(0); + mlir::Type result_type = result.getType(); + const auto& data_layout = mlir::DataLayout::closest(op); + cost.bytes_consumed = data_layout.getTypeSize(result_type); + if (auto shaped_type = mlir::dyn_cast(result_type)) { + if (shaped_type.hasStaticShape()) { + cost.flops = shaped_type.getNumElements(); + } + // For dynamic shapes, we cannot calculate flops or memory statically. + } else { + cost.flops = 1; // Scalar type + } + } + return cost; + } +}; + + // Estimator for memory-only operations like constant or load. +class MemoryOnlyOpEstimator : public OperationCostEstimator { + public: + OperationCost Estimate(mlir::Operation* op) const override { + OperationCost cost; + if (op->getNumResults() > 0) { + mlir::Value result = op->getResult(0); + const auto& data_layout = mlir::DataLayout::closest(op); + cost.bytes_consumed = data_layout.getTypeSize(result.getType()); + } + return cost; + } +}; + +// Estimator for matrix multiplication operations. +class MatmulOpEstimator : public OperationCostEstimator { + public: + OperationCost Estimate(mlir::Operation* op) const override { + std::string op_str; + OperationCost cost; + // auto matmul_op = mlir::dyn_cast(op); + + // if (!matmul_op) { + // LOG(WARNING) << "Matmul Casting Failed"; + // return cost; + // } + + // if (op->getNumOperands() < 2 || op->getNumResults() != 1) { + // return cost; + // } + + // auto lhs_st = mlir::dyn_cast + // (op->getOperand(0).getType()); + // auto rhs_st = mlir::dyn_cast + // (op->getOperand(1).getType()); + // auto result_st = + // mlir::dyn_cast(op->getResult(0).getType()); + + // if (!lhs_st || !lhs_st.hasStaticShape() || !rhs_st || + // !rhs_st.hasStaticShape() || !result_st || + // !result_st.hasStaticShape()) { + // return cost; // Need static shapes for cost estimation. + // } + + // uint64_t contracting_size_prod = 1; + // auto dims = + // op->getAttrOfType( + // "dimension_numbers"); + // if (dims) { + // LOG(INFO) << "Found dimension_numbers attribute."; + // auto contracting_dims = dims.getLhsContractingDims(); + // if (contracting_dims.empty()) { + // LOG(WARNING) << "dimension_numbers + // found but has no contracting dims."; + // return cost; + // } + // for (int64_t d : contracting_dims) { + // if (d >= lhs_st.getRank()) return cost; // Invalid dim index. + // contracting_size_prod *= lhs_st.getShape()[d]; + // } + // } else { + // LOG(INFO) << "No dimension_numbers attribute, assuming contraction on " + // "last dim of LHS."; + // if (lhs_st.getRank() < 1) return cost; + // // Fallback for standard matmul: contract last dim of LHS. + // contracting_size_prod = lhs_st.getShape()[lhs_st.getRank() - 1]; + // } + + // cost.flops = 2 * result_st.getNumElements() * contracting_size_prod; + // const auto& data_layout = mlir::DataLayout::closest(op); + // cost.bytes_consumed = + // data_layout.getTypeSize(op->getResult(0).getType()); + // for (mlir::Value operand : op->getOperands()) { + // cost.bytes_consumed += data_layout.getTypeSize(operand.getType()); + // } + return cost; + } +}; + +// Estimator for multi-reduction operations. +class MultiReductionOpEstimator : public OperationCostEstimator { + public: + OperationCost Estimate(mlir::Operation* op) const override { + OperationCost cost; + if (op->getNumOperands() > 0) { + mlir::Value input = op->getOperand(0); + if (auto shaped_type = + mlir::dyn_cast(input.getType())) { + if (shaped_type.hasStaticShape()) { + cost.flops = shaped_type.getNumElements(); + } + } + } + if (op->getNumResults() > 0) { + mlir::Value result = op->getResult(0); + const auto& data_layout = mlir::DataLayout::closest(op); + cost.bytes_consumed = data_layout.getTypeSize(result.getType()); + } + return cost; + } +}; + +// Estimator for store operations like vector.store or tpu.store. +class StoreOpEstimator : public OperationCostEstimator { + public: + OperationCost Estimate(mlir::Operation* op) const override { + OperationCost cost; + if (op->getNumOperands() > 0) { + mlir::Value input = op->getOperand(0); + const auto& data_layout = mlir::DataLayout::closest(op); + cost.bytes_consumed = data_layout.getTypeSize(input.getType()); + } + return cost; + } +}; + +// A singleton class to manage and dispatch cost estimations for MLIR ops. +class CostModel { + public: + static const CostModel& GetInstance() { + static absl::NoDestructor instance; + return *instance; + } + + OperationCost GetOperationCost(mlir::Operation* op) const { + auto it = estimators_.find(op->getName().getStringRef()); + if (it != estimators_.end()) { + return it->second->Estimate(op); + } + // Return zero cost for unknown or no-cost ops like control flow, yield etc. + return {}; + } + + private: + friend class absl::NoDestructor; + CostModel() { + // Register estimators for different operation kinds. + // NOLINTBEGIN + RegisterEstimator( + {"arith.cmpi", "arith.extui", + "vector.broadcast", "arith.muli", // NOLINT + "arith.index_cast", "arith.maximumf", + "arith.subf", "math.exp", + "arith.addf", "arith.divf", + "arith.cmpf", "arith.select", + "arith.mulf", "arith.truncf", + "arith.extsi", "arith.trunci", "tpu.iota"}); + // NOLINTEND + RegisterEstimator( + {"arith.constant", "vector.load", "tpu.load", + "tpu.bitcast", "vector.shape_cast"}); + + RegisterEstimator({"tpu.matmul"}); + + RegisterEstimator( + {"vector.multi_reduction"}); + RegisterEstimator({"tpu.store", "vector.store"}); + } + + template + void RegisterEstimator(std::initializer_list op_names) { + auto estimator = std::make_unique(); + for (const auto& op_name : op_names) { + estimators_[op_name] = estimator.get(); + } + owned_estimators_.push_back(std::move(estimator)); + } + + absl::flat_hash_map + estimators_; + std::vector> owned_estimators_; +}; + +void calculateOperationCost(mlir::Operation* op, + uint64_t& block_bytes_consumed, + uint64_t& block_flops) { + const auto& cost_model = CostModel::GetInstance(); + OperationCost cost = cost_model.GetOperationCost(op); + block_bytes_consumed += cost.bytes_consumed; + block_flops += cost.flops; +} + +void calculateCustomCallCost(const xla::HloInstruction& hlo_instruction, + absl::flat_hash_map>& custom_call_block_costs) { + if (!hlo_instruction.has_backend_config()) { + LOG(INFO) << "Backend config not found For Custom Call " + << hlo_instruction.name(); + return; + } + google::protobuf::json::ParseOptions options; + options.ignore_unknown_fields = true; + xprof::BackendConfig config; + + auto status = google::protobuf::util::JsonStringToMessage( + hlo_instruction.raw_backend_config_string(), &config, options); + if ( (!status.ok()) || (!config.has_custom_call_config()) ) { + LOG(INFO) << "Custom call config not found " + << hlo_instruction.name(); + } + xprof::CustomCallConfig custom_call_config = config.custom_call_config(); + // OLD CODE + mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); + context.allowUnregisteredDialects(true); + context.loadDialect(); + absl::StatusOr> mlir_op_ref + = xla::ParseMlirModuleString( + static_cast(custom_call_config.body()), + context); + + if (!mlir_op_ref.ok()) { + LOG(INFO) << "Failed to parse MLIR module for custom call " + << hlo_instruction.name() << " with status: " + << mlir_op_ref.status(); + return; + } + bool verify = false; + mlir::OwningOpRef& module_op = mlir_op_ref.value(); + auto manager = + mlir::PassManager::on(module_op->getContext()); + manager.enableVerifier(verify); + // manager.addPass(mlir::tpu::createMosaicSerdePass( + // mlir::tpu::MosaicSerdePassOptions{.serialize = false})); + if (mlir::failed(manager.run(module_op.get()))) { + LOG(WARNING) << "Skipping MosaicSerdePass for custom call " + << hlo_instruction.name(); + } + mlir::Operation* module_operation = module_op->getOperation(); + std::deque queue{&(module_operation->getRegion(0))}; + int64_t block_counter = 0; + uint64_t block_bytes_consumed = 0; + uint64_t block_flops = 0; + while (!queue.empty()) { + mlir::Region* region = queue.front(); + queue.pop_front(); + for (mlir::Block& block : *region) { + if (block.empty()) { + continue; + } + for (mlir::Operation& op : block) { + calculateOperationCost(&op, block_bytes_consumed, block_flops); + } + auto block_name = absl::StrCat("__block_", block_counter++); + custom_call_block_costs[block_name] = + std::make_pair(block_flops, block_bytes_consumed); + block_bytes_consumed = block_flops = 0; + for (mlir::Operation& op : block.without_terminator()) { + for (mlir::Region& region : op.getRegions()) { + if (!region.empty()) { + queue.push_back(®ion); + } + } + } + } + } +} + +} // namespace CustomCallCostEstimator + namespace tensorflow { namespace profiler { +void HloInstructionWrapper::SetCustomCallBlockCosts(){ + CustomCallCostEstimator::calculateCustomCallCost + (*instr_, custom_call_block_costs_); +} + HloInstructionWrapper::HloInstructionWrapper( const xla::HloInstruction* instr, const HloCostAnalysisWrapper* cost_analysis) @@ -60,6 +400,9 @@ HloInstructionWrapper::HloInstructionWrapper( instr); ProcessXlaCostAnalysis(cost_analysis->GetXlaCostAnalysis()); } + if (category_ == "custom-call") { + SetCustomCallBlockCosts(); + } } HloModuleWrapper::HloModuleWrapper( diff --git a/xprof/utils/hlo_module_map.h b/xprof/utils/hlo_module_map.h index bf7982696..d6ac5984a 100644 --- a/xprof/utils/hlo_module_map.h +++ b/xprof/utils/hlo_module_map.h @@ -140,9 +140,22 @@ class HloInstructionWrapper : public HloInstructionInterface { return performance_info_wrapper_.get(); } + std::optional> + GetCustomCallBlockCosts(std::string_view custom_call_name) const { + if (custom_call_block_costs_.contains(custom_call_name)) { + return custom_call_block_costs_.at(custom_call_name); + }else { + return std::nullopt; + } + } + + void SetCustomCallBlockCosts(); + private: const xla::HloInstruction* instr_; std::vector fused_children_; + absl::flat_hash_map> + custom_call_block_costs_; std::string op_full_name_; std::string tf_op_name_; size_t flops_ = 0;