Skip to content

Commit

Permalink
feat(core): add union type (sum type)
Browse files Browse the repository at this point in the history
  • Loading branch information
RIvance committed Nov 6, 2024
1 parent f1f42aa commit 421d53b
Show file tree
Hide file tree
Showing 14 changed files with 183 additions and 36 deletions.
4 changes: 3 additions & 1 deletion saki-concrete/src/main/antlr/Saki.g4
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ expr
| func=expr '[' NL* argList NL* ']' # exprImplicitCall
| subject=expr NL* '.' member=Identifier ('[' implicitArgList=argList ']')? # exprElimination
| inductive=expr '::' constructor=Identifier # exprConstructor
| '^' expr # exprTypeOf
| '(' '|'? types+=expr ('|' types+=expr)+ ')' # exprUnionType
| lhs=expr rhs=atom # exprSpine
| lhs=expr op=OptSymbol rhs=expr # exprSpineInfixOp
| lhs=expr op=OptSymbol # exprSpinePostfixOp
| op=OptSymbol rhs=expr # exprSpinePrefixOp
| '(' value=blockExpr ')' # exprParen
| '\'(' elements+=expr ',' NL* elements+=expr ')' # exprTuple
| '^(' types+=expr ',' NL* types+=expr ')' # exprTupleType
| '(' types+=expr ',' NL* types+=expr ')' # exprTupleType
| '(' NL* lambdaParamList=paramList NL* ')' (':' returnType=expr)? '=>' body=blockExpr # exprLambda
| func=expr ('|' lambdaParamList=untypedParamList '|' (':' returnType=expr)?)? body=block # exprCallWithLambda
// Control
Expand Down
13 changes: 9 additions & 4 deletions saki-concrete/src/main/scala/saki/concrete/Visitor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ class Visitor extends SakiBaseVisitor[SyntaxTree[?] | Seq[SyntaxTree[?]]] {
case ctx: ExprCallContext => visitExprCall(ctx)
case ctx: ExprImplicitCallContext => visitExprImplicitCall(ctx)
case ctx: ExprParenContext => visitExprParen(ctx)
case ctx: ExprUnionTypeContext => visitExprUnionType(ctx)
case ctx: ExprTupleTypeContext => visitExprTupleType(ctx)
case ctx: ExprTupleContext => visitExprTuple(ctx)
case ctx: ExprConstructorContext => visitExprConstructor(ctx)
Expand Down Expand Up @@ -245,6 +246,10 @@ class Visitor extends SakiBaseVisitor[SyntaxTree[?] | Seq[SyntaxTree[?]]] {

override def visitExprParen(ctx: ExprParenContext): ExprTree = ctx.value.visit

override def visitExprUnionType(ctx: ExprUnionTypeContext): ExprTree = {
ExprTree.Union(ctx.types.asScala.map(_.visit))(ctx)
}

override def visitExprTuple(ctx: ExprTupleContext): ExprTree = {
UnsupportedFeature.raise(ctx.span) { "Tuple is not supported yet" }
}
Expand Down Expand Up @@ -374,11 +379,11 @@ class Visitor extends SakiBaseVisitor[SyntaxTree[?] | Seq[SyntaxTree[?]]] {
val body = caseCtx.body.visit
caseCtx.clauses.asScala.map {
case clause: MatchClauseSingleContext => {
val pattern = clause.pattern.visit.get
if clause.`type` != null then UnsupportedFeature.raise(clause.span) {
"Type annotation in match clause is not supported yet"
val pattern: Pattern[ExprTree] = clause.pattern.visit.get
val typedPattern = if clause.`type` == null then pattern else {
Pattern.Typed(pattern, clause.`type`.visit)(ctx.span)
}
Clause(Seq(pattern), body)
Clause(Seq(typedPattern), body)
}
case clause: MatchClauseTupleContext => {
val patterns = clause.patternList.patterns.asScala.map(_.visit.get)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ enum ExprTree(implicit ctx: ParserRuleContext) extends SyntaxTree[CoreExpr] with
`type`: LiteralType,
)(implicit ctx: ParserRuleContext)

case Union(
types: Seq[ExprTree],
)(implicit ctx: ParserRuleContext)

case TypeOf(
value: ExprTree,
)(implicit ctx: ParserRuleContext)
Expand Down Expand Up @@ -96,6 +100,8 @@ enum ExprTree(implicit ctx: ParserRuleContext) extends SyntaxTree[CoreExpr] with
case PrimitiveValue(value) => CoreExpr.Primitive(value)
case PrimitiveType(ty) => CoreExpr.PrimitiveType(ty)

case Union(types) => CoreExpr.Union(types.map(_.emit))

case TypeOf(value) => CoreExpr.TypeOf(value.emit)

case Lambda(paramExpr, bodyExpr, returnTypeExpr) => {
Expand Down
23 changes: 21 additions & 2 deletions saki-concrete/src/test/scala/saki/core/DefinitionTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,34 @@ class DefinitionTest extends AnyFunSuite with should.Matchers with SakiTestExt {
test("overloaded 2") {
val code = {
"""
def concat(a b: Int): Int = a * 10 + b
def ceilDec(x: Int): Int = if x == 0 then 1 else 10 * ceilDec(x / 10)
def concat(a: Int, b: Int): Int = a * ceilDec(b) + b
def concat(a b: String): String = a ++ b
"""
}
val module = compileModule(code)
module.eval("concat(1, 2)") should be (module.eval("12"))
module.eval("concat(123, 456)") should be (module.eval("123456"))
module.eval("concat(\"It's \", \"mygo!!!!!\")") should be (module.eval("\"It's mygo!!!!!\""))
}

test("sum type") {
val code = {
"""
def describeValue(value: (Bool | ℤ | String)): String = match value {
case true => "It's true!"
case false => "It's false!"
case n: ℤ => "It's an integer: " ++ n.toString ++ "!"
case s: String => "It's " ++ s ++ "!!!!!"
}
"""
}
val module = compileModule(code)
module.eval("describeValue(true)") should be (module.eval("\"It's true!\""))
module.eval("describeValue(false)") should be (module.eval("\"It's false!\""))
module.eval("describeValue(114)") should be (module.eval("\"It's an integer: 114!\""))
module.eval("describeValue(\"mygo\")") should be (module.eval("\"It's mygo!!!!!\""))
}

test("mutual recursive") {
val code = {
"""
Expand Down
2 changes: 2 additions & 0 deletions saki-core/src/main/scala/saki/core/Entity.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ trait EntityFactory[T <: Entity, D <: Entity] {

def variable(ident: Var.Local, ty: T): T

def typeBarrier(value: T, ty: T): T

def inductiveType(inductive: Var.Defined[D, Inductive], args: Seq[T]): T

def functionInvoke(function: Var.Defined[D, Function], args: Seq[T]): T
Expand Down
11 changes: 8 additions & 3 deletions saki-core/src/main/scala/saki/core/domain/NeutralValue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ enum NeutralValue {

case Variable(ident: Var.Local, `type`: Type)

case TypeBarrier(value: Value, `type`: Type)

case Apply(fn: NeutralValue, arg: Value)

case Projection(record: Value, field: String)
Expand All @@ -30,6 +32,8 @@ enum NeutralValue {

case Variable(_, ty) => ty

case TypeBarrier(_, ty) => ty

case Apply(fn, arg) => fn.infer match {
case Value.Pi(paramType, codomain) => {
val argType = arg.infer
Expand Down Expand Up @@ -119,6 +123,7 @@ enum NeutralValue {
def readBack(implicit env: Environment.Typed[Value]): Term = this match {
case Variable(ident, _) => Term.Variable(ident)
case Apply(fn, arg) => Term.Apply(fn.readBack, arg.readBack)
case TypeBarrier(value, ty) => Term.TypeBarrier(value.readBack, ty.readBack)
case Projection(record, field) => Term.Projection(record.readBack, field)
case FunctionInvoke(fnRef, args) => Term.FunctionInvoke(fnRef, args.map(_.readBack))
case Match(scrutinees, clauses) => {
Expand All @@ -141,6 +146,7 @@ enum NeutralValue {

def containsMatching(implicit env: Environment.Typed[Value]): Boolean = this match {
case Variable(_, _) => false
case TypeBarrier(value, _) => value.containsMatching
case Apply(fn, arg) => fn.containsMatching || arg.containsMatching
case Projection(record, _) => record.containsMatching
case FunctionInvoke(_, args) => args.exists(_.containsMatching)
Expand All @@ -151,6 +157,7 @@ enum NeutralValue {
implicit env: Environment.Typed[Value]
): Boolean = this match {
case Variable(ident, _) => variables.contains(ident)
case TypeBarrier(value, _) => value.isFinal(variables)
case Apply(fn, arg) => fn.isFinal(variables) && arg.isFinal(variables)
case Projection(record, _) => record.isFinal(variables)
case FunctionInvoke(_, args) => args.forall(_.isFinal(variables))
Expand All @@ -161,9 +168,7 @@ enum NeutralValue {
}
}

infix def unify(that: NeutralValue)(
implicit env: Environment.Typed[Value]
): Boolean = (this, that) match {
infix def unify(that: NeutralValue)(implicit env: Environment.Typed[Value]): Boolean = (this, that) match {

case (lhs, rhs) if lhs == rhs => true

Expand Down
51 changes: 43 additions & 8 deletions saki-core/src/main/scala/saki/core/domain/Value.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ enum Value extends RuntimeEntity[Type] {

case Neutral(value: NeutralValue)

case Union(types: Set[Type])

case Intersection(types: Set[Type])

case Pi(paramType: Type, closure: CodomainClosure)

case OverloadedPi(
Expand All @@ -35,6 +39,8 @@ enum Value extends RuntimeEntity[Type] {
states: Map[Type, CodomainClosure]
) extends Value with OverloadedLambdaLike[OverloadedLambda]

case Pair(first: Value, second: Value)

case Record(fields: Map[String, Value])

case RecordType(fields: Map[String, Type])
Expand Down Expand Up @@ -62,6 +68,10 @@ enum Value extends RuntimeEntity[Type] {

case Neutral(neutral) => neutral.infer

case Union(types) => types.map(_.infer).reduce(_ <:> _)

case Intersection(_) => ??? // TODO: find the greatest lower bound

case Pi(_, _) => Universe

case OverloadedPi(_) => Universe
Expand All @@ -81,6 +91,7 @@ enum Value extends RuntimeEntity[Type] {
}
(paramType, _ => bodyType): (Type, CodomainClosure)
})
case Pair(first, second) => Sigma(first.infer, _ => second.infer)
case Record(fields) => RecordType(fields.map((name, value) => (name, value.infer)))
case RecordType(_) => Universe
case InductiveType(_, _) => Universe
Expand All @@ -102,6 +113,10 @@ enum Value extends RuntimeEntity[Type] {

case Neutral(value) => value.readBack

case Union(types) => Term.Union(types.map(_.readBack))

case Intersection(types) => Term.Intersection(types.map(_.readBack))

case Pi(paramType, codomainClosure) => {
val (param, codomain) = Value.readBackClosure(paramType, codomainClosure)
Term.Pi(param, codomain)
Expand All @@ -121,6 +136,8 @@ enum Value extends RuntimeEntity[Type] {

case lambda: OverloadedLambda => Term.OverloadedLambda(lambda.readBackStates)

case Pair(first, second) => Term.Pair(first.readBack, second.readBack)

case Record(fields) => Term.Record(fields.map((name, value) => (name, value.readBack)))

case RecordType(fields) => Term.RecordType(fields.map((name, ty) => (name, ty.readBack)))
Expand All @@ -137,6 +154,8 @@ enum Value extends RuntimeEntity[Type] {
def containsMatching(implicit env: Environment.Typed[Value]): Boolean = this match {
case Universe | Primitive(_) | PrimitiveType(_) => false
case Neutral(neutral) => neutral.containsMatching
case Union(types) => types.exists(_.containsMatching)
case Intersection(types) => types.exists(_.containsMatching)
case Pi(paramType, closure) => {
env.withNewUnique(paramType) { (env, ident, ty) =>
closure(Value.variable(ident, ty)).containsMatching(env)
Expand All @@ -160,6 +179,7 @@ enum Value extends RuntimeEntity[Type] {
}
}
}
case Pair(first, second) => first.containsMatching || second.containsMatching
case Record(fields) => fields.valuesIterator.exists(_.containsMatching)
case RecordType(fields) => fields.valuesIterator.exists(_.containsMatching)
case InductiveType(_, args) => args.exists(_.containsMatching)
Expand Down Expand Up @@ -276,6 +296,12 @@ enum Value extends RuntimeEntity[Type] {
// Neutral values
case (Neutral(value1), Neutral(value2)) => value1 <:< value2

case (Union(types1), Union(types2)) => types1.forall(t1 => types2.exists(t1 <:< _))

case (Union(types), rhs) => types.exists(_ <:< rhs)

case (lhs, Union(types)) => types.forall(lhs <:< _)

// Function type (Pi type) subtyping: Covariant in return type, contravariant in parameter type
case (Pi(paramType1, closure1), Pi(paramType2, closure2)) => {
paramType2 <:< paramType1 && env.withNewUnique(paramType2 <:> paramType1) {
Expand Down Expand Up @@ -352,11 +378,20 @@ enum Value extends RuntimeEntity[Type] {
case (Value.Universe, Value.Universe) => Value.Universe

// LUB for Primitive types: If they match, return it, otherwise no common LUB
case (Value.PrimitiveType(t1), Value.PrimitiveType(t2)) => {
if (t1 == t2) Value.PrimitiveType(t1)
else NoLeastUpperBound.raise {
s"No least upper bound for incompatible primitive types: $t1 and $t2"
}
case (Value.PrimitiveType(t1), Value.PrimitiveType(t2)) if t1 == t2 => {
Value.PrimitiveType(t1)
}

case (Value.Union(types1), Value.Union(types2)) => (types1 ++ types2).reduce(_ <:> _)

case (lhs, Value.Union(types)) => {
if types.exists(lhs <:< _) then lhs
else Value.Union(types + lhs)
}

case (Value.Union(types), rhs) => {
if types.exists(_ <:< rhs) then Value.Union(types)
else Value.Union(types + rhs)
}

// LUB for Pi types: contravariant parameter type, covariant return type
Expand Down Expand Up @@ -429,9 +464,7 @@ enum Value extends RuntimeEntity[Type] {
}

// No common LUB for incompatible types
case _ => NoLeastUpperBound.raise {
s"No least upper bound exists between: $this and $that"
}
case (lhs, rhs) => Value.Union(Set(lhs, rhs))
}

@deprecatedOverriding("For debugging purposes only, don't call it in production code")
Expand All @@ -452,6 +485,8 @@ object Value extends RuntimeEntityFactory[Value] {

override def variable(ident: Var.Local, ty: Type): Value = Neutral(NeutralValue.Variable(ident, ty))

override def typeBarrier(value: Value, ty: Type): Type = Neutral(NeutralValue.TypeBarrier(value, ty))

override def functionInvoke(function: Var.Defined[Term, Function], args: Seq[Type]): Type = {
Neutral(NeutralValue.FunctionInvoke(function, args))
}
Expand Down
10 changes: 10 additions & 0 deletions saki-core/src/main/scala/saki/core/elaborate/Resolve.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,16 @@ object Resolve {
case None => UnresolvedReference.raise(s"Unresolved variable: $name")
}

case Expr.Union(types) => {
val (resolvedTypes, typesCtx) = types.foldLeft((List.empty[Expr], ctx)) {
case ((resolvedTypes, ctx), ty) => {
val (resolved, newCtx) = ty.resolve(ctx)
(resolvedTypes :+ resolved, newCtx)
}
}
(Expr.Union(resolvedTypes), typesCtx)
}

case Expr.Hole(_) => {
val resolved = Expr.Hole(ctx.variables.flatMap {
case local: Var.Local => Some(local)
Expand Down
13 changes: 9 additions & 4 deletions saki-core/src/main/scala/saki/core/elaborate/Synthesis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ object Synthesis:

case Expr.PrimitiveType(ty) => Synth(Term.PrimitiveType(ty), Value.Universe)

case Expr.TypeOf(value) => value.synth(env).unpack match {
case (_, ty: Value) => Synth(ty.readBack, Value.Universe)
}

case Expr.Variable(ref) => ref match {
// Converting a definition reference to a lambda, enabling curry-style function application
case definitionVar: Var.Defined[Term@unchecked, ?] => env.getSymbol(definitionVar) match {
Expand All @@ -62,6 +58,15 @@ object Synthesis:
}
}

case Expr.Union(types) => {
val synthTypes: Seq[Synth] = types.map(_.synth(env))
Synth(Term.Union(synthTypes.map(_.term).toSet), Value.Universe)
}

case Expr.TypeOf(value) => value.synth(env).unpack match {
case (_, ty: Value) => Synth(ty.readBack, Value.Universe)
}

case Expr.Elimination(obj, member) => obj.synth(env).normalize.unpack match {
// This is a project operation
// `obj.field`
Expand Down
11 changes: 11 additions & 0 deletions saki-core/src/main/scala/saki/core/syntax/Clause.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ extension (clauses: Seq[Clause[Term]]) {
case Some(value) => (variable, value)
case None => (variable, Typed[Value](Value.variable(variable, ty), ty))
}
}.map {
case (variable, Typed(value, expectedType)) => {
val valueType = value.infer
// If the expected type is a subtype of the value type but not vice versa,
// then a type override is needed.
if !(expectedType <:< valueType) && (valueType <:< expectedType) then {
(variable, Typed[Value](Value.typeBarrier(value, expectedType), expectedType))
} else {
(variable, Typed[Value](value, expectedType))
}
}
}
val body = env.withLocals(bindings.toMap) {
implicit env => clause.body.eval(evalMode)
Expand Down
5 changes: 5 additions & 0 deletions saki-core/src/main/scala/saki/core/syntax/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ enum Expr(val span: SourceSpan) extends Entity {
`type`: LiteralType
)(implicit span: SourceSpan) extends Expr(span)

case Union(
types: Seq[Expr]
)(implicit span: SourceSpan) extends Expr(span)

case TypeOf(
value: Expr
)(implicit span: SourceSpan) extends Expr(span)
Expand Down Expand Up @@ -145,6 +149,7 @@ enum Expr(val span: SourceSpan) extends Entity {
case PrimitiveType(ty) => ty.toString
case Unresolved(name) => name
case Variable(ref) => ref.toString
case Union(types) => s"(${types.map(_.toString).mkString(" | ")})"
case TypeOf(value) => s"^($value)"
case Hole(_) => "_"
case Pi(param, result) => s"Π(${param.ident} : ${param.`type`}) -> $result"
Expand Down
Loading

0 comments on commit 421d53b

Please sign in to comment.