Skip to content

Commit

Permalink
ICACallbacks Refactor (#810)
Browse files Browse the repository at this point in the history
  • Loading branch information
sampocs authored Jul 4, 2023
1 parent 4b5d80a commit 4db410e
Show file tree
Hide file tree
Showing 23 changed files with 520 additions and 524 deletions.
40 changes: 14 additions & 26 deletions app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,6 @@ func NewStrideApp(
)

stakeibcModule := stakeibcmodule.NewAppModule(appCodec, app.StakeibcKeeper, app.AccountKeeper, app.BankKeeper)
stakeibcIBCModule := stakeibcmodule.NewIBCModule(app.StakeibcKeeper)

app.AutopilotKeeper = *autopilotkeeper.NewKeeper(
appCodec,
Expand Down Expand Up @@ -576,19 +575,13 @@ func NewStrideApp(
epochsModule := epochsmodule.NewAppModule(appCodec, app.EpochsKeeper)

icacallbacksModule := icacallbacksmodule.NewAppModule(appCodec, app.IcacallbacksKeeper, app.AccountKeeper, app.BankKeeper)
icacallbacksIBCModule := icacallbacksmodule.NewIBCModule(app.IcacallbacksKeeper)

// Register ICA calllbacks
// NOTE: The icacallbacks struct implemented below provides a mapping from ICA channel owner to ICACallback handler,
// where the callback handler stores and routes to the various callback functions for a particular module.
// However, as of ibc-go v6, the icacontroller module owns the ICA channel. A consequence of this is that there can
// be no more than one module that implements ICA callbacks. Should we add an new module with ICA support in the future,
// we'll need to refactor this
err = app.IcacallbacksKeeper.SetICACallbackHandler(icacontrollertypes.SubModuleName, app.StakeibcKeeper.ICACallbackHandler())
if err != nil {
return nil
}
err = app.IcacallbacksKeeper.SetICACallbackHandler(ibctransfertypes.ModuleName, app.RecordsKeeper.ICACallbackHandler())
if err != nil {
// Register IBC calllbacks
if err := app.IcacallbacksKeeper.SetICACallbacks(
app.StakeibcKeeper.Callbacks(),
app.RecordsKeeper.Callbacks(),
); err != nil {
return nil
}

Expand All @@ -605,20 +598,23 @@ func NewStrideApp(
app.MsgServiceRouter(),
)
icaModule := ica.NewAppModule(&app.ICAControllerKeeper, &app.ICAHostKeeper)

// Create the middleware stacks
// Stack one (ICAHost Stack) contains:
// - IBC
// - ICAHost
// - base app
icaHostIBCModule := icahost.NewIBCModule(app.ICAHostKeeper)

// Stack two (Stakeibc Stack) contains
// Stack two (ICACallbacks Stack) contains
// - IBC
// - ICA
// - stakeibc
// - ICACallbacks
// - base app
var stakeibcStack porttypes.IBCModule = stakeibcIBCModule
stakeibcStack = icacontroller.NewIBCMiddleware(stakeibcStack, app.ICAControllerKeeper)
var icacallbacksStack porttypes.IBCModule = icacallbacksIBCModule
icacallbacksStack = stakeibcmodule.NewIBCMiddleware(icacallbacksStack, app.StakeibcKeeper)
icacallbacksStack = icacontroller.NewIBCMiddleware(icacallbacksStack, app.ICAControllerKeeper)

// Stack three contains
// - IBC
Expand All @@ -633,20 +629,12 @@ func NewStrideApp(
transferStack = autopilot.NewIBCModule(app.AutopilotKeeper, transferStack)

// Create static IBC router, add transfer route, then set and seal it
// Two routes are included for the ICAController because of the following procedure when registering an ICA
// 1. RegisterInterchainAccount binds the new portId to the icacontroller module and initiates a channel opening
// 2. MsgChanOpenInit is invoked from the IBC message server. The message server identifies that the
// icacontroller module owns the portID and routes to the stakeibc stack (the "icacontroller" route below)
// 3. The stakeibc stack works top-down, first in the ICAController's OnChanOpenInit, and then in stakeibc's OnChanOpenInit
// 4. In stakeibc's OnChanOpenInit, the stakeibc module steals the portId from the icacontroller module
// 5. Now in OnChanOpenAck and any other subsequent IBC callback, the message server will identify
// the portID owner as stakeibc and route to the same stakeibcStack, this time using the "stakeibc" route instead
ibcRouter := porttypes.NewRouter()
ibcRouter.
// ICAHost Stack
AddRoute(icahosttypes.SubModuleName, icaHostIBCModule).
// Stakeibc Stack
AddRoute(icacontrollertypes.SubModuleName, stakeibcStack).
// ICACallbacks Stack
AddRoute(icacontrollertypes.SubModuleName, icacallbacksStack).
// Transfer stack
AddRoute(ibctransfertypes.ModuleName, transferStack)

Expand Down
4 changes: 2 additions & 2 deletions app/upgrades/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ package {upgradeVersion}

import (
sdk "github.com/cosmos/cosmos-sdk/types"
{new-consensus-version} "github.com/Stride-Labs/stride/v9/x/records/migrations/{new-consensus-version}"
{new-consensus-version} "github.com/Stride-Labs/stride/v11/x/records/migrations/{new-consensus-version}"
)

// TODO: Add migration logic to deserialize with old protos and re-serialize with new ones
Expand All @@ -98,7 +98,7 @@ func MigrateStore(ctx sdk.Context) error {
// app/upgrades/{upgradeVersion}/upgrades.go

import (
{module}migration "github.com/Stride-Labs/stride/v9/x/{module}/migrations/{new-consensus-version}"
{module}migration "github.com/Stride-Labs/stride/v11/x/{module}/migrations/{new-consensus-version}"
)

// CreateUpgradeHandler creates an SDK upgrade handler for {upgradeVersion}
Expand Down
173 changes: 173 additions & 0 deletions x/icacallbacks/ibc_module.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package icacallbacks

import (
"fmt"

errorsmod "cosmossdk.io/errors"
sdk "github.com/cosmos/cosmos-sdk/types"
capabilitytypes "github.com/cosmos/cosmos-sdk/x/capability/types"
channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types"
porttypes "github.com/cosmos/ibc-go/v7/modules/core/05-port/types"
ibcexported "github.com/cosmos/ibc-go/v7/modules/core/exported"

"github.com/Stride-Labs/stride/v11/x/icacallbacks/keeper"
"github.com/Stride-Labs/stride/v11/x/icacallbacks/types"
)

var _ porttypes.IBCModule = &IBCModule{}

type IBCModule struct {
keeper keeper.Keeper
}

func NewIBCModule(k keeper.Keeper) IBCModule {
return IBCModule{
keeper: k,
}
}

// No custom logic is necessary in OnChanOpenInit
func (im IBCModule) OnChanOpenInit(
ctx sdk.Context,
order channeltypes.Order,
connectionHops []string,
portID string,
channelID string,
channelCap *capabilitytypes.Capability,
counterparty channeltypes.Counterparty,
version string,
) (string, error) {
return version, nil
}

// OnChanOpenTry should not be executed in the ICA stack
func (im IBCModule) OnChanOpenTry(
ctx sdk.Context,
order channeltypes.Order,
connectionHops []string,
portID,
channelID string,
chanCap *capabilitytypes.Capability,
counterparty channeltypes.Counterparty,
counterpartyVersion string,
) (string, error) {
panic("UNIMPLEMENTED")
}

// No custom logic is necessary in OnChanOpenAck
func (im IBCModule) OnChanOpenAck(
ctx sdk.Context,
portID,
channelID string,
counterpartyChannelID string,
counterpartyVersion string,
) error {
return nil
}

// OnChanOpenConfirm should not be executed in the ICA stack
func (im IBCModule) OnChanOpenConfirm(
ctx sdk.Context,
portID,
channelID string,
) error {
panic("UNIMPLEMENTED")
}

// OnChanCloseInit should not be executed in the ICA stack
func (im IBCModule) OnChanCloseInit(
ctx sdk.Context,
portID,
channelID string,
) error {
panic("UNIMPLEMENTED")
}

// No custom logic is necessary in OnChanCloseConfirm
func (im IBCModule) OnChanCloseConfirm(
ctx sdk.Context,
portID,
channelID string,
) error {
return nil
}

// OnChanOpenAck routes the packet to the relevant callback function
func (im IBCModule) OnAcknowledgementPacket(
ctx sdk.Context,
modulePacket channeltypes.Packet,
acknowledgement []byte,
relayer sdk.AccAddress,
) error {
im.keeper.Logger(ctx).Info(fmt.Sprintf("OnAcknowledgementPacket (ICACallbacks) - packet: %+v, relayer: %v", modulePacket, relayer))

ackResponse, err := UnpackAcknowledgementResponse(ctx, im.keeper.Logger(ctx), acknowledgement, true)
if err != nil {
errMsg := fmt.Sprintf("Unable to unpack message data from acknowledgement, Sequence %d, from %s %s, to %s %s: %s",
modulePacket.Sequence, modulePacket.SourceChannel, modulePacket.SourcePort, modulePacket.DestinationChannel, modulePacket.DestinationPort, err.Error())
im.keeper.Logger(ctx).Error(errMsg)
return errorsmod.Wrapf(types.ErrInvalidAcknowledgement, errMsg)
}

ackInfo := fmt.Sprintf("sequence #%d, from %s %s, to %s %s",
modulePacket.Sequence, modulePacket.SourceChannel, modulePacket.SourcePort, modulePacket.DestinationChannel, modulePacket.DestinationPort)
im.keeper.Logger(ctx).Info(fmt.Sprintf("Acknowledgement was successfully unmarshalled: ackInfo: %s", ackInfo))

eventType := "ack"
ctx.EventManager().EmitEvent(
sdk.NewEvent(
eventType,
sdk.NewAttribute(sdk.AttributeKeyModule, types.ModuleName),
sdk.NewAttribute(types.AttributeKeyAck, ackInfo),
),
)

if err := im.keeper.CallRegisteredICACallback(ctx, modulePacket, ackResponse); err != nil {
errMsg := fmt.Sprintf("Unable to call registered ICACallback from OnAcknowledgePacket | Sequence %d, from %s %s, to %s %s",
modulePacket.Sequence, modulePacket.SourceChannel, modulePacket.SourcePort, modulePacket.DestinationChannel, modulePacket.DestinationPort)
im.keeper.Logger(ctx).Error(errMsg)
return errorsmod.Wrapf(types.ErrCallbackFailed, errMsg)
}
return nil
}

// OnTimeoutPacket routes the timeout to the relevant callback function
func (im IBCModule) OnTimeoutPacket(
ctx sdk.Context,
packet channeltypes.Packet,
relayer sdk.AccAddress,
) error {
im.keeper.Logger(ctx).Info(fmt.Sprintf("OnTimeoutPacket (ICACallbacks): packet %v, relayer %v", packet, relayer))

ackResponse := types.AcknowledgementResponse{
Status: types.AckResponseStatus_TIMEOUT,
}

if err := im.keeper.CallRegisteredICACallback(ctx, packet, &ackResponse); err != nil {
errMsg := fmt.Sprintf("Unable to call registered ICACallback from OnTimeoutPacket, Packet: %+v", packet)
im.keeper.Logger(ctx).Error(errMsg)
return errorsmod.Wrapf(types.ErrCallbackFailed, errMsg)
}
return nil
}

// OnRecvPacket should not be executed in the ICA stack
func (im IBCModule) OnRecvPacket(
ctx sdk.Context,
modulePacket channeltypes.Packet,
relayer sdk.AccAddress,
) ibcexported.Acknowledgement {
panic("UNIMPLEMENTED")
}

// No custom logic required in NegotiateAppVersion
func (im IBCModule) NegotiateAppVersion(
ctx sdk.Context,
order channeltypes.Order,
connectionID string,
portID string,
counterparty channeltypes.Counterparty,
proposedVersion string,
) (version string, err error) {
return proposedVersion, nil
}
85 changes: 23 additions & 62 deletions x/icacallbacks/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type (
storeKey storetypes.StoreKey
memKey storetypes.StoreKey
paramstore paramtypes.Subspace
icacallbacks map[string]types.ICACallbackHandler
icacallbacks map[string]types.ICACallback
IBCKeeper ibckeeper.Keeper
}
)
Expand All @@ -48,7 +48,7 @@ func NewKeeper(
storeKey: storeKey,
memKey: memKey,
paramstore: ps,
icacallbacks: make(map[string]types.ICACallbackHandler),
icacallbacks: make(map[string]types.ICACallback),
IBCKeeper: ibcKeeper,
}
}
Expand All @@ -57,77 +57,38 @@ func (k Keeper) Logger(ctx sdk.Context) log.Logger {
return ctx.Logger().With("module", fmt.Sprintf("x/%s", types.ModuleName))
}

// Should we add a `AddICACallback`
func (k *Keeper) SetICACallbackHandler(module string, handler types.ICACallbackHandler) error {
_, found := k.icacallbacks[module]
if found {
return fmt.Errorf("callback handler already set for %s", module)
func (k Keeper) SetICACallbacks(moduleCallbacks ...types.ModuleCallbacks) error {
for _, callbacks := range moduleCallbacks {
for _, callback := range callbacks {
if _, found := k.icacallbacks[callback.CallbackId]; found {
return fmt.Errorf("callback for ID %s already registered", callback.CallbackId)
}
k.icacallbacks[callback.CallbackId] = callback
}
}
k.icacallbacks[module] = handler.RegisterICACallbacks()
return nil
}

func (k *Keeper) GetICACallbackHandler(module string) (types.ICACallbackHandler, error) {
callback, found := k.icacallbacks[module]
if !found {
return nil, fmt.Errorf("no callback handler found for %s", module)
}
return callback, nil
}

func (k Keeper) GetCallbackDataFromPacket(ctx sdk.Context, modulePacket channeltypes.Packet, callbackDataKey string) (cbd *types.CallbackData, found bool) {
// get the relevant module from the channel and port
portID := modulePacket.GetSourcePort()
channelID := modulePacket.GetSourceChannel()
// fetch the callback data
func (k Keeper) CallRegisteredICACallback(ctx sdk.Context, packet channeltypes.Packet, ackResponse *types.AcknowledgementResponse) error {
// Get the callback key and associated callback data from the packet
callbackDataKey := types.PacketID(packet.GetSourcePort(), packet.GetSourceChannel(), packet.Sequence)
callbackData, found := k.GetCallbackData(ctx, callbackDataKey)
if !found {
k.Logger(ctx).Info(fmt.Sprintf("callback data not found for portID: %s, channelID: %s, sequence: %d", portID, channelID, modulePacket.Sequence))
return nil, false
} else {
k.Logger(ctx).Info(fmt.Sprintf("callback data found for portID: %s, channelID: %s, sequence: %d", portID, channelID, modulePacket.Sequence))
}
return &callbackData, true
}

func (k Keeper) GetICACallbackHandlerFromPacket(ctx sdk.Context, modulePacket channeltypes.Packet) (*types.ICACallbackHandler, error) {
module, _, err := k.IBCKeeper.ChannelKeeper.LookupModuleByChannel(ctx, modulePacket.GetSourcePort(), modulePacket.GetSourceChannel())
if err != nil {
k.Logger(ctx).Error(fmt.Sprintf("error LookupModuleByChannel for portID: %s, channelID: %s, sequence: %d", modulePacket.GetSourcePort(), modulePacket.GetSourceChannel(), modulePacket.Sequence))
return nil, err
}
// fetch the callback function
callbackHandler, err := k.GetICACallbackHandler(module)
if err != nil {
return nil, errorsmod.Wrapf(types.ErrCallbackHandlerNotFound, "Callback handler does not exist for module %s | err: %s", module, err.Error())
k.Logger(ctx).Info(fmt.Sprintf("callback data not found for portID: %s, channelID: %s, sequence: %d",
packet.SourcePort, packet.SourceChannel, packet.Sequence))
return nil
}
return &callbackHandler, nil
}

func (k Keeper) CallRegisteredICACallback(ctx sdk.Context, modulePacket channeltypes.Packet, ackResponse *types.AcknowledgementResponse) error {
callbackDataKey := types.PacketID(modulePacket.GetSourcePort(), modulePacket.GetSourceChannel(), modulePacket.Sequence)
callbackData, found := k.GetCallbackDataFromPacket(ctx, modulePacket, callbackDataKey)
// If there's an associated callback function, execute it
callback, found := k.icacallbacks[callbackData.CallbackId]
if !found {
k.Logger(ctx).Info(fmt.Sprintf("No associated callback with callback data %v", callbackData))
return nil
}
callbackHandler, err := k.GetICACallbackHandlerFromPacket(ctx, modulePacket)
if err != nil {
k.Logger(ctx).Error(fmt.Sprintf("GetICACallbackHandlerFromPacket %s", err.Error()))
return err
}

// call the callback
if (*callbackHandler).HasICACallback(callbackData.CallbackId) {
k.Logger(ctx).Info(fmt.Sprintf("Calling callback for %s", callbackData.CallbackId))
// if acknowledgement is empty, then it is a timeout
err := (*callbackHandler).CallICACallback(ctx, callbackData.CallbackId, modulePacket, ackResponse, callbackData.CallbackArgs)
if err != nil {
errMsg := fmt.Sprintf("Error occured while calling ICACallback (%s) | err: %s", callbackData.CallbackId, err.Error())
k.Logger(ctx).Error(errMsg)
return errorsmod.Wrapf(types.ErrCallbackFailed, errMsg)
}
} else {
k.Logger(ctx).Error(fmt.Sprintf("Callback %v has no associated callback", callbackData))
if err := callback.CallbackFunc(ctx, packet, ackResponse, callbackData.CallbackArgs); err != nil {
errMsg := fmt.Sprintf("Error occured while calling ICACallback (%s) | err: %s", callbackData.CallbackId, err.Error())
k.Logger(ctx).Error(errMsg)
return errorsmod.Wrapf(types.ErrCallbackFailed, errMsg)
}

// remove the callback data
Expand Down
Loading

0 comments on commit 4db410e

Please sign in to comment.