Skip to content

[audit-05] fix: [TRST-M-2] shared collection window logic #1202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: ma/indexing-payments-audit-fixes-04-M-1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions packages/horizon/contracts/interfaces/IRecurringCollector.sol
Original file line number Diff line number Diff line change
Expand Up @@ -290,14 +290,6 @@ interface IRecurringCollector is IAuthorizable, IPaymentsCollector {
*/
error RecurringCollectorInvalidCollectData(bytes invalidData);

/**
* @notice Thrown when calling collect() on a payer canceled agreement
* where the final collection has already been done
* @param agreementId The agreement ID
* @param finalCollectionAt The timestamp when the final collection was done
*/
error RecurringCollectorFinalCollectionDone(bytes16 agreementId, uint256 finalCollectionAt);

/**
* @notice Thrown when interacting with an agreement that has an incorrect state
* @param agreementId The agreement ID
Expand Down Expand Up @@ -420,11 +412,13 @@ interface IRecurringCollector is IAuthorizable, IPaymentsCollector {
function getAgreement(bytes16 agreementId) external view returns (AgreementData memory);

/**
* @notice Checks if an agreement is collectable.
* @dev "Collectable" means the agreement is in a valid state that allows collection attempts,
* not that there are necessarily funds available to collect.
* @notice Get collection info for an agreement
* @param agreement The agreement data
* @return The boolean indicating if the agreement is collectable
* @return isCollectable Whether the agreement is in a valid state that allows collection attempts,
* not that there are necessarily funds available to collect.
* @return collectionSeconds The valid collection duration in seconds (0 if not collectable)
*/
function isCollectable(AgreementData memory agreement) external view returns (bool);
function getCollectionInfo(
AgreementData memory agreement
) external view returns (bool isCollectable, uint256 collectionSeconds);
}
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,10 @@ contract RecurringCollector is EIP712, GraphDirectory, Authorizable, IRecurringC
}

