diff --git a/x/ibc/04-channel/keeper/handshake_test.go b/x/ibc/04-channel/keeper/handshake_test.go new file mode 100644 index 000000000000..5a8ce10ceb4a --- /dev/null +++ b/x/ibc/04-channel/keeper/handshake_test.go @@ -0,0 +1,329 @@ +package keeper_test + +import ( + "fmt" + + "github.com/cosmos/cosmos-sdk/store/prefix" + sdk "github.com/cosmos/cosmos-sdk/types" + clienttypestm "github.com/cosmos/cosmos-sdk/x/ibc/02-client/types/tendermint" + connection "github.com/cosmos/cosmos-sdk/x/ibc/03-connection" + "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/types" + commitment "github.com/cosmos/cosmos-sdk/x/ibc/23-commitment" + ibctypes "github.com/cosmos/cosmos-sdk/x/ibc/types" + abci "github.com/tendermint/tendermint/abci/types" +) + +func (suite *KeeperTestSuite) createClient() { + suite.app.Commit() + commitID := suite.app.LastCommitID() + + suite.app.BeginBlock(abci.RequestBeginBlock{Header: abci.Header{Height: suite.app.LastBlockHeight() + 1}}) + suite.ctx = suite.app.BaseApp.NewContext(false, abci.Header{}) + + consensusState := clienttypestm.ConsensusState{ + ChainID: testChainID, + Height: uint64(commitID.Version), + Root: commitment.NewRoot(commitID.Hash), + } + + _, err := suite.app.IBCKeeper.ClientKeeper.CreateClient(suite.ctx, testClient, testClientType, consensusState) + suite.NoError(err) +} + +func (suite *KeeperTestSuite) updateClient() { + // always commit and begin a new block on updateClient + suite.app.Commit() + commitID := suite.app.LastCommitID() + + suite.app.BeginBlock(abci.RequestBeginBlock{Header: abci.Header{Height: suite.app.LastBlockHeight() + 1}}) + suite.ctx = suite.app.BaseApp.NewContext(false, abci.Header{}) + + state := clienttypestm.ConsensusState{ + ChainID: testChainID, + Height: uint64(commitID.Version), + Root: commitment.NewRoot(commitID.Hash), + } + + suite.app.IBCKeeper.ClientKeeper.SetConsensusState(suite.ctx, testClient, state) + suite.app.IBCKeeper.ClientKeeper.SetVerifiedRoot(suite.ctx, testClient, state.GetHeight(), state.GetRoot()) +} + +func (suite *KeeperTestSuite) createConnection(state connection.State) { + connection := connection.ConnectionEnd{ + State: state, + ClientID: testClient, + Counterparty: connection.Counterparty{ + ClientID: testClient, + ConnectionID: testConnection, + Prefix: suite.app.IBCKeeper.ConnectionKeeper.GetCommitmentPrefix(), + }, + Versions: connection.GetCompatibleVersions(), + } + + suite.app.IBCKeeper.ConnectionKeeper.SetConnection(suite.ctx, testConnection, connection) +} + +func (suite *KeeperTestSuite) createChannel(portID string, chanID string, connID string, counterpartyPort string, counterpartyChan string, state types.State) { + channel := types.Channel{ + State: state, + Ordering: testChannelOrder, + Counterparty: types.Counterparty{ + PortID: counterpartyPort, + ChannelID: counterpartyChan, + }, + ConnectionHops: []string{connID}, + Version: testChannelVersion, + } + + suite.app.IBCKeeper.ChannelKeeper.SetChannel(suite.ctx, portID, chanID, channel) +} + +func (suite *KeeperTestSuite) deleteChannel(portID string, chanID string) { + store := prefix.NewStore(suite.ctx.KVStore(suite.app.GetKey(ibctypes.StoreKey)), []byte{}) + store.Delete(types.KeyChannel(portID, chanID)) +} + +func (suite *KeeperTestSuite) bindPort(portID string) sdk.CapabilityKey { + return suite.app.IBCKeeper.PortKeeper.BindPort(portID) +} + +func (suite *KeeperTestSuite) queryProof(key string) (proof commitment.Proof, height int64) { + res := suite.app.Query(abci.RequestQuery{ + Path: fmt.Sprintf("store/%s/key", ibctypes.StoreKey), + Data: []byte(key), + Prove: true, + }) + + height = res.Height + proof = commitment.Proof{ + Proof: res.Proof, + } + + return +} + +func (suite *KeeperTestSuite) TestChanOpenInit() { + counterparty := types.NewCounterparty(testPort2, testChannel2) + + suite.createChannel(testPort1, testChannel1, testConnection, testPort2, testChannel2, types.INIT) + err := suite.app.IBCKeeper.ChannelKeeper.ChanOpenInit(suite.ctx, testChannelOrder, []string{testConnection}, testPort1, testChannel1, counterparty, testChannelVersion) + suite.NotNil(err) // channel has already exist + + suite.deleteChannel(testPort1, testChannel1) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenInit(suite.ctx, testChannelOrder, []string{testConnection}, testPort1, testChannel1, counterparty, testChannelVersion) + suite.NotNil(err) // connection does not exist + + suite.createConnection(connection.NONE) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenInit(suite.ctx, testChannelOrder, []string{testConnection}, testPort1, testChannel1, counterparty, testChannelVersion) + suite.NotNil(err) // invalid connection state + + suite.createConnection(connection.OPEN) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenInit(suite.ctx, testChannelOrder, []string{testConnection}, testPort1, testChannel1, counterparty, testChannelVersion) + suite.Nil(err) // successfully executed + + channel, found := suite.app.IBCKeeper.ChannelKeeper.GetChannel(suite.ctx, testPort1, testChannel1) + suite.True(found) + suite.Equal(types.INIT, channel.State) +} + +func (suite *KeeperTestSuite) TestChanOpenTry() { + counterparty := types.NewCounterparty(testPort1, testChannel1) + suite.bindPort(testPort2) + channelKey := types.ChannelPath(testPort1, testChannel1) + + suite.createChannel(testPort1, testChannel1, testConnection, testPort2, testChannel2, types.INIT) + suite.createChannel(testPort2, testChannel2, testConnection, testPort1, testChannel1, types.INIT) + suite.updateClient() + proofInit, proofHeight := suite.queryProof(channelKey) + err := suite.app.IBCKeeper.ChannelKeeper.ChanOpenTry(suite.ctx, testChannelOrder, []string{testConnection}, testPort2, testChannel2, counterparty, testChannelVersion, testChannelVersion, proofInit, uint64(proofHeight)) + suite.NotNil(err) // channel has already exist + + suite.deleteChannel(testPort2, testChannel2) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenTry(suite.ctx, testChannelOrder, []string{testConnection}, testPort1, testChannel2, counterparty, testChannelVersion, testChannelVersion, proofInit, uint64(proofHeight)) + suite.NotNil(err) // unauthenticated port + + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenTry(suite.ctx, testChannelOrder, []string{testConnection}, testPort2, testChannel2, counterparty, testChannelVersion, testChannelVersion, proofInit, uint64(proofHeight)) + suite.NotNil(err) // connection does not exist + + suite.createConnection(connection.NONE) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenTry(suite.ctx, testChannelOrder, []string{testConnection}, testPort2, testChannel2, counterparty, testChannelVersion, testChannelVersion, proofInit, uint64(proofHeight)) + suite.NotNil(err) // invalid connection state + + suite.createConnection(connection.OPEN) + suite.createChannel(testPort1, testChannel1, testConnection, testPort2, testChannel2, types.OPENTRY) + suite.updateClient() + proofInit, proofHeight = suite.queryProof(channelKey) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenTry(suite.ctx, testChannelOrder, []string{testConnection}, testPort2, testChannel2, counterparty, testChannelVersion, testChannelVersion, proofInit, uint64(proofHeight)) + suite.NotNil(err) // channel membership verification failed due to invalid counterparty + + suite.createChannel(testPort1, testChannel1, testConnection, testPort2, testChannel2, types.INIT) + suite.updateClient() + proofInit, proofHeight = suite.queryProof(channelKey) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenTry(suite.ctx, testChannelOrder, []string{testConnection}, testPort2, testChannel2, counterparty, testChannelVersion, testChannelVersion, proofInit, uint64(proofHeight)) + suite.Nil(err) // successfully executed + + channel, found := suite.app.IBCKeeper.ChannelKeeper.GetChannel(suite.ctx, testPort2, testChannel2) + suite.True(found) + suite.Equal(types.OPENTRY, channel.State) +} + +func (suite *KeeperTestSuite) TestChanOpenAck() { + suite.bindPort(testPort1) + channelKey := types.ChannelPath(testPort2, testChannel2) + + suite.createChannel(testPort2, testChannel2, testConnection, testPort1, testChannel1, types.OPENTRY) + suite.updateClient() + proofTry, proofHeight := suite.queryProof(channelKey) + err := suite.app.IBCKeeper.ChannelKeeper.ChanOpenAck(suite.ctx, testPort1, testChannel1, testChannelVersion, proofTry, uint64(proofHeight)) + suite.NotNil(err) // channel does not exist + + suite.createChannel(testPort1, testChannel1, testConnection, testPort2, testChannel2, types.CLOSED) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenAck(suite.ctx, testPort1, testChannel1, testChannelVersion, proofTry, uint64(proofHeight)) + suite.NotNil(err) // invalid channel state + + suite.createChannel(testPort2, testChannel1, testConnection, testPort1, testChannel2, types.INIT) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenAck(suite.ctx, testPort2, testChannel1, testChannelVersion, proofTry, uint64(proofHeight)) + suite.NotNil(err) // unauthenticated port + + suite.createChannel(testPort1, testChannel1, testConnection, testPort2, testChannel2, types.INIT) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenAck(suite.ctx, testPort1, testChannel1, testChannelVersion, proofTry, uint64(proofHeight)) + suite.NotNil(err) // connection does not exist + + suite.createConnection(connection.NONE) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenAck(suite.ctx, testPort1, testChannel1, testChannelVersion, proofTry, uint64(proofHeight)) + suite.NotNil(err) // invalid connection state + + suite.createConnection(connection.OPEN) + suite.createChannel(testPort2, testChannel2, testConnection, testPort1, testChannel1, types.OPEN) + suite.updateClient() + proofTry, proofHeight = suite.queryProof(channelKey) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenAck(suite.ctx, testPort1, testChannel1, testChannelVersion, proofTry, uint64(proofHeight)) + suite.NotNil(err) // channel membership verification failed due to invalid counterparty + + suite.createChannel(testPort2, testChannel2, testConnection, testPort1, testChannel1, types.OPENTRY) + suite.updateClient() + proofTry, proofHeight = suite.queryProof(channelKey) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenAck(suite.ctx, testPort1, testChannel1, testChannelVersion, proofTry, uint64(proofHeight)) + suite.Nil(err) // successfully executed + + channel, found := suite.app.IBCKeeper.ChannelKeeper.GetChannel(suite.ctx, testPort1, testChannel1) + suite.True(found) + suite.Equal(types.OPEN, channel.State) +} + +func (suite *KeeperTestSuite) TestChanOpenConfirm() { + suite.bindPort(testPort2) + channelKey := types.ChannelPath(testPort1, testChannel1) + + suite.createChannel(testPort1, testChannel1, testConnection, testPort2, testChannel2, types.OPEN) + suite.updateClient() + proofAck, proofHeight := suite.queryProof(channelKey) + err := suite.app.IBCKeeper.ChannelKeeper.ChanOpenConfirm(suite.ctx, testPort2, testChannel2, proofAck, uint64(proofHeight)) + suite.NotNil(err) // channel does not exist + + suite.createChannel(testPort2, testChannel2, testConnection, testPort1, testChannel1, types.OPEN) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenConfirm(suite.ctx, testPort2, testChannel2, proofAck, uint64(proofHeight)) + suite.NotNil(err) // invalid channel state + + suite.createChannel(testPort1, testChannel2, testConnection, testPort2, testChannel1, types.OPENTRY) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenConfirm(suite.ctx, testPort1, testChannel2, proofAck, uint64(proofHeight)) + suite.NotNil(err) // unauthenticated port + + suite.createChannel(testPort2, testChannel2, testConnection, testPort1, testChannel1, types.OPENTRY) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenConfirm(suite.ctx, testPort2, testChannel2, proofAck, uint64(proofHeight)) + suite.NotNil(err) // connection does not exist + + suite.createConnection(connection.NONE) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenConfirm(suite.ctx, testPort2, testChannel2, proofAck, uint64(proofHeight)) + suite.NotNil(err) // invalid connection state + + suite.createConnection(connection.OPEN) + suite.createChannel(testPort1, testChannel1, testConnection, testPort2, testChannel2, types.OPENTRY) + suite.updateClient() + proofAck, proofHeight = suite.queryProof(channelKey) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenConfirm(suite.ctx, testPort2, testChannel2, proofAck, uint64(proofHeight)) + suite.NotNil(err) // channel membership verification failed due to invalid counterparty + + suite.createChannel(testPort1, testChannel1, testConnection, testPort2, testChannel2, types.OPEN) + suite.updateClient() + proofAck, proofHeight = suite.queryProof(channelKey) + err = suite.app.IBCKeeper.ChannelKeeper.ChanOpenConfirm(suite.ctx, testPort2, testChannel2, proofAck, uint64(proofHeight)) + suite.Nil(err) // successfully executed + + channel, found := suite.app.IBCKeeper.ChannelKeeper.GetChannel(suite.ctx, testPort2, testChannel2) + suite.True(found) + suite.Equal(types.OPEN, channel.State) +} + +func (suite *KeeperTestSuite) TestChanCloseInit() { + suite.bindPort(testPort1) + + err := suite.app.IBCKeeper.ChannelKeeper.ChanCloseInit(suite.ctx, testPort2, testChannel1) + suite.NotNil(err) // authenticated port + + err = suite.app.IBCKeeper.ChannelKeeper.ChanCloseInit(suite.ctx, testPort1, testChannel1) + suite.NotNil(err) // channel does not exist + + suite.createChannel(testPort1, testChannel1, testConnection, testPort2, testChannel2, types.CLOSED) + err = suite.app.IBCKeeper.ChannelKeeper.ChanCloseInit(suite.ctx, testPort1, testChannel1) + suite.NotNil(err) // channel is already closed + + suite.createChannel(testPort1, testChannel1, testConnection, testPort2, testChannel2, types.OPEN) + err = suite.app.IBCKeeper.ChannelKeeper.ChanCloseInit(suite.ctx, testPort1, testChannel1) + suite.NotNil(err) // connection does not exist + + suite.createConnection(connection.TRYOPEN) + err = suite.app.IBCKeeper.ChannelKeeper.ChanCloseInit(suite.ctx, testPort1, testChannel1) + suite.NotNil(err) // invalid connection state + + suite.createConnection(connection.OPEN) + err = suite.app.IBCKeeper.ChannelKeeper.ChanCloseInit(suite.ctx, testPort1, testChannel1) + suite.Nil(err) // successfully executed + + channel, found := suite.app.IBCKeeper.ChannelKeeper.GetChannel(suite.ctx, testPort1, testChannel1) + suite.True(found) + suite.Equal(types.CLOSED, channel.State) +} + +func (suite *KeeperTestSuite) TestChanCloseConfirm() { + suite.bindPort(testPort2) + channelKey := types.ChannelPath(testPort1, testChannel1) + + suite.createChannel(testPort1, testChannel1, testConnection, testPort2, testChannel2, types.CLOSED) + suite.updateClient() + proofInit, proofHeight := suite.queryProof(channelKey) + err := suite.app.IBCKeeper.ChannelKeeper.ChanCloseConfirm(suite.ctx, testPort1, testChannel2, proofInit, uint64(proofHeight)) + suite.NotNil(err) // unauthenticated port + + err = suite.app.IBCKeeper.ChannelKeeper.ChanCloseConfirm(suite.ctx, testPort2, testChannel2, proofInit, uint64(proofHeight)) + suite.NotNil(err) // channel does not exist + + suite.createChannel(testPort2, testChannel2, testConnection, testPort1, testChannel1, types.CLOSED) + err = suite.app.IBCKeeper.ChannelKeeper.ChanCloseConfirm(suite.ctx, testPort2, testChannel2, proofInit, uint64(proofHeight)) + suite.NotNil(err) // channel is already closed + + suite.createChannel(testPort2, testChannel2, testConnection, testPort1, testChannel1, types.OPEN) + err = suite.app.IBCKeeper.ChannelKeeper.ChanCloseConfirm(suite.ctx, testPort2, testChannel2, proofInit, uint64(proofHeight)) + suite.NotNil(err) // connection does not exist + + suite.createConnection(connection.TRYOPEN) + err = suite.app.IBCKeeper.ChannelKeeper.ChanCloseConfirm(suite.ctx, testPort2, testChannel2, proofInit, uint64(proofHeight)) + suite.NotNil(err) // invalid connection state + + suite.createConnection(connection.OPEN) + suite.createChannel(testPort1, testChannel1, testConnection, testPort2, testChannel2, types.OPEN) + suite.updateClient() + proofInit, proofHeight = suite.queryProof(channelKey) + err = suite.app.IBCKeeper.ChannelKeeper.ChanCloseConfirm(suite.ctx, testPort2, testChannel2, proofInit, uint64(proofHeight)) + suite.NotNil(err) // channel membership verification failed due to invalid counterparty + + suite.createChannel(testPort1, testChannel1, testConnection, testPort2, testChannel2, types.CLOSED) + suite.updateClient() + proofInit, proofHeight = suite.queryProof(channelKey) + err = suite.app.IBCKeeper.ChannelKeeper.ChanCloseConfirm(suite.ctx, testPort2, testChannel2, proofInit, uint64(proofHeight)) + suite.Nil(err) // successfully executed + + channel, found := suite.app.IBCKeeper.ChannelKeeper.GetChannel(suite.ctx, testPort2, testChannel2) + suite.True(found) + suite.Equal(types.CLOSED, channel.State) +} diff --git a/x/ibc/04-channel/keeper/keeper.go b/x/ibc/04-channel/keeper/keeper.go index 6169c72ffc3a..dc6bbd3fb021 100644 --- a/x/ibc/04-channel/keeper/keeper.go +++ b/x/ibc/04-channel/keeper/keeper.go @@ -144,3 +144,13 @@ func (k Keeper) SetPacketAcknowledgement(ctx sdk.Context, portID, channelID stri store := prefix.NewStore(ctx.KVStore(k.storeKey), k.prefix) store.Set(types.KeyPacketAcknowledgement(portID, channelID, sequence), ackHash) } + +// GetPacketAcknowledgement gets the packet ack hash from the store +func (k Keeper) GetPacketAcknowledgement(ctx sdk.Context, portID, channelID string, sequence uint64) ([]byte, bool) { + store := prefix.NewStore(ctx.KVStore(k.storeKey), k.prefix) + bz := store.Get(types.KeyPacketAcknowledgement(portID, channelID, sequence)) + if bz == nil { + return nil, false + } + return bz, true +} diff --git a/x/ibc/04-channel/keeper/keeper_test.go b/x/ibc/04-channel/keeper/keeper_test.go new file mode 100644 index 000000000000..f05cc56db841 --- /dev/null +++ b/x/ibc/04-channel/keeper/keeper_test.go @@ -0,0 +1,132 @@ +package keeper_test + +import ( + "testing" + + "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/simapp" + sdk "github.com/cosmos/cosmos-sdk/types" + clientexported "github.com/cosmos/cosmos-sdk/x/ibc/02-client/exported" + "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/types" + "github.com/stretchr/testify/suite" + abci "github.com/tendermint/tendermint/abci/types" +) + +// define constants used for testing +const ( + testChainID = "test-chain-id" + testClient = "test-client" + testClientType = clientexported.Tendermint + + testConnection = "testconnection" + testPort1 = "firstport" + testPort2 = "secondport" + testChannel1 = "firstchannel" + testChannel2 = "secondchannel" + + testChannelOrder = types.ORDERED + testChannelVersion = "1.0" +) + +type KeeperTestSuite struct { + suite.Suite + + cdc *codec.Codec + ctx sdk.Context + app *simapp.SimApp +} + +func (suite *KeeperTestSuite) SetupTest() { + isCheckTx := false + app := simapp.Setup(isCheckTx) + + suite.cdc = app.Codec() + suite.ctx = app.BaseApp.NewContext(isCheckTx, abci.Header{}) + suite.app = app + + suite.createClient() +} + +func (suite *KeeperTestSuite) TestSetChannel() { + _, found := suite.app.IBCKeeper.ChannelKeeper.GetChannel(suite.ctx, testPort1, testChannel1) + suite.False(found) + + channel := types.Channel{ + State: types.OPEN, + Ordering: testChannelOrder, + Counterparty: types.Counterparty{ + PortID: testPort1, + ChannelID: testChannel1, + }, + ConnectionHops: []string{testConnection}, + Version: testChannelVersion, + } + suite.app.IBCKeeper.ChannelKeeper.SetChannel(suite.ctx, testPort1, testChannel1, channel) + + storedChannel, found := suite.app.IBCKeeper.ChannelKeeper.GetChannel(suite.ctx, testPort1, testChannel1) + suite.True(found) + suite.Equal(channel, storedChannel) +} + +func (suite *KeeperTestSuite) TestSetChannelCapability() { + _, found := suite.app.IBCKeeper.ChannelKeeper.GetChannelCapability(suite.ctx, testPort1, testChannel1) + suite.False(found) + + channelCap := "test-channel-capability" + suite.app.IBCKeeper.ChannelKeeper.SetChannelCapability(suite.ctx, testPort1, testChannel1, channelCap) + + storedChannelCap, found := suite.app.IBCKeeper.ChannelKeeper.GetChannelCapability(suite.ctx, testPort1, testChannel1) + suite.True(found) + suite.Equal(channelCap, storedChannelCap) +} + +func (suite *KeeperTestSuite) TestSetSequence() { + _, found := suite.app.IBCKeeper.ChannelKeeper.GetNextSequenceSend(suite.ctx, testPort1, testChannel1) + suite.False(found) + + _, found = suite.app.IBCKeeper.ChannelKeeper.GetNextSequenceRecv(suite.ctx, testPort1, testChannel1) + suite.False(found) + + nextSeqSend, nextSeqRecv := uint64(10), uint64(10) + suite.app.IBCKeeper.ChannelKeeper.SetNextSequenceSend(suite.ctx, testPort1, testChannel1, nextSeqSend) + suite.app.IBCKeeper.ChannelKeeper.SetNextSequenceRecv(suite.ctx, testPort1, testChannel1, nextSeqRecv) + + storedNextSeqSend, found := suite.app.IBCKeeper.ChannelKeeper.GetNextSequenceSend(suite.ctx, testPort1, testChannel1) + suite.True(found) + suite.Equal(nextSeqSend, storedNextSeqSend) + + storedNextSeqRecv, found := suite.app.IBCKeeper.ChannelKeeper.GetNextSequenceSend(suite.ctx, testPort1, testChannel1) + suite.True(found) + suite.Equal(nextSeqRecv, storedNextSeqRecv) +} + +func (suite *KeeperTestSuite) TestPackageCommitment() { + seq := uint64(10) + storedCommitment := suite.app.IBCKeeper.ChannelKeeper.GetPacketCommitment(suite.ctx, testPort1, testChannel1, seq) + suite.Equal([]byte(nil), storedCommitment) + + commitment := []byte("commitment") + suite.app.IBCKeeper.ChannelKeeper.SetPacketCommitment(suite.ctx, testPort1, testChannel1, seq, commitment) + + storedCommitment = suite.app.IBCKeeper.ChannelKeeper.GetPacketCommitment(suite.ctx, testPort1, testChannel1, seq) + suite.Equal(commitment, storedCommitment) +} + +func (suite *KeeperTestSuite) TestSetPacketAcknowledgement() { + seq := uint64(10) + + storedAckHash, found := suite.app.IBCKeeper.ChannelKeeper.GetPacketAcknowledgement(suite.ctx, testPort1, testChannel1, seq) + suite.False(found) + suite.Nil(storedAckHash) + + ackHash := []byte("ackhash") + suite.app.IBCKeeper.ChannelKeeper.SetPacketAcknowledgement(suite.ctx, testPort1, testChannel1, seq, ackHash) + + storedAckHash, found = suite.app.IBCKeeper.ChannelKeeper.GetPacketAcknowledgement(suite.ctx, testPort1, testChannel1, seq) + suite.True(found) + suite.Equal(ackHash, storedAckHash) +} + +func TestKeeperTestSuite(t *testing.T) { + suite.Run(t, new(KeeperTestSuite)) +} diff --git a/x/ibc/04-channel/types/msgs_test.go b/x/ibc/04-channel/types/msgs_test.go new file mode 100644 index 000000000000..c57a35af74ad --- /dev/null +++ b/x/ibc/04-channel/types/msgs_test.go @@ -0,0 +1,304 @@ +package types + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + commitment "github.com/cosmos/cosmos-sdk/x/ibc/23-commitment" + "github.com/stretchr/testify/require" +) + +// define constants used for testing +const ( + invalidPort = "invalidport1" + invalidShortPort = "p" + invalidLongPort = "invalidlongportinvalidlongport" + + invalidChannel = "invalidchannel1" + invalidShortChannel = "invalidch" + invalidLongChannel = "invalidlongchannelinvalidlongchannel" + + invalidConnection = "invalidconnection1" + invalidShortConnection = "invalidcn" + invalidLongConnection = "invalidlongconnection" +) + +// define variables used for testing +var ( + connHops = []string{"testconnection"} + invalidConnHops = []string{"testconnection", "testconnection"} + invalidShortConnHops = []string{invalidShortConnection} + invalidLongConnHops = []string{invalidLongConnection} + + proof = commitment.Proof{} + + addr = sdk.AccAddress("testaddr") +) + +// TestMsgChannelOpenInit tests ValidateBasic for MsgChannelOpenInit +func TestMsgChannelOpenInit(t *testing.T) { + testMsgs := []MsgChannelOpenInit{ + NewMsgChannelOpenInit("testport", "testchannel", "1.0", ORDERED, connHops, "testcpport", "testcpchannel", addr), // valid msg + NewMsgChannelOpenInit(invalidShortPort, "testchannel", "1.0", ORDERED, connHops, "testcpport", "testcpchannel", addr), // too short port id + NewMsgChannelOpenInit(invalidLongPort, "testchannel", "1.0", ORDERED, connHops, "testcpport", "testcpchannel", addr), // too long port id + NewMsgChannelOpenInit(invalidPort, "testchannel", "1.0", ORDERED, connHops, "testcpport", "testcpchannel", addr), // port id contains non-alpha + NewMsgChannelOpenInit("testport", invalidShortChannel, "1.0", ORDERED, connHops, "testcpport", "testcpchannel", addr), // too short channel id + NewMsgChannelOpenInit("testport", invalidLongChannel, "1.0", ORDERED, connHops, "testcpport", "testcpchannel", addr), // too long channel id + NewMsgChannelOpenInit("testport", invalidChannel, "1.0", ORDERED, connHops, "testcpport", "testcpchannel", addr), // channel id contains non-alpha + NewMsgChannelOpenInit("testport", "testchannel", "1.0", Order(3), connHops, "testcpport", "testcpchannel", addr), // invalid channel order + NewMsgChannelOpenInit("testport", "testchannel", "1.0", ORDERED, invalidConnHops, "testcpport", "testcpchannel", addr), // connection hops more than 1 + NewMsgChannelOpenInit("testport", "testchannel", "1.0", UNORDERED, invalidShortConnHops, "testcpport", "testcpchannel", addr), // too short connection id + NewMsgChannelOpenInit("testport", "testchannel", "1.0", UNORDERED, invalidLongConnHops, "testcpport", "testcpchannel", addr), // too long connection id + NewMsgChannelOpenInit("testport", "testchannel", "1.0", UNORDERED, []string{invalidConnection}, "testcpport", "testcpchannel", addr), // connection id contains non-alpha + NewMsgChannelOpenInit("testport", "testchannel", "", UNORDERED, connHops, "testcpport", "testcpchannel", addr), // empty channel version + NewMsgChannelOpenInit("testport", "testchannel", "1.0", UNORDERED, connHops, invalidPort, "testcpchannel", addr), // invalid counterparty port id + NewMsgChannelOpenInit("testport", "testchannel", "1.0", UNORDERED, connHops, "testcpport", invalidChannel, addr), // invalid counterparty channel id + } + + testCases := []struct { + msg MsgChannelOpenInit + expPass bool + errMsg string + }{ + {testMsgs[0], true, ""}, + {testMsgs[1], false, "too short port id"}, + {testMsgs[2], false, "too long port id"}, + {testMsgs[3], false, "port id contains non-alpha"}, + {testMsgs[4], false, "too short channel id"}, + {testMsgs[5], false, "too long channel id"}, + {testMsgs[6], false, "channel id contains non-alpha"}, + {testMsgs[7], false, "invalid channel order"}, + {testMsgs[8], false, "connection hops more than 1 "}, + {testMsgs[9], false, "too short connection id"}, + {testMsgs[10], false, "too long connection id"}, + {testMsgs[11], false, "connection id contains non-alpha"}, + {testMsgs[12], false, "empty channel version"}, + {testMsgs[13], false, "invalid counterparty port id"}, + {testMsgs[14], false, "invalid counterparty channel id"}, + } + + for i, tc := range testCases { + err := tc.msg.ValidateBasic() + if tc.expPass { + require.Nil(t, err, "Msg %d failed: %v", i, err) + } else { + require.NotNil(t, err, "Invalid Msg %d passed: %s", i, tc.errMsg) + } + } +} + +// TestMsgChannelOpenTry tests ValidateBasic for MsgChannelOpenTry +func TestMsgChannelOpenTry(t *testing.T) { + testMsgs := []MsgChannelOpenTry{ + NewMsgChannelOpenTry("testport", "testchannel", "1.0", ORDERED, connHops, "testcpport", "testcpchannel", "1.0", proof, 1, addr), // valid msg + NewMsgChannelOpenTry(invalidShortPort, "testchannel", "1.0", ORDERED, connHops, "testcpport", "testcpchannel", "1.0", proof, 1, addr), // too short port id + NewMsgChannelOpenTry(invalidLongPort, "testchannel", "1.0", ORDERED, connHops, "testcpport", "testcpchannel", "1.0", proof, 1, addr), // too long port id + NewMsgChannelOpenTry(invalidPort, "testchannel", "1.0", ORDERED, connHops, "testcpport", "testcpchannel", "1.0", proof, 1, addr), // port id contains non-alpha + NewMsgChannelOpenTry("testport", invalidShortChannel, "1.0", ORDERED, connHops, "testcpport", "testcpchannel", "1.0", proof, 1, addr), // too short channel id + NewMsgChannelOpenTry("testport", invalidLongChannel, "1.0", ORDERED, connHops, "testcpport", "testcpchannel", "1.0", proof, 1, addr), // too long channel id + NewMsgChannelOpenTry("testport", invalidChannel, "1.0", ORDERED, connHops, "testcpport", "testcpchannel", "1.0", proof, 1, addr), // channel id contains non-alpha + NewMsgChannelOpenTry("testport", "testchannel", "1.0", ORDERED, connHops, "testcpport", "testcpchannel", "", proof, 1, addr), // empty counterparty version + NewMsgChannelOpenTry("testport", "testchannel", "1.0", ORDERED, connHops, "testcpport", "testcpchannel", "1.0", nil, 1, addr), // empty proof + NewMsgChannelOpenTry("testport", "testchannel", "1.0", ORDERED, connHops, "testcpport", "testcpchannel", "1.0", proof, 0, addr), // proof height is zero + NewMsgChannelOpenTry("testport", "testchannel", "1.0", Order(4), connHops, "testcpport", "testcpchannel", "1.0", proof, 1, addr), // invalid channel order + NewMsgChannelOpenTry("testport", "testchannel", "1.0", UNORDERED, invalidConnHops, "testcpport", "testcpchannel", "1.0", proof, 1, addr), // connection hops more than 1 + NewMsgChannelOpenTry("testport", "testchannel", "1.0", UNORDERED, invalidShortConnHops, "testcpport", "testcpchannel", "1.0", proof, 1, addr), // too short connection id + NewMsgChannelOpenTry("testport", "testchannel", "1.0", UNORDERED, invalidLongConnHops, "testcpport", "testcpchannel", "1.0", proof, 1, addr), // too long connection id + NewMsgChannelOpenTry("testport", "testchannel", "1.0", UNORDERED, []string{invalidConnection}, "testcpport", "testcpchannel", "1.0", proof, 1, addr), // connection id contains non-alpha + NewMsgChannelOpenTry("testport", "testchannel", "", UNORDERED, connHops, "testcpport", "testcpchannel", "1.0", proof, 1, addr), // empty channel version + NewMsgChannelOpenTry("testport", "testchannel", "1.0", UNORDERED, connHops, invalidPort, "testcpchannel", "1.0", proof, 1, addr), // invalid counterparty port id + NewMsgChannelOpenTry("testport", "testchannel", "1.0", UNORDERED, connHops, "testcpport", invalidChannel, "1.0", proof, 1, addr), // invalid counterparty channel id + } + + testCases := []struct { + msg MsgChannelOpenTry + expPass bool + errMsg string + }{ + {testMsgs[0], true, ""}, + {testMsgs[1], false, "too short port id"}, + {testMsgs[2], false, "too long port id"}, + {testMsgs[3], false, "port id contains non-alpha"}, + {testMsgs[4], false, "too short channel id"}, + {testMsgs[5], false, "too long channel id"}, + {testMsgs[6], false, "channel id contains non-alpha"}, + {testMsgs[7], false, "empty counterparty version"}, + {testMsgs[8], false, "empty proof"}, + {testMsgs[9], false, "proof height is zero"}, + {testMsgs[10], false, "invalid channel order"}, + {testMsgs[11], false, "connection hops more than 1 "}, + {testMsgs[12], false, "too short connection id"}, + {testMsgs[13], false, "too long connection id"}, + {testMsgs[14], false, "connection id contains non-alpha"}, + {testMsgs[15], false, "empty channel version"}, + {testMsgs[16], false, "invalid counterparty port id"}, + {testMsgs[17], false, "invalid counterparty channel id"}, + } + + for i, tc := range testCases { + err := tc.msg.ValidateBasic() + if tc.expPass { + require.Nil(t, err, "Msg %d failed: %v", i, err) + } else { + require.NotNil(t, err, "Invalid Msg %d passed: %s", i, tc.errMsg) + } + } +} + +// TestMsgChannelOpenAck tests ValidateBasic for MsgChannelOpenAck +func TestMsgChannelOpenAck(t *testing.T) { + testMsgs := []MsgChannelOpenAck{ + NewMsgChannelOpenAck("testport", "testchannel", "1.0", proof, 1, addr), // valid msg + NewMsgChannelOpenAck(invalidShortPort, "testchannel", "1.0", proof, 1, addr), // too short port id + NewMsgChannelOpenAck(invalidLongPort, "testchannel", "1.0", proof, 1, addr), // too long port id + NewMsgChannelOpenAck(invalidPort, "testchannel", "1.0", proof, 1, addr), // port id contains non-alpha + NewMsgChannelOpenAck("testport", invalidShortChannel, "1.0", proof, 1, addr), // too short channel id + NewMsgChannelOpenAck("testport", invalidLongChannel, "1.0", proof, 1, addr), // too long channel id + NewMsgChannelOpenAck("testport", invalidChannel, "1.0", proof, 1, addr), // channel id contains non-alpha + NewMsgChannelOpenAck("testport", "testchannel", "", proof, 1, addr), // empty counterparty version + NewMsgChannelOpenAck("testport", "testchannel", "1.0", nil, 1, addr), // empty proof + NewMsgChannelOpenAck("testport", "testchannel", "1.0", proof, 0, addr), // proof height is zero + } + + testCases := []struct { + msg MsgChannelOpenAck + expPass bool + errMsg string + }{ + {testMsgs[0], true, ""}, + {testMsgs[1], false, "too short port id"}, + {testMsgs[2], false, "too long port id"}, + {testMsgs[3], false, "port id contains non-alpha"}, + {testMsgs[4], false, "too short channel id"}, + {testMsgs[5], false, "too long channel id"}, + {testMsgs[6], false, "channel id contains non-alpha"}, + {testMsgs[7], false, "empty counterparty version"}, + {testMsgs[8], false, "empty proof"}, + {testMsgs[9], false, "proof height is zero"}, + } + + for i, tc := range testCases { + err := tc.msg.ValidateBasic() + if tc.expPass { + require.Nil(t, err, "Msg %d failed: %v", i, err) + } else { + require.NotNil(t, err, "Invalid Msg %d passed: %s", i, tc.errMsg) + } + } +} + +// TestMsgChannelOpenConfirm tests ValidateBasic for MsgChannelOpenConfirm +func TestMsgChannelOpenConfirm(t *testing.T) { + testMsgs := []MsgChannelOpenConfirm{ + NewMsgChannelOpenConfirm("testport", "testchannel", proof, 1, addr), // valid msg + NewMsgChannelOpenConfirm(invalidShortPort, "testchannel", proof, 1, addr), // too short port id + NewMsgChannelOpenConfirm(invalidLongPort, "testchannel", proof, 1, addr), // too long port id + NewMsgChannelOpenConfirm(invalidPort, "testchannel", proof, 1, addr), // port id contains non-alpha + NewMsgChannelOpenConfirm("testport", invalidShortChannel, proof, 1, addr), // too short channel id + NewMsgChannelOpenConfirm("testport", invalidLongChannel, proof, 1, addr), // too long channel id + NewMsgChannelOpenConfirm("testport", invalidChannel, proof, 1, addr), // channel id contains non-alpha + NewMsgChannelOpenConfirm("testport", "testchannel", nil, 1, addr), // empty proof + NewMsgChannelOpenConfirm("testport", "testchannel", proof, 0, addr), // proof height is zero + } + + testCases := []struct { + msg MsgChannelOpenConfirm + expPass bool + errMsg string + }{ + {testMsgs[0], true, ""}, + {testMsgs[1], false, "too short port id"}, + {testMsgs[2], false, "too long port id"}, + {testMsgs[3], false, "port id contains non-alpha"}, + {testMsgs[4], false, "too short channel id"}, + {testMsgs[5], false, "too long channel id"}, + {testMsgs[6], false, "channel id contains non-alpha"}, + {testMsgs[7], false, "empty proof"}, + {testMsgs[8], false, "proof height is zero"}, + } + + for i, tc := range testCases { + err := tc.msg.ValidateBasic() + if tc.expPass { + require.Nil(t, err, "Msg %d failed: %v", i, err) + } else { + require.NotNil(t, err, "Invalid Msg %d passed: %s", i, tc.errMsg) + } + } +} + +// TestMsgChannelCloseInit tests ValidateBasic for MsgChannelCloseInit +func TestMsgChannelCloseInit(t *testing.T) { + testMsgs := []MsgChannelCloseInit{ + NewMsgChannelCloseInit("testport", "testchannel", addr), // valid msg + NewMsgChannelCloseInit(invalidShortPort, "testchannel", addr), // too short port id + NewMsgChannelCloseInit(invalidLongPort, "testchannel", addr), // too long port id + NewMsgChannelCloseInit(invalidPort, "testchannel", addr), // port id contains non-alpha + NewMsgChannelCloseInit("testport", invalidShortChannel, addr), // too short channel id + NewMsgChannelCloseInit("testport", invalidLongChannel, addr), // too long channel id + NewMsgChannelCloseInit("testport", invalidChannel, addr), // channel id contains non-alpha + } + + testCases := []struct { + msg MsgChannelCloseInit + expPass bool + errMsg string + }{ + {testMsgs[0], true, ""}, + {testMsgs[1], false, "too short port id"}, + {testMsgs[2], false, "too long port id"}, + {testMsgs[3], false, "port id contains non-alpha"}, + {testMsgs[4], false, "too short channel id"}, + {testMsgs[5], false, "too long channel id"}, + {testMsgs[6], false, "channel id contains non-alpha"}, + } + + for i, tc := range testCases { + err := tc.msg.ValidateBasic() + if tc.expPass { + require.Nil(t, err, "Msg %d failed: %v", i, err) + } else { + require.NotNil(t, err, "Invalid Msg %d passed: %s", i, tc.errMsg) + } + } +} + +// TestMsgChannelCloseConfirm tests ValidateBasic for MsgChannelCloseConfirm +func TestMsgChannelCloseConfirm(t *testing.T) { + testMsgs := []MsgChannelCloseConfirm{ + NewMsgChannelCloseConfirm("testport", "testchannel", proof, 1, addr), // valid msg + NewMsgChannelCloseConfirm(invalidShortPort, "testchannel", proof, 1, addr), // too short port id + NewMsgChannelCloseConfirm(invalidLongPort, "testchannel", proof, 1, addr), // too long port id + NewMsgChannelCloseConfirm(invalidPort, "testchannel", proof, 1, addr), // port id contains non-alpha + NewMsgChannelCloseConfirm("testport", invalidShortChannel, proof, 1, addr), // too short channel id + NewMsgChannelCloseConfirm("testport", invalidLongChannel, proof, 1, addr), // too long channel id + NewMsgChannelCloseConfirm("testport", invalidChannel, proof, 1, addr), // channel id contains non-alpha + NewMsgChannelCloseConfirm("testport", "testchannel", nil, 1, addr), // empty proof + NewMsgChannelCloseConfirm("testport", "testchannel", proof, 0, addr), // proof height is zero + } + + testCases := []struct { + msg MsgChannelCloseConfirm + expPass bool + errMsg string + }{ + {testMsgs[0], true, ""}, + {testMsgs[1], false, "too short port id"}, + {testMsgs[2], false, "too long port id"}, + {testMsgs[3], false, "port id contains non-alpha"}, + {testMsgs[4], false, "too short channel id"}, + {testMsgs[5], false, "too long channel id"}, + {testMsgs[6], false, "channel id contains non-alpha"}, + {testMsgs[7], false, "empty proof"}, + {testMsgs[8], false, "proof height is zero"}, + } + + for i, tc := range testCases { + err := tc.msg.ValidateBasic() + if tc.expPass { + require.Nil(t, err, "Msg %d failed: %v", i, err) + } else { + require.NotNil(t, err, "Invalid Msg %d passed: %s", i, tc.errMsg) + } + } +} diff --git a/x/ibc/05-port/keeper/keeper.go b/x/ibc/05-port/keeper/keeper.go index 5a7a9fb05bae..cdb48010237f 100644 --- a/x/ibc/05-port/keeper/keeper.go +++ b/x/ibc/05-port/keeper/keeper.go @@ -15,7 +15,7 @@ type Keeper struct { cdc *codec.Codec codespace sdk.CodespaceType prefix []byte // prefix bytes for accessing the store - ports map[sdk.CapabilityKey]string + ports map[string]string bound []string } @@ -27,7 +27,7 @@ func NewKeeper(cdc *codec.Codec, key sdk.StoreKey, codespace sdk.CodespaceType) codespace: sdk.CodespaceType(fmt.Sprintf("%s/%s", codespace, types.DefaultCodespace)), // "ibc/port", prefix: []byte{}, // prefix: []byte(types.SubModuleName + "/"), // "port/" - ports: make(map[sdk.CapabilityKey]string), // map of capabilities to port ids + ports: make(map[string]string), // map of capability key names to port ids } } @@ -39,7 +39,7 @@ func (k Keeper) GetPorts() []string { // GetPort retrieves a given port ID from keeper map func (k Keeper) GetPort(ck sdk.CapabilityKey) (string, bool) { - portID, found := k.ports[ck] + portID, found := k.ports[ck.Name()] return portID, found } @@ -59,7 +59,7 @@ func (k Keeper) BindPort(portID string) sdk.CapabilityKey { } key := sdk.NewKVStoreKey(portID) - k.ports[key] = portID + k.ports[key.Name()] = portID k.bound = append(k.bound, portID) return key } @@ -73,5 +73,5 @@ func (k Keeper) Authenticate(key sdk.CapabilityKey, portID string) bool { panic(err.Error()) } - return k.ports[key] == portID + return k.ports[key.Name()] == portID }