-
Notifications
You must be signed in to change notification settings - Fork 132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement SafeRatio
#121
Merged
Implement SafeRatio
#121
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
from typing import List, Tuple, TYPE_CHECKING | ||
|
||
from ..types import TealType, require_type | ||
from ..errors import TealInputError, TealInternalError, TealCompileError | ||
from ..ir import TealOp, Op, TealSimpleBlock, TealBlock | ||
from .expr import Expr | ||
|
||
if TYPE_CHECKING: | ||
from ..compiler import CompileOptions | ||
|
||
|
||
def multiplyFactors( | ||
expr: Expr, factors: List[Expr], options: "CompileOptions" | ||
) -> Tuple[TealSimpleBlock, TealSimpleBlock]: | ||
if len(factors) == 0: | ||
raise TealInternalError("Received 0 factors") | ||
|
||
start = TealSimpleBlock([]) | ||
|
||
fac0Start, fac0End = factors[0].__teal__(options) | ||
|
||
if len(factors) == 1: | ||
# need to use 0 as high word | ||
highword = TealSimpleBlock([TealOp(expr, Op.int, 0)]) | ||
|
||
start.setNextBlock(highword) | ||
highword.setNextBlock(fac0Start) | ||
|
||
end = fac0End | ||
else: | ||
start.setNextBlock(fac0Start) | ||
|
||
fac1Start, fac1End = factors[1].__teal__(options) | ||
fac0End.setNextBlock(fac1Start) | ||
|
||
multiplyFirst2 = TealSimpleBlock([TealOp(expr, Op.mulw)]) | ||
fac1End.setNextBlock(multiplyFirst2) | ||
|
||
end = multiplyFirst2 | ||
for factor in factors[2:]: | ||
facXStart, facXEnd = factor.__teal__(options) | ||
end.setNextBlock(facXStart) | ||
|
||
# stack is [..., A, B, C], where C is current factor | ||
# need to pop all A,B,C from stack and push X,Y, where X and Y are: | ||
# X * 2**64 + Y = (A * 2**64 + B) * C | ||
# <=> X * 2**64 + Y = A * C * 2**64 + B * C | ||
# <=> X = A * C + highword(B * C) | ||
# Y = lowword(B * C) | ||
multiply = TealSimpleBlock( | ||
[ | ||
TealOp(expr, Op.uncover, 2), # stack: [..., B, C, A] | ||
TealOp(expr, Op.dig, 1), # stack: [..., B, C, A, C] | ||
TealOp(expr, Op.mul), # stack: [..., B, C, A*C] | ||
TealOp(expr, Op.cover, 2), # stack: [..., A*C, B, C] | ||
TealOp( | ||
expr, Op.mulw | ||
), # stack: [..., A*C, highword(B*C), lowword(B*C)] | ||
TealOp( | ||
expr, Op.cover, 2 | ||
), # stack: [..., lowword(B*C), A*C, highword(B*C)] | ||
TealOp( | ||
expr, Op.add | ||
), # stack: [..., lowword(B*C), A*C+highword(B*C)] | ||
TealOp( | ||
expr, Op.swap | ||
), # stack: [..., A*C+highword(B*C), lowword(B*C)] | ||
] | ||
) | ||
|
||
facXEnd.setNextBlock(multiply) | ||
end = multiply | ||
|
||
return start, end | ||
|
||
|
||
class SafeRatio(Expr): | ||
"""A class used to calculate expressions of the form :code:`(N_1 * N_2 * N_3 * ...) / (D_1 * D_2 * D_3 * ...)` | ||
|
||
Use this class if all inputs to the expression are uint64s, the output fits in a uint64, and all | ||
intermediate values fit in a uint128. | ||
""" | ||
|
||
def __init__( | ||
self, numeratorFactors: List[Expr], denominatorFactors: List[Expr] | ||
) -> None: | ||
"""Create a new SafeRatio expression with the given numerator and denominator factors. | ||
|
||
This will calculate :code:`(N_1 * N_2 * N_3 * ...) / (D_1 * D_2 * D_3 * ...)`, where each | ||
:code:`N_i` represents an element in :code:`numeratorFactors` and each :code:`D_i` | ||
represents an element in :code:`denominatorFactors`. | ||
|
||
Requires TEAL version 5 or higher. | ||
|
||
Args: | ||
numeratorFactors: The factors in the numerator of the ratio. This list must have at | ||
least 1 element. If this list has exactly 1 element, then denominatorFactors must | ||
have more than 1 element (otherwise basic division should be used). | ||
denominatorFactors: The factors in the denominator of the ratio. This list must have at | ||
least 1 element. | ||
""" | ||
super().__init__() | ||
if len(numeratorFactors) == 0 or len(denominatorFactors) == 0: | ||
raise TealInternalError( | ||
"At least 1 factor must be present in the numerator and denominator" | ||
) | ||
if len(numeratorFactors) == 1 and len(denominatorFactors) == 1: | ||
raise TealInternalError( | ||
"There is only a single factor in the numerator and denominator. Use basic division instead." | ||
) | ||
self.numeratorFactors = numeratorFactors | ||
self.denominatorFactors = denominatorFactors | ||
|
||
def __teal__(self, options: "CompileOptions"): | ||
if options.version < Op.cover.min_version: | ||
raise TealCompileError( | ||
"SafeRatio requires TEAL version {} or higher".format( | ||
Op.cover.min_version | ||
), | ||
self, | ||
) | ||
|
||
numStart, numEnd = multiplyFactors(self, self.numeratorFactors, options) | ||
denomStart, denomEnd = multiplyFactors(self, self.denominatorFactors, options) | ||
numEnd.setNextBlock(denomStart) | ||
|
||
combine = TealSimpleBlock( | ||
[ | ||
TealOp(self, Op.divmodw), | ||
TealOp(self, Op.pop), # pop remainder low word | ||
TealOp(self, Op.pop), # pop remainder high word | ||
TealOp(self, Op.swap), # swap quotient high and low words | ||
TealOp(self, Op.logic_not), | ||
TealOp(self, Op.assert_), # assert quotient high word is 0 | ||
# end with quotient low word remaining on the stack | ||
] | ||
) | ||
denomEnd.setNextBlock(combine) | ||
|
||
return numStart, combine | ||
|
||
def __str__(self): | ||
ret_str = "(SafeRatio (*" | ||
for f in self.numeratorFactors: | ||
ret_str += " " + str(f) | ||
ret_str += ") (*" | ||
for f in self.denominatorFactors: | ||
ret_str += " " + str(f) | ||
ret_str += ")" | ||
return ret_str | ||
|
||
def type_of(self): | ||
return TealType.uint64 | ||
|
||
def has_return(self): | ||
return False | ||
|
||
|
||
SafeRatio.__module__ = "pyteal" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1337,3 +1337,157 @@ def storeValue(key: Expr, t1: Expr, t2: Expr, t3: Expr) -> Expr: | |
""".strip() | ||
actual = compileTeal(program, Mode.Application, version=4, assembleConstants=True) | ||
assert actual == expected | ||
|
||
|
||
def test_compile_safe_ratio(): | ||
cases = ( | ||
( | ||
SafeRatio([Int(2), Int(100)], [Int(5)]), | ||
"""#pragma version 5 | ||
int 2 | ||
int 100 | ||
mulw | ||
int 0 | ||
int 5 | ||
divmodw | ||
pop | ||
pop | ||
swap | ||
! | ||
assert | ||
return | ||
""", | ||
Comment on lines
+1345
to
+1359
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This convinces me about the "no overhead" on A*B/C |
||
), | ||
( | ||
SafeRatio([Int(2), Int(100)], [Int(10), Int(5)]), | ||
"""#pragma version 5 | ||
int 2 | ||
int 100 | ||
mulw | ||
int 10 | ||
int 5 | ||
mulw | ||
divmodw | ||
pop | ||
pop | ||
swap | ||
! | ||
assert | ||
return | ||
""", | ||
), | ||
( | ||
SafeRatio([Int(2), Int(100), Int(3)], [Int(10), Int(5)]), | ||
"""#pragma version 5 | ||
int 2 | ||
int 100 | ||
mulw | ||
int 3 | ||
uncover 2 | ||
dig 1 | ||
* | ||
cover 2 | ||
mulw | ||
cover 2 | ||
+ | ||
swap | ||
int 10 | ||
int 5 | ||
mulw | ||
divmodw | ||
pop | ||
pop | ||
swap | ||
! | ||
assert | ||
return | ||
""", | ||
), | ||
( | ||
SafeRatio([Int(2), Int(100), Int(3)], [Int(10), Int(5), Int(6)]), | ||
"""#pragma version 5 | ||
int 2 | ||
int 100 | ||
mulw | ||
int 3 | ||
uncover 2 | ||
dig 1 | ||
* | ||
cover 2 | ||
mulw | ||
cover 2 | ||
+ | ||
swap | ||
int 10 | ||
int 5 | ||
mulw | ||
int 6 | ||
uncover 2 | ||
dig 1 | ||
* | ||
cover 2 | ||
mulw | ||
cover 2 | ||
+ | ||
swap | ||
divmodw | ||
pop | ||
pop | ||
swap | ||
! | ||
assert | ||
return | ||
""", | ||
), | ||
( | ||
SafeRatio([Int(2), Int(100), Int(3), Int(4)], [Int(10), Int(5), Int(6)]), | ||
"""#pragma version 5 | ||
int 2 | ||
int 100 | ||
mulw | ||
int 3 | ||
uncover 2 | ||
dig 1 | ||
* | ||
cover 2 | ||
mulw | ||
cover 2 | ||
+ | ||
swap | ||
int 4 | ||
uncover 2 | ||
dig 1 | ||
* | ||
cover 2 | ||
mulw | ||
cover 2 | ||
+ | ||
swap | ||
int 10 | ||
int 5 | ||
mulw | ||
int 6 | ||
uncover 2 | ||
dig 1 | ||
* | ||
cover 2 | ||
mulw | ||
cover 2 | ||
+ | ||
swap | ||
divmodw | ||
pop | ||
pop | ||
swap | ||
! | ||
assert | ||
return | ||
""", | ||
), | ||
) | ||
|
||
for program, expected in cases: | ||
actual = compileTeal( | ||
program, Mode.Application, version=5, assembleConstants=False | ||
) | ||
assert actual == expected.strip() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the numerator factors together have to fit in u128? I suppose it would be better to do the math in different orders sometimes, multiply then divide, and so on, so the intermediates are small enough. If we're not doing that (and it's data dependent, so I don't think we can) then I'm not sure of the value of this compared to A*B/C. Did you run into the need to do this all at once?
If there's no overhead in the generated code in the simple A*B/C case, I suppose it's fine though. I'll try to see if that's the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason to allow > 2 numerators is because I believe an interface that only allows
A*B/C
is pretty restrictive. If we only exposed that interface and you needed to have 3 numerators, then you would have to do(A_1*A_2)*B/C
, which is strictly worse sinceA_1*A_2
has to fit in a uint64, instead ofA_1*A_2*B
having to fit in a uint128.Plus, due to integer rounding behavior during division, I'm not sure it would be possible to implement the optimization that you described.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could do A_1; A_2; mulw; C; divmodw; B mul. (taking some liberties with stack management in there). That would keep intermediates smaller. But I don't know how badly rounding would compound imprecision.
Anyway, since this does A*B/C efficiently, I don't see any reason to complain much about it doing more if you can ask it to do more. Though I guess I'd argue that the docs should be clearer about the order of the steps, so it's clearer what is meant by "intermediate values".