/// @inheritdoc IRecurringCollector
function isCollectable(AgreementData memory agreement) external pure returns (bool) {
return _isCollectable(agreement);
function getCollectionInfo(
AgreementData memory agreement
) external view returns (bool isCollectable, uint256 collectionSeconds) {
return _getCollectionInfo(agreement);
}

/**
Expand All @@ -274,9 +276,14 @@ contract RecurringCollector is EIP712, GraphDirectory, Authorizable, IRecurringC
CollectParams memory _params
) private returns (uint256) {
AgreementData storage agreement = _getAgreementStorage(_params.agreementId);

// Check if agreement exists first (for unknown agreements)
(bool isCollectable, uint256 collectionSeconds) = _getCollectionInfo(agreement);
require(isCollectable, RecurringCollectorAgreementIncorrectState(_params.agreementId, agreement.state));

require(
_isCollectable(agreement),
RecurringCollectorAgreementIncorrectState(_params.agreementId, agreement.state)
collectionSeconds > 0,
RecurringCollectorZeroCollectionSeconds(_params.agreementId, block.timestamp, agreement.lastCollectionAt)
);

require(
Expand All @@ -297,7 +304,7 @@ contract RecurringCollector is EIP712, GraphDirectory, Authorizable, IRecurringC

uint256 tokensToCollect = 0;
if (_params.tokens != 0) {
tokensToCollect = _requireValidCollect(agreement, _params.agreementId, _params.tokens);
tokensToCollect = _requireValidCollect(agreement, _params.agreementId, _params.tokens, collectionSeconds);

_graphPaymentsEscrow().collect(
_paymentType,
Expand Down Expand Up @@ -374,53 +381,37 @@ contract RecurringCollector is EIP712, GraphDirectory, Authorizable, IRecurringC
* @param _agreement The agreement data
* @param _agreementId The ID of the agreement
* @param _tokens The number of tokens to collect
* @param _collectionSeconds Collection duration from _getCollectionInfo()
* @return The number of tokens that can be collected
*/
function _requireValidCollect(
AgreementData memory _agreement,
bytes16 _agreementId,
uint256 _tokens
uint256 _tokens,
uint256 _collectionSeconds
) private view returns (uint256) {
bool canceledOrElapsed = _agreement.state == AgreementState.CanceledByPayer ||
block.timestamp > _agreement.endsAt;
uint256 canceledOrNow = _agreement.state == AgreementState.CanceledByPayer
? _agreement.canceledAt
: block.timestamp;

// if canceled by the payer allow collection till canceledAt
// if elapsed allow collection till endsAt
// if both are true, use the earlier one
uint256 collectionEnd = canceledOrElapsed ? Math.min(canceledOrNow, _agreement.endsAt) : block.timestamp;
uint256 collectionStart = _agreementCollectionStartAt(_agreement);
require(
collectionEnd != collectionStart,
RecurringCollectorZeroCollectionSeconds(_agreementId, block.timestamp, uint64(collectionStart))
);
require(collectionEnd > collectionStart, RecurringCollectorFinalCollectionDone(_agreementId, collectionStart));

uint256 collectionSeconds = collectionEnd - collectionStart;
// Check that the collection window is long enough
// If the agreement is canceled or elapsed, allow a shorter collection window
if (!canceledOrElapsed) {
require(
collectionSeconds >= _agreement.minSecondsPerCollection,
_collectionSeconds >= _agreement.minSecondsPerCollection,
RecurringCollectorCollectionTooSoon(
_agreementId,
uint32(collectionSeconds),
uint32(_collectionSeconds),
_agreement.minSecondsPerCollection
)
);
}
require(
collectionSeconds <= _agreement.maxSecondsPerCollection,
_collectionSeconds <= _agreement.maxSecondsPerCollection,
RecurringCollectorCollectionTooLate(
_agreementId,
uint64(collectionSeconds),
uint64(_collectionSeconds),
_agreement.maxSecondsPerCollection
)
);

uint256 maxTokens = _agreement.maxOngoingTokensPerSecond * collectionSeconds;
uint256 maxTokens = _agreement.maxOngoingTokensPerSecond * _collectionSeconds;
maxTokens += _agreement.lastCollectionAt == 0 ? _agreement.maxInitialTokens : 0;

return Math.min(_tokens, maxTokens);
Expand Down Expand Up @@ -546,20 +537,47 @@ contract RecurringCollector is EIP712, GraphDirectory, Authorizable, IRecurringC
}

/**
* @notice Gets the start time for the collection of an agreement.
* @notice Internal function to get collection info for an agreement
* @dev This is the single source of truth for collection window logic
* @param _agreement The agreement data
* @return The start time for the collection of the agreement
* @return isCollectable Whether the agreement can be collected from
* @return collectionSeconds The valid collection duration in seconds (0 if not collectable)
*/
function _agreementCollectionStartAt(AgreementData memory _agreement) private pure returns (uint256) {
return _agreement.lastCollectionAt > 0 ? _agreement.lastCollectionAt : _agreement.acceptedAt;
function _getCollectionInfo(
AgreementData memory _agreement
) private view returns (bool isCollectable, uint256 collectionSeconds) {
// Check if agreement is in collectable state
isCollectable =
_agreement.state == AgreementState.Accepted ||
_agreement.state == AgreementState.CanceledByPayer;

if (!isCollectable) {
return (false, 0);
}

bool canceledOrElapsed = _agreement.state == AgreementState.CanceledByPayer ||
block.timestamp > _agreement.endsAt;
uint256 canceledOrNow = _agreement.state == AgreementState.CanceledByPayer
? _agreement.canceledAt
: block.timestamp;

uint256 collectionEnd = canceledOrElapsed ? Math.min(canceledOrNow, _agreement.endsAt) : block.timestamp;
uint256 collectionStart = _agreementCollectionStartAt(_agreement);

if (collectionEnd < collectionStart) {
return (false, 0);
}

collectionSeconds = collectionEnd - collectionStart;
return (isCollectable, collectionSeconds);
}

/**
* @notice Requires that the agreement is collectable.
* @notice Gets the start time for the collection of an agreement.
* @param _agreement The agreement data
* @return The boolean indicating if the agreement is collectable
* @return The start time for the collection of the agreement
*/
function _isCollectable(AgreementData memory _agreement) private pure returns (bool) {
return _agreement.state == AgreementState.Accepted || _agreement.state == AgreementState.CanceledByPayer;
function _agreementCollectionStartAt(AgreementData memory _agreement) private pure returns (uint256) {
return _agreement.lastCollectionAt > 0 ? _agreement.lastCollectionAt : _agreement.acceptedAt;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
(IRecurringCollector.SignedRCA memory accepted, ) = _sensibleAuthorizeAndAccept(fuzzy.fuzzyTestAccept);
IRecurringCollector.CollectParams memory collectParams = fuzzy.collectParams;

skip(1);

collectParams.agreementId = accepted.rca.agreementId;
bytes memory data = _generateCollectData(collectParams);

Expand All @@ -53,6 +55,8 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
collectParams.tokens = bound(collectParams.tokens, 1, type(uint256).max);
bytes memory data = _generateCollectData(collectParams);

skip(1);

// Set up the scenario where service provider has no tokens staked with data service
// This simulates an unauthorized data service attack
_horizonStaking.setProvision(
Expand Down
32 changes: 13 additions & 19 deletions packages/subgraph-service/contracts/libraries/IndexingAgreement.sol
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,11 @@ library IndexingAgreement {
allocation.indexer == params.indexer,
IndexingAgreementNotAuthorized(params.agreementId, params.indexer)
);
require(_isCollectable(wrapper), IndexingAgreementNotCollectable(params.agreementId));
// Get collection info from RecurringCollector (single source of truth for temporal logic)
(bool isCollectable, uint256 collectionSeconds) = _directory().recurringCollector().getCollectionInfo(
wrapper.collectorAgreement
);
require(_isValid(wrapper) && isCollectable, IndexingAgreementNotCollectable(params.agreementId));

require(
wrapper.agreement.version == IndexingAgreementVersion.V1,
Expand All @@ -540,7 +544,7 @@ library IndexingAgreement {

uint256 expectedTokens = (data.entities == 0 && data.poi == bytes32(0))
? 0
: _tokensToCollect(self, params.agreementId, wrapper.collectorAgreement, data.entities);
: _tokensToCollect(self, params.agreementId, data.entities, collectionSeconds);

// `tokensCollected` <= `expectedTokens` because the recurring collector will further narrow
// down the tokens allowed, based on the RCA terms.
Expand Down Expand Up @@ -677,28 +681,21 @@ library IndexingAgreement {
}

/**
* @notice Calculate the number of tokens to collect for an indexing agreement.
*
* @dev This function calculates the number of tokens to collect based on the agreement terms and the collection time.
*
* @param _manager The indexing agreement storage manager
* @param _agreementId The id of the agreement
* @param _agreement The collector agreement data
* @notice Calculate tokens to collect based on pre-validated duration
* @param _manager The storage manager
* @param _agreementId The agreement ID
* @param _entities The number of entities indexed
* @param _collectionSeconds Pre-calculated valid collection duration
* @return The number of tokens to collect
*/
function _tokensToCollect(
StorageManager storage _manager,
bytes16 _agreementId,
IRecurringCollector.AgreementData memory _agreement,
uint256 _entities
uint256 _entities,
uint256 _collectionSeconds
) private view returns (uint256) {
IndexingAgreementTermsV1 memory termsV1 = _manager.termsV1[_agreementId];

uint256 collectionSeconds = block.timestamp;
collectionSeconds -= _agreement.lastCollectionAt > 0 ? _agreement.lastCollectionAt : _agreement.acceptedAt;

return collectionSeconds * (termsV1.tokensPerSecond + termsV1.tokensPerEntityPerSecond * _entities);
return _collectionSeconds * (termsV1.tokensPerSecond + termsV1.tokensPerEntityPerSecond * _entities);
}

/**
Expand All @@ -721,9 +718,6 @@ library IndexingAgreement {
* @param wrapper The agreement wrapper containing the indexing agreement and collector agreement data
* @return True if the agreement is collectable, false otherwise
**/
function _isCollectable(AgreementWrapper memory wrapper) private view returns (bool) {
return _isValid(wrapper) && _directory().recurringCollector().isCollectable(wrapper.collectorAgreement);
}

/**
* @notice Checks if the agreement is valid
Expand Down