From 19b76907d12a327944b64aa18f39230191ad309a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Tue, 25 Feb 2025 15:34:13 -0800 Subject: [PATCH] compile boolean operators, add JumpIfTrue instruction --- bbq/compiler/codegen.go | 4 + bbq/compiler/compiler.go | 46 +++++++++ bbq/compiler/compiler_test.go | 114 ++++++++++++++++++++++ bbq/opcode/instructions.go | 32 +++++++ bbq/opcode/instructions.yml | 13 +++ bbq/opcode/opcode.go | 2 +- bbq/opcode/opcode_string.go | 9 +- bbq/vm/test/vm_test.go | 176 ++++++++++++++++++++++++++++++++-- bbq/vm/value_conversions.go | 4 + bbq/vm/vm.go | 9 ++ 10 files changed, 394 insertions(+), 15 deletions(-) diff --git a/bbq/compiler/codegen.go b/bbq/compiler/codegen.go index 0e93a7fd1..4a0a3062e 100644 --- a/bbq/compiler/codegen.go +++ b/bbq/compiler/codegen.go @@ -82,6 +82,10 @@ func (g *InstructionCodeGen) PatchJump(offset int, newTarget uint16) { ins.Target = newTarget (*g.target)[offset] = ins + case opcode.InstructionJumpIfTrue: + ins.Target = newTarget + (*g.target)[offset] = ins + case opcode.InstructionJumpIfNil: ins.Target = newTarget (*g.target)[offset] = ins diff --git a/bbq/compiler/compiler.go b/bbq/compiler/compiler.go index 38a852ea3..8a6fd071b 100644 --- a/bbq/compiler/compiler.go +++ b/bbq/compiler/compiler.go @@ -262,6 +262,12 @@ func (c *Compiler[_]) emitUndefinedJumpIfFalse() int { return offset } +func (c *Compiler[_]) emitUndefinedJumpIfTrue() int { + offset := c.codeGen.Offset() + c.codeGen.Emit(opcode.InstructionJumpIfTrue{Target: math.MaxUint16}) + return offset +} + func (c *Compiler[_]) emitUndefinedJumpIfNil() int { offset := c.codeGen.Offset() c.codeGen.Emit(opcode.InstructionJumpIfNil{Target: math.MaxUint16}) @@ -1282,6 +1288,44 @@ func (c *Compiler[_]) VisitBinaryExpression(expression *ast.BinaryExpression) (_ // End c.patchJump(thenJump) + case ast.OperationOr: + // TODO: optimize chains of ors / ands + + leftTrueJump := c.emitUndefinedJumpIfTrue() + + c.compileExpression(expression.Right) + rightFalseJump := c.emitUndefinedJumpIfFalse() + + // Left or right is true + c.patchJump(leftTrueJump) + c.codeGen.Emit(opcode.InstructionTrue{}) + trueJump := c.emitUndefinedJump() + + // Left and right are false + c.patchJump(rightFalseJump) + c.codeGen.Emit(opcode.InstructionFalse{}) + + c.patchJump(trueJump) + + case ast.OperationAnd: + // TODO: optimize chains of ors / ands + + leftFalseJump := c.emitUndefinedJumpIfFalse() + + c.compileExpression(expression.Right) + rightFalseJump := c.emitUndefinedJumpIfFalse() + + // Left and right are true + c.codeGen.Emit(opcode.InstructionTrue{}) + trueJump := c.emitUndefinedJump() + + // Left or right is false + c.patchJump(leftFalseJump) + c.patchJump(rightFalseJump) + c.codeGen.Emit(opcode.InstructionFalse{}) + + c.patchJump(trueJump) + default: c.compileExpression(expression.Right) @@ -1296,10 +1340,12 @@ func (c *Compiler[_]) VisitBinaryExpression(expression *ast.BinaryExpression) (_ c.codeGen.Emit(opcode.InstructionDivide{}) case ast.OperationMod: c.codeGen.Emit(opcode.InstructionMod{}) + case ast.OperationEqual: c.codeGen.Emit(opcode.InstructionEqual{}) case ast.OperationNotEqual: c.codeGen.Emit(opcode.InstructionNotEqual{}) + case ast.OperationLess: c.codeGen.Emit(opcode.InstructionLess{}) case ast.OperationLessEqual: diff --git a/bbq/compiler/compiler_test.go b/bbq/compiler/compiler_test.go index 6ab173036..d537e3c0a 100644 --- a/bbq/compiler/compiler_test.go +++ b/bbq/compiler/compiler_test.go @@ -3637,3 +3637,117 @@ func TestCompileConditional(t *testing.T) { program.Constants, ) } + +func TestCompileOr(t *testing.T) { + + t.Parallel() + + checker, err := ParseAndCheck(t, ` + fun test(x: Bool, y: Bool): Bool { + return x || y + } + `) + require.NoError(t, err) + + comp := compiler.NewInstructionCompiler(checker) + program := comp.Compile() + + const parameterCount = 2 + + const ( + // xIndex is the index of the parameter `x`, which is the first parameter + xIndex = iota + // yIndex is the index of the parameter `y`, which is the second parameter + yIndex + ) + + // resultIndex is the index of the $result variable + const resultIndex = parameterCount + + require.Len(t, program.Functions, 1) + + functions := comp.ExportFunctions() + require.Equal(t, len(program.Functions), len(functions)) + + assert.Equal(t, + []opcode.Instruction{ + // return x || y + opcode.InstructionGetLocal{LocalIndex: xIndex}, + opcode.InstructionJumpIfTrue{Target: 4}, + + opcode.InstructionGetLocal{LocalIndex: yIndex}, + opcode.InstructionJumpIfFalse{Target: 6}, + + opcode.InstructionTrue{}, + opcode.InstructionJump{Target: 7}, + + opcode.InstructionFalse{}, + + // assign to temp $result + opcode.InstructionTransfer{TypeIndex: 0}, + opcode.InstructionSetLocal{LocalIndex: resultIndex}, + + // return $result + opcode.InstructionGetLocal{LocalIndex: resultIndex}, + opcode.InstructionReturnValue{}, + }, + functions[0].Code, + ) +} + +func TestCompileAnd(t *testing.T) { + + t.Parallel() + + checker, err := ParseAndCheck(t, ` + fun test(x: Bool, y: Bool): Bool { + return x && y + } + `) + require.NoError(t, err) + + comp := compiler.NewInstructionCompiler(checker) + program := comp.Compile() + + const parameterCount = 2 + + const ( + // xIndex is the index of the parameter `x`, which is the first parameter + xIndex = iota + // yIndex is the index of the parameter `y`, which is the second parameter + yIndex + ) + + // resultIndex is the index of the $result variable + const resultIndex = parameterCount + + require.Len(t, program.Functions, 1) + + functions := comp.ExportFunctions() + require.Equal(t, len(program.Functions), len(functions)) + + assert.Equal(t, + []opcode.Instruction{ + // return x && y + opcode.InstructionGetLocal{LocalIndex: xIndex}, + opcode.InstructionJumpIfFalse{Target: 6}, + + opcode.InstructionGetLocal{LocalIndex: yIndex}, + opcode.InstructionJumpIfFalse{Target: 6}, + + opcode.InstructionTrue{}, + opcode.InstructionJump{Target: 7}, + + opcode.InstructionFalse{}, + + // assign to temp $result + opcode.InstructionTransfer{TypeIndex: 0}, + opcode.InstructionSetLocal{LocalIndex: resultIndex}, + + // return $result + opcode.InstructionGetLocal{LocalIndex: resultIndex}, + opcode.InstructionReturnValue{}, + }, + functions[0].Code, + ) +} diff --git a/bbq/opcode/instructions.go b/bbq/opcode/instructions.go index e27752999..3d4cf4718 100644 --- a/bbq/opcode/instructions.go +++ b/bbq/opcode/instructions.go @@ -841,6 +841,36 @@ func DecodeJumpIfFalse(ip *uint16, code []byte) (i InstructionJumpIfFalse) { return i } +// InstructionJumpIfTrue +// +// Pops a value off the stack. If it is `true`, jumps to the target instruction. +type InstructionJumpIfTrue struct { + Target uint16 +} + +var _ Instruction = InstructionJumpIfTrue{} + +func (InstructionJumpIfTrue) Opcode() Opcode { + return JumpIfTrue +} + +func (i InstructionJumpIfTrue) String() string { + var sb strings.Builder + sb.WriteString(i.Opcode().String()) + printfArgument(&sb, "target", i.Target) + return sb.String() +} + +func (i InstructionJumpIfTrue) Encode(code *[]byte) { + emitOpcode(code, i.Opcode()) + emitUint16(code, i.Target) +} + +func DecodeJumpIfTrue(ip *uint16, code []byte) (i InstructionJumpIfTrue) { + i.Target = decodeUint16(ip, code) + return i +} + // InstructionJumpIfNil // // Pops a value off the stack. If it is `nil`, jumps to the target instruction. @@ -1303,6 +1333,8 @@ func DecodeInstruction(ip *uint16, code []byte) Instruction { return DecodeJump(ip, code) case JumpIfFalse: return DecodeJumpIfFalse(ip, code) + case JumpIfTrue: + return DecodeJumpIfTrue(ip, code) case JumpIfNil: return DecodeJumpIfNil(ip, code) case Return: diff --git a/bbq/opcode/instructions.yml b/bbq/opcode/instructions.yml index 9f1953531..a2138d940 100644 --- a/bbq/opcode/instructions.yml +++ b/bbq/opcode/instructions.yml @@ -400,6 +400,19 @@ - name: "value" type: "value" +- name: "jumpIfTrue" + description: + Pops a value off the stack. If it is `true`, jumps to the target instruction. + operands: + - name: "target" + type: "index" + controlEffects: + - jump: "target" + valueEffects: + pop: + - name: "value" + type: "value" + - name: "jumpIfNil" description: Pops a value off the stack. If it is `nil`, jumps to the target instruction. diff --git a/bbq/opcode/opcode.go b/bbq/opcode/opcode.go index 8a3862fef..c2b52b3b2 100644 --- a/bbq/opcode/opcode.go +++ b/bbq/opcode/opcode.go @@ -31,12 +31,12 @@ const ( ReturnValue Jump JumpIfFalse + JumpIfTrue JumpIfNil _ _ _ _ - _ // Int operations diff --git a/bbq/opcode/opcode_string.go b/bbq/opcode/opcode_string.go index c532c7290..a05d41ee3 100644 --- a/bbq/opcode/opcode_string.go +++ b/bbq/opcode/opcode_string.go @@ -13,7 +13,8 @@ func _() { _ = x[ReturnValue-2] _ = x[Jump-3] _ = x[JumpIfFalse-4] - _ = x[JumpIfNil-5] + _ = x[JumpIfTrue-5] + _ = x[JumpIfNil-6] _ = x[Add-11] _ = x[Subtract-12] _ = x[Multiply-13] @@ -60,7 +61,7 @@ func _() { } const ( - _Opcode_name_0 = "UnknownReturnReturnValueJumpJumpIfFalseJumpIfNil" + _Opcode_name_0 = "UnknownReturnReturnValueJumpJumpIfFalseJumpIfTrueJumpIfNil" _Opcode_name_1 = "AddSubtractMultiplyDivideMod" _Opcode_name_2 = "LessGreaterLessOrEqualGreaterOrEqual" _Opcode_name_3 = "EqualNotEqualNot" @@ -73,7 +74,7 @@ const ( ) var ( - _Opcode_index_0 = [...]uint8{0, 7, 13, 24, 28, 39, 48} + _Opcode_index_0 = [...]uint8{0, 7, 13, 24, 28, 39, 49, 58} _Opcode_index_1 = [...]uint8{0, 3, 11, 19, 25, 28} _Opcode_index_2 = [...]uint8{0, 4, 11, 22, 36} _Opcode_index_3 = [...]uint8{0, 5, 13, 16} @@ -87,7 +88,7 @@ var ( func (i Opcode) String() string { switch { - case i <= 5: + case i <= 6: return _Opcode_name_0[_Opcode_index_0[i]:_Opcode_index_0[i+1]] case 11 <= i && i <= 15: i -= 11 diff --git a/bbq/vm/test/vm_test.go b/bbq/vm/test/vm_test.go index a5f338ed8..b1e2b6ed7 100644 --- a/bbq/vm/test/vm_test.go +++ b/bbq/vm/test/vm_test.go @@ -1442,7 +1442,7 @@ func TestTransaction(t *testing.T) { checker, err := ParseAndCheck(t, ` transaction { - var a: String + var a: String prepare() { self.a = "Hello!" } @@ -1496,7 +1496,7 @@ func TestTransaction(t *testing.T) { checker, err := ParseAndCheck(t, ` transaction(param1: String, param2: String) { - var a: String + var a: String prepare() { self.a = param1 } @@ -4512,16 +4512,16 @@ func TestCompileIf(t *testing.T) { test := func(t *testing.T, argument vm.Value) vm.Value { result, err := compileAndInvoke(t, ` - fun test(x: Bool): Int { + fun test(x: Bool): Int { var y = 0 - if x { + if x { y = 1 } else { y = 2 } return y - } - `, + } + `, "test", argument, ) @@ -4551,10 +4551,10 @@ func TestCompileConditional(t *testing.T) { test := func(t *testing.T, argument vm.Value) vm.Value { result, err := compileAndInvoke(t, ` - fun test(x: Bool): Int { - return x ? 1 : 2 - } - `, + fun test(x: Bool): Int { + return x ? 1 : 2 + } + `, "test", argument, ) @@ -4576,3 +4576,159 @@ func TestCompileConditional(t *testing.T) { require.Equal(t, vm.NewIntValue(2), actual) }) } + +func TestCompileOr(t *testing.T) { + + t.Parallel() + + test := func(t *testing.T, x, y vm.Value) vm.Value { + result, err := compileAndInvoke(t, + ` + struct Tester { + let x: Bool + let y: Bool + var z: Int + + init(x: Bool, y: Bool) { + self.x = x + self.y = y + self.z = 0 + } + + fun a(): Bool { + self.z = 1 + return self.x + } + + fun b(): Bool { + self.z = 2 + return self.y + } + + fun test(): Int { + if self.a() || self.b() { + return self.z + 10 + } else { + return self.z + 20 + } + } + } + + fun test(x: Bool, y: Bool): Int { + return Tester(x: x, y: y).test() + } + `, + "test", + x, + y, + ) + require.NoError(t, err) + return result + } + + t.Run("true, true", func(t *testing.T) { + t.Parallel() + + actual := test(t, vm.BoolValue(true), vm.BoolValue(true)) + require.Equal(t, vm.NewIntValue(11), actual) + }) + + t.Run("true, false", func(t *testing.T) { + t.Parallel() + + actual := test(t, vm.BoolValue(true), vm.BoolValue(false)) + require.Equal(t, vm.NewIntValue(11), actual) + }) + + t.Run("false, true", func(t *testing.T) { + t.Parallel() + + actual := test(t, vm.BoolValue(false), vm.BoolValue(true)) + require.Equal(t, vm.NewIntValue(12), actual) + }) + + t.Run("false, false", func(t *testing.T) { + t.Parallel() + + actual := test(t, vm.BoolValue(false), vm.BoolValue(false)) + require.Equal(t, vm.NewIntValue(22), actual) + }) +} + +func TestCompileAnd(t *testing.T) { + + t.Parallel() + + test := func(t *testing.T, x, y vm.Value) vm.Value { + result, err := compileAndInvoke(t, + ` + struct Tester { + let x: Bool + let y: Bool + var z: Int + + init(x: Bool, y: Bool) { + self.x = x + self.y = y + self.z = 0 + } + + fun a(): Bool { + self.z = 1 + return self.x + } + + fun b(): Bool { + self.z = 2 + return self.y + } + + fun test(): Int { + if self.a() && self.b() { + return self.z + 10 + } else { + return self.z + 20 + } + } + } + + fun test(x: Bool, y: Bool): Int { + return Tester(x: x, y: y).test() + } + `, + "test", + x, + y, + ) + require.NoError(t, err) + return result + } + + t.Run("true, true", func(t *testing.T) { + t.Parallel() + + actual := test(t, vm.BoolValue(true), vm.BoolValue(true)) + require.Equal(t, vm.NewIntValue(12), actual) + }) + + t.Run("true, false", func(t *testing.T) { + t.Parallel() + + actual := test(t, vm.BoolValue(true), vm.BoolValue(false)) + require.Equal(t, vm.NewIntValue(22), actual) + }) + + t.Run("false, true", func(t *testing.T) { + t.Parallel() + + actual := test(t, vm.BoolValue(false), vm.BoolValue(true)) + require.Equal(t, vm.NewIntValue(21), actual) + }) + + t.Run("false, false", func(t *testing.T) { + t.Parallel() + + actual := test(t, vm.BoolValue(false), vm.BoolValue(false)) + require.Equal(t, vm.NewIntValue(21), actual) + }) +} diff --git a/bbq/vm/value_conversions.go b/bbq/vm/value_conversions.go index 175afb6d4..fa0b910a5 100644 --- a/bbq/vm/value_conversions.go +++ b/bbq/vm/value_conversions.go @@ -31,6 +31,8 @@ func InterpreterValueToVMValue(storage interpreter.Storage, value interpreter.Va switch value := value.(type) { case nil: return nil + case interpreter.BoolValue: + return BoolValue(value) case interpreter.NilValue: return Nil case interpreter.IntValue: @@ -99,6 +101,8 @@ func VMValueToInterpreterValue(config *Config, value Value) interpreter.Value { switch value := value.(type) { case nil: return nil + case BoolValue: + return interpreter.BoolValue(value) case NilValue: return interpreter.Nil case IntValue: diff --git a/bbq/vm/vm.go b/bbq/vm/vm.go index e9f0ec0ac..b093c82d5 100644 --- a/bbq/vm/vm.go +++ b/bbq/vm/vm.go @@ -348,6 +348,13 @@ func opJumpIfFalse(vm *VM, ins opcode.InstructionJumpIfFalse) { } } +func opJumpIfTrue(vm *VM, ins opcode.InstructionJumpIfTrue) { + value := vm.pop().(BoolValue) + if value { + vm.ip = ins.Target + } +} + func opJumpIfNil(vm *VM, ins opcode.InstructionJumpIfNil) { _, ok := vm.pop().(NilValue) if ok { @@ -820,6 +827,8 @@ func (vm *VM) run() { opJump(vm, ins) case opcode.InstructionJumpIfFalse: opJumpIfFalse(vm, ins) + case opcode.InstructionJumpIfTrue: + opJumpIfTrue(vm, ins) case opcode.InstructionJumpIfNil: opJumpIfNil(vm, ins) case opcode.InstructionAdd: