Skip to content

Commit 892a2d3

Browse files
committed
Refactor for Dialect.
1 parent 1a5cb17 commit 892a2d3

13 files changed

+453
-215
lines changed

clvm/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .SExp import SExp
2+
from .dialect import Dialect # noqa
23
from .operators import ( # noqa
3-
QUOTE_ATOM,
4+
QUOTE_ATOM, # deprecated
45
KEYWORD_TO_ATOM,
56
KEYWORD_FROM_ATOM,
67
)

clvm/chainable_multi_op_fn.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from dataclasses import dataclass
2+
from typing import Optional, Tuple
3+
4+
from .types import CLVMObjectType, MultiOpFn, OperatorDict
5+
6+
7+
@dataclass
8+
class ChainableMultiOpFn:
9+
"""
10+
This structure handles clvm operators. Given an atom, it looks it up in a `dict`, then
11+
falls back to calling `unknown_op_handler`.
12+
"""
13+
op_lookup: OperatorDict
14+
unknown_op_handler: MultiOpFn
15+
16+
def __call__(
17+
self, op: bytes, arguments: CLVMObjectType, max_cost: Optional[int] = None
18+
) -> Tuple[int, CLVMObjectType]:
19+
f = self.op_lookup.get(op)
20+
if f:
21+
try:
22+
return f(arguments)
23+
except TypeError:
24+
# some operators require `max_cost`
25+
return f(arguments, max_cost)
26+
return self.unknown_op_handler(op, arguments, max_cost)

clvm/chia_dialect.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from .casts import int_to_bytes
2+
from .dialect import ConversionFn, Dialect, new_dialect, opcode_table_for_backend
3+
4+
KEYWORDS = (
5+
# core opcodes 0x01-x08
6+
". q a i c f r l x "
7+
8+
# opcodes on atoms as strings 0x09-0x0f
9+
"= >s sha256 substr strlen concat . "
10+
11+
# opcodes on atoms as ints 0x10-0x17
12+
"+ - * / divmod > ash lsh "
13+
14+
# opcodes on atoms as vectors of bools 0x18-0x1c
15+
"logand logior logxor lognot . "
16+
17+
# opcodes for bls 1381 0x1d-0x1f
18+
"point_add pubkey_for_exp . "
19+
20+
# bool opcodes 0x20-0x23
21+
"not any all . "
22+
23+
# misc 0x24
24+
"softfork "
25+
).split()
26+
27+
KEYWORD_FROM_ATOM = {int_to_bytes(k): v for k, v in enumerate(KEYWORDS)}
28+
KEYWORD_TO_ATOM = {v: k for k, v in KEYWORD_FROM_ATOM.items()}
29+
30+
31+
def chia_dialect(strict: bool, to_python: ConversionFn, backend=None) -> Dialect:
32+
quote_kw = KEYWORD_TO_ATOM["q"]
33+
apply_kw = KEYWORD_TO_ATOM["a"]
34+
dialect = new_dialect(quote_kw, apply_kw, strict, to_python, backend=backend)
35+
table = opcode_table_for_backend(KEYWORD_TO_ATOM, backend=backend)
36+
dialect.update(table)
37+
return dialect

clvm/dialect.py

