From 2fc70f5b90180383ce1dc1339ec213df960d9ef0 Mon Sep 17 00:00:00 2001 From: Trevor Richard Date: Wed, 3 Jul 2024 15:50:44 +0000 Subject: [PATCH] improve overflow protection for shutdown portion calculation --- src/PrizePool.sol | 30 +++++++++++++----------------- test/PrizePool.t.sol | 36 +++++++++++++++++++++++------------- 2 files changed, 36 insertions(+), 30 deletions(-) diff --git a/src/PrizePool.sol b/src/PrizePool.sol index 2bffd6c..4f632ec 100644 --- a/src/PrizePool.sol +++ b/src/PrizePool.sol @@ -5,7 +5,7 @@ import { SafeCast } from "openzeppelin/utils/math/SafeCast.sol"; import { IERC20 } from "openzeppelin/token/ERC20/IERC20.sol"; import { SafeERC20 } from "openzeppelin/token/ERC20/utils/SafeERC20.sol"; import { SD59x18, convert, sd } from "prb-math/SD59x18.sol"; -import { SD1x18, unwrap, UNIT } from "prb-math/SD1x18.sol"; +import { UD60x18, convert } from "prb-math/UD60x18.sol"; import { TwabController } from "pt-v5-twab-controller/TwabController.sol"; import { DrawAccumulatorLib, Observation, MAX_OBSERVATION_CARDINALITY } from "./libraries/DrawAccumulatorLib.sol"; @@ -167,14 +167,6 @@ struct ConstructorParams { uint24 drawTimeout; } -/// @notice A struct to represent a shutdown portion of liquidity for a vault and account -/// @param numerator The numerator of the portion -/// @param denominator The denominator of the portion -struct ShutdownPortion { - uint256 numerator; - uint256 denominator; -} - /// @title PoolTogether V5 Prize Pool /// @author G9 Software Inc. & PoolTogether Inc. Team /// @notice The Prize Pool holds the prize liquidity and allows vaults to claim prizes. @@ -326,7 +318,7 @@ contract PrizePool is TieredLiquidityDistributor { mapping(address vault => mapping(address account => Observation lastWithdrawalTotalContributedObservation)) internal _withdrawalObservations; /// @notice The shutdown portion of liquidity for a vault and account - mapping(address vault => mapping(address account => ShutdownPortion shutdownPortion)) internal _shutdownPortions; + mapping(address vault => mapping(address account => UD60x18 shutdownPortion)) internal _shutdownPortions; /* ============ Constructor ============ */ @@ -879,7 +871,7 @@ contract PrizePool is TieredLiquidityDistributor { /// @param _vault The vault whose contributions are measured /// @param _account The account whose vault twab is measured /// @return The portion of the shutdown balance that the account is entitled to. - function computeShutdownPortion(address _vault, address _account) public view returns (ShutdownPortion memory) { + function computeShutdownPortion(address _vault, address _account) public view returns (UD60x18) { uint24 drawIdPriorToShutdown = getShutdownDrawId() - 1; uint24 startDrawIdInclusive = computeRangeStartDrawIdInclusive(drawIdPriorToShutdown, grandPrizePeriodDraws); @@ -896,11 +888,15 @@ contract PrizePool is TieredLiquidityDistributor { drawIdPriorToShutdown ); - if (_vaultTwabTotalSupply == 0) { - return ShutdownPortion(0, 0); + if (_vaultTwabTotalSupply == 0 || totalContrib == 0) { + return UD60x18.wrap(0); } - return ShutdownPortion(vaultContrib * _userTwab, totalContrib * _vaultTwabTotalSupply); + // first division purposely done before multiplication to avoid overflow + return convert(vaultContrib) + .div(convert(totalContrib)) + .mul(convert(_userTwab)) + .div(convert(_vaultTwabTotalSupply)); } /// @notice Returns the shutdown balance for a given vault and account. The prize pool must already be shutdown. @@ -916,7 +912,7 @@ contract PrizePool is TieredLiquidityDistributor { } Observation memory withdrawalObservation = _withdrawalObservations[_vault][_account]; - ShutdownPortion memory shutdownPortion; + UD60x18 shutdownPortion; uint256 balance; // if we haven't withdrawn yet, add the portion of the shutdown balance @@ -928,7 +924,7 @@ contract PrizePool is TieredLiquidityDistributor { shutdownPortion = _shutdownPortions[_vault][_account]; } - if (shutdownPortion.denominator == 0) { + if (shutdownPortion.unwrap() == 0) { return 0; } @@ -937,7 +933,7 @@ contract PrizePool is TieredLiquidityDistributor { Observation memory newestObs = _totalAccumulator.newestObservation(); balance += (newestObs.available + newestObs.disbursed) - (withdrawalObservation.available + withdrawalObservation.disbursed); - return (shutdownPortion.numerator * balance) / shutdownPortion.denominator; + return convert(convert(balance).mul(shutdownPortion)); } /// @notice Withdraws the shutdown balance for a given vault and sender diff --git a/test/PrizePool.t.sol b/test/PrizePool.t.sol index 9c1d74a..70d7d37 100644 --- a/test/PrizePool.t.sol +++ b/test/PrizePool.t.sol @@ -6,6 +6,7 @@ import "forge-std/Test.sol"; import { ERC20 } from "openzeppelin/token/ERC20/ERC20.sol"; import { IERC20 } from "openzeppelin/token/ERC20/IERC20.sol"; import { sd, SD59x18 } from "prb-math/SD59x18.sol"; +import { UD60x18, convert } from "prb-math/UD60x18.sol"; import { UD2x18, ud2x18 } from "prb-math/UD2x18.sol"; import { SD1x18, sd1x18 } from "prb-math/SD1x18.sol"; import { TwabController } from "pt-v5-twab-controller/TwabController.sol"; @@ -44,8 +45,7 @@ import { IncompatibleTwabPeriodOffset, ClaimPeriodExpired, PrizePoolShutdown, - Observation, - ShutdownPortion + Observation } from "../src/PrizePool.sol"; import { ERC20Mintable } from "./mocks/ERC20Mintable.sol"; @@ -918,8 +918,8 @@ contract PrizePoolTest is Test { uint bobShutdownBalance = 630e18/6; uint aliceShutdownBalance = 630e18/3; - assertEq(prizePool.shutdownBalanceOf(vault, bob), bobShutdownBalance, "bob balance"); - assertEq(prizePool.shutdownBalanceOf(vault2, alice), aliceShutdownBalance, "alice balance"); + assertApproxEqAbs(prizePool.shutdownBalanceOf(vault, bob), bobShutdownBalance, 1000, "bob balance"); + assertApproxEqAbs(prizePool.shutdownBalanceOf(vault2, alice), aliceShutdownBalance, 1000, "alice balance"); assertEq(prizePool.rewardBalance(bob), 0.1e18, "bob rewards"); assertEq(prizePool.rewardBalance(alice), remainder, "alice rewards"); @@ -927,22 +927,33 @@ contract PrizePoolTest is Test { prizePool.withdrawRewards(bob, 0.1e18); vm.prank(bob); prizePool.withdrawShutdownBalance(vault, bob); - assertEq(prizeToken.balanceOf(bob), bobShutdownBalance + 0.1e18, "bob token balance"); + assertApproxEqAbs(prizeToken.balanceOf(bob), bobShutdownBalance + 0.1e18, 1000, "bob token balance"); vm.prank(alice); prizePool.withdrawShutdownBalance(vault2, alice); vm.prank(alice); prizePool.withdrawRewards(alice, remainder); - assertEq(prizeToken.balanceOf(alice), aliceShutdownBalance + remainder, "alice token balance"); + assertApproxEqAbs(prizeToken.balanceOf(alice), aliceShutdownBalance + remainder, 1000, "alice token balance"); - assertEq(prizePool.accountedBalance(), 660e18 - (630e18/6 + 630e18/3) - 0.1e18 - remainder, "final balance"); + assertApproxEqAbs(prizePool.accountedBalance(), 660e18 - (630e18/6 + 630e18/3) - 0.1e18 - remainder, 1000, "final balance"); + } + + function test_shutdownBalanceOf_noOverflow() public { + // The contribution, TWAB, and available prize token balance would overflow if multiplied together, + // but we should not see this overflow happen in the shutdown logic. + contribute(type(uint96).max, vault); + uint newTime = prizePool.shutdownAt(); + vm.warp(newTime); + mockShutdownTwab(type(uint96).max, type(uint96).max, bob, vault); + UD60x18 portion = prizePool.computeShutdownPortion(vault, bob); + assertEq(portion.unwrap(), 1e18); + assertEq(prizePool.shutdownBalanceOf(vault, bob), type(uint96).max); } function test_computeShutdownPortion_empty() public { vm.warp(prizePool.shutdownAt()); - ShutdownPortion memory portion = prizePool.computeShutdownPortion(address(this), bob); - assertEq(portion.numerator, 0); - assertEq(portion.denominator, 0); + UD60x18 portion = prizePool.computeShutdownPortion(address(this), bob); + assertEq(portion.unwrap(), 0); } function test_computeShutdownPortion_nonZero() public { @@ -950,9 +961,8 @@ contract PrizePoolTest is Test { uint newTime = prizePool.shutdownAt(); vm.warp(newTime); mockShutdownTwab(0.5e18, 1e18, bob, vault); - ShutdownPortion memory portion = prizePool.computeShutdownPortion(vault, bob); - assertEq(portion.numerator, 220e18 * 0.5e18); - assertEq(portion.denominator, 220e18 * 1e18); + UD60x18 portion = prizePool.computeShutdownPortion(vault, bob); + assertEq(portion.unwrap(), 0.5e18); } function test_withdrawShutdownBalance_notShutdown() public {