Skip to content

Commit f5e56fb

Browse files
authored
fix: update/enable bn128 tests (bluealloy#1242)
1 parent b6bda4c commit f5e56fb

File tree

1 file changed

+79
-84
lines changed

1 file changed

+79
-84
lines changed

crates/precompile/src/bn128.rs

+79-84
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,22 @@ use crate::{
33
Address, Error, Precompile, PrecompileResult, PrecompileWithAddress,
44
};
55
use bn::{AffineG1, AffineG2, Fq, Fq2, Group, Gt, G1, G2};
6-
use revm_primitives::Bytes;
76

87
pub mod add {
98
use super::*;
109

1110
const ADDRESS: Address = crate::u64_to_address(6);
1211

12+
pub const ISTANBUL_ADD_GAS_COST: u64 = 150;
1313
pub const ISTANBUL: PrecompileWithAddress = PrecompileWithAddress(
1414
ADDRESS,
15-
Precompile::Standard(|input, gas_limit| {
16-
if 150 > gas_limit {
17-
return Err(Error::OutOfGas);
18-
}
19-
Ok((150, super::run_add(input)?))
20-
}),
15+
Precompile::Standard(|input, gas_limit| run_add(input, ISTANBUL_ADD_GAS_COST, gas_limit)),
2116
);
2217

18+
pub const BYZANTIUM_ADD_GAS_COST: u64 = 500;
2319
pub const BYZANTIUM: PrecompileWithAddress = PrecompileWithAddress(
2420
ADDRESS,
25-
Precompile::Standard(|input, gas_limit| {
26-
if 500 > gas_limit {
27-
return Err(Error::OutOfGas);
28-
}
29-
Ok((500, super::run_add(input)?))
30-
}),
21+
Precompile::Standard(|input, gas_limit| run_add(input, BYZANTIUM_ADD_GAS_COST, gas_limit)),
3122
);
3223
}
3324

@@ -36,24 +27,16 @@ pub mod mul {
3627

3728
const ADDRESS: Address = crate::u64_to_address(7);
3829

30+
pub const ISTANBUL_MUL_GAS_COST: u64 = 6_000;
3931
pub const ISTANBUL: PrecompileWithAddress = PrecompileWithAddress(
4032
ADDRESS,
41-
Precompile::Standard(|input, gas_limit| {
42-
if 6_000 > gas_limit {
43-
return Err(Error::OutOfGas);
44-
}
45-
Ok((6_000, super::run_mul(input)?))
46-
}),
33+
Precompile::Standard(|input, gas_limit| run_mul(input, ISTANBUL_MUL_GAS_COST, gas_limit)),
4734
);
4835

36+
pub const BYZANTIUM_MUL_GAS_COST: u64 = 40_000;
4937
pub const BYZANTIUM: PrecompileWithAddress = PrecompileWithAddress(
5038
ADDRESS,
51-
Precompile::Standard(|input, gas_limit| {
52-
if 40_000 > gas_limit {
53-
return Err(Error::OutOfGas);
54-
}
55-
Ok((40_000, super::run_mul(input)?))
56-
}),
39+
Precompile::Standard(|input, gas_limit| run_mul(input, BYZANTIUM_MUL_GAS_COST, gas_limit)),
5740
);
5841
}
5942

@@ -67,7 +50,7 @@ pub mod pair {
6750
pub const ISTANBUL: PrecompileWithAddress = PrecompileWithAddress(
6851
ADDRESS,
6952
Precompile::Standard(|input, gas_limit| {
70-
super::run_pair(
53+
run_pair(
7154
input,
7255
ISTANBUL_PAIR_PER_POINT,
7356
ISTANBUL_PAIR_BASE,
@@ -81,7 +64,7 @@ pub mod pair {
8164
pub const BYZANTIUM: PrecompileWithAddress = PrecompileWithAddress(
8265
ADDRESS,
8366
Precompile::Standard(|input, gas_limit| {
84-
super::run_pair(
67+
run_pair(
8568
input,
8669
BYZANTIUM_PAIR_PER_POINT,
8770
BYZANTIUM_PAIR_BASE,
@@ -137,7 +120,11 @@ pub fn new_g1_point(px: Fq, py: Fq) -> Result<G1, Error> {
137120
}
138121
}
139122

140-
pub fn run_add(input: &[u8]) -> Result<Bytes, Error> {
123+
pub fn run_add(input: &[u8], gas_cost: u64, gas_limit: u64) -> PrecompileResult {
124+
if gas_cost > gas_limit {
125+
return Err(Error::OutOfGas);
126+
}
127+
141128
let input = right_pad::<ADD_INPUT_LEN>(input);
142129

143130
let p1 = read_point(&input[..64])?;
@@ -148,10 +135,14 @@ pub fn run_add(input: &[u8]) -> Result<Bytes, Error> {
148135
sum.x().to_big_endian(&mut output[..32]).unwrap();
149136
sum.y().to_big_endian(&mut output[32..]).unwrap();
150137
}
151-
Ok(output.into())
138+
Ok((gas_cost, output.into()))
152139
}
153140

154-
pub fn run_mul(input: &[u8]) -> Result<Bytes, Error> {
141+
pub fn run_mul(input: &[u8], gas_cost: u64, gas_limit: u64) -> PrecompileResult {
142+
if gas_cost > gas_limit {
143+
return Err(Error::OutOfGas);
144+
}
145+
155146
let input = right_pad::<MUL_INPUT_LEN>(input);
156147

157148
let p = read_point(&input[..64])?;
@@ -164,7 +155,7 @@ pub fn run_mul(input: &[u8]) -> Result<Bytes, Error> {
164155
mul.x().to_big_endian(&mut output[..32]).unwrap();
165156
mul.y().to_big_endian(&mut output[32..]).unwrap();
166157
}
167-
Ok(output.into())
158+
Ok((gas_cost, output.into()))
168159
}
169160

170161
pub fn run_pair(
@@ -223,10 +214,12 @@ pub fn run_pair(
223214
Ok((gas_used, bool_to_bytes32(success)))
224215
}
225216

226-
/*
227217
#[cfg(test)]
228218
mod tests {
229-
use crate::test_utils::new_context;
219+
use crate::bn128::add::BYZANTIUM_ADD_GAS_COST;
220+
use crate::bn128::mul::BYZANTIUM_MUL_GAS_COST;
221+
use crate::bn128::pair::{BYZANTIUM_PAIR_BASE, BYZANTIUM_PAIR_PER_POINT};
222+
use revm_primitives::hex;
230223

231224
use super::*;
232225

@@ -247,9 +240,7 @@ mod tests {
247240
)
248241
.unwrap();
249242

250-
let res = Bn128Add::<Byzantium>::run(&input, 500, &new_context(), false)
251-
.unwrap()
252-
.output;
243+
let (_, res) = run_add(&input, BYZANTIUM_ADD_GAS_COST, 500).unwrap();
253244
assert_eq!(res, expected);
254245

255246
// zero sum test
@@ -268,9 +259,7 @@ mod tests {
268259
)
269260
.unwrap();
270261

271-
let res = Bn128Add::<Byzantium>::run(&input, 500, &new_context(), false)
272-
.unwrap()
273-
.output;
262+
let (_, res) = run_add(&input, BYZANTIUM_ADD_GAS_COST, 500).unwrap();
274263
assert_eq!(res, expected);
275264

276265
// out of gas test
@@ -282,8 +271,10 @@ mod tests {
282271
0000000000000000000000000000000000000000000000000000000000000000",
283272
)
284273
.unwrap();
285-
let res = Bn128Add::<Byzantium>::run(&input, 499, &new_context(), false);
286-
assert!(matches!(res, Err(Return::OutOfGas)));
274+
275+
let res = run_add(&input, BYZANTIUM_ADD_GAS_COST, 499);
276+
println!("{:?}", res);
277+
assert!(matches!(res, Err(Error::OutOfGas)));
287278

288279
// no input test
289280
let input = [0u8; 0];
@@ -294,9 +285,7 @@ mod tests {
294285
)
295286
.unwrap();
296287

297-
let res = Bn128Add::<Byzantium>::run(&input, 500, &new_context(), false)
298-
.unwrap()
299-
.output;
288+
let (_, res) = run_add(&input, BYZANTIUM_ADD_GAS_COST, 500).unwrap();
300289
assert_eq!(res, expected);
301290

302291
// point not on curve fail
@@ -309,11 +298,8 @@ mod tests {
309298
)
310299
.unwrap();
311300

312-
let res = Bn128Add::<Byzantium>::run(&input, 500, &new_context(), false);
313-
assert!(matches!(
314-
res,
315-
Err(Return::Other(Cow::Borrowed("ERR_BN128_INVALID_POINT")))
316-
));
301+
let res = run_add(&input, BYZANTIUM_ADD_GAS_COST, 500);
302+
assert!(matches!(res, Err(Error::Bn128AffineGFailedToCreate)));
317303
}
318304

319305
#[test]
@@ -332,9 +318,7 @@ mod tests {
332318
)
333319
.unwrap();
334320

335-
let res = Bn128Mul::<Byzantium>::run(&input, 40_000, &new_context(), false)
336-
.unwrap()
337-
.output;
321+
let (_, res) = run_mul(&input, BYZANTIUM_MUL_GAS_COST, 40_000).unwrap();
338322
assert_eq!(res, expected);
339323

340324
// out of gas test
@@ -345,8 +329,9 @@ mod tests {
345329
0200000000000000000000000000000000000000000000000000000000000000",
346330
)
347331
.unwrap();
348-
let res = Bn128Mul::<Byzantium>::run(&input, 39_999, &new_context(), false);
349-
assert!(matches!(res, Err(Return::OutOfGas)));
332+
333+
let res = run_mul(&input, BYZANTIUM_MUL_GAS_COST, 39_999);
334+
assert!(matches!(res, Err(Error::OutOfGas)));
350335

351336
// zero multiplication test
352337
let input = hex::decode(
@@ -363,9 +348,7 @@ mod tests {
363348
)
364349
.unwrap();
365350

366-
let res = Bn128Mul::<Byzantium>::run(&input, 40_000, &new_context(), false)
367-
.unwrap()
368-
.output;
351+
let (_, res) = run_mul(&input, BYZANTIUM_MUL_GAS_COST, 40_000).unwrap();
369352
assert_eq!(res, expected);
370353

371354
// no input test
@@ -377,9 +360,7 @@ mod tests {
377360
)
378361
.unwrap();
379362

380-
let res = Bn128Mul::<Byzantium>::run(&input, 40_000, &new_context(), false)
381-
.unwrap()
382-
.output;
363+
let (_, res) = run_mul(&input, BYZANTIUM_MUL_GAS_COST, 40_000).unwrap();
383364
assert_eq!(res, expected);
384365

385366
// point not on curve fail
@@ -391,11 +372,8 @@ mod tests {
391372
)
392373
.unwrap();
393374

394-
let res = Bn128Mul::<Byzantium>::run(&input, 40_000, &new_context(), false);
395-
assert!(matches!(
396-
res,
397-
Err(Return::Other(Cow::Borrowed("ERR_BN128_INVALID_POINT")))
398-
));
375+
let res = run_mul(&input, BYZANTIUM_MUL_GAS_COST, 40_000);
376+
assert!(matches!(res, Err(Error::Bn128AffineGFailedToCreate)));
399377
}
400378

401379
#[test]
@@ -420,9 +398,13 @@ mod tests {
420398
hex::decode("0000000000000000000000000000000000000000000000000000000000000001")
421399
.unwrap();
422400

423-
let res = Bn128Pair::<Byzantium>::run(&input, 260_000, &new_context(), false)
424-
.unwrap()
425-
.output;
401+
let (_, res) = run_pair(
402+
&input,
403+
BYZANTIUM_PAIR_PER_POINT,
404+
BYZANTIUM_PAIR_BASE,
405+
260_000,
406+
)
407+
.unwrap();
426408
assert_eq!(res, expected);
427409

428410
// out of gas test
@@ -442,18 +424,28 @@ mod tests {
442424
12c85ea5db8c6deb4aab71808dcb408fe3d1e7690c43d37b4ce6cc0166fa7daa",
443425
)
444426
.unwrap();
445-
let res = Bn128Pair::<Byzantium>::run(&input, 259_999, &new_context(), false);
446-
assert!(matches!(res, Err(Return::OutOfGas)));
427+
428+
let res = run_pair(
429+
&input,
430+
BYZANTIUM_PAIR_PER_POINT,
431+
BYZANTIUM_PAIR_BASE,
432+
259_999,
433+
);
434+
assert!(matches!(res, Err(Error::OutOfGas)));
447435

448436
// no input test
449437
let input = [0u8; 0];
450438
let expected =
451439
hex::decode("0000000000000000000000000000000000000000000000000000000000000001")
452440
.unwrap();
453441

454-
let res = Bn128Pair::<Byzantium>::run(&input, 260_000, &new_context(), false)
455-
.unwrap()
456-
.output;
442+
let (_, res) = run_pair(
443+
&input,
444+
BYZANTIUM_PAIR_PER_POINT,
445+
BYZANTIUM_PAIR_BASE,
446+
260_000,
447+
)
448+
.unwrap();
457449
assert_eq!(res, expected);
458450

459451
// point not on curve fail
@@ -468,11 +460,13 @@ mod tests {
468460
)
469461
.unwrap();
470462

471-
let res = Bn128Pair::<Byzantium>::run(&input, 260_000, &new_context(), false);
472-
assert!(matches!(
473-
res,
474-
Err(Return::Other(Cow::Borrowed("ERR_BN128_INVALID_A")))
475-
));
463+
let res = run_pair(
464+
&input,
465+
BYZANTIUM_PAIR_PER_POINT,
466+
BYZANTIUM_PAIR_BASE,
467+
260_000,
468+
);
469+
assert!(matches!(res, Err(Error::Bn128AffineGFailedToCreate)));
476470

477471
// invalid input length
478472
let input = hex::decode(
@@ -484,11 +478,12 @@ mod tests {
484478
)
485479
.unwrap();
486480

487-
let res = Bn128Pair::<Byzantium>::run(&input, 260_000, &new_context(), false);
488-
assert!(matches!(
489-
res,
490-
Err(Return::Other(Cow::Borrowed("ERR_BN128_INVALID_LEN",)))
491-
));
481+
let res = run_pair(
482+
&input,
483+
BYZANTIUM_PAIR_PER_POINT,
484+
BYZANTIUM_PAIR_BASE,
485+
260_000,
486+
);
487+
assert!(matches!(res, Err(Error::Bn128PairLength)));
492488
}
493489
}
494-
*/

0 commit comments

Comments
 (0)