Skip to content

Commit

Permalink
- refund tax on contract execution
Browse files Browse the repository at this point in the history
  • Loading branch information
StrathCole committed Oct 14, 2024
1 parent 80a0d2d commit e322cd9
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 31 deletions.
58 changes: 38 additions & 20 deletions custom/auth/ante/fee.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,22 @@ func (fd FeeDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, nex

msgs := feeTx.GetMsgs()
// Compute taxes
taxes := FilterMsgAndComputeTax(ctx, fd.treasuryKeeper, fd.taxKeeper, simulate, msgs...)
taxes, nonTaxableTaxes := FilterMsgAndComputeTax(ctx, fd.treasuryKeeper, fd.taxKeeper, simulate, msgs...)

// check if the tx has paid fees for both(!) fee and tax
// if not, then set the tax to zero at this point as it then is handled in the message route
reverseCharge := false
refundNonTaxableTax := false

if !simulate {
priority, reverseCharge, err = fd.checkTxFee(ctx, tx, taxes)
priority, reverseCharge, refundNonTaxableTax, err = fd.checkTxFee(ctx, tx, taxes, nonTaxableTaxes)
if err != nil {
return ctx, err
}

if !refundNonTaxableTax {
nonTaxableTaxes = sdk.Coins{}
}
}

if reverseCharge {
Expand All @@ -75,7 +80,7 @@ func (fd FeeDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, nex
taxes = sdk.Coins{}
}

newCtx, err := fd.checkDeductFee(ctx, feeTx, taxes, simulate)
newCtx, err := fd.checkDeductFee(ctx, feeTx, taxes, nonTaxableTaxes, simulate)
if err != nil {
return newCtx, err
}
Expand All @@ -85,7 +90,7 @@ func (fd FeeDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, nex
return next(newCtx, tx, simulate)
}

func (fd FeeDecorator) checkDeductFee(ctx sdk.Context, feeTx sdk.FeeTx, taxes sdk.Coins, simulate bool) (sdk.Context, error) {
func (fd FeeDecorator) checkDeductFee(ctx sdk.Context, feeTx sdk.FeeTx, taxes sdk.Coins, nonTaxableTaxes sdk.Coins, simulate bool) (sdk.Context, error) {
if addr := fd.accountKeeper.GetModuleAddress(types.FeeCollectorName); addr == nil {
return ctx, fmt.Errorf("fee collector module account (%s) has not been set", types.FeeCollectorName)
}
Expand Down Expand Up @@ -159,11 +164,29 @@ func (fd FeeDecorator) checkDeductFee(ctx sdk.Context, feeTx sdk.FeeTx, taxes sd
}
}

events := sdk.Events{
sdk.NewEvent(
sdk.EventTypeTx,
sdk.NewAttribute(sdk.AttributeKeyFee, fee.String()),
sdk.NewAttribute(sdk.AttributeKeyFeePayer, deductFeesFrom.String()),
),
}

if !feesOrTax.IsZero() {
// we will only deduct the fees from the account, not the tax
// the tax will be deducted in the message route for reverse charge
// or in the post handler for normal tax charge
deductFees := feesOrTax.Sub(taxes...) // feesOrTax can never be lower than taxes
if !nonTaxableTaxes.IsZero() {
// if we have non-taxable taxes, we need to subtract them from the fees to be deducted
deductFees = deductFees.Sub(nonTaxableTaxes...)

// add the non-taxable taxes to the events
events = append(events, sdk.NewEvent(
taxtypes.EventTypeTaxRefund,
sdk.NewAttribute(taxtypes.AttributeKeyTaxAmount, nonTaxableTaxes.String()),
))
}

ctx = ctx.WithValue(taxtypes.ContextKeyTaxDue, taxes).WithValue(taxtypes.ContextKeyTaxPayer, deductFeesFrom.String())

Expand All @@ -175,13 +198,6 @@ func (fd FeeDecorator) checkDeductFee(ctx sdk.Context, feeTx sdk.FeeTx, taxes sd
}
}

events := sdk.Events{
sdk.NewEvent(
sdk.EventTypeTx,
sdk.NewAttribute(sdk.AttributeKeyFee, fee.String()),
sdk.NewAttribute(sdk.AttributeKeyFeePayer, deductFeesFrom.String()),
),
}
ctx.EventManager().EmitEvents(events)

return ctx, nil
Expand All @@ -205,10 +221,10 @@ func DeductFees(bankKeeper types.BankKeeper, ctx sdk.Context, acc types.AccountI
// unit of gas is fixed and set by each validator, can the tx priority is computed from the gas price.
// Transaction with only oracle messages will skip gas fee check and will have the most priority.
// It also checks enough fee for treasury tax
func (fd FeeDecorator) checkTxFee(ctx sdk.Context, tx sdk.Tx, taxes sdk.Coins) (int64, bool, error) {
func (fd FeeDecorator) checkTxFee(ctx sdk.Context, tx sdk.Tx, taxes sdk.Coins, nonTaxableTaxes sdk.Coins) (int64, bool, bool, error) {
feeTx, ok := tx.(sdk.FeeTx)
if !ok {
return 0, false, errorsmod.Wrap(sdkerrors.ErrTxDecode, "Tx must be a FeeTx")
return 0, false, false, errorsmod.Wrap(sdkerrors.ErrTxDecode, "Tx must be a FeeTx")
}

feeCoins := feeTx.GetFee()
Expand All @@ -217,6 +233,7 @@ func (fd FeeDecorator) checkTxFee(ctx sdk.Context, tx sdk.Tx, taxes sdk.Coins) (
isOracleTx := isOracleTx(msgs)
minGasPrices := fd.taxKeeper.GetEffectiveGasPrices(ctx)
reverseCharge := false
refundNonTaxableTaxes := false

// Ensure that the provided fees meet a minimum threshold for the validator,
// if this is a CheckTx. This is only for local mempool purposes, and thus
Expand All @@ -236,21 +253,22 @@ func (fd FeeDecorator) checkTxFee(ctx sdk.Context, tx sdk.Tx, taxes sdk.Coins) (
}

requiredFees := requiredGasFees.Add(taxes...)

// fmt.Println("requiredFees", requiredFees, "feeCoins", feeCoins, "requiredGasFees", requiredGasFees, "taxes", taxes, "minGasPrices", minGasPrices)
allFees := requiredFees.Add(nonTaxableTaxes...)

// Check required fees
if !requiredFees.IsZero() && !feeCoins.IsAnyGTE(requiredFees) {
// we don't have enough for tax and gas fees. But do we have enough for gas alone?
if !requiredGasFees.IsZero() && !feeCoins.IsAnyGTE(requiredGasFees) {
return 0, false, errorsmod.Wrapf(sdkerrors.ErrInsufficientFee, "insufficient fees; got: %q, required: %q = %q(gas) + %q(stability)", feeCoins, requiredFees, requiredGasFees, taxes)
return 0, false, false, errorsmod.Wrapf(sdkerrors.ErrInsufficientFee, "insufficient fees; got: %q, required: %q = %q(gas) + %q(stability)", feeCoins, requiredFees, requiredGasFees, taxes)
}

// we have enough for gas fees but not for tax fees
reverseCharge = true
// ctx.Logger().Info("Insufficient fees to pay for gas and taxes (doing reverse charge)", "sentFee", feeCoins, "taxes", taxes, "requiredGasFees", requiredGasFees, "requiredFees", requiredFees)
// } else {
// ctx.Logger().Info("Sufficient fees to pay for gas and taxes (doing normal tax charge)", "sentFee", feeCoins, "taxes", taxes, "requiredGasFees", requiredGasFees, "requiredFees", requiredFees)
}

if !allFees.IsZero() && feeCoins.IsAnyGTE(allFees) {
// we have enough for all fees
refundNonTaxableTaxes = true
}
}

Expand All @@ -260,7 +278,7 @@ func (fd FeeDecorator) checkTxFee(ctx sdk.Context, tx sdk.Tx, taxes sdk.Coins) (
priority = getTxPriority(feeCoins, int64(gas))
}

return priority, reverseCharge, nil
return priority, reverseCharge, refundNonTaxableTaxes, nil
}

// getTxPriority returns a naive tx priority based on the amount of the smallest denomination of the gas price
Expand Down
19 changes: 11 additions & 8 deletions custom/auth/ante/fee_tax.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"regexp"
"strings"

wasmtypes "github.com/CosmWasm/wasmd/x/wasm/types"
sdk "github.com/cosmos/cosmos-sdk/types"
authz "github.com/cosmos/cosmos-sdk/x/authz"
banktypes "github.com/cosmos/cosmos-sdk/x/bank/types"
Expand All @@ -19,8 +20,9 @@ func isIBCDenom(denom string) bool {
}

// FilterMsgAndComputeTax computes the stability tax on messages.
func FilterMsgAndComputeTax(ctx sdk.Context, tk TreasuryKeeper, th TaxKeeper, simulate bool, msgs ...sdk.Msg) sdk.Coins {
func FilterMsgAndComputeTax(ctx sdk.Context, tk TreasuryKeeper, th TaxKeeper, simulate bool, msgs ...sdk.Msg) (sdk.Coins, sdk.Coins) {
taxes := sdk.Coins{}
nonTaxableTaxes := sdk.Coins{}

for _, msg := range msgs {
switch msg := msg.(type) {
Expand Down Expand Up @@ -55,26 +57,27 @@ func FilterMsgAndComputeTax(ctx sdk.Context, tk TreasuryKeeper, th TaxKeeper, si

// The contract messages were disabled to remove double-taxation
// whenever a contract sends funds to a wallet, it is taxed (deducted from sent amount)
/*case *wasmtypes.MsgInstantiateContract:
taxes = taxes.Add(computeTax(ctx, tk, msg.Funds, simulate)...)
case *wasmtypes.MsgInstantiateContract:
nonTaxableTaxes = nonTaxableTaxes.Add(computeTax(ctx, tk, th, msg.Funds, simulate)...)

case *wasmtypes.MsgInstantiateContract2:
taxes = taxes.Add(computeTax(ctx, tk, msg.Funds, simulate)...)
nonTaxableTaxes = nonTaxableTaxes.Add(computeTax(ctx, tk, th, msg.Funds, simulate)...)

case *wasmtypes.MsgExecuteContract:
if !tk.HasBurnTaxExemptionContract(ctx, msg.Contract) {
taxes = taxes.Add(computeTax(ctx, tk, msg.Funds, simulate)...)
nonTaxableTaxes = nonTaxableTaxes.Add(computeTax(ctx, tk, th, msg.Funds, simulate)...)
}
*/
case *authz.MsgExec:
messages, err := msg.GetMessages()
if err == nil {
taxes = taxes.Add(FilterMsgAndComputeTax(ctx, tk, th, simulate, messages...)...)
execTaxes, execNonTaxable := FilterMsgAndComputeTax(ctx, tk, th, simulate, messages...)
taxes = taxes.Add(execTaxes...)
nonTaxableTaxes = nonTaxableTaxes.Add(execNonTaxable...)
}
}
}

return taxes
return taxes, nonTaxableTaxes
}

// computes the stability tax according to tax-rate and tax-cap
Expand Down
3 changes: 2 additions & 1 deletion custom/auth/ante/fee_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -877,13 +877,14 @@ func (s *AnteTestSuite) runBurnSplitTaxTest(burnSplitRate sdk.Dec, oracleSplitRa
newCtx, err := antehandler(s.ctx, tx, false)
require.NoError(err)
_, err = postHandler(newCtx, tx, false, true)
require.NoError(err)

// burn the burn account
tk.BurnCoinsFromBurnAccount(s.ctx)

feeCollectorAfter := bk.GetAllBalances(s.ctx, ak.GetModuleAddress(authtypes.FeeCollectorName))
oracleAfter := bk.GetAllBalances(s.ctx, ak.GetModuleAddress(oracletypes.ModuleName))
taxes := ante.FilterMsgAndComputeTax(s.ctx, tk, th, false, msg)
taxes, _ := ante.FilterMsgAndComputeTax(s.ctx, tk, th, false, msg)
communityPoolAfter, _ := dk.GetFeePoolCommunityCoins(s.ctx).TruncateDecimal()
if communityPoolAfter.IsZero() {
communityPoolAfter = sdk.NewCoins(sdk.NewCoin(core.MicroSDRDenom, sdk.ZeroInt()))
Expand Down
2 changes: 1 addition & 1 deletion custom/auth/tx/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (ts txServer) ComputeTax(c context.Context, req *ComputeTaxRequest) (*Compu
return nil, status.Errorf(codes.InvalidArgument, "empty txBytes is not allowed")
}

taxAmount := customante.FilterMsgAndComputeTax(ctx, ts.treasuryKeeper, ts.taxKeeper, false, msgs...)
taxAmount, _ := customante.FilterMsgAndComputeTax(ctx, ts.treasuryKeeper, ts.taxKeeper, false, msgs...)
return &ComputeTaxResponse{
TaxAmount: taxAmount,
}, nil
Expand Down
1 change: 0 additions & 1 deletion x/tax/post/post.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ func NewTaxDecorator(tk taxkeeper.Keeper, bk bankkeeper.Keeper, ak accountkeeper
}

func (dd TaxDecorator) PostHandle(ctx sdk.Context, tx sdk.Tx, simulate, success bool, next sdk.PostHandler) (sdk.Context, error) {

value := ctx.Value(taxtypes.ContextKeyTaxDue)
dueTax, ok := value.(sdk.Coins)
if !ok {
Expand Down
1 change: 1 addition & 0 deletions x/tax/types/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const (
ContextKeyTaxPayer = "tax.payer"

EventTypeTax = "tax_payment"
EventTypeTaxRefund = "tax_refund"
AttributeKeyReverseCharge = "reverse_charge"
AttributeValueReverseCharge = "true"
AttributeValueNoReverseCharge = "false"
Expand Down

0 comments on commit e322cd9

Please sign in to comment.