Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SafeRatio #121

Merged
merged 3 commits into from
Sep 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyteal/ast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@

# more ops
from .naryexpr import NaryExpr, And, Or, Concat
from .safemath import SafeRatio

# control flow
from .if_ import If
Expand Down Expand Up @@ -194,6 +195,7 @@
"And",
"Or",
"Concat",
"SafeRatio",
"If",
"Cond",
"Seq",
Expand Down
159 changes: 159 additions & 0 deletions pyteal/ast/safemath.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from typing import List, Tuple, TYPE_CHECKING

from ..types import TealType, require_type
from ..errors import TealInputError, TealInternalError, TealCompileError
from ..ir import TealOp, Op, TealSimpleBlock, TealBlock
from .expr import Expr

if TYPE_CHECKING:
from ..compiler import CompileOptions


def multiplyFactors(
expr: Expr, factors: List[Expr], options: "CompileOptions"
) -> Tuple[TealSimpleBlock, TealSimpleBlock]:
if len(factors) == 0:
raise TealInternalError("Received 0 factors")

start = TealSimpleBlock([])

fac0Start, fac0End = factors[0].__teal__(options)

if len(factors) == 1:
# need to use 0 as high word
highword = TealSimpleBlock([TealOp(expr, Op.int, 0)])

start.setNextBlock(highword)
highword.setNextBlock(fac0Start)

end = fac0End
else:
start.setNextBlock(fac0Start)

fac1Start, fac1End = factors[1].__teal__(options)
fac0End.setNextBlock(fac1Start)

multiplyFirst2 = TealSimpleBlock([TealOp(expr, Op.mulw)])
fac1End.setNextBlock(multiplyFirst2)

end = multiplyFirst2
for factor in factors[2:]:
facXStart, facXEnd = factor.__teal__(options)
end.setNextBlock(facXStart)

# stack is [..., A, B, C], where C is current factor
# need to pop all A,B,C from stack and push X,Y, where X and Y are:
# X * 2**64 + Y = (A * 2**64 + B) * C
# <=> X * 2**64 + Y = A * C * 2**64 + B * C
# <=> X = A * C + highword(B * C)
# Y = lowword(B * C)
multiply = TealSimpleBlock(
[
TealOp(expr, Op.uncover, 2), # stack: [..., B, C, A]
TealOp(expr, Op.dig, 1), # stack: [..., B, C, A, C]
TealOp(expr, Op.mul), # stack: [..., B, C, A*C]
TealOp(expr, Op.cover, 2), # stack: [..., A*C, B, C]
TealOp(
expr, Op.mulw
), # stack: [..., A*C, highword(B*C), lowword(B*C)]
TealOp(
expr, Op.cover, 2
), # stack: [..., lowword(B*C), A*C, highword(B*C)]
TealOp(
expr, Op.add
), # stack: [..., lowword(B*C), A*C+highword(B*C)]
TealOp(
expr, Op.swap
), # stack: [..., A*C+highword(B*C), lowword(B*C)]
]
)

facXEnd.setNextBlock(multiply)
end = multiply

return start, end


class SafeRatio(Expr):
"""A class used to calculate expressions of the form :code:`(N_1 * N_2 * N_3 * ...) / (D_1 * D_2 * D_3 * ...)`

Use this class if all inputs to the expression are uint64s, the output fits in a uint64, and all
intermediate values fit in a uint128.
Comment on lines +80 to +81
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the numerator factors together have to fit in u128? I suppose it would be better to do the math in different orders sometimes, multiply then divide, and so on, so the intermediates are small enough. If we're not doing that (and it's data dependent, so I don't think we can) then I'm not sure of the value of this compared to A*B/C. Did you run into the need to do this all at once?

If there's no overhead in the generated code in the simple A*B/C case, I suppose it's fine though. I'll try to see if that's the case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason to allow > 2 numerators is because I believe an interface that only allows A*B/C is pretty restrictive. If we only exposed that interface and you needed to have 3 numerators, then you would have to do (A_1*A_2)*B/C, which is strictly worse since A_1*A_2 has to fit in a uint64, instead of A_1*A_2*B having to fit in a uint128.

Plus, due to integer rounding behavior during division, I'm not sure it would be possible to implement the optimization that you described.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could do A_1; A_2; mulw; C; divmodw; B mul. (taking some liberties with stack management in there). That would keep intermediates smaller. But I don't know how badly rounding would compound imprecision.

Anyway, since this does A*B/C efficiently, I don't see any reason to complain much about it doing more if you can ask it to do more. Though I guess I'd argue that the docs should be clearer about the order of the steps, so it's clearer what is meant by "intermediate values".

"""