+159
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
from typing import Callable, Optional, Tuple
2+
3+
try:
4+
import clvm_rs
5+
except ImportError:
6+
clvm_rs = None
7+
8+
from . import core_ops, more_ops
9+
from .chainable_multi_op_fn import ChainableMultiOpFn
10+
from .handle_unknown_op import (
11+
handle_unknown_op_softfork_ready,
12+
handle_unknown_op_strict,
13+
)
14+
from .run_program import _run_program
15+
from .types import CLVMObjectType, ConversionFn, MultiOpFn, OperatorDict
16+
17+
18+
OP_REWRITE = {
19+
"+": "add",
20+
"-": "subtract",
21+
"*": "multiply",
22+
"/": "div",
23+
"i": "if",
24+
"c": "cons",
25+
"f": "first",
26+
"r": "rest",
27+
"l": "listp",
28+
"x": "raise",
29+
"=": "eq",
30+
">": "gr",
31+
">s": "gr_bytes",
32+
}
33+
34+
35+
def op_table_for_module(mod):
36+
37+
# python-implemented operators don't take `max_cost` and rust-implemented operators do
38+
# So we make the `max_cost` operator optional with this trick
39+
# TODO: have python-implemented ops also take `max_cost` and unify the API.
40+
41+
def elide_max_cost(f):
42+
def inner_op(sexp, max_cost=None):
43+
try:
44+
return f(sexp, max_cost)
45+
except TypeError:
46+
return f(sexp)
47+
return inner_op
48+
49+
return {k: elide_max_cost(v) for k, v in mod.__dict__.items() if k.startswith("op_")}
50+
51+
52+
def op_imp_table_for_backend(backend):
53+
if backend is None and clvm_rs:
54+
backend = "native"
55+
56+
if backend == "native":
57+
if clvm_rs is None:
58+
raise RuntimeError("native backend not installed")
59+
return clvm_rs.native_opcodes_dict()
60+
61+
table = {}
62+
table.update(op_table_for_module(core_ops))
63+
table.update(op_table_for_module(more_ops))
64+
return table
65+
66+
67+
def op_atom_to_imp_table(op_imp_table, keyword_to_atom, op_rewrite=OP_REWRITE):
68+
op_atom_to_imp_table = {}
69+
for op, bytecode in keyword_to_atom.items():
70+
op_name = "op_%s" % op_rewrite.get(op, op)
71+
op_f = op_imp_table.get(op_name)
72+
if op_f:
73+
op_atom_to_imp_table[bytecode] = op_f
74+
return op_atom_to_imp_table
75+
76+
77+
def opcode_table_for_backend(keyword_to_atom, backend):
78+
op_imp_table = op_imp_table_for_backend(backend)
79+
return op_atom_to_imp_table(op_imp_table, keyword_to_atom)
80+
81+
82+
class Dialect:
83+
def __init__(
84+
self,
85+
quote_kw: bytes,
86+
apply_kw: bytes,
87+
multi_op_fn: MultiOpFn,
88+
to_python: ConversionFn,
89+
):
90+
self.quote_kw = quote_kw
91+
self.apply_kw = apply_kw
92+
self.opcode_lookup = dict()
93+
self.multi_op_fn = ChainableMultiOpFn(self.opcode_lookup, multi_op_fn)
94+
self.to_python = to_python
95+
96+
def update(self, d: OperatorDict) -> None:
97+
self.opcode_lookup.update(d)
98+
99+
def clear(self) -> None:
100+
self.opcode_lookup.clear()
101+
102+
def run_program(
103+
self,
104+
program: CLVMObjectType,
105+
env: CLVMObjectType,
106+
max_cost: int,
107+
pre_eval_f: Optional[
108+
Callable[[CLVMObjectType, CLVMObjectType], Tuple[int, CLVMObjectType]]
109+
] = None,
110+
) -> Tuple[int, CLVMObjectType]:
111+
cost, r = _run_program(
112+
program,
113+
env,
114+
self.multi_op_fn,
115+
self.quote_kw,
116+
self.apply_kw,
117+
max_cost,
118+
pre_eval_f,
119+
)
120+
return cost, self.to_python(r)
121+
122+
123+
def native_new_dialect(
124+
quote_kw: bytes, apply_kw: bytes, strict: bool, to_python: ConversionFn
125+
) -> Dialect:
126+
unknown_op_callback = (
127+
clvm_rs.NATIVE_OP_UNKNOWN_STRICT
128+
if strict
129+
else clvm_rs.NATIVE_OP_UNKNOWN_NON_STRICT
130+
)
131+
dialect = clvm_rs.Dialect(
132+
quote_kw,
133+
apply_kw,
134+
unknown_op_callback,
135+
to_python=to_python,
136+
)
137+
return dialect
138+
139+
140+
def python_new_dialect(
141+
quote_kw: bytes, apply_kw: bytes, strict: bool, to_python: ConversionFn
142+
) -> Dialect:
143+
unknown_op_callback = (
144+
handle_unknown_op_strict if strict else handle_unknown_op_softfork_ready
145+
)
146+
dialect = Dialect(
147+
quote_kw,
148+
apply_kw,
149+
unknown_op_callback,
150+
to_python=to_python,
151+
)
152+
return dialect
153+
154+
155+
def new_dialect(quote_kw: bytes, apply_kw: bytes, strict: bool, to_python: ConversionFn, backend=None):
156+
if backend is None:
157+
backend = "python" if clvm_rs is None else "native"
158+
backend_f = native_new_dialect if backend == "native" else python_new_dialect
159+
return backend_f(quote_kw, apply_kw, strict, to_python)

