Skip to content

Commit

Permalink
feat(sol-thir-lowering): implement infer_lam
Browse files Browse the repository at this point in the history
  • Loading branch information
aripiprazole committed May 30, 2024
1 parent 3e65a4d commit 9189a3a
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 32 deletions.
17 changes: 0 additions & 17 deletions sol-thir-lowering/src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,6 @@ fn lam_pi(
Ok(Term::Lam(definition, pi.implicitness, elab_term.into()))
}

enum Curried {
Lam(Definition, Box<Curried>),
Expr(Expr),
}

fn new_curried_function(db: &dyn ThirLoweringDb, abs: LamExpr) -> Curried {
let mut acc = Curried::Expr(*abs.value);
for parameter in abs.parameters.into_iter() {
let parameter = extract_parameter_definition(db, parameter);
acc = Curried::Lam(parameter, Box::new(acc));
}
if let Curried::Expr(_) = acc {
todo!("handle: no parameters")
}
acc
}

#[rustfmt::skip]
fn lam_thir_check(db: &dyn ThirLoweringDb, ctx: Context, expr: Curried, type_repr: Type, icit: Implicitness) -> sol_diagnostic::Result<Term> {
match (&expr, &type_repr) {
Expand Down
52 changes: 37 additions & 15 deletions sol-thir-lowering/src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use sol_diagnostic::fail;
use sol_diagnostic::{fail, Result};
use sol_thir::{
find_reference_type, infer_constructor,
shared::{Constructor, ConstructorKind},
Expand All @@ -7,6 +7,15 @@ use sol_thir::{

use super::*;

#[derive(Debug, thiserror::Error, miette::Diagnostic)]
#[error("unsupported term")]
#[diagnostic(code(sol::thir::unsupported_term))]
pub struct UnsupportedTermError {
#[source_code]
#[label = "here"]
pub location: Location,
}

fn create_from_type(definition: sol_hir::source::expr::Type, location: Location) -> Term {
use sol_hir::source::expr::Type::*;

Expand All @@ -31,29 +40,42 @@ fn create_from_type(definition: sol_hir::source::expr::Type, location: Location)
})
}

#[derive(Debug, thiserror::Error, miette::Diagnostic)]
#[error("unsupported term")]
#[diagnostic(code(sol::thir::unsupported_term))]
pub struct UnsupportedTerm {
#[source_code]
#[label = "here"]
pub location: Location,
fn infer_lam(db: &dyn ThirLoweringDb, ctx: Context, fun: Curried) -> Result<ElaboratedTerm> {
match fun {
Curried::Lam(domain, codomain) => {
let domain_type = Value::default();
let codomain_ctx = ctx.create_new_value(db, domain, domain_type.clone());
let ElaboratedTerm(codomain_term, codomain_type) =
infer_lam(db, codomain_ctx, *codomain)?;
let term = Term::Lam(domain, Implicitness::Explicit, codomain_term.clone().into());

Ok(ElaboratedTerm(
term,
Value::Pi(Pi {
name: Some(domain),
implicitness: Implicitness::Explicit,
domain: Box::new(domain_type),
codomain: Closure {
env: ctx.locals(db),
expr: db.thir_quote(ctx.lvl(db), codomain_type)?,
},
}),
))
}
Curried::Expr(expr) => thir_infer(db, ctx, expr),
}
}

/// The infer function to infer the type of the term.
#[salsa::tracked]
pub fn thir_infer(
db: &dyn ThirLoweringDb,
ctx: Context,
expr: Expr,
) -> sol_diagnostic::Result<ElaboratedTerm> {
pub fn thir_infer(db: &dyn ThirLoweringDb, ctx: Context, expr: Expr) -> Result<ElaboratedTerm> {
use sol_hir::source::expr::Pi as EPi;
use sol_hir::source::pattern::Pattern;
use Expr::*;

Ok(ElaboratedTerm::from(match expr {
Empty | Error(_) | Match(_) | Sigma(_) => {
return fail(UnsupportedTerm {
return fail(UnsupportedTermError {
location: expr.location(db),
})
}
Expand Down Expand Up @@ -86,7 +108,7 @@ pub fn thir_infer(
(term, actual_type)
}
Call(_) => todo!(),
Lam(_) => todo!(),
Lam(lam) => return infer_lam(db, ctx, new_curried_function(db, lam)),
Pi(EPi {
parameters, value, ..
}) => {
Expand Down
17 changes: 17 additions & 0 deletions sol-thir-lowering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,23 @@ pub fn thir_quote(
.unwrap_or_else(|| thir_quote_impl(db, None, lvl, value))
}

enum Curried {
Lam(Definition, Box<Curried>),
Expr(Expr),
}

fn new_curried_function(db: &dyn ThirLoweringDb, abs: LamExpr) -> Curried {
let mut acc = Curried::Expr(*abs.value);
for parameter in abs.parameters.into_iter() {
let parameter = extract_parameter_definition(db, parameter);
acc = Curried::Lam(parameter, Box::new(acc));
}
if let Curried::Expr(_) = acc {
todo!("handle: no parameters")
}
acc
}

pub fn extract_parameter_definition(db: &dyn ThirLoweringDb, pattern: Pattern) -> Definition {
let location = pattern.location(db);
let hole = HirPath::create(db, "_");
Expand Down

0 comments on commit 9189a3a

Please sign in to comment.