def __init__(
self, numeratorFactors: List[Expr], denominatorFactors: List[Expr]
) -> None:
"""Create a new SafeRatio expression with the given numerator and denominator factors.

This will calculate :code:`(N_1 * N_2 * N_3 * ...) / (D_1 * D_2 * D_3 * ...)`, where each
:code:`N_i` represents an element in :code:`numeratorFactors` and each :code:`D_i`
represents an element in :code:`denominatorFactors`.

Requires TEAL version 5 or higher.

Args:
numeratorFactors: The factors in the numerator of the ratio. This list must have at
least 1 element. If this list has exactly 1 element, then denominatorFactors must
have more than 1 element (otherwise basic division should be used).
denominatorFactors: The factors in the denominator of the ratio. This list must have at
least 1 element.
"""
super().__init__()
if len(numeratorFactors) == 0 or len(denominatorFactors) == 0:
raise TealInternalError(
"At least 1 factor must be present in the numerator and denominator"
)
if len(numeratorFactors) == 1 and len(denominatorFactors) == 1:
raise TealInternalError(
"There is only a single factor in the numerator and denominator. Use basic division instead."
)
self.numeratorFactors = numeratorFactors
self.denominatorFactors = denominatorFactors

def __teal__(self, options: "CompileOptions"):
if options.version < Op.cover.min_version:
raise TealCompileError(
"SafeRatio requires TEAL version {} or higher".format(
Op.cover.min_version
),
self,
)

numStart, numEnd = multiplyFactors(self, self.numeratorFactors, options)
denomStart, denomEnd = multiplyFactors(self, self.denominatorFactors, options)
numEnd.setNextBlock(denomStart)

combine = TealSimpleBlock(
[
TealOp(self, Op.divmodw),
TealOp(self, Op.pop), # pop remainder low word
TealOp(self, Op.pop), # pop remainder high word
TealOp(self, Op.swap), # swap quotient high and low words
TealOp(self, Op.logic_not),
TealOp(self, Op.assert_), # assert quotient high word is 0
# end with quotient low word remaining on the stack
]
)
denomEnd.setNextBlock(combine)

return numStart, combine

def __str__(self):
ret_str = "(SafeRatio (*"
for f in self.numeratorFactors:
ret_str += " " + str(f)
ret_str += ") (*"
for f in self.denominatorFactors:
ret_str += " " + str(f)
ret_str += ")"
return ret_str

def type_of(self):
return TealType.uint64

def has_return(self):
return False


SafeRatio.__module__ = "pyteal"
154 changes: 154 additions & 0 deletions pyteal/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,3 +1337,157 @@ def storeValue(key: Expr, t1: Expr, t2: Expr, t3: Expr) -> Expr:
""".strip()
actual = compileTeal(program, Mode.Application, version=4, assembleConstants=True)
assert actual == expected


def test_compile_safe_ratio():
cases = (
(
SafeRatio([Int(2), Int(100)], [Int(5)]),
"""#pragma version 5
int 2
int 100
mulw
int 0
int 5
divmodw
pop
pop
swap
!
assert
return
""",
Comment on lines +1345 to +1359
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This convinces me about the "no overhead" on A*B/C

),
(
SafeRatio([Int(2), Int(100)], [Int(10), Int(5)]),
"""#pragma version 5
int 2
int 100
mulw
int 10
int 5
mulw
divmodw
pop
pop
swap
!
assert
return
""",
),
(
SafeRatio([Int(2), Int(100), Int(3)], [Int(10), Int(5)]),
"""#pragma version 5
int 2
int 100
mulw
int 3
uncover 2
dig 1
*
cover 2
mulw
cover 2
+
swap
int 10
int 5
mulw
divmodw
pop
pop
swap
!
assert
return
""",
),
(
SafeRatio([Int(2), Int(100), Int(3)], [Int(10), Int(5), Int(6)]),
"""#pragma version 5
int 2
int 100
mulw
int 3
uncover 2
dig 1
*
cover 2
mulw
cover 2
+
swap
int 10
int 5
mulw
int 6
uncover 2
dig 1
*
cover 2
mulw
cover 2
+
swap
divmodw
pop
pop
swap
!
assert
return
""",
),
(
SafeRatio([Int(2), Int(100), Int(3), Int(4)], [Int(10), Int(5), Int(6)]),
"""#pragma version 5
int 2
int 100
mulw
int 3
uncover 2
dig 1
*
cover 2
mulw
cover 2
+
swap
int 4
uncover 2
dig 1
*
cover 2
mulw
cover 2
+
swap
int 10
int 5
mulw
int 6
uncover 2
dig 1
*
cover 2
mulw
cover 2
+
swap
divmodw
pop
pop
swap
!
assert
return
""",
),
)

for program, expected in cases:
actual = compileTeal(
program, Mode.Application, version=5, assembleConstants=False
)
assert actual == expected.strip()