clvm/handle_unknown_op.py

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from typing import Tuple
2+
3+
from .CLVMObject import CLVMObject
4+
from .EvalError import EvalError
5+
6+
from .costs import (
7+
ARITH_BASE_COST,
8+
ARITH_COST_PER_BYTE,
9+
ARITH_COST_PER_ARG,
10+
MUL_BASE_COST,
11+
MUL_COST_PER_OP,
12+
MUL_LINEAR_COST_PER_BYTE,
13+
MUL_SQUARE_COST_PER_BYTE_DIVIDER,
14+
CONCAT_BASE_COST,
15+
CONCAT_COST_PER_ARG,
16+
CONCAT_COST_PER_BYTE,
17+
)
18+
19+
20+
def handle_unknown_op_strict(op, arguments, _max_cost=None):
21+
raise EvalError("unimplemented operator", arguments.to(op))
22+
23+
24+
def args_len(op_name, args):
25+
for arg in args.as_iter():
26+
if arg.pair:
27+
raise EvalError("%s requires int args" % op_name, arg)
28+
yield len(arg.as_atom())
29+
30+
31+
# unknown ops are reserved if they start with 0xffff
32+
# otherwise, unknown ops are no-ops, but they have costs. The cost is computed
33+
# like this:
34+
35+
# byte index (reverse):
36+
# | 4 | 3 | 2 | 1 | 0 |
37+
# +---+---+---+---+------------+
38+
# | multiplier |XX | XXXXXX |
39+
# +---+---+---+---+---+--------+
40+
# ^ ^ ^
41+
# | | + 6 bits ignored when computing cost
42+
# cost_multiplier |
43+
# + 2 bits
44+
# cost_function
45+
46+
# 1 is always added to the multiplier before using it to multiply the cost, this
47+
# is since cost may not be 0.
48+
49+
# cost_function is 2 bits and defines how cost is computed based on arguments:
50+
# 0: constant, cost is 1 * (multiplier + 1)
51+
# 1: computed like operator add, multiplied by (multiplier + 1)
52+
# 2: computed like operator mul, multiplied by (multiplier + 1)
53+
# 3: computed like operator concat, multiplied by (multiplier + 1)
54+
55+
# this means that unknown ops where cost_function is 1, 2, or 3, may still be
56+
# fatal errors if the arguments passed are not atoms.
57+
58+
59+
def handle_unknown_op_softfork_ready(
60+
op: bytes, args: CLVMObject, max_cost: int
61+
) -> Tuple[int, CLVMObject]:
62+
# any opcode starting with ffff is reserved (i.e. fatal error)
63+
# opcodes are not allowed to be empty
64+
if len(op) == 0 or op[:2] == b"\xff\xff":
65+
raise EvalError("reserved operator", args.to(op))
66+
67+
# all other unknown opcodes are no-ops
68+
# the cost of the no-ops is determined by the opcode number, except the
69+
# 6 least significant bits.
70+
71+
cost_function = (op[-1] & 0b11000000) >> 6
72+
# the multiplier cannot be 0. it starts at 1
73+
74+
if len(op) > 5:
75+
raise EvalError("invalid operator", args.to(op))
76+
77+
cost_multiplier = int.from_bytes(op[:-1], "big", signed=False) + 1
78+
79+
# 0 = constant
80+
# 1 = like op_add/op_sub
81+
# 2 = like op_multiply
82+
# 3 = like op_concat
83+
if cost_function == 0:
84+
cost = 1
85+
elif cost_function == 1:
86+
# like op_add
87+
cost = ARITH_BASE_COST
88+
arg_size = 0
89+
for length in args_len("unknown op", args):
90+
arg_size += length
91+
cost += ARITH_COST_PER_ARG
92+
cost += arg_size * ARITH_COST_PER_BYTE
93+
elif cost_function == 2:
94+
# like op_multiply
95+
cost = MUL_BASE_COST
96+
operands = args_len("unknown op", args)
97+
try:
98+
vs = next(operands)
99+
for rs in operands:
100+
cost += MUL_COST_PER_OP
101+
cost += (rs + vs) * MUL_LINEAR_COST_PER_BYTE
102+
cost += (rs * vs) // MUL_SQUARE_COST_PER_BYTE_DIVIDER
103+
# this is an estimate, since we don't want to actually multiply the
104+
# values
105+
vs += rs
106+
except StopIteration:
107+
pass
108+
109+
elif cost_function == 3:
110+
# like concat
111+
cost = CONCAT_BASE_COST
112+
length = 0
113+
for arg in args.as_iter():
114+
if arg.pair:
115+
raise EvalError("unknown op on list", arg)
116+
cost += CONCAT_COST_PER_ARG
117+
length += len(arg.atom)
118+
cost += length * CONCAT_COST_PER_BYTE
119+
120+
cost *= cost_multiplier
121+
if cost >= 2**32:
122+
raise EvalError("invalid operator", args.to(op))
123+
124+
return (cost, args.to(b""))

0 commit comments

Comments
 (0)