Skip to content

Commit

Permalink
✨() bt-72 reset module state
Browse files Browse the repository at this point in the history
  • Loading branch information
pgonday committed Jan 25, 2025
1 parent 07b45bf commit 85dd459
Show file tree
Hide file tree
Showing 26 changed files with 771 additions and 189 deletions.
11 changes: 11 additions & 0 deletions contracts/compliance/modular/modules/AbstractModuleUpgradeable.sol
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ abstract contract AbstractModuleUpgradeable is IModule, Initializable, OwnableOn
struct AbstractModuleStorage {
/// compliance contract binding status
mapping(address => bool) complianceBound;

/// nonce for the module
mapping(address => uint256) nonces;
}

// keccak256(abi.encode(uint256(keccak256("ERC3643.storage.AbstractModule")) - 1)) & ~bytes32(uint256(0xff))
Expand Down Expand Up @@ -119,7 +122,10 @@ abstract contract AbstractModuleUpgradeable is IModule, Initializable, OwnableOn
AbstractModuleStorage storage s = _getAbstractModuleStorage();
require(_compliance != address(0), ZeroAddress());
require(msg.sender == _compliance, OnlyComplianceContractCanCall());

s.complianceBound[_compliance] = false;
s.nonces[_compliance] ++;

emit ComplianceUnbound(_compliance);
}

Expand All @@ -131,6 +137,11 @@ abstract contract AbstractModuleUpgradeable is IModule, Initializable, OwnableOn
return s.complianceBound[_compliance];
}

function getNonce(address _compliance) public view returns (uint256) {
AbstractModuleStorage storage s = _getAbstractModuleStorage();
return s.nonces[_compliance];
}

/**
* @dev See {IERC165-supportsInterface}.
*/
Expand Down
38 changes: 20 additions & 18 deletions contracts/compliance/modular/modules/ConditionalTransferModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ import "./AbstractModuleUpgradeable.sol";
/// @param _to is the address of transfer recipient.
/// @param _amount is the token amount to be sent (take care of decimals).
/// @param _token is the token address of the token concerned by the approval.
event TransferApproved(address _from, address _to, uint _amount, address _token);
event TransferApproved(address _from, address _to, uint256 _amount, address _token);


/// @dev This event is emitted whenever a transfer approval is removed.
/// @param _from is the address of transfer sender.
/// @param _to is the address of transfer recipient.
/// @param _amount is the token amount to be sent (take care of decimals).
/// @param _token is the token address of the token concerned by the approval.
event ApprovalRemoved(address _from, address _to, uint _amount, address _token);
event ApprovalRemoved(address _from, address _to, uint256 _amount, address _token);


/// Errors
Expand All @@ -97,15 +97,15 @@ event ApprovalRemoved(address _from, address _to, uint _amount, address _token);
/// @param _from the address of the transfer sender.
/// @param _to the address of the transfer receiver.
/// @param _amount the amount of tokens that `_from` was allowed to send to `_to`.
error TransferNotApproved(address _from, address _to, uint _amount);
error TransferNotApproved(address _from, address _to, uint256 _amount);


