|
6 | 6 | namespace kuzu { |
7 | 7 | namespace processor { |
8 | 8 |
|
9 | | -class SemiMasker : public PhysicalOperator { |
10 | | -public: |
11 | | - SemiMasker(const DataPos& keyDataPos, std::unique_ptr<PhysicalOperator> child, uint32_t id, |
12 | | - const std::string& paramsString) |
| 9 | +class BaseSemiMasker : public PhysicalOperator { |
| 10 | +protected: |
| 11 | + BaseSemiMasker(const DataPos& keyDataPos, ScanNodeIDSharedState* scanNodeIDSharedState, |
| 12 | + std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString) |
13 | 13 | : PhysicalOperator{PhysicalOperatorType::SEMI_MASKER, std::move(child), id, paramsString}, |
14 | | - keyDataPos{keyDataPos}, maskerIdx{0}, scanTableNodeIDSharedState{nullptr} {} |
| 14 | + keyDataPos{keyDataPos}, scanNodeIDSharedState{scanNodeIDSharedState} {} |
15 | 15 |
|
16 | | - SemiMasker(const SemiMasker& other) |
17 | | - : PhysicalOperator{PhysicalOperatorType::SEMI_MASKER, other.children[0]->clone(), other.id, |
18 | | - other.paramsString}, |
19 | | - keyDataPos{other.keyDataPos}, maskerIdx{other.maskerIdx}, |
20 | | - scanTableNodeIDSharedState{other.scanTableNodeIDSharedState} {} |
| 16 | + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; |
21 | 17 |
|
22 | | - inline void setSharedState(ScanTableNodeIDSharedState* sharedState) { |
23 | | - scanTableNodeIDSharedState = sharedState; |
24 | | - } |
| 18 | +protected: |
| 19 | + DataPos keyDataPos; |
| 20 | + ScanNodeIDSharedState* scanNodeIDSharedState; |
| 21 | + std::shared_ptr<common::ValueVector> keyValueVector; |
| 22 | +}; |
| 23 | + |
| 24 | +class SingleTableSemiMasker : public BaseSemiMasker { |
| 25 | +public: |
| 26 | + SingleTableSemiMasker(const DataPos& keyDataPos, ScanNodeIDSharedState* scanNodeIDSharedState, |
| 27 | + std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString) |
| 28 | + : BaseSemiMasker{keyDataPos, scanNodeIDSharedState, std::move(child), id, paramsString} {} |
| 29 | + |
| 30 | + void initGlobalStateInternal(kuzu::processor::ExecutionContext* context) override; |
25 | 31 |
|
26 | 32 | bool getNextTuplesInternal() override; |
27 | 33 |
|
28 | 34 | inline std::unique_ptr<PhysicalOperator> clone() override { |
29 | | - return std::make_unique<SemiMasker>(*this); |
| 35 | + auto result = std::make_unique<SingleTableSemiMasker>( |
| 36 | + keyDataPos, scanNodeIDSharedState, children[0]->clone(), id, paramsString); |
| 37 | + result->maskerIdxAndMask = maskerIdxAndMask; |
| 38 | + return result; |
30 | 39 | } |
31 | 40 |
|
32 | 41 | private: |
33 | | - void initGlobalStateInternal(ExecutionContext* context) override; |
| 42 | + // Multiple maskers can point to the same SemiMask, thus we associate each masker with an idx |
| 43 | + // to indicate the execution sequence of its pipeline. Also, the maskerIdx is used as a flag to |
| 44 | + // indicate if a value in the mask is masked or not, as each masker will increment the selected |
| 45 | + // value in the mask by 1. More details are described in NodeTableSemiMask. |
| 46 | + std::pair<uint8_t, NodeTableSemiMask*> maskerIdxAndMask; |
| 47 | +}; |
34 | 48 |
|
35 | | - void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; |
| 49 | +class MultiTableSemiMasker : public BaseSemiMasker { |
| 50 | +public: |
| 51 | + MultiTableSemiMasker(const DataPos& keyDataPos, ScanNodeIDSharedState* scanNodeIDSharedState, |
| 52 | + std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString) |
| 53 | + : BaseSemiMasker{keyDataPos, scanNodeIDSharedState, std::move(child), id, paramsString} {} |
| 54 | + |
| 55 | + void initGlobalStateInternal(kuzu::processor::ExecutionContext* context) override; |
| 56 | + |
| 57 | + bool getNextTuplesInternal() override; |
| 58 | + |
| 59 | + inline std::unique_ptr<PhysicalOperator> clone() override { |
| 60 | + auto result = std::make_unique<MultiTableSemiMasker>( |
| 61 | + keyDataPos, scanNodeIDSharedState, children[0]->clone(), id, paramsString); |
| 62 | + result->maskerIdxAndMasks = maskerIdxAndMasks; |
| 63 | + return result; |
| 64 | + } |
36 | 65 |
|
37 | 66 | private: |
38 | | - DataPos keyDataPos; |
39 | | - // Multiple maskers can point to the same scanNodeID, thus we associate each masker with an idx |
40 | | - // to indicate the execution sequence of its pipeline. Also, the maskerIdx is used as a flag to |
41 | | - // indicate if a value in the mask is masked or not, as each masker will increment the selected |
42 | | - // value in the mask by 1. More details are described in ScanNodeIDSemiMask. |
43 | | - uint8_t maskerIdx; |
44 | | - std::shared_ptr<common::ValueVector> keyValueVector; |
45 | | - ScanTableNodeIDSharedState* scanTableNodeIDSharedState; |
| 67 | + std::unordered_map<common::table_id_t, std::pair<uint8_t, NodeTableSemiMask*>> |
| 68 | + maskerIdxAndMasks; |
46 | 69 | }; |
| 70 | + |
47 | 71 | } // namespace processor |
48 | 72 | } // namespace kuzu |
0 commit comments