diff --git a/applications/flash_attention_v2/kernel/tile_scheduler.hpp b/applications/flash_attention_v2/kernel/tile_scheduler.hpp index 951d1784bf..85b27fe820 100644 --- a/applications/flash_attention_v2/kernel/tile_scheduler.hpp +++ b/applications/flash_attention_v2/kernel/tile_scheduler.hpp @@ -39,6 +39,11 @@ namespace cutlass::flash_attention { +struct XeFlashRowTile { + int bh; // = b * num_heads_q + h + int m_tile; // row tile index along Q +}; + namespace kernel { struct XeFlashIndividualTileScheduler { @@ -92,6 +97,62 @@ struct XeFlashIndividualTileScheduler { } }; +// Only schedule valid(non-empty) work groups. +struct XeFlashIndividualValidOnlyTileScheduler { + struct Params { + dim3 grid; + FastDivmod divmod_num_heads; + const XeFlashRowTile* tiles = nullptr; + int total_tiles = 0; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + XeFlashIndividualValidOnlyTileScheduler(Params const& params) : params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + TileShape const& tile_shape, + const XeFlashRowTile* tiles_dev, int total_tiles) { + + using namespace cute; + // problem_size = [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] + dim3 grid(size(ceil_div(shape<6>(problem_size), shape<1>(tile_shape))), total_tiles, 1); + + return Params{grid, {shape<1>(problem_size)}, tiles_dev, total_tiles}; + } + + template + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + + int t = BlockIdxY(); + XeFlashRowTile tile = params.tiles[t]; + int bidb = 0, bidh = 0; + params.divmod_num_heads(bidb, bidh, tile.bh); + return make_coord(BlockIdxX(), tile.m_tile, bidb, bidh); +} + + CUTLASS_DEVICE + XeFlashIndividualValidOnlyTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + struct XeFlashDecodeIndividualTileScheduler { struct Params { @@ -230,6 +291,7 @@ struct XeFlashPersistentTileScheduler { } // namespace kernel struct IndividualScheduler{}; + struct IndividualValidOnlyScheduler{}; struct PersistentScheduler{}; struct FlashDecodeIndividualScheduler{}; @@ -267,6 +329,15 @@ struct XeFlashPersistentTileScheduler { using Scheduler = kernel::XeFlashIndividualTileScheduler; }; + template + struct TileSchedulerSelector< + IndividualValidOnlyScheduler, + ArchTag, + cute::enable_if_t>> + { + using Scheduler = kernel::XeFlashIndividualValidOnlyTileScheduler; + }; + template struct TileSchedulerSelector< PersistentScheduler, diff --git a/applications/flash_attention_v2/kernel/xe_flash_attn_prefill.hpp b/applications/flash_attention_v2/kernel/xe_flash_attn_prefill.hpp index 88c1b2042c..f5fa1f729e 100644 --- a/applications/flash_attention_v2/kernel/xe_flash_attn_prefill.hpp +++ b/applications/flash_attention_v2/kernel/xe_flash_attn_prefill.hpp @@ -78,10 +78,12 @@ class FMHAPrefill { using SoftmaxParams = typename CollectiveSoftmaxEpilogue::Params; static_assert(cute::is_void_v or cute::is_same_v or + cute::is_same_v or cute::is_same_v, "Unsupported TileScheduler for Intel Xe."); using TileSchedulerTag = TileScheduler_; using TileScheduler = typename detail::TileSchedulerSelector::Scheduler; using TileSchedulerParams = typename TileScheduler::Params; + using XeFlashRowTile = cutlass::flash_attention::XeFlashRowTile; // Epilogue derived types using CollectiveEpilogue = CollectiveEpilogue_; @@ -147,6 +149,8 @@ class FMHAPrefill { SoftmaxArguments softmax{}; EpilogueArguments epilogue{}; KernelHardwareInfo hw_info{}; + const XeFlashRowTile* tiles = nullptr; + int total_tiles = 0; }; // Kernel entry point API @@ -166,11 +170,19 @@ class FMHAPrefill { // Convert to underlying arguments. In this case, a simple copy for the aliased type. static Params to_underlying_arguments(Arguments const &args, void *workspace) { (void)workspace; - return {args.mode, args.problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), - CollectiveSoftmaxEpilogue::to_underlying_arguments(args.softmax), - CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), - TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, TileShapeOutput{})}; + if constexpr (cute::is_same_v) { + return {args.mode, args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveSoftmaxEpilogue::to_underlying_arguments(args.softmax), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, TileShapeOutput{}, args.tiles, args.total_tiles)}; + } else { + return {args.mode, args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveSoftmaxEpilogue::to_underlying_arguments(args.softmax), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, TileShapeOutput{})}; + } } static bool can_implement(Arguments const &args) { @@ -224,6 +236,7 @@ class FMHAPrefill { TileScheduler tile_scheduler{params.scheduler}; + CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); // head_size_blk_idx, seq_len_blk_idx, batch_blk_idx, num_heads_blk_idx diff --git a/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_runner.hpp b/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_runner.hpp index 6d7b5e0401..06a9f9110b 100644 --- a/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_runner.hpp +++ b/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_runner.hpp @@ -59,13 +59,14 @@ struct Options { bool error; bool is_causal; bool varlen = false; + std::string varlen_dist; std::string scheduler; int batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo, iterations; float softmax_scale; Options() - : help(false), error(false), is_causal(false), varlen(false), batch(32), num_heads_q(16), num_heads_kv(16), seq_len_qo(512), head_size_qk(128), + : help(false), error(false), is_causal(false), varlen(false), varlen_dist("normal"), batch(32), num_heads_q(16), num_heads_kv(16), seq_len_qo(512), head_size_qk(128), seq_len_kv(512), head_size_vo(128), iterations(100), softmax_scale(1.f), scheduler("Individual") {} // Parses the command line @@ -86,6 +87,7 @@ struct Options { } cmd.get_cmd_line_argument("scheduler", scheduler, std::string("Individual")); + cmd.get_cmd_line_argument("varlen_dist", varlen_dist, std::string("normal")); cmd.get_cmd_line_argument("batch", batch, 32); cmd.get_cmd_line_argument("num_heads_q", num_heads_q, 16); @@ -107,7 +109,8 @@ struct Options { << " --help If specified, displays this usage statement\n\n" << " --is_causal Apply Causal Mask to the output of first Matmul\n" << " --varlen Enable variable sequence length\n" - << " --scheduler=\"Value\" Choose between Individual or Persistent Scheduler\n" + << " --varlen_dist=\"Value\" Choose normal or zipf for varlen init distribution\n" + << " --scheduler=\"Value\" Choose between Individual, ValidOnlyIndividual or Persistent Scheduler\n" << " --batch= Sets the Batch Size of the Multi-Head Self Attention module\n" << " --num_heads_q= Sets the Number of Attention Heads for Key-Value pair the Multi-Head Self Attention module\n" << " --num_heads_kv= Sets the Number of Attention Heads for Query input in the Multi-Head Self Attention module\n" @@ -147,6 +150,7 @@ template struct ExampleRunner { using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; using ProblemShapeType = typename FMHAPrefillKernel::ProblemShape; + using XeFlashRowTile = typename FMHAPrefillKernel::XeFlashRowTile; // // Data members @@ -401,9 +405,17 @@ template struct ExampleRunner { int max_seqlen_q = 0; int max_seqlen_kv = 0; + std::cout << "[VarLen Init] batches=" << num_batches + << " mean_q=" << get<3>(problem_size) + << " mean_kv=" << get<4>(problem_size) + << " AlignQ(elems)=" << AlignmentQ + << " AlignKV(elems)=" << AlignmentKV << "\n"; + for (int i = 0; i < num_batches; i++) { int seqlen_q = cutlass::round_up(generate_positive_int(dist_q, rng), AlignmentQ); - int seqlen_kv = cutlass::round_up(generate_positive_int(dist_kv, rng), AlignmentKV); + int seqlen_kv = seqlen_q; + + std::cout << " batch " << i << " : Q=" << seqlen_q << ", KV=" << seqlen_kv << "\n"; total_seqlen_q += seqlen_q; total_seqlen_kv += seqlen_kv; @@ -415,6 +427,13 @@ template struct ExampleRunner { cumulative_seqlen_kv.push_back(cumulative_seqlen_kv.back() + seqlen_kv); } + std::cout << " cumulative_seqlen_q : "; + for (auto v : cumulative_seqlen_q) std::cout << v << " "; + std::cout << "\n cumulative_seqlen_kv: "; + for (auto v : cumulative_seqlen_kv) std::cout << v << " "; + std::cout << "\n totals: Q=" << total_seqlen_q << " (max " << max_seqlen_q << ")" + << ", KV=" << total_seqlen_kv << " (max " << max_seqlen_kv << ")\n"; + ProblemShape problem_size_for_init = problem_size; get<0>(problem_size_for_init) = 1; get<3>(problem_size_for_init) = total_seqlen_q; @@ -434,6 +453,104 @@ template struct ExampleRunner { return cute::make_tuple(problem_size_for_init, problem_size_for_launch); } + + + template + auto initialize_varlen_zipf(const ProblemShape& problem_size) { + int num_batches = get<0>(problem_size); + + constexpr int Nmax = 16384; + constexpr double s = 1.098; + + std::mt19937 rng(0x202305151552ull); + std::uniform_real_distribution uni(0.0, 1.0); + + struct ZipfCDF { + std::vector cdf; + ZipfCDF() { + cdf.resize(Nmax); + double Z = 0.0; + for (int k = 1; k <= Nmax; ++k) Z += std::pow(static_cast(k), -s); + double acc = 0.0; + for (int k = 1; k <= Nmax; ++k) { + acc += std::pow(static_cast(k), -s) / Z; + cdf[k - 1] = acc; + } + cdf.back() = 1.0; + } + int sample(std::mt19937& gen, std::uniform_real_distribution& u) const { + double x = u(gen); + auto it = std::upper_bound(cdf.begin(), cdf.end(), x); + int idx = static_cast(it - cdf.begin()); + return idx + 1; + } + }; + static const ZipfCDF zipf; + + constexpr int cacheline_bytes = 64; + constexpr int AlignmentQ = cacheline_bytes / sizeof(ElementQ); + constexpr int AlignmentKV = cacheline_bytes / sizeof(ElementK); + + auto round_pos = [](int v, int align) { + return cutlass::round_up(v, align); + }; + + cumulative_seqlen_q = {0}; + cumulative_seqlen_kv = {0}; + + int total_seqlen_q = 0; + int total_seqlen_kv = 0; + int max_seqlen_q = 0; + int max_seqlen_kv = 0; + + std::cout << "[VarLen Init/Zipf] batches=" << num_batches + << " Zipf(Nmax=" << Nmax << ", s=" << s << ")" + << " AlignQ(elems)=" << AlignmentQ + << " AlignKV(elems)=" << AlignmentKV << "\n"; + + for (int i = 0; i < num_batches; i++) { + int raw_q = zipf.sample(rng, uni); + int raw_kv = raw_q; + + int seqlen_q = round_pos(raw_q, AlignmentQ); + int seqlen_kv = round_pos(raw_kv, AlignmentKV); + + std::cout << " batch " << i << " : Q=" << seqlen_q << ", KV=" << seqlen_kv << "\n"; + + total_seqlen_q += seqlen_q; + total_seqlen_kv += seqlen_kv; + + max_seqlen_q = std::max(max_seqlen_q, seqlen_q); + max_seqlen_kv = std::max(max_seqlen_kv, seqlen_kv); + + cumulative_seqlen_q.push_back(cumulative_seqlen_q.back() + seqlen_q); + cumulative_seqlen_kv.push_back(cumulative_seqlen_kv.back() + seqlen_kv); + } + + std::cout << " cumulative_seqlen_q : "; + for (auto v : cumulative_seqlen_q) std::cout << v << " "; + std::cout << "\n cumulative_seqlen_kv: "; + for (auto v : cumulative_seqlen_kv) std::cout << v << " "; + std::cout << "\n totals: Q=" << total_seqlen_q << " (max " << max_seqlen_q << ")" + << ", KV=" << total_seqlen_kv << " (max " << max_seqlen_kv << ")\n"; + + ProblemShape problem_size_for_init = problem_size; + get<0>(problem_size_for_init) = 1; + get<3>(problem_size_for_init) = total_seqlen_q; + get<4>(problem_size_for_init) = total_seqlen_kv; + + ProblemShapeType problem_size_for_launch; + get<3>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{max_seqlen_q}; + get<4>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{max_seqlen_kv}; + get<5>(problem_size_for_launch) = get<5>(problem_size); + get<6>(problem_size_for_launch) = get<6>(problem_size); + get<0>(problem_size_for_launch) = get<0>(problem_size); + get<1>(problem_size_for_launch) = get<1>(problem_size); + get<2>(problem_size_for_launch) = get<2>(problem_size); + + return cute::make_tuple(problem_size_for_init, problem_size_for_launch); + } + /// Initialize operands to be used in the GEMM and reference GEMM ProblemShapeType initialize(const Options &options) { auto problem_shape_in = @@ -443,9 +560,15 @@ template struct ExampleRunner { decltype(problem_shape_in) problem_size; if constexpr (isVarLen) { - auto [problem_shape_init, problem_shape_launch] = initialize_varlen(problem_shape_in); - problem_shape = problem_shape_launch; - problem_size = problem_shape_init; + if (options.varlen_dist == "zipf") { + auto [problem_shape_init, problem_shape_launch] = initialize_varlen_zipf(problem_shape_in); + problem_size = problem_shape_init; + problem_shape = problem_shape_launch; + } else { // "normal" (default) + auto [problem_shape_init, problem_shape_launch] = initialize_varlen(problem_shape_in); + problem_size = problem_shape_init; + problem_shape = problem_shape_launch; + } } else { problem_size = problem_shape_in; @@ -521,10 +644,42 @@ template struct ExampleRunner { EventManager::getInstance().addEvent(event); } + + // Build a compact list of row tiles for variable-length prefill. + // seqlen_q[b] = valid Q length for batch b (not max; the actual length for that sample) + inline int build_row_tile_list(std::vector& tiles, + const int* seqlen_q, + int batch, + int num_heads_q, + int M_TILE) { + tiles.clear(); + int total = 0; + for (int b = 0; b < batch; ++b) { + const int mt = (seqlen_q[b] + M_TILE - 1) / M_TILE; + for (int h = 0; h < num_heads_q; ++h) { + const int bh = b * num_heads_q + h; + for (int m = 0; m < mt; ++m) { + tiles.push_back({bh, m}); + } + } + total += mt * num_heads_q; + } + std::cout << "Total tile number: " << total << std::endl; + return total; + } + cutlass::Status run(const Options &options, const cutlass::KernelHardwareInfo &hw_info) { ProblemShapeType problem_size = initialize(options); + using TileShapeOut = typename FMHAPrefillKernel::TileShapeOutput; + const int batch = cutlass::get<0>(problem_size); + const int num_heads_q = cutlass::get<1>(problem_size); + const int M_TILE = cutlass::get<0>(TileShapeOut{}); + + std::vector seqlen_q_host(batch); + cutlass::device_memory::allocation tiles_dev; + typename FMHAPrefillKernel::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, @@ -533,6 +688,29 @@ template struct ExampleRunner { {block_O.get(), stride_O}, hw_info}; + arguments.tiles = nullptr; + arguments.total_tiles = 0; + + if (options.varlen && options.scheduler == "ValidOnlyIndividual") { + for (int b = 0; b < batch; ++b) { + seqlen_q_host[b] = cumulative_seqlen_q[b + 1] - cumulative_seqlen_q[b]; + } + + // Build compact (bh, m_tile) list + std::vector tiles_host; + int total_tiles = build_row_tile_list(tiles_host, seqlen_q_host.data(), batch, num_heads_q, M_TILE); + + if (total_tiles > 0) { + tiles_dev = cutlass::device_memory::allocation(total_tiles); + syclcompat::memcpy(tiles_dev.get(), tiles_host.data(), + sizeof(XeFlashRowTile) * total_tiles); + syclcompat::wait(); + } + + arguments.tiles = (total_tiles > 0) ? tiles_dev.get() : nullptr; + arguments.total_tiles = total_tiles; + } + // Define device-global scratch memory size_t workspace_size = FMHAPrefillKernel::get_workspace_size(arguments); cutlass::device_memory::allocation workspace(workspace_size); @@ -570,23 +748,66 @@ template struct ExampleRunner { run(params); } syclcompat::wait(); - // when seq_len_qo is not equal to seq_len_kv we use bottom up approach for the masking. - // Following changes will adjust the effective_seq_len_kv when masking applied for such cases - auto offset = cute::min(options.seq_len_qo, options.seq_len_kv); - auto discard_seq_coord = options.seq_len_qo - offset; - auto full_tile_offset = options.seq_len_kv - offset; - // offset + 1 is going to be ceil_div - auto effective_seq_len_kv = options.is_causal ? full_tile_offset + ((offset + 1) / 2.0): options.seq_len_kv; - auto effective_seq_len_qo = options.is_causal ? options.seq_len_qo - discard_seq_coord : options.seq_len_qo; + double cute_time = timer.seconds() / options.iterations; - double flops_qk = 2.0 * options.batch * options.num_heads_q * effective_seq_len_qo * effective_seq_len_kv * options.head_size_qk; - double flops_pv = 2.0 * options.batch * options.num_heads_q * effective_seq_len_qo * options.head_size_vo * effective_seq_len_kv; - double tflops = ((flops_qk + flops_pv) * 1e-12) / cute_time; - double gbps_qk = options.batch * (sizeof(ElementQ) * options.num_heads_q * effective_seq_len_qo * options.head_size_qk + - sizeof(ElementK) * options.num_heads_kv * effective_seq_len_kv * options.head_size_qk); - double gbps_pv = sizeof(ElementV) * options.batch * options.num_heads_kv * effective_seq_len_kv * options.head_size_vo + - sizeof(ElementOutput) * options.batch * options.num_heads_q * effective_seq_len_qo * options.head_size_vo; - double gbps = ((gbps_qk + gbps_pv) * 1e-9) / (cute_time); + double flops_qk = 0.0; + double flops_pv = 0.0; + double tflops = 0.0; + double gbps = 0.0; + + if (!options.varlen) { + // when seq_len_qo is not equal to seq_len_kv we use bottom up approach for the masking. + // Following changes will adjust the effective_seq_len_kv when masking applied for such cases + auto offset = cute::min(options.seq_len_qo, options.seq_len_kv); + auto discard_seq_coord = options.seq_len_qo - offset; + auto full_tile_offset = options.seq_len_kv - offset; + // offset + 1 is going to be ceil_div + auto effective_seq_len_kv = options.is_causal ? full_tile_offset + ((offset + 1) / 2.0): options.seq_len_kv; + auto effective_seq_len_qo = options.is_causal ? options.seq_len_qo - discard_seq_coord : options.seq_len_qo; + + flops_qk = 2.0 * options.batch * options.num_heads_q * effective_seq_len_qo * effective_seq_len_kv * options.head_size_qk; + flops_pv = 2.0 * options.batch * options.num_heads_q * effective_seq_len_qo * options.head_size_vo * effective_seq_len_kv; + tflops = ((flops_qk + flops_pv) * 1e-12) / cute_time; + double gbps_qk = options.batch * (sizeof(ElementQ) * options.num_heads_q * effective_seq_len_qo * options.head_size_qk + + sizeof(ElementK) * options.num_heads_kv * effective_seq_len_kv * options.head_size_qk); + double gbps_pv = sizeof(ElementV) * options.batch * options.num_heads_kv * effective_seq_len_kv * options.head_size_vo + + sizeof(ElementOutput) * options.batch * options.num_heads_q * effective_seq_len_qo * options.head_size_vo; + gbps = ((gbps_qk + gbps_pv) * 1e-9) / (cute_time); + } else { + long senlen_qo_sum = 0; + long seqlen_kv_sum = 0; + long effective_seq_len_qo_sum = 0; + double effective_seq_len_kv_sum = 0.0; + double effective_pairs_sum = 0.0; + + for (int b = 0; b < options.batch; ++b) { + int seqlen_qo = cumulative_seqlen_q[b + 1] - cumulative_seqlen_q[b]; + int seqlen_kv = cumulative_seqlen_kv[b + 1] - cumulative_seqlen_kv[b]; + senlen_qo_sum += seqlen_qo; + seqlen_kv_sum += seqlen_kv; + + auto offset = cute::min(seqlen_qo, seqlen_kv); + auto discard_seq_coord = seqlen_qo - offset; + auto full_tile_offset = seqlen_kv - offset; + // offset + 1 is going to be ceil_div + auto effective_seq_len_kv_b = options.is_causal ? full_tile_offset + ((offset + 1) / 2.0) : (double)seqlen_kv; + auto effective_seq_len_qo_b = options.is_causal ? (seqlen_qo - discard_seq_coord) : seqlen_qo; + + effective_seq_len_qo_sum += effective_seq_len_qo_b; + effective_seq_len_kv_sum += effective_seq_len_kv_b; + effective_pairs_sum += effective_seq_len_qo_b * effective_seq_len_kv_b; + } + + flops_qk = 2.0 * options.num_heads_q * options.head_size_qk * effective_pairs_sum; + flops_pv = 2.0 * options.num_heads_q * options.head_size_vo * effective_pairs_sum; + tflops = ((flops_qk + flops_pv) * 1e-12) / cute_time; + + double gbps_qk = sizeof(ElementQ) * options.num_heads_q * effective_seq_len_qo_sum * options.head_size_qk + + sizeof(ElementK) * options.num_heads_kv * effective_seq_len_kv_sum * options.head_size_qk; + double gbps_pv = sizeof(ElementV) * options.num_heads_kv * effective_seq_len_kv_sum * options.head_size_vo + + sizeof(ElementOutput) * options.num_heads_q * effective_seq_len_qo_sum * options.head_size_vo; + gbps = ((gbps_qk + gbps_pv) * 1e-9) / cute_time; + } std::cout << "Batch: " << options.batch << "\tNumHeads_q: " << options.num_heads_q << "\tNumHeads_kv: " << options.num_heads_kv << "\tSeq Length QO: " << options.seq_len_qo << "\tSeq Length KV: " << options.seq_len_kv << "\tHead Size QK: " << options.head_size_qk << "\tHead Size VO: " << options.head_size_vo << "\tCausal Mask: " << (options.is_causal ? "true" : "false") << "\tVariable Sequence Length: " << (options.varlen ? "true" : "false") @@ -659,10 +880,15 @@ template (options); + if (options.scheduler == "ValidOnlyIndividual") { + return options.varlen ? run(options) + : run(options); + } else if (options.scheduler == "Persistent") { + return options.varlen ? run(options) + : run(options); } else { - return run(options); + return options.varlen ? run(options) + : run(options); } } };