Skip to content

Commit

Permalink
core: protect epoch checksum with its own mutex (decred#2977)
Browse files Browse the repository at this point in the history
* match checksum to separate mutex and add dc cancel tracking
  • Loading branch information
buck54321 committed Oct 17, 2024
1 parent 58a27f8 commit 1030097
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 93 deletions.
113 changes: 77 additions & 36 deletions client/core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ type dexConnection struct {
// processed by a dex server.
inFlightOrders map[uint64]*InFlightOrder

// A map linking cancel order IDs to trade order IDs.
cancelsMtx sync.RWMutex
cancels map[order.OrderID]order.OrderID

blindCancelsMtx sync.Mutex
blindCancels map[order.OrderID]order.Preimage

Expand Down Expand Up @@ -253,6 +257,25 @@ func (dc *dexConnection) bondAssets() (map[uint32]*BondAsset, uint64) {
return bondAssets, cfg.BondExpiry
}

func (dc *dexConnection) registerCancelLink(cid, oid order.OrderID) {
dc.cancelsMtx.Lock()
dc.cancels[cid] = oid
dc.cancelsMtx.Unlock()
}

func (dc *dexConnection) deleteCancelLink(cid order.OrderID) {
dc.cancelsMtx.Lock()
delete(dc.cancels, cid)
dc.cancelsMtx.Unlock()
}

func (dc *dexConnection) cancelTradeID(cid order.OrderID) (order.OrderID, bool) {
dc.cancelsMtx.RLock()
defer dc.cancelsMtx.RUnlock()
oid, found := dc.cancels[cid]
return oid, found
}

// marketConfig is the market's configuration, as returned by the server in the
// 'config' response.
func (dc *dexConnection) marketConfig(mktID string) *msgjson.Market {
Expand Down Expand Up @@ -577,17 +600,19 @@ func (dc *dexConnection) activeOrders() ([]*Order, []*InFlightOrder) {

// findOrder returns the tracker and preimage for an order ID, and a boolean
// indicating whether this is a cancel order.
func (dc *dexConnection) findOrder(oid order.OrderID) (tracker *trackedTrade, preImg order.Preimage, isCancel bool) {
func (dc *dexConnection) findOrder(oid order.OrderID) (tracker *trackedTrade, isCancel bool) {
dc.tradeMtx.RLock()
defer dc.tradeMtx.RUnlock()
// Try to find the order as a trade.
if tracker, found := dc.trades[oid]; found {
return tracker, tracker.preImg, false
return tracker, false
}
// Search the cancel order IDs.
for _, tracker := range dc.trades {
if tracker.cancel != nil && tracker.cancel.ID() == oid {
return tracker, tracker.cancel.preImg, true

if tid, found := dc.cancelTradeID(oid); found {
if tracker, found := dc.trades[tid]; found {
return tracker, true
} else {
dc.log.Errorf("Did not find trade for cancel order ID %s", oid)
}
}
return
Expand Down Expand Up @@ -645,7 +670,7 @@ func (c *Core) sendCancelOrder(dc *dexConnection, oid order.OrderID, base, quote
// tryCancel will look for an order with the specified order ID, and attempt to
// cancel the order. It is not an error if the order is not found.
func (c *Core) tryCancel(dc *dexConnection, oid order.OrderID) (found bool, err error) {
tracker, _, _ := dc.findOrder(oid)
tracker, _ := dc.findOrder(oid)
if tracker == nil {
return // false, nil
}
Expand Down Expand Up @@ -771,7 +796,7 @@ func (dc *dexConnection) parseMatches(msgMatches []*msgjson.Match, checkSigs boo
for _, msgMatch := range msgMatches {
var oid order.OrderID
copy(oid[:], msgMatch.OrderID)
tracker, _, isCancel := dc.findOrder(oid)
tracker, isCancel := dc.findOrder(oid)
if tracker == nil {
dc.blindCancelsMtx.Lock()
_, found := dc.blindCancels[oid]
Expand Down Expand Up @@ -4953,7 +4978,7 @@ func (c *Core) Order(oidB dex.Bytes) (*Order, error) {
}
// See if it's an active order first.
for _, dc := range c.dexConnections() {
tracker, _, _ := dc.findOrder(oid)
tracker, _ := dc.findOrder(oid)
if tracker != nil {
return tracker.coreOrder(), nil
}
Expand Down Expand Up @@ -8104,6 +8129,7 @@ func (c *Core) newDEXConnection(acctInfo *db.AccountInfo, flag connectDEXFlag) (
ticker: newDexTicker(defaultTickInterval), // updated when server config obtained
books: make(map[string]*bookie),
trades: make(map[order.OrderID]*trackedTrade),
cancels: make(map[order.OrderID]order.OrderID),
inFlightOrders: make(map[uint64]*InFlightOrder),
blindCancels: make(map[order.OrderID]order.Preimage),
apiVer: -1,
Expand Down Expand Up @@ -8509,7 +8535,7 @@ func handleRevokeOrderMsg(c *Core, dc *dexConnection, msg *msgjson.Message) erro
var oid order.OrderID
copy(oid[:], revocation.OrderID)

tracker, _, isCancel := dc.findOrder(oid)
tracker, isCancel := dc.findOrder(oid)
if tracker == nil {
return fmt.Errorf("no order found with id %s", oid.String())
}
Expand Down Expand Up @@ -8551,7 +8577,7 @@ func handleRevokeMatchMsg(c *Core, dc *dexConnection, msg *msgjson.Message) erro
var oid order.OrderID
copy(oid[:], revocation.OrderID)

tracker, _, _ := dc.findOrder(oid)
tracker, _ := dc.findOrder(oid)
if tracker == nil {
return fmt.Errorf("no order found with id %s (not an error if you've completed your side of the swap)", oid.String())
}
Expand Down Expand Up @@ -8916,7 +8942,7 @@ func handlePreimageRequest(c *Core, dc *dexConnection, msg *msgjson.Message) err
}

if len(req.Commitment) != order.CommitmentSize {
return fmt.Errorf("received preimage request for %v with no corresponding order submission response.", oid)
return fmt.Errorf("received preimage request for %s with no corresponding order submission response", oid)
}

// See if we recognize that commitment, and if we do, just wait for the
Expand Down Expand Up @@ -8950,7 +8976,8 @@ func handlePreimageRequest(c *Core, dc *dexConnection, msg *msgjson.Message) err
}

func processPreimageRequest(c *Core, dc *dexConnection, reqID uint64, oid order.OrderID, commitChecksum dex.Bytes) error {
tracker, preImg, isCancel := dc.findOrder(oid)
tracker, isCancel := dc.findOrder(oid)
var preImg order.Preimage
if tracker == nil {
var found bool
dc.blindCancelsMtx.Lock()
Expand All @@ -8962,7 +8989,8 @@ func processPreimageRequest(c *Core, dc *dexConnection, reqID uint64, oid order.
} else {
// Record the csum if this preimage request is novel, and deny it if
// this is a duplicate request with an altered csum.
if !acceptCsum(tracker, isCancel, commitChecksum) {
var accept bool
if accept, preImg = acceptCsum(tracker, isCancel, commitChecksum); !accept {
csumErr := errors.New("invalid csum in duplicate preimage request")
resp, err := msgjson.NewResponse(reqID, nil,
msgjson.NewError(msgjson.InvalidRequestError, "%v", csumErr))
Expand Down Expand Up @@ -9005,26 +9033,25 @@ func processPreimageRequest(c *Core, dc *dexConnection, reqID uint64, oid order.
// the server may have used the knowledge of this preimage we are sending them
// now to alter the epoch shuffle. The return value is false if a previous
// checksum has been recorded that differs from the provided one.
func acceptCsum(tracker *trackedTrade, isCancel bool, commitChecksum dex.Bytes) bool {
func acceptCsum(tracker *trackedTrade, isCancel bool, commitChecksum dex.Bytes) (bool, order.Preimage) {
// Do not allow csum to be changed once it has been committed to
// (initialized to something other than `nil`) because it is probably a
// malicious behavior by the server.
tracker.mtx.Lock()
defer tracker.mtx.Unlock()

tracker.csumMtx.Lock()
defer tracker.csumMtx.Unlock()
if isCancel {
if tracker.cancel.csum == nil {
tracker.cancel.csum = commitChecksum
return true
if tracker.cancelCsum == nil {
tracker.cancelCsum = commitChecksum
return true, tracker.cancelPreimg
}
return bytes.Equal(commitChecksum, tracker.cancel.csum)
return bytes.Equal(commitChecksum, tracker.cancelCsum), tracker.cancelPreimg
}
if tracker.csum == nil {
tracker.csum = commitChecksum
return true
return true, tracker.preImg
}

return bytes.Equal(commitChecksum, tracker.csum)
return bytes.Equal(commitChecksum, tracker.csum), tracker.preImg
}

// handleMatchRoute processes the DEX-originating match route request,
Expand Down Expand Up @@ -9103,7 +9130,7 @@ func handleNoMatchRoute(c *Core, dc *dexConnection, msg *msgjson.Message) error
var oid order.OrderID
copy(oid[:], nomatchMsg.OrderID)

tracker, _, _ := dc.findOrder(oid)
tracker, _ := dc.findOrder(oid)
if tracker == nil {
dc.blindCancelsMtx.Lock()
_, found := dc.blindCancels[oid]
Expand Down Expand Up @@ -9177,7 +9204,7 @@ func handleAuditRoute(c *Core, dc *dexConnection, msg *msgjson.Message) error {
var oid order.OrderID
copy(oid[:], audit.OrderID)

tracker, _, _ := dc.findOrder(oid)
tracker, _ := dc.findOrder(oid)
if tracker == nil {
return fmt.Errorf("audit request received for unknown order: %s", string(msg.Payload))
}
Expand All @@ -9202,7 +9229,7 @@ func handleRedemptionRoute(c *Core, dc *dexConnection, msg *msgjson.Message) err
var oid order.OrderID
copy(oid[:], redemption.OrderID)

tracker, _, isCancel := dc.findOrder(oid)
tracker, isCancel := dc.findOrder(oid)
if tracker != nil {
if isCancel {
return fmt.Errorf("redemption request received for cancel order %v, match %v (you ok server?)",
Expand Down Expand Up @@ -10253,7 +10280,7 @@ func (c *Core) RemoveWalletPeer(assetID uint32, address string) error {
// id. An error is returned if it cannot be found.
func (c *Core) findActiveOrder(oid order.OrderID) (*trackedTrade, error) {
for _, dc := range c.dexConnections() {
tracker, _, _ := dc.findOrder(oid)
tracker, _ := dc.findOrder(oid)
if tracker != nil {
return tracker, nil
}
Expand Down Expand Up @@ -10640,7 +10667,7 @@ func (c *Core) handleRetryRedemptionAction(actionB []byte) error {
copy(oid[:], req.OrderID)
var tracker *trackedTrade
for _, dc := range c.dexConnections() {
tracker, _, _ = dc.findOrder(oid)
tracker, _ = dc.findOrder(oid)
if tracker != nil {
break
}
Expand Down Expand Up @@ -10738,25 +10765,39 @@ func (c *Core) checkEpochResolution(host string, mktID string) {
}
currentEpoch := dc.marketEpoch(mktID, time.Now())
lastEpoch := currentEpoch - 1

// Short path if we're already resolved.
dc.epochMtx.RLock()
resolvedEpoch := dc.resolvedEpoch[mktID]
dc.epochMtx.RUnlock()
if lastEpoch == resolvedEpoch {
return
}

ts, inFlights := dc.marketTrades(mktID)
for _, ord := range inFlights {
if ord.Epoch == lastEpoch {
return
}
}
for _, t := range ts {
// Is this order from the last epoch and still not booked or executed?
if t.epochIdx() == lastEpoch && t.status() == order.OrderStatusEpoch {
return
}
if t.cancel != nil && t.cancelEpochIdx() == lastEpoch {
t.mtx.RLock()
matched := t.cancel.matches.taker != nil
t.mtx.RUnlock()
if !matched {
return
}
// Does this order have an in-flight cancel order that is not yet
// resolved?
t.mtx.RLock()
unresolvedCancel := t.cancel != nil && t.cancelEpochIdx() == lastEpoch && t.cancel.matches.taker == nil
t.mtx.RUnlock()
if unresolvedCancel {
return
}
}

// We don't have any unresolved orders or cancel orders from the last epoch.
// Just make sure that not other thread has resolved the epoch and then send
// the notification.
dc.epochMtx.Lock()
sendUpdate := lastEpoch > dc.resolvedEpoch[mktID]
dc.resolvedEpoch[mktID] = lastEpoch
Expand Down
Loading

0 comments on commit 1030097

Please sign in to comment.