Skip to content

Commit

Permalink
Uts
Browse files Browse the repository at this point in the history
  • Loading branch information
goran-ethernal committed Jun 3, 2024
1 parent c8eead9 commit 807bc91
Show file tree
Hide file tree
Showing 2 changed files with 354 additions and 0 deletions.
34 changes: 34 additions & 0 deletions zk/txpool/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ func (p Policy) ToByteArray() []byte {
return []byte{byte(p)}
}

// IsSupportedPolicy checks if the given policy is supported
func IsSupportedPolicy(policy Policy) bool {
switch policy {
case SendTx, Deploy:
return true
default:
return false
}
}

func ResolvePolicy(policy string) (Policy, error) {
switch policy {
case "sendTx":
Expand All @@ -45,6 +55,10 @@ func containsPolicy(policies []byte, policy Policy) bool {

// CheckPolicy checks if the given address has the given policy for the online ACL mode
func CheckPolicy(ctx context.Context, aclDB kv.RwDB, addr common.Address, policy Policy) (bool, error) {
if !IsSupportedPolicy(policy) {
return false, errUnknownPolicy
}

// Retrieve the mode configuration
var hasPolicy bool
err := aclDB.View(ctx, func(tx kv.Tx) error {
Expand Down Expand Up @@ -127,6 +141,10 @@ func UpdatePolicies(ctx context.Context, aclDB kv.RwDB, aclType string, addrs []

// AddPolicy adds a policy to the ACL of given address
func AddPolicy(ctx context.Context, aclDB kv.RwDB, aclType string, addr common.Address, policy Policy) error {
if !IsSupportedPolicy(policy) {
return errUnknownPolicy
}

table, err := resolveTable(aclType)
if err != nil {
return err
Expand Down Expand Up @@ -199,6 +217,22 @@ func SetMode(ctx context.Context, aclDB kv.RwDB, mode string) error {
})
}

// GetMode gets the mode of the ACL
func GetMode(ctx context.Context, aclDB kv.RwDB) (ACLMode, error) {
var mode ACLMode
err := aclDB.View(ctx, func(tx kv.Tx) error {
value, err := tx.GetOne(Config, []byte(modeKey))
if err != nil {
return err
}

mode = ACLMode(value)
return nil
})

return mode, err
}

// resolveTable resolves the ACL table based on aclType
func resolveTable(aclType string) (string, error) {
at, err := ResolveACLType(aclType)
Expand Down
320 changes: 320 additions & 0 deletions zk/txpool/policy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
package txpool

import (
"context"
"fmt"
"os"
"testing"
"time"

"github.com/ledgerwatch/erigon-lib/common"
"github.com/ledgerwatch/erigon-lib/kv"
"github.com/stretchr/testify/require"
)

// newTestState creates new instance of state used by tests.
func newTestACLDB(tb testing.TB) kv.RwDB {
tb.Helper()

dir := fmt.Sprintf("/tmp/acl-db-temp_%v", time.Now().UTC().Format(time.RFC3339Nano))
err := os.Mkdir(dir, 0775)

if err != nil {
tb.Fatal(err)
}

state, err := OpenACLDB(context.Background(), dir)
if err != nil {
tb.Fatal(err)
}

tb.Cleanup(func() {
if err := os.RemoveAll(dir); err != nil {
tb.Fatal(err)
}
})

return state
}

func TestSetMode(t *testing.T) {
t.Parallel()

db := newTestACLDB(t)
ctx := context.Background()

t.Run("SetMode - Valid Mode", func(t *testing.T) {
t.Parallel()

mode := AllowlistMode

err := SetMode(ctx, db, mode)
require.NoError(t, err)

// Check if the mode is set correctly
modeInDB, err := GetMode(ctx, db)
require.NoError(t, err)
require.Equal(t, string(mode), string(modeInDB))
})

t.Run("SetMode - Invalid Mode", func(t *testing.T) {
t.Parallel()

mode := "invalid_mode"

err := SetMode(ctx, db, mode)
require.ErrorIs(t, err, errInvalidMode)
})
}

func TestRemovePolicy(t *testing.T) {
t.Parallel()

db := newTestACLDB(t)
ctx := context.Background()

SetMode(ctx, db, BlocklistMode)

t.Run("RemovePolicy - Policy Exists", func(t *testing.T) {
t.Parallel()

// Create a test address and policy
addr := common.HexToAddress("0x1234567890abcdef")
policy := SendTx

// Add the policy to the ACL
require.NoError(t, AddPolicy(ctx, db, "blocklist", addr, policy))

// Remove the policy from the ACL
err := RemovePolicy(ctx, db, "blocklist", addr, policy)
require.NoError(t, err)

// Check if the policy is removed from the ACL
hasPolicy, err := CheckPolicy(ctx, db, addr, policy)
require.NoError(t, err)
require.False(t, hasPolicy)
})

t.Run("RemovePolicy - Policy Does Not Exist", func(t *testing.T) {
t.Parallel()

// Create a test address and policy
addr := common.HexToAddress("0x1234567890abcdef")
policy := SendTx

// Add some different policy to the ACL
require.NoError(t, AddPolicy(ctx, db, "blocklist", addr, Deploy))

// Remove the policy from the ACL
err := RemovePolicy(ctx, db, "blocklist", addr, policy)
require.NoError(t, err)

// Check if the policy is still not present in the ACL
hasPolicy, err := CheckPolicy(ctx, db, addr, policy)
require.NoError(t, err)
require.False(t, hasPolicy)
})

t.Run("RemovePolicy - Address Not Found", func(t *testing.T) {
t.Parallel()

// Create a test address and policy
addr := common.HexToAddress("0x1234567890abcdef")
policy := SendTx

// Remove the policy from the ACL
err := RemovePolicy(ctx, db, "blocklist", addr, policy)
require.NoError(t, err)

// Check if the policy is still not present in the ACL
hasPolicy, err := CheckPolicy(ctx, db, addr, policy)
require.NoError(t, err)
require.False(t, hasPolicy)
})

t.Run("RemovePolicy - Unsupported acl type", func(t *testing.T) {
t.Parallel()

// Create a test address and policy
addr := common.HexToAddress("0x1234567890abcdef")
policy := SendTx

err := RemovePolicy(ctx, db, "unknown_acl_type", addr, policy)
require.ErrorIs(t, err, errUnsupportedACLType)
})
}

func TestAddPolicy(t *testing.T) {
t.Parallel()

db := newTestACLDB(t)
ctx := context.Background()

SetMode(ctx, db, BlocklistMode)

t.Run("AddPolicy - Policy Does Not Exist", func(t *testing.T) {
t.Parallel()

// Create a test address and policy
addr := common.HexToAddress("0x1234567890abcdef")
policy := SendTx

err := AddPolicy(ctx, db, "blocklist", addr, policy)
require.NoError(t, err)

// Check if the policy exists in the ACL
hasPolicy, err := CheckPolicy(ctx, db, addr, policy)
require.NoError(t, err)
require.True(t, hasPolicy)
})

t.Run("AddPolicy - Policy Already Exists", func(t *testing.T) {
t.Parallel()

// Create a test address and policy
addr := common.HexToAddress("0x1234567890abcdef")
policy := SendTx

// Add the policy to the ACL
require.NoError(t, AddPolicy(ctx, db, "blocklist", addr, policy))

// Add the policy again
err := AddPolicy(ctx, db, "blocklist", addr, policy)
require.NoError(t, err)

// Check if the policy still exists in the ACL
hasPolicy, err := CheckPolicy(ctx, db, addr, policy)
require.NoError(t, err)
require.True(t, hasPolicy)
})

t.Run("AddPolicy - Unsupported Policy", func(t *testing.T) {
t.Parallel()

// Create a test address and policy
addr := common.HexToAddress("0x1234567890abcdef")
policy := Policy(33) // Assume Policy(33) is not supported

err := AddPolicy(ctx, db, "blocklist", addr, policy)
require.ErrorIs(t, err, errUnknownPolicy)
})

t.Run("AddPolicy - Unsupported acl type", func(t *testing.T) {
t.Parallel()

// Create a test address and policy
addr := common.HexToAddress("0x1234567890abcdef")
policy := SendTx

err := AddPolicy(ctx, db, "unknown_acl_type", addr, policy)
require.ErrorIs(t, err, errUnsupportedACLType)
})
}

func TestUpdatePolicies(t *testing.T) {
t.Parallel()

db := newTestACLDB(t)
ctx := context.Background()

SetMode(ctx, db, BlocklistMode)

t.Run("UpdatePolicies - Add Policies", func(t *testing.T) {
t.Parallel()

// Create test addresses and policies
addr1 := common.HexToAddress("0x1234567890abcdef")
addr2 := common.HexToAddress("0xabcdef1234567890")
policies := [][]Policy{
{SendTx, Deploy},
{SendTx},
}

err := UpdatePolicies(ctx, db, "blocklist", []common.Address{addr1, addr2}, policies)
require.NoError(t, err)

// Check if the policies are added correctly
hasPolicy, err := CheckPolicy(ctx, db, addr1, SendTx)
require.NoError(t, err)
require.True(t, hasPolicy)

hasPolicy, err = CheckPolicy(ctx, db, addr1, Deploy)
require.NoError(t, err)
require.True(t, hasPolicy)

hasPolicy, err = CheckPolicy(ctx, db, addr2, SendTx)
require.NoError(t, err)
require.True(t, hasPolicy)
})

t.Run("UpdatePolicies - Remove Policies", func(t *testing.T) {
t.Parallel()

// Create test addresses and policies
addr1 := common.HexToAddress("0x1234567890abcdef")
addr2 := common.HexToAddress("0xabcdef1234567890")
policies := [][]Policy{
{},
{SendTx},
}

err := UpdatePolicies(ctx, db, "blocklist", []common.Address{addr1, addr2}, policies)
require.NoError(t, err)

// Check if the policies are removed correctly
hasPolicy, err := CheckPolicy(ctx, db, addr1, SendTx)
require.NoError(t, err)
require.False(t, hasPolicy)

hasPolicy, err = CheckPolicy(ctx, db, addr1, Deploy)
require.NoError(t, err)
require.False(t, hasPolicy)

hasPolicy, err = CheckPolicy(ctx, db, addr2, SendTx)
require.NoError(t, err)
require.True(t, hasPolicy)
})

t.Run("UpdatePolicies - Empty Policies", func(t *testing.T) {
t.Parallel()

// Create test addresses and policies
addr1 := common.HexToAddress("0x1234567890abcdef")
addr2 := common.HexToAddress("0xabcdef1234567890")
policies := [][]Policy{
{},
{},
}

err := UpdatePolicies(ctx, db, "blocklist", []common.Address{addr1, addr2}, policies)
require.NoError(t, err)

// Check if the policies are removed correctly
hasPolicy, err := CheckPolicy(ctx, db, addr1, SendTx)
require.NoError(t, err)
require.False(t, hasPolicy)

hasPolicy, err = CheckPolicy(ctx, db, addr1, Deploy)
require.NoError(t, err)
require.False(t, hasPolicy)

hasPolicy, err = CheckPolicy(ctx, db, addr2, SendTx)
require.NoError(t, err)
require.False(t, hasPolicy)
})

t.Run("UpdatePolicies - Unsupported acl type", func(t *testing.T) {
t.Parallel()

// Create test addresses and policies
addr1 := common.HexToAddress("0x1234567890abcdef")
addr2 := common.HexToAddress("0xabcdef1234567890")
policies := [][]Policy{
{SendTx, Deploy},
{SendTx},
}

err := UpdatePolicies(ctx, db, "unknown_acl_type", []common.Address{addr1, addr2}, policies)
require.ErrorIs(t, err, errUnsupportedACLType)
})
}

0 comments on commit 807bc91

Please sign in to comment.