diff --git a/x/tax/handlers/bank_msg_server.go b/x/tax/handlers/bank_msg_server.go index 9dd5886b..45e71e33 100644 --- a/x/tax/handlers/bank_msg_server.go +++ b/x/tax/handlers/bank_msg_server.go @@ -38,7 +38,7 @@ func (s *BankMsgServer) Send(ctx context.Context, msg *banktypes.MsgSend) (*bank fromAddr := sdk.MustAccAddressFromBech32(msg.FromAddress) if !s.treasuryKeeper.HasBurnTaxExemptionAddress(sdkCtx, msg.FromAddress, msg.ToAddress) { - netAmount, err := s.taxKeeper.DeductTax(sdkCtx, fromAddr, msg.Amount) + netAmount, err := s.taxKeeper.DeductTax(sdkCtx, fromAddr, msg.Amount, false) if err != nil { return nil, err } @@ -76,7 +76,7 @@ func (s *BankMsgServer) MultiSend(ctx context.Context, msg *banktypes.MsgMultiSe if tainted { for i, input := range msg.Inputs { fromAddr := sdk.MustAccAddressFromBech32(input.Address) - netCoins, err := s.taxKeeper.DeductTax(sdkCtx, fromAddr, input.Coins) + netCoins, err := s.taxKeeper.DeductTax(sdkCtx, fromAddr, input.Coins, false) if err != nil { return nil, err } @@ -85,7 +85,7 @@ func (s *BankMsgServer) MultiSend(ctx context.Context, msg *banktypes.MsgMultiSe for i, output := range msg.Outputs { toAddr := sdk.MustAccAddressFromBech32(output.Address) - netCoins, err := s.taxKeeper.DeductTax(sdkCtx, toAddr, output.Coins) + netCoins, err := s.taxKeeper.DeductTax(sdkCtx, toAddr, output.Coins, true) if err != nil { return nil, err } diff --git a/x/tax/handlers/market_msg_server.go b/x/tax/handlers/market_msg_server.go index 510c232f..1656fe26 100644 --- a/x/tax/handlers/market_msg_server.go +++ b/x/tax/handlers/market_msg_server.go @@ -37,7 +37,7 @@ func (s *MarketMsgServer) SwapSend(ctx context.Context, msg *markettypes.MsgSwap sender := sdk.MustAccAddressFromBech32(msg.FromAddress) - netOfferCoin, err := s.taxKeeper.DeductTax(sdkCtx, sender, sdk.NewCoins(msg.OfferCoin)) + netOfferCoin, err := s.taxKeeper.DeductTax(sdkCtx, sender, sdk.NewCoins(msg.OfferCoin), false) if err != nil { return nil, err } diff --git a/x/tax/handlers/wasm_msg_server.go b/x/tax/handlers/wasm_msg_server.go index f77a7d3a..6f990f98 100644 --- a/x/tax/handlers/wasm_msg_server.go +++ b/x/tax/handlers/wasm_msg_server.go @@ -59,7 +59,7 @@ func (s *WasmMsgServer) ExecuteContract(ctx context.Context, msg *wasmtypes.MsgE sender := sdk.MustAccAddressFromBech32(msg.Sender) if !s.treasuryKeeper.HasBurnTaxExemptionContract(sdkCtx, msg.Contract) { - netFunds, err := s.taxKeeper.DeductTax(sdkCtx, sender, msg.Funds) + netFunds, err := s.taxKeeper.DeductTax(sdkCtx, sender, msg.Funds, false) if err != nil { return nil, err } @@ -89,7 +89,7 @@ func (s *WasmMsgServer) InstantiateContract(ctx context.Context, msg *wasmtypes. sender := sdk.MustAccAddressFromBech32(msg.Sender) - netFunds, err := s.taxKeeper.DeductTax(sdkCtx, sender, msg.Funds) + netFunds, err := s.taxKeeper.DeductTax(sdkCtx, sender, msg.Funds, false) if err != nil { return nil, err } @@ -117,7 +117,7 @@ func (s *WasmMsgServer) InstantiateContract2(ctx context.Context, msg *wasmtypes sender := sdk.MustAccAddressFromBech32(msg.Sender) - netFunds, err := s.taxKeeper.DeductTax(sdkCtx, sender, msg.Funds) + netFunds, err := s.taxKeeper.DeductTax(sdkCtx, sender, msg.Funds, false) if err != nil { return nil, err } diff --git a/x/tax/keeper/keeper.go b/x/tax/keeper/keeper.go index 8686007c..1e2bb217 100644 --- a/x/tax/keeper/keeper.go +++ b/x/tax/keeper/keeper.go @@ -96,6 +96,7 @@ func (k Keeper) DeductTax( ctx sdk.Context, sender sdk.AccAddress, amount sdk.Coins, + skipDeduct bool, ) (sdk.Coins, error) { ctx.Logger().Info("Deducting tax", "sender", sender, "amount", amount, ctx.Value(types.ContextKeyTaxReverseCharge)) @@ -107,7 +108,7 @@ func (k Keeper) DeductTax( taxes := k.ComputeTax(ctx, amount) netAmount := amount.Sub(taxes...) - if !taxes.IsZero() { + if !taxes.IsZero() && !skipDeduct { // Deduct the total tax amount from the sender and send to FeeCollector if err := k.bankKeeper.SendCoinsFromAccountToModule(ctx, sender, authtypes.FeeCollectorName, taxes); err != nil { return nil, err