diff --git a/contracts/core/GuardableModifier.sol b/contracts/core/GuardableModifier.sol index ed84333..dd559c0 100644 --- a/contracts/core/GuardableModifier.sol +++ b/contracts/core/GuardableModifier.sol @@ -3,7 +3,7 @@ pragma solidity >=0.7.0 <0.9.0; import {Guardable} from "../guard/Guardable.sol"; import {IAvatar} from "../interfaces/IAvatar.sol"; -import {IGuard} from "../interfaces/IGuard.sol"; +import {IModuleGuard} from "../interfaces/IGuard.sol"; import {Modifier} from "./Modifier.sol"; import {Module} from "./Module.sol"; @@ -22,21 +22,14 @@ abstract contract GuardableModifier is Module, Guardable, Modifier { bytes memory data, Operation operation ) internal virtual override returns (bool success) { + bytes32 moduleTxHash; address currentGuard = guard; if (currentGuard != address(0)) { - IGuard(currentGuard).checkTransaction( - /// Transaction info used by module transactions. + moduleTxHash = IModuleGuard(currentGuard).checkModuleTransaction( to, value, data, operation, - /// Zero out the redundant transaction information only used for Safe multisig transctions. - 0, - 0, - 0, - address(0), - payable(0), - "", sentOrSignedByModule() ); } @@ -47,7 +40,10 @@ abstract contract GuardableModifier is Module, Guardable, Modifier { operation ); if (currentGuard != address(0)) { - IGuard(currentGuard).checkAfterExecution(bytes32(0), success); + IModuleGuard(currentGuard).checkAfterModuleExecution( + moduleTxHash, + success + ); } } @@ -63,22 +59,15 @@ abstract contract GuardableModifier is Module, Guardable, Modifier { bytes memory data, Operation operation ) internal virtual override returns (bool success, bytes memory returnData) { + bytes32 moduleTxHash; address currentGuard = guard; if (currentGuard != address(0)) { - IGuard(currentGuard).checkTransaction( - /// Transaction info used by module transactions. + moduleTxHash = IModuleGuard(currentGuard).checkModuleTransaction( to, value, data, operation, - /// Zero out the redundant transaction information only used for Safe multisig transctions. - 0, - 0, - 0, - address(0), - payable(0), - "", - sentOrSignedByModule() + address(this) ); } @@ -90,7 +79,10 @@ abstract contract GuardableModifier is Module, Guardable, Modifier { ); if (currentGuard != address(0)) { - IGuard(currentGuard).checkAfterExecution(bytes32(0), success); + IModuleGuard(currentGuard).checkAfterModuleExecution( + moduleTxHash, + success + ); } } } diff --git a/contracts/core/GuardableModule.sol b/contracts/core/GuardableModule.sol index aba72b3..895733b 100644 --- a/contracts/core/GuardableModule.sol +++ b/contracts/core/GuardableModule.sol @@ -1,7 +1,7 @@ // SPDX-License-Identifier: LGPL-3.0-only pragma solidity >=0.7.0 <0.9.0; -import {IGuard} from "../interfaces/IGuard.sol"; +import {IModuleGuard} from "../interfaces/IGuard.sol"; import {Guardable} from "../guard/Guardable.sol"; import {Module} from "./Module.sol"; import {IAvatar} from "../interfaces/IAvatar.sol"; @@ -21,22 +21,15 @@ abstract contract GuardableModule is Module, Guardable { bytes memory data, Operation operation ) internal override returns (bool success) { + bytes32 moduleTxHash; address currentGuard = guard; if (currentGuard != address(0)) { - IGuard(currentGuard).checkTransaction( - /// Transaction info used by module transactions. + moduleTxHash = IModuleGuard(currentGuard).checkModuleTransaction( to, value, data, operation, - /// Zero out the redundant transaction information only used for Safe multisig transctions. - 0, - 0, - 0, - address(0), - payable(0), - "", - msg.sender + address(this) ); } success = IAvatar(target).execTransactionFromModule( @@ -46,7 +39,10 @@ abstract contract GuardableModule is Module, Guardable { operation ); if (currentGuard != address(0)) { - IGuard(currentGuard).checkAfterExecution(bytes32(0), success); + IModuleGuard(currentGuard).checkAfterModuleExecution( + moduleTxHash, + success + ); } } @@ -62,22 +58,15 @@ abstract contract GuardableModule is Module, Guardable { bytes memory data, Operation operation ) internal virtual override returns (bool success, bytes memory returnData) { + bytes32 moduleTxHash; address currentGuard = guard; if (currentGuard != address(0)) { - IGuard(currentGuard).checkTransaction( - /// Transaction info used by module transactions. + moduleTxHash = IModuleGuard(currentGuard).checkModuleTransaction( to, value, data, operation, - /// Zero out the redundant transaction information only used for Safe multisig transctions. - 0, - 0, - 0, - address(0), - payable(0), - "", - msg.sender + address(this) ); } @@ -89,7 +78,10 @@ abstract contract GuardableModule is Module, Guardable { ); if (currentGuard != address(0)) { - IGuard(currentGuard).checkAfterExecution(bytes32(0), success); + IModuleGuard(currentGuard).checkAfterModuleExecution( + moduleTxHash, + success + ); } } } diff --git a/contracts/guard/BaseGuard.sol b/contracts/guard/BaseGuard.sol index 40af22b..a59a50e 100644 --- a/contracts/guard/BaseGuard.sol +++ b/contracts/guard/BaseGuard.sol @@ -3,11 +3,11 @@ pragma solidity >=0.7.0 <0.9.0; import {IERC165} from "../interfaces/IERC165.sol"; -import {IGuard} from "../interfaces/IGuard.sol"; +import {IGuard, IModuleGuard} from "../interfaces/IGuard.sol"; import "../core/Operation.sol"; -abstract contract BaseGuard is IERC165 { +abstract contract BaseGuard is IGuard, IERC165 { function supportsInterface( bytes4 interfaceId ) external pure override returns (bool) { @@ -35,3 +35,26 @@ abstract contract BaseGuard is IERC165 { function checkAfterExecution(bytes32 txHash, bool success) external virtual; } + +abstract contract BaseModuleGuard is IModuleGuard, IERC165 { + function checkModuleTransaction( + address to, + uint256 value, + bytes memory data, + Operation operation, + address module + ) external virtual returns (bytes32 moduleTxHash); + + function checkAfterModuleExecution( + bytes32 txHash, + bool success + ) external virtual; + + function supportsInterface( + bytes4 interfaceId + ) external pure override returns (bool) { + return + interfaceId == type(IModuleGuard).interfaceId || // 0x58401ed8 + interfaceId == type(IERC165).interfaceId; // 0x01ffc9a7 + } +} diff --git a/contracts/guard/Guardable.sol b/contracts/guard/Guardable.sol index 51d81cd..eb49096 100644 --- a/contracts/guard/Guardable.sol +++ b/contracts/guard/Guardable.sol @@ -3,8 +3,8 @@ pragma solidity >=0.7.0 <0.9.0; import {Ownable} from "../factory/Ownable.sol"; -import {BaseGuard} from "../guard/BaseGuard.sol"; -import {IGuard} from "../interfaces/IGuard.sol"; +import {BaseModuleGuard} from "../guard/BaseGuard.sol"; +import {IModuleGuard} from "../interfaces/IGuard.sol"; /// @title Guardable - A contract that manages fallback calls made to this contract contract Guardable is Ownable { @@ -19,8 +19,11 @@ contract Guardable is Ownable { /// @param _guard The address of the guard to be used or the 0 address to disable the guard. function setGuard(address _guard) external onlyOwner { if (_guard != address(0)) { - if (!BaseGuard(_guard).supportsInterface(type(IGuard).interfaceId)) - revert NotIERC165Compliant(_guard); + if ( + !BaseModuleGuard(_guard).supportsInterface( + type(IModuleGuard).interfaceId + ) + ) revert NotIERC165Compliant(_guard); } guard = _guard; emit ChangedGuard(guard); diff --git a/contracts/interfaces/IGuard.sol b/contracts/interfaces/IGuard.sol index aeabedc..27fa435 100644 --- a/contracts/interfaces/IGuard.sol +++ b/contracts/interfaces/IGuard.sol @@ -20,3 +20,18 @@ interface IGuard { function checkAfterExecution(bytes32 txHash, bool success) external; } + +interface IModuleGuard { + function checkModuleTransaction( + address to, + uint256 value, + bytes memory data, + Operation operation, + address module + ) external returns (bytes32 moduleTxHash); + + function checkAfterModuleExecution( + bytes32 moduleTxHash, + bool success + ) external; +} diff --git a/contracts/signature/SignatureChecker.sol b/contracts/signature/SignatureChecker.sol index e33535b..f8f1ff6 100644 --- a/contracts/signature/SignatureChecker.sol +++ b/contracts/signature/SignatureChecker.sol @@ -52,16 +52,18 @@ abstract contract SignatureChecker { if (start < 4 || start > end) { return (bytes32(0), address(0)); } + bytes32 hash = moduleTxHash(data[:start], salt); address signer = address(uint160(uint256(r))); - bytes32 hash = moduleTxHash(data[:start], salt); return _isValidContractSignature(signer, hash, data[start:end]) ? (hash, signer) : (bytes32(0), address(0)); } else { bytes32 hash = moduleTxHash(data[:end], salt); - return (hash, ecrecover(hash, v, r, s)); + address signer = ecrecover(hash, v, r, s); + + return signer != address(0) ? (hash, signer) : (bytes32(0), address(0)); } } diff --git a/contracts/test/TestGuard.sol b/contracts/test/TestGuard.sol index 8ff2ea0..1f03cee 100644 --- a/contracts/test/TestGuard.sol +++ b/contracts/test/TestGuard.sol @@ -3,7 +3,7 @@ pragma solidity >=0.7.0 <0.9.0; import {IERC165} from "../interfaces/IERC165.sol"; -import {BaseGuard} from "../guard/BaseGuard.sol"; +import {BaseModuleGuard} from "../guard/BaseGuard.sol"; import {FactoryFriendly} from "../factory/FactoryFriendly.sol"; import {GuardableModule} from "../core/GuardableModule.sol"; @@ -11,8 +11,8 @@ import "../core/Operation.sol"; /* solhint-disable */ -contract TestGuard is FactoryFriendly, BaseGuard { - event PreChecked(address sender); +contract TestGuard is FactoryFriendly, BaseModuleGuard { + event PreChecked(address module); event PostChecked(bool checked); address public module; @@ -26,27 +26,22 @@ contract TestGuard is FactoryFriendly, BaseGuard { module = _module; } - function checkTransaction( + function checkModuleTransaction( address to, uint256 value, bytes memory data, Operation operation, - uint256, - uint256, - uint256, - address, - address payable, - bytes memory, - address sender - ) public override { + address _module + ) public override returns (bytes32) { require(to != address(0), "Cannot send to zero address"); require(value != 1337, "Cannot send 1337"); require(bytes3(data) != bytes3(0xbaddad), "Cannot call 0xbaddad"); require(operation != Operation(1), "No delegate calls"); - emit PreChecked(sender); + emit PreChecked(_module); + return keccak256(abi.encodePacked(to, value, data, operation, _module)); } - function checkAfterExecution(bytes32, bool) public override { + function checkAfterModuleExecution(bytes32, bool) public override { require( GuardableModule(module).guard() == address(this), "Module cannot remove its own guard." @@ -64,20 +59,4 @@ contract TestNonCompliantGuard is IERC165 { function supportsInterface(bytes4) external pure returns (bool) { return false; } - - function checkTransaction( - address, - uint256, - bytes memory, - Operation, - uint256, - uint256, - uint256, - address, - address, - bytes memory, - address - ) public {} - - function checkAfterExecution(bytes32, bool) public {} } diff --git a/test/04_Guard.spec.ts b/test/04_Guard.spec.ts index 197ac93..51cee9f 100644 --- a/test/04_Guard.spec.ts +++ b/test/04_Guard.spec.ts @@ -37,23 +37,7 @@ async function setupTests() { const GuardNonCompliant = await hre.ethers.getContractFactory( "TestNonCompliantGuard" ); - const guardNonCompliant = TestGuard__factory.connect( - await (await GuardNonCompliant.deploy()).getAddress(), - hre.ethers.provider - ); - - const tx = { - to: await avatar.getAddress(), - value: 0, - data: "0x", - operation: 0, - avatarTxGas: 0, - baseGas: 0, - gasPrice: 0, - gasToken: ZeroAddress, - refundReceiver: ZeroAddress, - signatures: "0x", - }; + const guardNonCompliant = await GuardNonCompliant.deploy(); return { owner, @@ -61,7 +45,6 @@ async function setupTests() { module, guard, guardNonCompliant, - tx, }; } @@ -137,90 +120,85 @@ describe("Guardable", async () => { }); }); -describe("BaseGuard", async () => { - const txHash = +describe("BaseModuleGuard", async () => { + const moduleTxHash = "0x0000000000000000000000000000000000000000000000000000000000000001"; /** * Tests support for interfaces. * Verifies that the guard supports the required interfaces. */ - it("supports interface", async () => { + it("supports IModuleGuard interface", async () => { const { guard } = await loadFixture(setupTests); - expect(await guard.supportsInterface("0xe6d7a83a")).to.be.true; + // IModuleGuard interface ID + const iModuleGuardId = "0x58401ed8"; + expect(await guard.supportsInterface(iModuleGuardId)).to.be.true; expect(await guard.supportsInterface("0x01ffc9a7")).to.be.true; }); - describe("checkTransaction", async () => { + it("does not support IGuard interface", async () => { + const { guard } = await loadFixture(setupTests); + expect(await guard.supportsInterface("0xe6d7a83a")).to.be.false; + }); + + describe("checkModuleTransaction", async () => { /** - * Tests checking a transaction. + * Tests checking a module transaction. * Verifies that checking the transaction reverts if the test fails. */ it("reverts if test fails", async () => { - const { guard, tx } = await loadFixture(setupTests); + const { guard, module } = await loadFixture(setupTests); await expect( - guard.checkTransaction( - tx.to, + guard.checkModuleTransaction( + await module.getAddress(), 1337, - tx.data, - tx.operation, - tx.avatarTxGas, - tx.baseGas, - tx.gasPrice, - tx.gasToken, - tx.refundReceiver, - tx.signatures, - ZeroAddress + "0x", + 0, + await module.getAddress() ) ).to.be.revertedWith("Cannot send 1337"); }); /** - * Tests checking a transaction. + * Tests checking a module transaction. * Verifies that the transaction can be checked successfully. */ - it("checks transaction", async () => { - const { guard, tx } = await loadFixture(setupTests); + it("checks module transaction", async () => { + const { guard, module } = await loadFixture(setupTests); await expect( - guard.checkTransaction( - tx.to, - tx.value, - tx.data, - tx.operation, - tx.avatarTxGas, - tx.baseGas, - tx.gasPrice, - tx.gasToken, - tx.refundReceiver, - tx.signatures, - ZeroAddress + guard.checkModuleTransaction( + await module.getAddress(), + 0, + "0x", + 0, + await module.getAddress() ) ).to.emit(guard, "PreChecked"); }); }); - describe("checkAfterExecution", async () => { + describe("checkAfterModuleExecution", async () => { /** - * Tests checking the state after execution. + * Tests checking the state after module execution. * Verifies that checking the state after execution reverts if the test fails. */ it("reverts if test fails", async () => { const { guard } = await loadFixture(setupTests); - await expect(guard.checkAfterExecution(txHash, true)).to.be.revertedWith( - "Module cannot remove its own guard." - ); + await expect( + guard.checkAfterModuleExecution(moduleTxHash, true) + ).to.be.revertedWith("Module cannot remove its own guard."); }); /** - * Tests checking the state after execution. + * Tests checking the state after module execution. * Verifies that the state can be checked successfully after execution. */ - it("checks state after execution", async () => { + it("checks state after module execution", async () => { const { module, guard } = await loadFixture(setupTests); await expect(module.setGuard(await guard.getAddress())) .to.emit(module, "ChangedGuard") .withArgs(await guard.getAddress()); - await expect(guard.checkAfterExecution(txHash, true)) + await expect(guard.checkAfterModuleExecution(moduleTxHash, true)) .to.emit(guard, "PostChecked") .withArgs(true); }); diff --git a/test/07_GuardableModifier.spec.ts b/test/07_GuardableModifier.spec.ts index 9fc4001..0e0c4ba 100644 --- a/test/07_GuardableModifier.spec.ts +++ b/test/07_GuardableModifier.spec.ts @@ -76,7 +76,7 @@ describe("GuardableModifier", async () => { /** * Tests executing a transaction with a guard set. - * Verifies that the guard's pre-check is called and emits the PreChecked event. + * Verifies that the guard's pre-check is called and emits the PreChecked event with the executor address. */ it("pre-checks transaction if guard is set", async () => { const { avatar, executor, modifier, guard } = @@ -94,9 +94,9 @@ describe("GuardableModifier", async () => { /** * Tests executing a relayed transaction with a guard set. - * Verifies that the guard's pre-check is called with the signer's address. + * Verifies that the guard's pre-check is called with the signer's address (via sentOrSignedByModule). */ - it("pre-check gets called with signer when transaction is relayed", async () => { + it("pre-check gets called with signer address when transaction is relayed", async () => { const { signer, modifier, relayer, avatar, guard } = await loadFixture(setupTests); @@ -223,14 +223,14 @@ describe("GuardableModifier", async () => { ) ) .to.emit(guard, "PreChecked") - .withArgs(await executor.getAddress()); + .withArgs(await modifier.getAddress()); }); /** * Tests executing a relayed transaction that returns data with a guard set. - * Verifies that the guard's pre-check is called with the signer's address. + * Verifies that the guard's pre-check is called with the modifier's address. */ - it("pre-check gets called with signer when transaction is relayed", async () => { + it("pre-check gets called with modifier address when transaction is relayed", async () => { const { signer, modifier, relayer, avatar, guard } = await loadFixture(setupTests); @@ -264,7 +264,7 @@ describe("GuardableModifier", async () => { await expect(await relayer.sendTransaction(transactionWithSig)) .to.emit(guard, "PreChecked") - .withArgs(signer.address); + .withArgs(await modifier.getAddress()); }); /**