Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: protect epoch checksum with its own mutex #2977

Merged
merged 3 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 77 additions & 36 deletions client/core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ type dexConnection struct {
// processed by a dex server.
inFlightOrders map[uint64]*InFlightOrder

// A map linking cancel order IDs to trade order IDs.
cancelsMtx sync.RWMutex
cancels map[order.OrderID]order.OrderID

blindCancelsMtx sync.Mutex
blindCancels map[order.OrderID]order.Preimage

Expand Down Expand Up @@ -253,6 +257,25 @@ func (dc *dexConnection) bondAssets() (map[uint32]*BondAsset, uint64) {
return bondAssets, cfg.BondExpiry
}

func (dc *dexConnection) registerCancelLink(cid, oid order.OrderID) {
dc.cancelsMtx.Lock()
dc.cancels[cid] = oid
dc.cancelsMtx.Unlock()
}

func (dc *dexConnection) deleteCancelLink(cid order.OrderID) {
dc.cancelsMtx.Lock()
delete(dc.cancels, cid)
dc.cancelsMtx.Unlock()
}

func (dc *dexConnection) cancelTradeID(cid order.OrderID) (order.OrderID, bool) {
dc.cancelsMtx.RLock()
defer dc.cancelsMtx.RUnlock()
oid, found := dc.cancels[cid]
return oid, found
}

// marketConfig is the market's configuration, as returned by the server in the
// 'config' response.
func (dc *dexConnection) marketConfig(mktID string) *msgjson.Market {
Expand Down Expand Up @@ -577,17 +600,19 @@ func (dc *dexConnection) activeOrders() ([]*Order, []*InFlightOrder) {

// findOrder returns the tracker and preimage for an order ID, and a boolean
// indicating whether this is a cancel order.
func (dc *dexConnection) findOrder(oid order.OrderID) (tracker *trackedTrade, preImg order.Preimage, isCancel bool) {
func (dc *dexConnection) findOrder(oid order.OrderID) (tracker *trackedTrade, isCancel bool) {
dc.tradeMtx.RLock()
defer dc.tradeMtx.RUnlock()
// Try to find the order as a trade.
if tracker, found := dc.trades[oid]; found {
return tracker, tracker.preImg, false
return tracker, false
}
// Search the cancel order IDs.
for _, tracker := range dc.trades {
if tracker.cancel != nil && tracker.cancel.ID() == oid {
return tracker, tracker.cancel.preImg, true

if tid, found := dc.cancelTradeID(oid); found {
if tracker, found := dc.trades[tid]; found {
return tracker, true
} else {
dc.log.Errorf("Did not find trade for cancel order ID %s", oid)
}
}
return
Expand Down Expand Up @@ -645,7 +670,7 @@ func (c *Core) sendCancelOrder(dc *dexConnection, oid order.OrderID, base, quote
// tryCancel will look for an order with the specified order ID, and attempt to
// cancel the order. It is not an error if the order is not found.
func (c *Core) tryCancel(dc *dexConnection, oid order.OrderID) (found bool, err error) {
tracker, _, _ := dc.findOrder(oid)
tracker, _ := dc.findOrder(oid)
if tracker == nil {
return // false, nil
}
Expand Down Expand Up @@ -771,7 +796,7 @@ func (dc *dexConnection) parseMatches(msgMatches []*msgjson.Match, checkSigs boo
for _, msgMatch := range msgMatches {
var oid order.OrderID
copy(oid[:], msgMatch.OrderID)
tracker, _, isCancel := dc.findOrder(oid)
tracker, isCancel := dc.findOrder(oid)
if tracker == nil {
dc.blindCancelsMtx.Lock()
_, found := dc.blindCancels[oid]
Expand Down Expand Up @@ -4951,7 +4976,7 @@ func (c *Core) Order(oidB dex.Bytes) (*Order, error) {
}
// See if it's an active order first.
for _, dc := range c.dexConnections() {
tracker, _, _ := dc.findOrder(oid)
tracker, _ := dc.findOrder(oid)
if tracker != nil {
return tracker.coreOrder(), nil
}
Expand Down Expand Up @@ -8087,6 +8112,7 @@ func (c *Core) newDEXConnection(acctInfo *db.AccountInfo, flag connectDEXFlag) (
ticker: newDexTicker(defaultTickInterval), // updated when server config obtained
books: make(map[string]*bookie),
trades: make(map[order.OrderID]*trackedTrade),
cancels: make(map[order.OrderID]order.OrderID),
inFlightOrders: make(map[uint64]*InFlightOrder),
blindCancels: make(map[order.OrderID]order.Preimage),
apiVer: -1,
Expand Down Expand Up @@ -8492,7 +8518,7 @@ func handleRevokeOrderMsg(c *Core, dc *dexConnection, msg *msgjson.Message) erro
var oid order.OrderID
copy(oid[:], revocation.OrderID)

tracker, _, isCancel := dc.findOrder(oid)
tracker, isCancel := dc.findOrder(oid)
if tracker == nil {
return fmt.Errorf("no order found with id %s", oid.String())
}
Expand Down Expand Up @@ -8534,7 +8560,7 @@ func handleRevokeMatchMsg(c *Core, dc *dexConnection, msg *msgjson.Message) erro
var oid order.OrderID
copy(oid[:], revocation.OrderID)

tracker, _, _ := dc.findOrder(oid)
tracker, _ := dc.findOrder(oid)
if tracker == nil {
return fmt.Errorf("no order found with id %s (not an error if you've completed your side of the swap)", oid.String())
}
Expand Down Expand Up @@ -8899,7 +8925,7 @@ func handlePreimageRequest(c *Core, dc *dexConnection, msg *msgjson.Message) err
}

if len(req.Commitment) != order.CommitmentSize {
return fmt.Errorf("received preimage request for %v with no corresponding order submission response.", oid)
return fmt.Errorf("received preimage request for %s with no corresponding order submission response", oid)
}

// See if we recognize that commitment, and if we do, just wait for the
Expand Down Expand Up @@ -8933,7 +8959,8 @@ func handlePreimageRequest(c *Core, dc *dexConnection, msg *msgjson.Message) err
}

func processPreimageRequest(c *Core, dc *dexConnection, reqID uint64, oid order.OrderID, commitChecksum dex.Bytes) error {
tracker, preImg, isCancel := dc.findOrder(oid)
tracker, isCancel := dc.findOrder(oid)
var preImg order.Preimage
if tracker == nil {
var found bool
dc.blindCancelsMtx.Lock()
Expand All @@ -8945,7 +8972,8 @@ func processPreimageRequest(c *Core, dc *dexConnection, reqID uint64, oid order.
} else {
// Record the csum if this preimage request is novel, and deny it if
// this is a duplicate request with an altered csum.
if !acceptCsum(tracker, isCancel, commitChecksum) {
var accept bool
if accept, preImg = acceptCsum(tracker, isCancel, commitChecksum); !accept {
csumErr := errors.New("invalid csum in duplicate preimage request")
resp, err := msgjson.NewResponse(reqID, nil,
msgjson.NewError(msgjson.InvalidRequestError, "%v", csumErr))
Expand Down Expand Up @@ -8988,26 +9016,25 @@ func processPreimageRequest(c *Core, dc *dexConnection, reqID uint64, oid order.
// the server may have used the knowledge of this preimage we are sending them
// now to alter the epoch shuffle. The return value is false if a previous
// checksum has been recorded that differs from the provided one.
func acceptCsum(tracker *trackedTrade, isCancel bool, commitChecksum dex.Bytes) bool {
func acceptCsum(tracker *trackedTrade, isCancel bool, commitChecksum dex.Bytes) (bool, order.Preimage) {
// Do not allow csum to be changed once it has been committed to
// (initialized to something other than `nil`) because it is probably a
// malicious behavior by the server.
tracker.mtx.Lock()
defer tracker.mtx.Unlock()

tracker.csumMtx.Lock()
defer tracker.csumMtx.Unlock()
if isCancel {
if tracker.cancel.csum == nil {
tracker.cancel.csum = commitChecksum
return true
if tracker.cancelCsum == nil {
tracker.cancelCsum = commitChecksum
return true, tracker.cancelPreimg
}
return bytes.Equal(commitChecksum, tracker.cancel.csum)
return bytes.Equal(commitChecksum, tracker.cancelCsum), tracker.cancelPreimg
}
if tracker.csum == nil {
tracker.csum = commitChecksum
return true
return true, tracker.preImg
}

return bytes.Equal(commitChecksum, tracker.csum)
return bytes.Equal(commitChecksum, tracker.csum), tracker.preImg
}

// handleMatchRoute processes the DEX-originating match route request,
Expand Down Expand Up @@ -9086,7 +9113,7 @@ func handleNoMatchRoute(c *Core, dc *dexConnection, msg *msgjson.Message) error
var oid order.OrderID
copy(oid[:], nomatchMsg.OrderID)

tracker, _, _ := dc.findOrder(oid)
tracker, _ := dc.findOrder(oid)
if tracker == nil {
dc.blindCancelsMtx.Lock()
_, found := dc.blindCancels[oid]
Expand Down Expand Up @@ -9160,7 +9187,7 @@ func handleAuditRoute(c *Core, dc *dexConnection, msg *msgjson.Message) error {
var oid order.OrderID
copy(oid[:], audit.OrderID)

tracker, _, _ := dc.findOrder(oid)
tracker, _ := dc.findOrder(oid)
if tracker == nil {
return fmt.Errorf("audit request received for unknown order: %s", string(msg.Payload))
}
Expand All @@ -9185,7 +9212,7 @@ func handleRedemptionRoute(c *Core, dc *dexConnection, msg *msgjson.Message) err
var oid order.OrderID
copy(oid[:], redemption.OrderID)

tracker, _, isCancel := dc.findOrder(oid)
tracker, isCancel := dc.findOrder(oid)
if tracker != nil {
if isCancel {
return fmt.Errorf("redemption request received for cancel order %v, match %v (you ok server?)",
Expand Down Expand Up @@ -10236,7 +10263,7 @@ func (c *Core) RemoveWalletPeer(assetID uint32, address string) error {
// id. An error is returned if it cannot be found.
func (c *Core) findActiveOrder(oid order.OrderID) (*trackedTrade, error) {
for _, dc := range c.dexConnections() {
tracker, _, _ := dc.findOrder(oid)
tracker, _ := dc.findOrder(oid)
if tracker != nil {
return tracker, nil
}
Expand Down Expand Up @@ -10623,7 +10650,7 @@ func (c *Core) handleRetryRedemptionAction(actionB []byte) error {
copy(oid[:], req.OrderID)
var tracker *trackedTrade
for _, dc := range c.dexConnections() {
tracker, _, _ = dc.findOrder(oid)
tracker, _ = dc.findOrder(oid)
if tracker != nil {
break
}
Expand Down Expand Up @@ -10721,25 +10748,39 @@ func (c *Core) checkEpochResolution(host string, mktID string) {
}
currentEpoch := dc.marketEpoch(mktID, time.Now())
lastEpoch := currentEpoch - 1

// Short path if we're already resolved.
dc.epochMtx.RLock()
resolvedEpoch := dc.resolvedEpoch[mktID]
dc.epochMtx.RUnlock()
if lastEpoch == resolvedEpoch {
return
}

ts, inFlights := dc.marketTrades(mktID)
for _, ord := range inFlights {
if ord.Epoch == lastEpoch {
return
}
}
for _, t := range ts {
// Is this order from the last epoch and still not booked or executed?
if t.epochIdx() == lastEpoch && t.status() == order.OrderStatusEpoch {
return
}
if t.cancel != nil && t.cancelEpochIdx() == lastEpoch {
t.mtx.RLock()
matched := t.cancel.matches.taker != nil
t.mtx.RUnlock()
if !matched {
return
}
// Does this order have an in-flight cancel order that is not yet
// resolved?
t.mtx.RLock()
unresolvedCancel := t.cancel != nil && t.cancelEpochIdx() == lastEpoch && t.cancel.matches.taker == nil
t.mtx.RUnlock()
if unresolvedCancel {
return
}
}

// We don't have any unresolved orders or cancel orders from the last epoch.
// Just make sure that not other thread has resolved the epoch and then send
// the notification.
dc.epochMtx.Lock()
sendUpdate := lastEpoch > dc.resolvedEpoch[mktID]
dc.resolvedEpoch[mktID] = lastEpoch
Expand Down
Loading
Loading