diff --git a/src/jug.sol b/src/jug.sol index 1188c9cf..fc45e07a 100644 --- a/src/jug.sol +++ b/src/jug.sol @@ -38,27 +38,30 @@ contract Jug is DSNote { // --- Math --- function rpow(uint x, uint n, uint b) internal pure returns (uint z) { - assembly { - switch x case 0 {switch n case 0 {z := b} default {z := 0}} - default { - switch mod(n, 2) case 0 { z := b } default { z := x } - let half := div(b, 2) // for rounding. - for { n := div(n, 2) } n { n := div(n,2) } { - let xx := mul(x, x) - if iszero(eq(div(xx, x), x)) { revert(0,0) } - let xxRound := add(xx, half) - if lt(xxRound, xx) { revert(0,0) } - x := div(xxRound, b) - if mod(n,2) { - let zx := mul(z, x) - if and(iszero(iszero(x)), iszero(eq(div(zx, x), z))) { revert(0,0) } - let zxRound := add(zx, half) - if lt(zxRound, zx) { revert(0,0) } - z := div(zxRound, b) + assembly { + switch n case 0 { z := b } + default { + switch x case 0 { z := 0 } + default { + switch mod(n, 2) case 0 { z := b } default { z := x } + let half := div(b, 2) // for rounding. + for { n := div(n, 2) } n { n := div(n,2) } { + let xx := mul(x, x) + if shr(128, x) { revert(0,0) } + let xxRound := add(xx, half) + if lt(xxRound, xx) { revert(0,0) } + x := div(xxRound, b) + if mod(n,2) { + let zx := mul(z, x) + if and(iszero(iszero(x)), iszero(eq(div(zx, x), z))) { revert(0,0) } + let zxRound := add(zx, half) + if lt(zxRound, zx) { revert(0,0) } + z := div(zxRound, b) + } + } + } } - } } - } } uint256 constant ONE = 10 ** 27; function add(uint x, uint y) internal pure returns (uint z) { diff --git a/src/pot.sol b/src/pot.sol index 0b18bc8f..4b070f1a 100644 --- a/src/pot.sol +++ b/src/pot.sol @@ -77,22 +77,25 @@ contract Pot is DSNote { uint256 constant ONE = 10 ** 27; function rpow(uint x, uint n, uint base) internal pure returns (uint z) { assembly { - switch x case 0 {switch n case 0 {z := base} default {z := 0}} + switch n case 0 { z := base } default { - switch mod(n, 2) case 0 { z := base } default { z := x } - let half := div(base, 2) // for rounding. - for { n := div(n, 2) } n { n := div(n,2) } { - let xx := mul(x, x) - if iszero(eq(div(xx, x), x)) { revert(0,0) } - let xxRound := add(xx, half) - if lt(xxRound, xx) { revert(0,0) } - x := div(xxRound, base) - if mod(n,2) { - let zx := mul(z, x) - if and(iszero(iszero(x)), iszero(eq(div(zx, x), z))) { revert(0,0) } - let zxRound := add(zx, half) - if lt(zxRound, zx) { revert(0,0) } - z := div(zxRound, base) + switch x case 0 { z := 0 } + default { + switch mod(n, 2) case 0 { z := base } default { z := x } + let half := div(base, 2) // for rounding. + for { n := div(n, 2) } n { n := div(n,2) } { + let xx := mul(x, x) + if shr(128, x) { revert(0,0) } + let xxRound := add(xx, half) + if lt(xxRound, xx) { revert(0,0) } + x := div(xxRound, base) + if mod(n,2) { + let zx := mul(z, x) + if and(iszero(iszero(x)), iszero(eq(div(zx, x), z))) { revert(0,0) } + let zxRound := add(zx, half) + if lt(zxRound, zx) { revert(0,0) } + z := div(zxRound, base) + } } } }