Skip to content

Commit 9fe94ad

Browse files
committed
Add multi-label asp
1 parent b13e5b8 commit 9fe94ad

File tree

8 files changed

+132
-68
lines changed

8 files changed

+132
-68
lines changed

benchmark/serializer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def serialize(dataset_name, dataset_path, serialized_graph_path):
5151
try:
5252
# Run kuzu shell one query at a time. This ensures a new process is
5353
# created for each query to avoid memory leaks.
54-
subprocess.run([kuzu_exec_path, '-i', serialized_graph_path],
54+
subprocess.run([kuzu_exec_path, serialized_graph_path],
5555
input=(s + ";" + "\n").encode("ascii"), check=True)
5656
except subprocess.CalledProcessError as e:
5757
logging.error('Error executing query: %s', s)

src/include/processor/operator/scan_node_id.h

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ struct Mask {
2424
};
2525

2626
// Note: This class is not thread-safe.
27-
struct ScanNodeIDSemiMask {
27+
struct NodeTableSemiMask {
2828
public:
29-
explicit ScanNodeIDSemiMask() : numMaskers{0} {}
29+
NodeTableSemiMask() : numMaskers{0} {}
3030

3131
inline void initializeMaskData(common::offset_t maxNodeOffset, common::offset_t maxMorselIdx) {
3232
if (nodeMask != nullptr) {
@@ -59,11 +59,12 @@ struct ScanNodeIDSemiMask {
5959
};
6060

6161
// Note: This class is not thread-safe. It relies on its caller to correctly synchronize its state.
62-
class ScanTableNodeIDSharedState {
62+
class NodeTableState {
6363
public:
64-
explicit ScanTableNodeIDSharedState(storage::NodeTable* table)
65-
: table{table}, maxNodeOffset{UINT64_MAX}, maxMorselIdx{UINT64_MAX}, currentNodeOffset{0} {
66-
semiMask = std::make_unique<ScanNodeIDSemiMask>();
64+
explicit NodeTableState(storage::NodeTable* table)
65+
: table{table}, maxNodeOffset{common::INVALID_NODE_OFFSET}, maxMorselIdx{UINT64_MAX},
66+
currentNodeOffset{0} {
67+
semiMask = std::make_unique<NodeTableSemiMask>();
6768
}
6869

6970
inline storage::NodeTable* getTable() { return table; }
@@ -83,7 +84,7 @@ class ScanTableNodeIDSharedState {
8384
semiMask->initializeMaskData(maxNodeOffset, maxMorselIdx);
8485
}
8586
inline bool isSemiMaskEnabled() { return semiMask->getNumMaskers() > 0; }
86-
inline ScanNodeIDSemiMask* getSemiMask() { return semiMask.get(); }
87+
inline NodeTableSemiMask* getSemiMask() { return semiMask.get(); }
8788
inline uint8_t getNumMaskers() const { return semiMask->getNumMaskers(); }
8889
inline void incrementNumMaskers() { semiMask->incrementNumMaskers(); }
8990

@@ -94,33 +95,30 @@ class ScanTableNodeIDSharedState {
9495
uint64_t maxNodeOffset;
9596
uint64_t maxMorselIdx;
9697
uint64_t currentNodeOffset;
97-
std::unique_ptr<ScanNodeIDSemiMask> semiMask;
98+
std::unique_ptr<NodeTableSemiMask> semiMask;
9899
};
99100

100101
class ScanNodeIDSharedState {
101102
public:
102103
ScanNodeIDSharedState() : currentStateIdx{0} {};
103104

104105
inline void addTableState(storage::NodeTable* table) {
105-
tableStates.push_back(std::make_unique<ScanTableNodeIDSharedState>(table));
106+
tableStates.push_back(std::make_unique<NodeTableState>(table));
106107
}
107108
inline uint32_t getNumTableStates() const { return tableStates.size(); }
108-
inline ScanTableNodeIDSharedState* getTableState(uint32_t idx) const {
109-
return tableStates[idx].get();
110-
}
109+
inline NodeTableState* getTableState(uint32_t idx) const { return tableStates[idx].get(); }
111110

112111
inline void initialize(transaction::Transaction* transaction) {
113112
for (auto& tableState : tableStates) {
114113
tableState->initializeMaxOffset(transaction);
115114
}
116115
}
117116

118-
std::tuple<ScanTableNodeIDSharedState*, common::offset_t, common::offset_t>
119-
getNextRangeToRead();
117+
std::tuple<NodeTableState*, common::offset_t, common::offset_t> getNextRangeToRead();
120118

121119
private:
122120
std::mutex mtx;
123-
std::vector<std::unique_ptr<ScanTableNodeIDSharedState>> tableStates;
121+
std::vector<std::unique_ptr<NodeTableState>> tableStates;
124122
uint32_t currentStateIdx;
125123
};
126124

@@ -148,8 +146,8 @@ class ScanNodeID : public PhysicalOperator {
148146
sharedState->initialize(context->transaction);
149147
}
150148

151-
void setSelVector(ScanTableNodeIDSharedState* tableState, common::offset_t startOffset,
152-
common::offset_t endOffset);
149+
void setSelVector(
150+
NodeTableState* tableState, common::offset_t startOffset, common::offset_t endOffset);
153151

154152
private:
155153
DataPos outDataPos;

src/include/processor/operator/semi_masker.h

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,67 @@
66
namespace kuzu {
77
namespace processor {
88

9-
class SemiMasker : public PhysicalOperator {
10-
public:
11-
SemiMasker(const DataPos& keyDataPos, std::unique_ptr<PhysicalOperator> child, uint32_t id,
12-
const std::string& paramsString)
9+
class BaseSemiMasker : public PhysicalOperator {
10+
protected:
11+
BaseSemiMasker(const DataPos& keyDataPos, ScanNodeIDSharedState* scanNodeIDSharedState,
12+
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
1313
: PhysicalOperator{PhysicalOperatorType::SEMI_MASKER, std::move(child), id, paramsString},
14-
keyDataPos{keyDataPos}, maskerIdx{0}, scanTableNodeIDSharedState{nullptr} {}
14+
keyDataPos{keyDataPos}, scanNodeIDSharedState{scanNodeIDSharedState} {}
1515

16-
SemiMasker(const SemiMasker& other)
17-
: PhysicalOperator{PhysicalOperatorType::SEMI_MASKER, other.children[0]->clone(), other.id,
18-
other.paramsString},
19-
keyDataPos{other.keyDataPos}, maskerIdx{other.maskerIdx},
20-
scanTableNodeIDSharedState{other.scanTableNodeIDSharedState} {}
16+
void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;
2117

22-
inline void setSharedState(ScanTableNodeIDSharedState* sharedState) {
23-
scanTableNodeIDSharedState = sharedState;
24-
}
18+
protected:
19+
DataPos keyDataPos;
20+
ScanNodeIDSharedState* scanNodeIDSharedState;
21+
std::shared_ptr<common::ValueVector> keyValueVector;
22+
};
23+
24+
class SingleTableSemiMasker : public BaseSemiMasker {
25+
public:
26+
SingleTableSemiMasker(const DataPos& keyDataPos, ScanNodeIDSharedState* scanNodeIDSharedState,
27+
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
28+
: BaseSemiMasker{keyDataPos, scanNodeIDSharedState, std::move(child), id, paramsString} {}
29+
30+
void initGlobalStateInternal(kuzu::processor::ExecutionContext* context) override;
2531

2632
bool getNextTuplesInternal() override;
2733

2834
inline std::unique_ptr<PhysicalOperator> clone() override {
29-
return std::make_unique<SemiMasker>(*this);
35+
auto result = std::make_unique<SingleTableSemiMasker>(
36+
keyDataPos, scanNodeIDSharedState, children[0]->clone(), id, paramsString);
37+
result->maskerIdxAndMask = maskerIdxAndMask;
38+
return result;
3039
}
3140

3241
private:
33-
void initGlobalStateInternal(ExecutionContext* context) override;
42+
// Multiple maskers can point to the same SemiMask, thus we associate each masker with an idx
43+
// to indicate the execution sequence of its pipeline. Also, the maskerIdx is used as a flag to
44+
// indicate if a value in the mask is masked or not, as each masker will increment the selected
45+
// value in the mask by 1. More details are described in NodeTableSemiMask.
46+
std::pair<uint8_t, NodeTableSemiMask*> maskerIdxAndMask;
47+
};
3448

35-
void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;
49+
class MultiTableSemiMasker : public BaseSemiMasker {
50+
public:
51+
MultiTableSemiMasker(const DataPos& keyDataPos, ScanNodeIDSharedState* scanNodeIDSharedState,
52+
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
53+
: BaseSemiMasker{keyDataPos, scanNodeIDSharedState, std::move(child), id, paramsString} {}
54+
55+
void initGlobalStateInternal(kuzu::processor::ExecutionContext* context) override;
56+
57+
bool getNextTuplesInternal() override;
58+
59+
inline std::unique_ptr<PhysicalOperator> clone() override {
60+
auto result = std::make_unique<MultiTableSemiMasker>(
61+
keyDataPos, scanNodeIDSharedState, children[0]->clone(), id, paramsString);
62+
result->maskerIdxAndMasks = maskerIdxAndMasks;
63+
return result;
64+
}
3665

3766
private:
38-
DataPos keyDataPos;
39-
// Multiple maskers can point to the same scanNodeID, thus we associate each masker with an idx
40-
// to indicate the execution sequence of its pipeline. Also, the maskerIdx is used as a flag to
41-
// indicate if a value in the mask is masked or not, as each masker will increment the selected
42-
// value in the mask by 1. More details are described in ScanNodeIDSemiMask.
43-
uint8_t maskerIdx;
44-
std::shared_ptr<common::ValueVector> keyValueVector;
45-
ScanTableNodeIDSharedState* scanTableNodeIDSharedState;
67+
std::unordered_map<common::table_id_t, std::pair<uint8_t, NodeTableSemiMask*>>
68+
maskerIdxAndMasks;
4669
};
70+
4771
} // namespace processor
4872
} // namespace kuzu

src/optimizer/asp_optimizer.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,6 @@ std::vector<planner::LogicalOperator*> ASPOptimizer::resolveScanNodesToApplySemi
8181
scanNodesCollector.collect(buildRoot);
8282
for (auto& op : scanNodesCollector.getOperators()) {
8383
auto scanNode = (LogicalScanNode*)op;
84-
if (scanNode->getNode()->isMultiLabeled()) {
85-
// We don't push semi mask to multi-labeled scan. This can be solved.
86-
continue;
87-
}
8884
auto nodeID = scanNode->getNode()->getInternalIDProperty();
8985
if (!nodeIDToScanOperatorsMap.contains(nodeID)) {
9086
nodeIDToScanOperatorsMap.insert({nodeID, std::vector<LogicalOperator*>{}});

src/processor/mapper/map_asp.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,15 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalSemiMaskerToPhysical(
3333
auto physicalScanNode = (ScanNodeID*)logicalOpToPhysicalOpMap.at(logicalScanNode);
3434
auto keyDataPos =
3535
DataPos(inSchema->getExpressionPos(*logicalScanNode->getNode()->getInternalIDProperty()));
36-
auto semiMasker = make_unique<SemiMasker>(keyDataPos, std::move(prevOperator), getOperatorID(),
37-
logicalSemiMasker->getExpressionsForPrinting());
38-
assert(physicalScanNode->getSharedState()->getNumTableStates() == 1);
39-
semiMasker->setSharedState(physicalScanNode->getSharedState()->getTableState(0));
40-
return semiMasker;
36+
if (physicalScanNode->getSharedState()->getNumTableStates() > 1) {
37+
return std::make_unique<MultiTableSemiMasker>(keyDataPos,
38+
physicalScanNode->getSharedState(), std::move(prevOperator), getOperatorID(),
39+
logicalSemiMasker->getExpressionsForPrinting());
40+
} else {
41+
return std::make_unique<SingleTableSemiMasker>(keyDataPos,
42+
physicalScanNode->getSharedState(), std::move(prevOperator), getOperatorID(),
43+
logicalSemiMasker->getExpressionsForPrinting());
44+
}
4145
}
4246

4347
} // namespace processor

src/processor/operator/scan_node_id.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace processor {
77

88
// Note: blindly update mask does not parallelize well, so we minimize write by first checking
99
// if the mask is set to true (mask value is equal to the expected currentMaskValue) or not.
10-
void ScanNodeIDSemiMask::incrementMaskValue(uint64_t nodeOffset, uint8_t currentMaskValue) {
10+
void NodeTableSemiMask::incrementMaskValue(uint64_t nodeOffset, uint8_t currentMaskValue) {
1111
if (nodeMask->isMasked(nodeOffset, currentMaskValue)) {
1212
nodeMask->setMask(nodeOffset, currentMaskValue + 1);
1313
}
@@ -17,7 +17,7 @@ void ScanNodeIDSemiMask::incrementMaskValue(uint64_t nodeOffset, uint8_t current
1717
}
1818
}
1919

20-
std::pair<offset_t, offset_t> ScanTableNodeIDSharedState::getNextRangeToRead() {
20+
std::pair<offset_t, offset_t> NodeTableState::getNextRangeToRead() {
2121
// Note: we use maxNodeOffset=UINT64_MAX to represent an empty table.
2222
if (currentNodeOffset > maxNodeOffset || maxNodeOffset == INVALID_NODE_OFFSET) {
2323
return std::make_pair(currentNodeOffset, currentNodeOffset);
@@ -36,8 +36,7 @@ std::pair<offset_t, offset_t> ScanTableNodeIDSharedState::getNextRangeToRead() {
3636
return std::make_pair(startOffset, startOffset + range);
3737
}
3838

39-
std::tuple<ScanTableNodeIDSharedState*, offset_t, offset_t>
40-
ScanNodeIDSharedState::getNextRangeToRead() {
39+
std::tuple<NodeTableState*, offset_t, offset_t> ScanNodeIDSharedState::getNextRangeToRead() {
4140
std::unique_lock lck{mtx};
4241
if (currentStateIdx == tableStates.size()) {
4342
return std::make_tuple(nullptr, INVALID_NODE_OFFSET, INVALID_NODE_OFFSET);
@@ -81,7 +80,7 @@ bool ScanNodeID::getNextTuplesInternal() {
8180
}
8281

8382
void ScanNodeID::setSelVector(
84-
ScanTableNodeIDSharedState* tableState, offset_t startOffset, offset_t endOffset) {
83+
NodeTableState* tableState, offset_t startOffset, offset_t endOffset) {
8584
if (tableState->isSemiMaskEnabled()) {
8685
outValueVector->state->selVector->resetSelectorToValuePosBuffer();
8786
// Fill selected positions based on node mask for nodes between the given startOffset and

src/processor/operator/semi_masker.cpp

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,65 @@ using namespace kuzu::common;
55
namespace kuzu {
66
namespace processor {
77

8-
void SemiMasker::initGlobalStateInternal(kuzu::processor::ExecutionContext* context) {
9-
scanTableNodeIDSharedState->initSemiMask(context->transaction);
10-
maskerIdx = scanTableNodeIDSharedState->getNumMaskers();
8+
void BaseSemiMasker::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) {
9+
keyValueVector = resultSet->getValueVector(keyDataPos);
10+
assert(keyValueVector->dataType.typeID == INTERNAL_ID);
11+
}
12+
13+
static std::pair<uint8_t, NodeTableSemiMask*> initSemiMaskForTableState(
14+
NodeTableState* tableState, transaction::Transaction* trx) {
15+
tableState->initSemiMask(trx);
16+
auto maskerIdx = tableState->getNumMaskers();
1117
assert(maskerIdx < UINT8_MAX);
12-
scanTableNodeIDSharedState->incrementNumMaskers();
18+
tableState->incrementNumMaskers();
19+
return std::make_pair(maskerIdx, tableState->getSemiMask());
1320
}
1421

15-
void SemiMasker::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) {
16-
keyValueVector = resultSet->getValueVector(keyDataPos);
17-
assert(keyValueVector->dataType.typeID == INTERNAL_ID);
22+
void SingleTableSemiMasker::initGlobalStateInternal(kuzu::processor::ExecutionContext* context) {
23+
assert(scanNodeIDSharedState->getNumTableStates() == 1);
24+
auto tableState = scanNodeIDSharedState->getTableState(0);
25+
maskerIdxAndMask = initSemiMaskForTableState(tableState, context->transaction);
26+
}
27+
28+
bool SingleTableSemiMasker::getNextTuplesInternal() {
29+
if (!children[0]->getNextTuple()) {
30+
return false;
31+
}
32+
auto [maskerIdx, mask] = maskerIdxAndMask;
33+
auto numValues =
34+
keyValueVector->state->isFlat() ? 1 : keyValueVector->state->selVector->selectedSize;
35+
for (auto i = 0u; i < numValues; i++) {
36+
auto pos = keyValueVector->state->selVector->selectedPositions[i];
37+
auto nodeID = keyValueVector->getValue<nodeID_t>(pos);
38+
mask->incrementMaskValue(nodeID.offset, maskerIdx);
39+
}
40+
metrics->numOutputTuple.increase(numValues);
41+
return true;
42+
}
43+
44+
void MultiTableSemiMasker::initGlobalStateInternal(kuzu::processor::ExecutionContext* context) {
45+
assert(scanNodeIDSharedState->getNumTableStates() > 1);
46+
for (auto i = 0u; i < scanNodeIDSharedState->getNumTableStates(); ++i) {
47+
auto tableState = scanNodeIDSharedState->getTableState(i);
48+
auto maskerIdxAndMask = initSemiMaskForTableState(tableState, context->transaction);
49+
maskerIdxAndMasks.insert(
50+
{tableState->getTable()->getTableID(), std::move(maskerIdxAndMask)});
51+
}
1852
}
1953

20-
bool SemiMasker::getNextTuplesInternal() {
54+
bool MultiTableSemiMasker::getNextTuplesInternal() {
2155
if (!children[0]->getNextTuple()) {
2256
return false;
2357
}
2458
auto numValues =
2559
keyValueVector->state->isFlat() ? 1 : keyValueVector->state->selVector->selectedSize;
2660
for (auto i = 0u; i < numValues; i++) {
2761
auto pos = keyValueVector->state->selVector->selectedPositions[i];
28-
scanTableNodeIDSharedState->getSemiMask()->incrementMaskValue(
29-
keyValueVector->getValue<nodeID_t>(pos).offset, maskerIdx);
62+
auto nodeID = keyValueVector->getValue<nodeID_t>(pos);
63+
auto [maskerIdx, mask] = maskerIdxAndMasks.at(nodeID.tableID);
64+
mask->incrementMaskValue(nodeID.offset, maskerIdx);
3065
}
31-
metrics->numOutputTuple.increase(
32-
keyValueVector->state->isFlat() ? 1 : keyValueVector->state->selVector->selectedSize);
66+
metrics->numOutputTuple.increase(numValues);
3367
return true;
3468
}
3569

test/test_files/tinysnb/asp/asp.test

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ Alice
66
Bob
77
Dan
88

9+
-NAME AspMultiLabel
10+
-QUERY MATCH (a:person)-[e1:knows|:studyAt|:workAt]->(b:person:organisation) WHERE a.age > 35 RETURN b.fName, b.name
11+
-ENCODED_JOIN HJ(b._id){E(b)S(a)}{S(b)}
12+
---- 4
13+
Alice|
14+
Bob|
15+
Dan|
16+
|CsWork
17+
918
-NAME AspMultiKey
1019
-QUERY MATCH (a:person)-[e1:knows]->(b:person)-[e2:knows]->(c:person), (a)-[e3:knows]->(c) WHERE a.fName='Alice' RETURN b.fName, c.fName
1120
#-ENCODED_JOIN HJ(c._id,b._id){E(b)E(c)S(a)}{HJ(b._id){S(b)}{E(b)S(c)}}

0 commit comments

Comments
 (0)