Skip to content

Commit 4f5108b

Browse files
committed
Block mapjoin refactor
- Transition to block lists for block map join's right input (BlockMapJoinCore/BlockMapJoinIndex/BlockStorage nodes affected) * Right key columns/key drops are now addressed by name - Optimizers which fuses ListToBlocks over ListFromBlocks and vice versa commit_hash:bdcee24edd1e5298c038716d4d205858a199d0db
1 parent 6157adb commit 4f5108b

File tree

13 files changed

+372
-267
lines changed

13 files changed

+372
-267
lines changed

yql/essentials/core/peephole_opt/yql_opt_peephole_physical.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,25 @@ TExprNode::TPtr OptimizeWideToBlocks(const TExprNode::TPtr& node, TExprContext&
191191
return node;
192192
}
193193

194+
TExprNode::TPtr OptimizeListToBlocks(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& types) {
195+
Y_UNUSED(ctx);
196+
Y_UNUSED(types);
197+
const auto& input = node->Head();
198+
if (input.IsCallable("ListFromBlocks")) {
199+
// TODO: actually (ListToBlocks (ListFromBlocks a)) is (ReplicateScalars a),
200+
// but there is no sources of scalars in block lists at the moment, so just ensure it
201+
auto inputItemType = input.Head().GetTypeAnn()->Cast<TListExprType>()->GetItemType()->Cast<TStructExprType>();
202+
for (auto item : inputItemType->GetItems()) {
203+
YQL_ENSURE(item->GetItemType()->IsBlock() || (item->GetItemType()->IsScalar() && item->GetName() == BlockLengthColumnName));
204+
}
205+
206+
YQL_CLOG(DEBUG, CorePeepHole) << "Drop " << node->Content() << " over " << input.Content();
207+
return input.HeadPtr();
208+
}
209+
210+
return node;
211+
}
212+
194213
TExprNode::TPtr OptimizeWideFromBlocks(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& types) {
195214
Y_UNUSED(types);
196215
const auto& input = node->Head();
@@ -207,6 +226,18 @@ TExprNode::TPtr OptimizeWideFromBlocks(const TExprNode::TPtr& node, TExprContext
207226
return node;
208227
}
209228

229+
TExprNode::TPtr OptimizeListFromBlocks(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& types) {
230+
Y_UNUSED(ctx);
231+
Y_UNUSED(types);
232+
const auto& input = node->Head();
233+
if (input.IsCallable("ListToBlocks")) {
234+
YQL_CLOG(DEBUG, CorePeepHole) << "Drop " << node->Content() << " over " << input.Content();
235+
return input.HeadPtr();
236+
}
237+
238+
return node;
239+
}
240+
210241
TExprNode::TPtr OptimizeWideTakeSkipBlocks(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& types) {
211242
Y_UNUSED(types);
212243
const auto& input = node->HeadPtr();
@@ -8995,7 +9026,9 @@ struct TPeepHoleRules {
89959026
{"NarrowMap", &OptimizeWideMapBlocks},
89969027
{"WideFilter", &OptimizeWideFilterBlocks},
89979028
{"WideToBlocks", &OptimizeWideToBlocks},
9029+
{"ListToBlocks", &OptimizeListToBlocks},
89989030
{"WideFromBlocks", &OptimizeWideFromBlocks},
9031+
{"ListFromBlocks", &OptimizeListFromBlocks},
89999032
{"WideTakeBlocks", &OptimizeWideTakeSkipBlocks},
90009033
{"WideSkipBlocks", &OptimizeWideTakeSkipBlocks},
90019034
{"BlockCompress", &OptimizeBlockCompress},

yql/essentials/core/type_ann/type_ann_join.cpp

Lines changed: 91 additions & 84 deletions
Large diffs are not rendered by default.

yql/essentials/minikql/comp_nodes/mkql_block_map_join.cpp

Lines changed: 87 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -315,57 +315,66 @@ class TBlockStorage : public TComputationValue<TBlockStorage> {
315315

316316
TBlockStorage(
317317
TMemoryUsageInfo* memInfo,
318-
const TVector<TType*>& itemTypes,
319-
NUdf::TUnboxedValue stream,
318+
const TVector<TType*>& types,
319+
size_t blockLengthIndex,
320+
NUdf::TUnboxedValue listIter,
320321
TStringBuf resourceTag,
321322
arrow::MemoryPool* pool
322323
)
323324
: TBase(memInfo)
324-
, InputsDescr_(ToValueDescr(itemTypes))
325-
, Stream_(std::move(stream))
326-
, Inputs_(itemTypes.size())
325+
, InputsDescr_(ToValueDescr(types))
326+
, Readers_(types.size())
327+
, Hashers_(types.size())
328+
, Comparators_(types.size())
329+
, Trimmers_(types.size())
330+
, ListIter_(std::move(listIter))
331+
, BlockLengthIndex_(blockLengthIndex)
327332
, ResourceTag_(std::move(resourceTag))
328333
{
329334
TBlockTypeHelper helper;
330-
for (size_t i = 0; i < itemTypes.size(); i++) {
331-
TType* blockItemType = AS_TYPE(TBlockType, itemTypes[i])->GetItemType();
332-
Readers_.push_back(MakeBlockReader(TTypeInfoHelper(), blockItemType));
333-
Hashers_.push_back(helper.MakeHasher(blockItemType));
334-
Comparators_.push_back(helper.MakeComparator(blockItemType));
335-
Trimmers_.push_back(MakeBlockTrimmer(TTypeInfoHelper(), blockItemType, pool));
335+
for (size_t i = 0; i < types.size(); i++) {
336+
if (i == BlockLengthIndex_) {
337+
continue;
338+
}
339+
340+
TType* blockItemType = AS_TYPE(TBlockType, types[i])->GetItemType();
341+
Readers_[i] = MakeBlockReader(TTypeInfoHelper(), blockItemType);
342+
Hashers_[i] = helper.MakeHasher(blockItemType);
343+
Comparators_[i] = helper.MakeComparator(blockItemType);
344+
Trimmers_[i] = MakeBlockTrimmer(TTypeInfoHelper(), blockItemType, pool);
336345
}
337346
}
338347

339-
NUdf::EFetchStatus FetchStream() {
340-
switch (Stream_.WideFetch(Inputs_.data(), Inputs_.size())) {
341-
case NUdf::EFetchStatus::Yield:
342-
return NUdf::EFetchStatus::Yield;
343-
case NUdf::EFetchStatus::Finish:
348+
bool FetchNextBlock() {
349+
if (!ListIter_.Next(Block_)) {
344350
IsFinished_ = true;
345-
return NUdf::EFetchStatus::Finish;
346-
case NUdf::EFetchStatus::Ok:
347-
break;
351+
return false;
348352
}
353+
BlockItems_ = Block_.GetElements();
349354

350355
Y_ENSURE(!IsFinished_, "Got data on finished stream");
351356

352-
std::vector<arrow::Datum> blockColumns;
353-
for (size_t i = 0; i < Inputs_.size() - 1; i++) {
354-
auto& datum = TArrowBlock::From(Inputs_[i]).GetDatum();
357+
std::vector<arrow::Datum> blockColumns(Readers_.size());
358+
for (size_t i = 0; i < Readers_.size(); i++) {
359+
if (i == BlockLengthIndex_) {
360+
continue;
361+
}
362+
363+
auto& datum = TArrowBlock::From(BlockItems_[i]).GetDatum();
355364
ARROW_DEBUG_CHECK_DATUM_TYPES(InputsDescr_[i], datum.descr());
356365
if (datum.is_scalar()) {
357-
blockColumns.push_back(datum);
366+
blockColumns[i] = datum;
358367
} else {
359368
MKQL_ENSURE(datum.is_array(), "Expecting array");
360-
blockColumns.push_back(Trimmers_[i]->Trim(datum.array()));
369+
blockColumns[i] = Trimmers_[i]->Trim(datum.array());
361370
}
362371
}
363372

364-
auto blockSize = ::GetBlockCount(Inputs_[Inputs_.size() - 1]);
373+
auto blockSize = ::GetBlockCount(BlockItems_[BlockLengthIndex_]);
365374
Data_.emplace_back(blockSize, std::move(blockColumns));
366375
RowCount_ += blockSize;
367376

368-
return NUdf::EFetchStatus::Ok;
377+
return true;
369378
}
370379

371380
const TBlock& GetBlock(size_t blockOffset) const {
@@ -392,11 +401,11 @@ class TBlockStorage : public TComputationValue<TBlockStorage> {
392401
}
393402

394403
TBlockItem GetItem(TRowEntry entry, ui32 columnIdx) const {
395-
Y_ENSURE(columnIdx < Inputs_.size() - 1);
396404
return GetItemFromBlock(GetBlock(entry.BlockOffset), columnIdx, entry.ItemOffset);
397405
}
398406

399407
TBlockItem GetItemFromBlock(const TBlock& block, ui32 columnIdx, size_t offset) const {
408+
Y_ENSURE(columnIdx < Readers_.size() && columnIdx != BlockLengthIndex_);
400409
Y_ENSURE(offset < block.Size);
401410
const auto& datum = block.Columns[columnIdx];
402411
if (datum.is_scalar()) {
@@ -447,8 +456,11 @@ class TBlockStorage : public TComputationValue<TBlockStorage> {
447456
size_t RowCount_ = 0;
448457
bool IsFinished_ = false;
449458

450-
NUdf::TUnboxedValue Stream_;
451-
TUnboxedValueVector Inputs_;
459+
NUdf::TUnboxedValue ListIter_;
460+
NUdf::TUnboxedValue Block_;
461+
const NUdf::TUnboxedValue* BlockItems_ = nullptr;
462+
463+
size_t BlockLengthIndex_ = 0;
452464

453465
const TStringBuf ResourceTag_;
454466
};
@@ -459,33 +471,44 @@ class TBlockStorageWrapper : public TMutableComputationNode<TBlockStorageWrapper
459471
public:
460472
TBlockStorageWrapper(
461473
TComputationMutables& mutables,
462-
TVector<TType*>&& itemTypes,
463-
IComputationNode* stream,
474+
TStructType* structType,
475+
IComputationNode* list,
464476
const TStringBuf& resourceTag
465477
)
466478
: TBaseComputation(mutables, EValueRepresentation::Boxed)
467-
, ItemTypes_(std::move(itemTypes))
468-
, Stream_(stream)
479+
, List_(list)
469480
, ResourceTag_(resourceTag)
470-
{}
481+
{
482+
for (size_t i = 0; i < structType->GetMembersCount(); i++) {
483+
if (structType->GetMemberName(i) == NYql::BlockLengthColumnName) {
484+
BlockLengthIndex_ = i;
485+
Types_.push_back(nullptr);
486+
continue;
487+
}
488+
Types_.push_back(structType->GetMemberType(i));
489+
}
490+
}
471491

472492
NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
473493
return ctx.HolderFactory.Create<TBlockStorage>(
474-
ItemTypes_,
475-
std::move(Stream_->GetValue(ctx)),
494+
Types_,
495+
BlockLengthIndex_,
496+
List_->GetValue(ctx).GetListIterator(),
476497
ResourceTag_,
477498
&ctx.ArrowMemoryPool
478499
);
479500
}
480501

481502
private:
482503
void RegisterDependencies() const final {
483-
DependsOn(Stream_);
504+
DependsOn(List_);
484505
}
485506

486507
private:
487-
const TVector<TType*> ItemTypes_;
488-
IComputationNode* const Stream_;
508+
TVector<TType*> Types_;
509+
size_t BlockLengthIndex_ = 0;
510+
511+
IComputationNode* const List_;
489512

490513
const TString ResourceTag_;
491514
};
@@ -992,17 +1015,13 @@ class TBlockMapJoinCoreWraper : public TMutableComputationNode<TBlockMapJoinCore
9921015
auto& indexState = *static_cast<TIndexState*>(RightBlockIndex_.GetResource());
9931016
auto& storageState = *static_cast<TStorageState*>(indexState.GetBlockStorage().GetResource());
9941017

995-
if (!RightStreamConsumed_) {
996-
auto fetchStatus = NUdf::EFetchStatus::Ok;
997-
while (fetchStatus != NUdf::EFetchStatus::Finish) {
998-
fetchStatus = storageState.FetchStream();
999-
if (fetchStatus == NUdf::EFetchStatus::Yield) {
1000-
return NUdf::EFetchStatus::Yield;
1001-
}
1018+
if (!RightInputConsumed_) {
1019+
while (storageState.FetchNextBlock()) {
1020+
// Fetch entire data from the right input
10021021
}
10031022

10041023
indexState.BuildIndex();
1005-
RightStreamConsumed_ = true;
1024+
RightInputConsumed_ = true;
10061025
}
10071026

10081027
auto* inputFields = joinState.GetRawInputFields();
@@ -1104,7 +1123,7 @@ class TBlockMapJoinCoreWraper : public TMutableComputationNode<TBlockMapJoinCore
11041123
const TVector<ui32>& LeftKeyColumns_;
11051124

11061125
const TVector<ui32>& RightIOMap_;
1107-
bool RightStreamConsumed_ = false;
1126+
bool RightInputConsumed_ = false;
11081127

11091128
std::array<typename TIndexState::TIterator, PrefetchBatchSize> LookupBatchIterators_;
11101129
ui32 LookupBatchCurrent_ = 0;
@@ -1204,16 +1223,12 @@ class TBlockCrossJoinCoreWraper : public TMutableComputationNode<TBlockCrossJoin
12041223
auto& joinState = *static_cast<TJoinState*>(JoinState_.AsBoxed().Get());
12051224
auto& storageState = *static_cast<TStorageState*>(RightBlockStorage_.GetResource());
12061225

1207-
if (!RightStreamConsumed_) {
1208-
auto fetchStatus = NUdf::EFetchStatus::Ok;
1209-
while (fetchStatus != NUdf::EFetchStatus::Finish) {
1210-
fetchStatus = storageState.FetchStream();
1211-
if (fetchStatus == NUdf::EFetchStatus::Yield) {
1212-
return NUdf::EFetchStatus::Yield;
1213-
}
1226+
if (!RightInputConsumed_) {
1227+
while (storageState.FetchNextBlock()) {
1228+
// Fetch entire data from the right input
12141229
}
12151230

1216-
RightStreamConsumed_ = true;
1231+
RightInputConsumed_ = true;
12171232
RightRowIterator_ = storageState.GetRowIterator();
12181233
}
12191234

@@ -1270,7 +1285,7 @@ class TBlockCrossJoinCoreWraper : public TMutableComputationNode<TBlockCrossJoin
12701285
NUdf::TUnboxedValue JoinState_;
12711286

12721287
const TVector<ui32>& RightIOMap_;
1273-
bool RightStreamConsumed_ = false;
1288+
bool RightInputConsumed_ = false;
12741289

12751290
TStorageState::TRowIterator RightRowIterator_;
12761291

@@ -1310,19 +1325,15 @@ IComputationNode* WrapBlockStorage(TCallable& callable, const TComputationNodeFa
13101325
MKQL_ENSURE(resultResourceType->GetTag().StartsWith(BlockStorageResourcePrefix), "Expected block storage resource");
13111326

13121327
const auto inputType = callable.GetInput(0).GetStaticType();
1313-
MKQL_ENSURE(inputType->IsStream(), "Expected WideStream as an input stream");
1314-
const auto inputStreamType = AS_TYPE(TStreamType, inputType);
1315-
MKQL_ENSURE(inputStreamType->GetItemType()->IsMulti(),
1316-
"Expected Multi as a left stream item type");
1317-
const auto inputStreamComponents = GetWideComponents(inputStreamType);
1318-
MKQL_ENSURE(inputStreamComponents.size() > 0, "Expected at least one column");
1319-
TVector<TType*> inputStreamItems(inputStreamComponents.cbegin(), inputStreamComponents.cend());
1328+
MKQL_ENSURE(inputType->IsList(), "Expected List as an input stream");
1329+
const auto inputItemType = AS_TYPE(TListType, inputType)->GetItemType();;
1330+
MKQL_ENSURE(inputItemType->IsStruct(), "Expected Struct as a list item type");
13201331

1321-
const auto inputStream = LocateNode(ctx.NodeLocator, callable, 0);
1332+
const auto list = LocateNode(ctx.NodeLocator, callable, 0);
13221333
return new TBlockStorageWrapper(
13231334
ctx.Mutables,
1324-
std::move(inputStreamItems),
1325-
inputStream,
1335+
AS_TYPE(TStructType, inputItemType),
1336+
list,
13261337
resultResourceType->GetTag()
13271338
);
13281339
}
@@ -1341,9 +1352,9 @@ IComputationNode* WrapBlockMapJoinIndex(TCallable& callable, const TComputationN
13411352
MKQL_ENSURE(inputResourceType->GetTag().StartsWith(BlockStorageResourcePrefix), "Expected block storage resource");
13421353

13431354
auto origInputItemType = AS_VALUE(TTypeType, callable.GetInput(1));
1344-
MKQL_ENSURE(origInputItemType->IsMulti(), "Expected Multi as an input item type");
1345-
const auto streamComponents = AS_TYPE(TMultiType, origInputItemType)->GetElements();
1346-
MKQL_ENSURE(streamComponents.size() > 0, "Expected at least one column");
1355+
MKQL_ENSURE(origInputItemType->IsStruct(), "Expected Struct as an input item type");
1356+
const auto origInputItemStructType = AS_TYPE(TStructType, origInputItemType);
1357+
MKQL_ENSURE(origInputItemStructType->GetMembersCount() > 0, "Expected at least one column");
13471358

13481359
const auto keyColumnsLiteral = callable.GetInput(2);
13491360
const auto keyColumnsTuple = AS_VALUE(TTupleLiteral, keyColumnsLiteral);
@@ -1355,7 +1366,7 @@ IComputationNode* WrapBlockMapJoinIndex(TCallable& callable, const TComputationN
13551366
}
13561367

13571368
for (ui32 keyColumn : keyColumns) {
1358-
MKQL_ENSURE(keyColumn < streamComponents.size() - 1, "Key column out of range");
1369+
MKQL_ENSURE(keyColumn < origInputItemStructType->GetMembersCount(), "Key column out of range");
13591370
}
13601371

13611372
const auto anyNode = callable.GetInput(3);
@@ -1408,10 +1419,9 @@ IComputationNode* WrapBlockMapJoinCore(TCallable& callable, const TComputationNo
14081419
}
14091420

14101421
auto origRightItemType = AS_VALUE(TTypeType, callable.GetInput(2));
1411-
MKQL_ENSURE(origRightItemType->IsMulti(), "Expected Multi as a right stream item type");
1412-
const auto rightStreamComponents = AS_TYPE(TMultiType, origRightItemType)->GetElements();
1413-
MKQL_ENSURE(rightStreamComponents.size() > 0, "Expected at least one column");
1414-
const TVector<TType*> rightStreamItems(rightStreamComponents.cbegin(), rightStreamComponents.cend());
1422+
MKQL_ENSURE(origRightItemType->IsStruct(), "Expected Struct as a right stream item type");
1423+
const auto origRightItemStructType = AS_TYPE(TStructType, origRightItemType);
1424+
MKQL_ENSURE(origRightItemStructType->GetMembersCount() > 0, "Expected at least one column");
14151425

14161426
const auto leftKeyColumnsLiteral = callable.GetInput(4);
14171427
const auto leftKeyColumnsTuple = AS_VALUE(TTupleLiteral, leftKeyColumnsLiteral);
@@ -1479,8 +1489,8 @@ IComputationNode* WrapBlockMapJoinCore(TCallable& callable, const TComputationNo
14791489
// XXX: Mind the last wide item, containing block length.
14801490
TVector<ui32> rightIOMap;
14811491
if (joinKind == EJoinKind::Inner || joinKind == EJoinKind::Left || joinKind == EJoinKind::Cross) {
1482-
for (size_t i = 0; i < rightStreamItems.size() - 1; i++) {
1483-
if (rightKeyDrops.contains(i)) {
1492+
for (size_t i = 0; i < origRightItemStructType->GetMembersCount(); i++) {
1493+
if (rightKeyDrops.contains(i) || origRightItemStructType->GetMemberName(i) == NYql::BlockLengthColumnName) {
14841494
continue;
14851495
}
14861496
rightIOMap.push_back(i);

0 commit comments

Comments
 (0)