/**
* this module allows to require the pre-validation of a transfer before allowing it to be executed
*/
contract ConditionalTransferModule is AbstractModuleUpgradeable {
/// Mapping between transfer details and their approval status (amount of transfers approved) per compliance
mapping(address => mapping(bytes32 => uint)) private _transfersApproved;
mapping(address compliance => mapping(uint256 nonce => mapping(bytes32 => uint256))) private _transfersApproved;

/**
* @dev initializes the contract and sets the initial state.
Expand All @@ -126,7 +126,7 @@ contract ConditionalTransferModule is AbstractModuleUpgradeable {
* Only a bound compliance can call this function
* emits `_from.length` `TransferApproved` events
*/
function batchApproveTransfers(address[] calldata _from, address[] calldata _to, uint[] calldata _amount)
function batchApproveTransfers(address[] calldata _from, address[] calldata _to, uint256[] calldata _amount)
external onlyComplianceCall {
for (uint256 i = 0; i < _from.length; i++){
approveTransfer(_from[i], _to[i], _amount[i]);
Expand All @@ -145,7 +145,7 @@ contract ConditionalTransferModule is AbstractModuleUpgradeable {
* Only a bound compliance can call this function
* emits `_from.length` `ApprovalRemoved` events
*/
function batchUnApproveTransfers(address[] calldata _from, address[] calldata _to, uint[] calldata _amount)
function batchUnApproveTransfers(address[] calldata _from, address[] calldata _to, uint256[] calldata _amount)
external onlyComplianceCall {
for (uint256 i = 0; i < _from.length; i++){
unApproveTransfer(_from[i], _to[i], _amount[i]);
Expand All @@ -163,8 +163,9 @@ contract ConditionalTransferModule is AbstractModuleUpgradeable {
uint256 _value)
external override onlyComplianceCall {
bytes32 transferHash = calculateTransferHash(_from, _to, _value, IModularCompliance(msg.sender).getTokenBound());
if(_transfersApproved[msg.sender][transferHash] > 0) {
_transfersApproved[msg.sender][transferHash]--;
uint256 nonce = getNonce(msg.sender);
if(_transfersApproved[msg.sender][nonce][transferHash] > 0) {
_transfersApproved[msg.sender][nonce][transferHash]--;
emit ApprovalRemoved(_from, _to, _value, IModularCompliance(msg.sender).getTokenBound());
}
}
Expand Down Expand Up @@ -200,7 +201,7 @@ contract ConditionalTransferModule is AbstractModuleUpgradeable {
/**
* @dev See {IModule-canComplianceBind}.
*/
function canComplianceBind(address /*_compliance*/) external view override returns (bool) {
function canComplianceBind(address /*_compliance*/) external pure override returns (bool) {
return true;
}

Expand All @@ -220,9 +221,9 @@ contract ConditionalTransferModule is AbstractModuleUpgradeable {
* Only a bound compliance can call this function
* emits a `TransferApproved` event
*/
function approveTransfer(address _from, address _to, uint _amount) public onlyComplianceCall {
function approveTransfer(address _from, address _to, uint256 _amount) public onlyComplianceCall {
bytes32 transferHash = calculateTransferHash(_from, _to, _amount, IModularCompliance(msg.sender).getTokenBound());
_transfersApproved[msg.sender][transferHash]++;
_transfersApproved[msg.sender][getNonce(msg.sender)][transferHash]++;
emit TransferApproved(_from, _to, _amount, IModularCompliance(msg.sender).getTokenBound());
}

Expand All @@ -236,10 +237,11 @@ contract ConditionalTransferModule is AbstractModuleUpgradeable {
* Only a bound compliance can call this function
* emits an `ApprovalRemoved` event
*/
function unApproveTransfer(address _from, address _to, uint _amount) public onlyComplianceCall {
function unApproveTransfer(address _from, address _to, uint256 _amount) public onlyComplianceCall {
bytes32 transferHash = calculateTransferHash(_from, _to, _amount, IModularCompliance(msg.sender).getTokenBound());
require(_transfersApproved[msg.sender][transferHash] > 0, TransferNotApproved(_from, _to, _amount));
_transfersApproved[msg.sender][transferHash]--;
uint256 nonce = getNonce(msg.sender);
require(_transfersApproved[msg.sender][nonce][transferHash] > 0, TransferNotApproved(_from, _to, _amount));
_transfersApproved[msg.sender][nonce][transferHash]--;
emit ApprovalRemoved(_from, _to, _amount, IModularCompliance(msg.sender).getTokenBound());

}
Expand All @@ -251,7 +253,7 @@ contract ConditionalTransferModule is AbstractModuleUpgradeable {
* requires `_compliance` to be bound to this module
*/
function isTransferApproved(address _compliance, bytes32 _transferHash) public view returns (bool) {
if (((_transfersApproved[_compliance])[_transferHash]) > 0) {
if (((_transfersApproved[_compliance][getNonce(_compliance)])[_transferHash]) > 0) {
return true;
}
return false;
Expand All @@ -263,8 +265,8 @@ contract ConditionalTransferModule is AbstractModuleUpgradeable {
* @param _transferHash, bytes corresponding to the transfer details, hashed
* requires `_compliance` to be bound to this module
*/
function getTransferApprovals(address _compliance, bytes32 _transferHash) public view returns (uint) {
return (_transfersApproved[_compliance])[_transferHash];
function getTransferApprovals(address _compliance, bytes32 _transferHash) public view returns (uint256) {
return (_transfersApproved[_compliance][getNonce(_compliance)])[_transferHash];
}

/**
Expand All @@ -278,7 +280,7 @@ contract ConditionalTransferModule is AbstractModuleUpgradeable {
function calculateTransferHash (
address _from,
address _to,
uint _amount,
uint256 _amount,
address _token
) public pure returns (bytes32){
bytes32 transferHash = keccak256(abi.encode(_from, _to, _amount, _token));
Expand Down
23 changes: 14 additions & 9 deletions contracts/compliance/modular/modules/CountryAllowModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ error CountryNotAllowed(address _compliance, uint16 _country);

contract CountryAllowModule is AbstractModuleUpgradeable {
/// Mapping between country and their allowance status per compliance contract
mapping(address => mapping(uint16 => bool)) private _allowedCountries;
mapping(address compliance => mapping(uint256 nonce => mapping(uint16 => bool))) private _allowedCountries;

/// functions

Expand All @@ -122,8 +122,9 @@ contract CountryAllowModule is AbstractModuleUpgradeable {
* emits an `AddedAllowedCountry` event
*/
function batchAllowCountries(uint16[] calldata _countries) external onlyComplianceCall {
uint256 nonce = getNonce(msg.sender);
for (uint256 i = 0; i < _countries.length; i++) {
(_allowedCountries[msg.sender])[_countries[i]] = true;
(_allowedCountries[msg.sender][nonce])[_countries[i]] = true;
emit CountryAllowed(msg.sender, _countries[i]);
}
}
Expand All @@ -137,8 +138,9 @@ contract CountryAllowModule is AbstractModuleUpgradeable {
* emits an `RemoveAllowedCountry` event
*/
function batchDisallowCountries(uint16[] calldata _countries) external onlyComplianceCall {
uint256 nonce = getNonce(msg.sender);
for (uint256 i = 0; i < _countries.length; i++) {
(_allowedCountries[msg.sender])[_countries[i]] = false;
(_allowedCountries[msg.sender][nonce])[_countries[i]] = false;
emit CountryUnallowed(msg.sender, _countries[i]);
}
}
Expand All @@ -151,9 +153,10 @@ contract CountryAllowModule is AbstractModuleUpgradeable {
* emits an `AddedAllowedCountry` event
*/
function addAllowedCountry(uint16 _country) external onlyComplianceCall {
require(!(_allowedCountries[msg.sender])[_country], CountryAlreadyAllowed(msg.sender, _country));
uint256 nonce = getNonce(msg.sender);
require(!(_allowedCountries[msg.sender][nonce])[_country], CountryAlreadyAllowed(msg.sender, _country));

(_allowedCountries[msg.sender])[_country] = true;
(_allowedCountries[msg.sender][nonce])[_country] = true;
emit CountryAllowed(msg.sender, _country);
}

Expand All @@ -166,9 +169,10 @@ contract CountryAllowModule is AbstractModuleUpgradeable {
* emits an `RemoveAllowedCountry` event
*/
function removeAllowedCountry(uint16 _country) external onlyComplianceCall {
require((_allowedCountries[msg.sender])[_country], CountryNotAllowed(msg.sender, _country));
uint256 nonce = getNonce(msg.sender);
require((_allowedCountries[msg.sender][nonce])[_country], CountryNotAllowed(msg.sender, _country));

(_allowedCountries[msg.sender])[_country] = false;
(_allowedCountries[msg.sender][nonce])[_country] = false;
emit CountryUnallowed(msg.sender, _country);
}

Expand Down Expand Up @@ -212,7 +216,7 @@ contract CountryAllowModule is AbstractModuleUpgradeable {
/**
* @dev See {IModule-canComplianceBind}.
*/
function canComplianceBind(address /*_compliance*/) external view override returns (bool) {
function canComplianceBind(address /*_compliance*/) external pure override returns (bool) {
return true;
}

Expand All @@ -225,10 +229,11 @@ contract CountryAllowModule is AbstractModuleUpgradeable {

/**
* @dev Returns true if country is Allowed
* @param _compliance compliance contract address contract
* @param _country, numeric ISO 3166-1 standard of the country to be checked
*/
function isCountryAllowed(address _compliance, uint16 _country) public view returns (bool) {
return _allowedCountries[_compliance][_country];
return _allowedCountries[_compliance][getNonce(_compliance)][_country];
}

/**
Expand Down
30 changes: 19 additions & 11 deletions contracts/compliance/modular/modules/CountryRestrictModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ error MaxCountriesInBatchReached(uint256 _max);
contract CountryRestrictModule is AbstractModuleUpgradeable {

/// Mapping between country and their restriction status per compliance contract
mapping(address => mapping(uint16 => bool)) private _restrictedCountries;
mapping(address compliance => mapping(uint256 nonce => mapping(uint16 => bool))) private _restrictedCountries;

/**
* @dev initializes the contract and sets the initial state.
Expand All @@ -125,8 +125,9 @@ contract CountryRestrictModule is AbstractModuleUpgradeable {
* emits an `AddedRestrictedCountry` event
*/
function addCountryRestriction(uint16 _country) external onlyComplianceCall {
require((_restrictedCountries[msg.sender])[_country] == false, CountryAlreadyRestricted(msg.sender, _country));
(_restrictedCountries[msg.sender])[_country] = true;
uint256 nonce = getNonce(msg.sender);
require((_restrictedCountries[msg.sender][nonce])[_country] == false, CountryAlreadyRestricted(msg.sender, _country));
_restrictedCountries[msg.sender][nonce][_country] = true;
emit AddedRestrictedCountry(msg.sender, _country);
}

Expand All @@ -139,8 +140,9 @@ contract CountryRestrictModule is AbstractModuleUpgradeable {
* emits an `RemovedRestrictedCountry` event
*/
function removeCountryRestriction(uint16 _country) external onlyComplianceCall {
require((_restrictedCountries[msg.sender])[_country] == true, CountryNotRestricted(msg.sender, _country));
(_restrictedCountries[msg.sender])[_country] = false;
uint256 nonce = getNonce(msg.sender);
require((_restrictedCountries[msg.sender][nonce])[_country] == true, CountryNotRestricted(msg.sender, _country));
(_restrictedCountries[msg.sender][nonce])[_country] = false;
emit RemovedRestrictedCountry(msg.sender, _country);
}

Expand All @@ -155,9 +157,13 @@ contract CountryRestrictModule is AbstractModuleUpgradeable {
*/
function batchRestrictCountries(uint16[] calldata _countries) external onlyComplianceCall {
require(_countries.length < 195, MaxCountriesInBatchReached(195));
uint256 nonce = getNonce(msg.sender);
for (uint256 i = 0; i < _countries.length; i++) {
require(!(_restrictedCountries[msg.sender])[_countries[i]], CountryAlreadyRestricted(msg.sender, _countries[i]));
(_restrictedCountries[msg.sender])[_countries[i]] = true;
require(
!(_restrictedCountries[msg.sender][nonce])[_countries[i]],
CountryAlreadyRestricted(msg.sender, _countries[i])
);
_restrictedCountries[msg.sender][nonce][_countries[i]] = true;
emit AddedRestrictedCountry(msg.sender, _countries[i]);
}
}
Expand All @@ -173,9 +179,10 @@ contract CountryRestrictModule is AbstractModuleUpgradeable {
*/
function batchUnrestrictCountries(uint16[] calldata _countries) external onlyComplianceCall {
require(_countries.length < 195, MaxCountriesInBatchReached(195));
uint256 nonce = getNonce(msg.sender);
for (uint256 i = 0; i < _countries.length; i++) {
require((_restrictedCountries[msg.sender])[_countries[i]], CountryNotRestricted(msg.sender, _countries[i]));
(_restrictedCountries[msg.sender])[_countries[i]] = false;
require((_restrictedCountries[msg.sender][nonce])[_countries[i]], CountryNotRestricted(msg.sender, _countries[i]));
_restrictedCountries[msg.sender][nonce][_countries[i]] = false;
emit RemovedRestrictedCountry(msg.sender, _countries[i]);
}
}
Expand Down Expand Up @@ -223,7 +230,7 @@ contract CountryRestrictModule is AbstractModuleUpgradeable {
/**
* @dev See {IModule-canComplianceBind}.
*/
function canComplianceBind(address /*_compliance*/) external view override returns (bool) {
function canComplianceBind(address /*_compliance*/) external pure override returns (bool) {
return true;
}

Expand All @@ -240,7 +247,8 @@ contract CountryRestrictModule is AbstractModuleUpgradeable {
*/
function isCountryRestricted(address _compliance, uint16 _country) public view
returns (bool) {
return ((_restrictedCountries[_compliance])[_country]);
uint256 nonce = getNonce(_compliance);
return _restrictedCountries[_compliance][nonce][_country];
}

/**
Expand Down
Loading

0 comments on commit 85dd459

Please sign in to comment.