use core::{cmp::Ordering, iter, ops, ptr};
use self::fns::{MonoTypeTransformer, ParamMapping};
use crate::{
alloc::{vec, HashMap, HashSet, String, Vec},
arith::{CompleteConstraints, Constraint},
error::{ErrorKind, ErrorPathFragment, OpErrors, TupleContext},
visit::{self, Visit, VisitMut},
Function, Object, PrimitiveType, Tuple, TupleLen, Type, TypeVar, UnknownLen,
};
mod fns;
#[cfg(test)]
mod tests;
#[derive(Debug, Clone, Copy)]
enum LenErrorKind {
UnresolvedParam,
Mismatch,
Dynamic(TupleLen),
}
#[derive(Debug, Clone)]
pub struct Substitutions<Prim: PrimitiveType> {
type_var_count: usize,
eqs: HashMap<usize, Type<Prim>>,
constraints: HashMap<usize, CompleteConstraints<Prim>>,
len_var_count: usize,
length_eqs: HashMap<usize, TupleLen>,
static_lengths: HashSet<usize>,
}
impl<Prim: PrimitiveType> Default for Substitutions<Prim> {
fn default() -> Self {
Self {
type_var_count: 0,
eqs: HashMap::new(),
constraints: HashMap::new(),
len_var_count: 0,
length_eqs: HashMap::new(),
static_lengths: HashSet::new(),
}
}
}
impl<Prim: PrimitiveType> Substitutions<Prim> {
pub fn insert_constraint<C>(
&mut self,
var_idx: usize,
constraint: &C,
mut errors: OpErrors<'_, Prim>,
) where
C: Constraint<Prim> + Clone,
{
for idx in self.equivalent_vars(var_idx) {
let mut current_constraints = self.constraints.remove(&idx).unwrap_or_default();
current_constraints.insert(constraint.clone(), self, errors.by_ref());
self.constraints.insert(idx, current_constraints);
}
}
pub(crate) fn object_constraint(&self, var: TypeVar) -> Option<Object<Prim>> {
if var.is_free() {
let mut ty = self.constraints.get(&var.index())?.object.clone()?;
self.resolver().visit_object_mut(&mut ty);
Some(ty)
} else {
None
}
}
pub(crate) fn insert_obj_constraint(
&mut self,
var_idx: usize,
constraint: &Object<Prim>,
mut errors: OpErrors<'_, Prim>,
) {
let mut checker = OccurrenceChecker::new(self, self.equivalent_vars(var_idx));
checker.visit_object(constraint);
if let Some(var) = checker.recursive_var {
self.handle_recursive_type(Type::Object(constraint.clone()), var, &mut errors);
return;
}
for idx in self.equivalent_vars(var_idx) {
let mut current_constraints = self.constraints.remove(&idx).unwrap_or_default();
current_constraints.insert_obj_constraint(constraint.clone(), self, errors.by_ref());
self.constraints.insert(idx, current_constraints);
}
}
fn handle_recursive_type(
&self,
ty: Type<Prim>,
recursive_var: usize,
errors: &mut OpErrors<'_, Prim>,
) {
let mut resolved_ty = ty;
self.resolver().visit_type_mut(&mut resolved_ty);
TypeSanitizer::new(recursive_var).visit_type_mut(&mut resolved_ty);
errors.push(ErrorKind::RecursiveType(resolved_ty));
}
fn equivalent_vars(&self, var_idx: usize) -> Vec<usize> {
let ty = Type::free_var(var_idx);
let mut ty = &ty;
let mut equivalent_vars = vec![];
while let Type::Var(var) = ty {
debug_assert!(var.is_free());
equivalent_vars.push(var.index());
if let Some(resolved) = self.eqs.get(&var.index()) {
ty = resolved;
} else {
break;
}
}
equivalent_vars
}
#[allow(clippy::missing_panics_doc)]
pub fn apply_static_len(&mut self, len: TupleLen) -> Result<(), ErrorKind<Prim>> {
let resolved = self.resolve_len(len);
self.apply_static_len_inner(resolved)
.map_err(|err| match err {
LenErrorKind::UnresolvedParam => ErrorKind::UnresolvedParam,
LenErrorKind::Dynamic(len) => ErrorKind::DynamicLen(len),
LenErrorKind::Mismatch => unreachable!(),
})
}
fn apply_static_len_inner(&mut self, len: TupleLen) -> Result<(), LenErrorKind> {
match len.components().0 {
None => Ok(()),
Some(UnknownLen::Dynamic) => Err(LenErrorKind::Dynamic(len)),
Some(UnknownLen::Var(var)) => {
if var.is_free() {
self.static_lengths.insert(var.index());
Ok(())
} else {
Err(LenErrorKind::UnresolvedParam)
}
}
}
}
pub fn fast_resolve<'a>(&'a self, mut ty: &'a Type<Prim>) -> &'a Type<Prim> {
while let Type::Var(var) = ty {
if !var.is_free() {
break;
}
if let Some(resolved) = self.eqs.get(&var.index()) {
ty = resolved;
} else {
break;
}
}
ty
}
pub fn resolver(&self) -> impl VisitMut<Prim> + '_ {
TypeResolver {
substitutions: self,
}
}
pub(crate) fn resolve_len(&self, len: TupleLen) -> TupleLen {
let mut resolved = len;
while let (Some(UnknownLen::Var(var)), exact) = resolved.components() {
if !var.is_free() {
break;
}
if let Some(eq_rhs) = self.length_eqs.get(&var.index()) {
resolved = *eq_rhs + exact;
} else {
break;
}
}
resolved
}
pub fn new_type_var(&mut self) -> Type<Prim> {
let new_type = Type::free_var(self.type_var_count);
self.type_var_count += 1;
new_type
}
pub(crate) fn new_len_var(&mut self) -> UnknownLen {
let new_length = UnknownLen::free_var(self.len_var_count);
self.len_var_count += 1;
new_length
}
pub fn unify(&mut self, lhs: &Type<Prim>, rhs: &Type<Prim>, mut errors: OpErrors<'_, Prim>) {
let resolved_lhs = self.fast_resolve(lhs).clone();
let resolved_rhs = self.fast_resolve(rhs).clone();
match (&resolved_lhs, &resolved_rhs) {
(Type::Var(var), ty) => {
if var.is_free() {
self.unify_var(var.index(), ty, true, errors);
} else {
errors.push(ErrorKind::UnresolvedParam);
}
}
(ty, other_ty) if ty == other_ty => {
}
(Type::Dyn(constraints), ty) => {
constraints.inner.apply_all(ty, self, errors);
}
(ty, Type::Var(var)) => {
if var.is_free() {
self.unify_var(var.index(), ty, false, errors);
} else {
errors.push(ErrorKind::UnresolvedParam);
}
}
(Type::Tuple(lhs_tuple), Type::Tuple(rhs_tuple)) => {
self.unify_tuples(lhs_tuple, rhs_tuple, TupleContext::Generic, errors);
}
(Type::Object(lhs_obj), Type::Object(rhs_obj)) => {
self.unify_objects(lhs_obj, rhs_obj, errors);
}
(Type::Function(lhs_fn), Type::Function(rhs_fn)) => {
self.unify_fn_types(lhs_fn, rhs_fn, errors);
}
(ty, other_ty) => {
let mut resolver = self.resolver();
let mut ty = ty.clone();
resolver.visit_type_mut(&mut ty);
let mut other_ty = other_ty.clone();
resolver.visit_type_mut(&mut other_ty);
errors.push(ErrorKind::TypeMismatch(ty, other_ty));
}
}
}
fn unify_tuples(
&mut self,
lhs: &Tuple<Prim>,
rhs: &Tuple<Prim>,
context: TupleContext,
mut errors: OpErrors<'_, Prim>,
) {
let resolved_len = self.unify_lengths(lhs.len(), rhs.len(), context);
let resolved_len = match resolved_len {
Ok(len) => len,
Err(err) => {
self.unify_tuples_after_error(lhs, rhs, &err, context, errors.by_ref());
errors.push(err);
return;
}
};
if let (None, exact) = resolved_len.components() {
self.unify_tuple_elements(lhs.iter(exact), rhs.iter(exact), context, errors);
} else {
for (lhs_elem, rhs_elem) in lhs.equal_elements_dyn(rhs) {
let elem_errors = errors.join_path(match context {
TupleContext::Generic => ErrorPathFragment::TupleElement(None),
TupleContext::FnArgs => ErrorPathFragment::FnArg(None),
});
self.unify(lhs_elem, rhs_elem, elem_errors);
}
}
}
#[inline]
fn unify_tuple_elements<'it>(
&mut self,
lhs_elements: impl Iterator<Item = &'it Type<Prim>>,
rhs_elements: impl Iterator<Item = &'it Type<Prim>>,
context: TupleContext,
mut errors: OpErrors<'_, Prim>,
) {
for (i, (lhs_elem, rhs_elem)) in lhs_elements.zip(rhs_elements).enumerate() {
let location = context.element(i);
self.unify(lhs_elem, rhs_elem, errors.join_path(location));
}
}
fn unify_tuples_after_error(
&mut self,
lhs: &Tuple<Prim>,
rhs: &Tuple<Prim>,
err: &ErrorKind<Prim>,
context: TupleContext,
errors: OpErrors<'_, Prim>,
) {
let (lhs_len, rhs_len) = match err {
ErrorKind::TupleLenMismatch {
lhs: lhs_len,
rhs: rhs_len,
..
} => (*lhs_len, *rhs_len),
_ => return,
};
let (lhs_var, lhs_exact) = lhs_len.components();
let (rhs_var, rhs_exact) = rhs_len.components();
match (lhs_var, rhs_var) {
(None, None) => {
debug_assert_ne!(lhs_exact, rhs_exact);
self.unify_tuple_elements(
lhs.iter(lhs_exact),
rhs.iter(rhs_exact),
context,
errors,
);
}
(None, Some(UnknownLen::Dynamic)) => {
self.unify_tuple_elements(
lhs.iter(lhs_exact),
rhs.iter(rhs_exact),
context,
errors,
);
}
_ => { }
}
}
fn unify_lengths(
&mut self,
lhs: TupleLen,
rhs: TupleLen,
context: TupleContext,
) -> Result<TupleLen, ErrorKind<Prim>> {
let resolved_lhs = self.resolve_len(lhs);
let resolved_rhs = self.resolve_len(rhs);
self.unify_lengths_inner(resolved_lhs, resolved_rhs)
.map_err(|err| match err {
LenErrorKind::UnresolvedParam => ErrorKind::UnresolvedParam,
LenErrorKind::Mismatch => ErrorKind::TupleLenMismatch {
lhs: resolved_lhs,
rhs: resolved_rhs,
context,
},
LenErrorKind::Dynamic(len) => ErrorKind::DynamicLen(len),
})
}
fn unify_lengths_inner(
&mut self,
resolved_lhs: TupleLen,
resolved_rhs: TupleLen,
) -> Result<TupleLen, LenErrorKind> {
let (lhs_var, lhs_exact) = resolved_lhs.components();
let (rhs_var, rhs_exact) = resolved_rhs.components();
let (lhs_var, rhs_var) = match (lhs_var, rhs_var) {
(Some(lhs_var), Some(rhs_var)) => (lhs_var, rhs_var),
(Some(lhs_var), None) if rhs_exact >= lhs_exact => {
return self
.unify_simple_length(lhs_var, TupleLen::from(rhs_exact - lhs_exact), true)
.map(|len| len + lhs_exact);
}
(None, Some(rhs_var)) if lhs_exact >= rhs_exact => {
return self
.unify_simple_length(rhs_var, TupleLen::from(lhs_exact - rhs_exact), false)
.map(|len| len + rhs_exact);
}
(None, None) if lhs_exact == rhs_exact => return Ok(TupleLen::from(lhs_exact)),
_ => return Err(LenErrorKind::Mismatch),
};
match lhs_exact.cmp(&rhs_exact) {
Ordering::Equal => self.unify_simple_length(lhs_var, TupleLen::from(rhs_var), true),
Ordering::Greater => {
let reduced = lhs_var + (lhs_exact - rhs_exact);
self.unify_simple_length(rhs_var, reduced, false)
.map(|len| len + rhs_exact)
}
Ordering::Less => {
let reduced = rhs_var + (rhs_exact - lhs_exact);
self.unify_simple_length(lhs_var, reduced, true)
.map(|len| len + lhs_exact)
}
}
}
fn unify_simple_length(
&mut self,
simple_len: UnknownLen,
source: TupleLen,
is_lhs: bool,
) -> Result<TupleLen, LenErrorKind> {
match simple_len {
UnknownLen::Var(var) if var.is_free() => self.unify_var_length(var.index(), source),
UnknownLen::Dynamic => self.unify_dyn_length(source, is_lhs),
_ => Err(LenErrorKind::UnresolvedParam),
}
}
#[inline]
fn unify_var_length(
&mut self,
var_idx: usize,
source: TupleLen,
) -> Result<TupleLen, LenErrorKind> {
match source.components() {
(Some(UnknownLen::Var(var)), _) if !var.is_free() => Err(LenErrorKind::UnresolvedParam),
(Some(UnknownLen::Var(var)), offset) if var.index() == var_idx => {
if offset == 0 {
Ok(source)
} else {
Err(LenErrorKind::Mismatch)
}
}
_ => {
if self.static_lengths.contains(&var_idx) {
self.apply_static_len_inner(source)?;
}
self.length_eqs.insert(var_idx, source);
Ok(source)
}
}
}
#[inline]
fn unify_dyn_length(
&mut self,
source: TupleLen,
is_lhs: bool,
) -> Result<TupleLen, LenErrorKind> {
if is_lhs {
Ok(source) } else {
let source_var_idx = match source.components() {
(Some(UnknownLen::Var(var)), 0) if var.is_free() => var.index(),
(Some(UnknownLen::Dynamic), 0) => return Ok(source),
_ => return Err(LenErrorKind::Mismatch),
};
self.unify_var_length(source_var_idx, UnknownLen::Dynamic.into())
}
}
fn unify_objects(
&mut self,
lhs: &Object<Prim>,
rhs: &Object<Prim>,
mut errors: OpErrors<'_, Prim>,
) {
let lhs_fields: HashSet<_> = lhs.field_names().collect();
let rhs_fields: HashSet<_> = rhs.field_names().collect();
if lhs_fields == rhs_fields {
for (field_name, ty) in lhs.iter() {
self.unify(ty, &rhs[field_name], errors.join_path(field_name));
}
} else {
errors.push(ErrorKind::FieldsMismatch {
lhs_fields: lhs_fields.into_iter().map(String::from).collect(),
rhs_fields: rhs_fields.into_iter().map(String::from).collect(),
});
}
}
fn unify_fn_types(
&mut self,
lhs: &Function<Prim>,
rhs: &Function<Prim>,
mut errors: OpErrors<'_, Prim>,
) {
if lhs.is_parametric() {
errors.push(ErrorKind::UnsupportedParam);
return;
}
let instantiated_lhs = self.instantiate_function(lhs);
let instantiated_rhs = self.instantiate_function(rhs);
self.unify_tuples(
&instantiated_rhs.args,
&instantiated_lhs.args,
TupleContext::FnArgs,
errors.by_ref(),
);
self.unify(
&instantiated_lhs.return_type,
&instantiated_rhs.return_type,
errors.join_path(ErrorPathFragment::FnReturnType),
);
}
fn instantiate_function(&mut self, fn_type: &Function<Prim>) -> Function<Prim> {
if !fn_type.is_parametric() {
return fn_type.clone();
}
let fn_params = fn_type.params.as_ref().expect("fn with params");
let mapping = ParamMapping {
types: fn_params
.type_params
.iter()
.enumerate()
.map(|(i, (var_idx, _))| (*var_idx, self.type_var_count + i))
.collect(),
lengths: fn_params
.len_params
.iter()
.enumerate()
.map(|(i, (var_idx, _))| (*var_idx, self.len_var_count + i))
.collect(),
};
self.type_var_count += fn_params.type_params.len();
self.len_var_count += fn_params.len_params.len();
let mut instantiated_fn_type = fn_type.clone();
MonoTypeTransformer::transform(&mapping, &mut instantiated_fn_type);
for (original_idx, is_static) in &fn_params.len_params {
if *is_static {
let new_idx = mapping.lengths[original_idx];
self.static_lengths.insert(new_idx);
}
}
for (original_idx, constraints) in &fn_params.type_params {
let new_idx = mapping.types[original_idx];
let mono_constraints =
MonoTypeTransformer::transform_constraints(&mapping, constraints);
self.constraints.insert(new_idx, mono_constraints);
}
instantiated_fn_type
}
fn unify_var(
&mut self,
var_idx: usize,
ty: &Type<Prim>,
is_lhs: bool,
mut errors: OpErrors<'_, Prim>,
) {
debug_assert!(is_lhs || !matches!(ty, Type::Any | Type::Dyn(_)));
debug_assert!(!self.eqs.contains_key(&var_idx));
debug_assert!(if let Type::Var(var) = ty {
!self.eqs.contains_key(&var.index())
} else {
true
});
if let Type::Var(var) = ty {
if !var.is_free() {
errors.push(ErrorKind::UnresolvedParam);
return;
} else if var.index() == var_idx {
return;
}
}
let mut checker = OccurrenceChecker::new(self, iter::once(var_idx));
checker.visit_type(ty);
if let Some(var) = checker.recursive_var {
self.handle_recursive_type(ty.clone(), var, &mut errors);
} else {
let mut ty = ty.clone();
if !is_lhs {
TypeSpecifier::new(self).visit_type_mut(&mut ty);
}
self.eqs.insert(var_idx, ty.clone());
if let Some(constraints) = self.constraints.get(&var_idx).cloned() {
constraints.apply_all(&ty, self, errors);
}
}
}
}
#[derive(Debug)]
struct OccurrenceChecker<'a, Prim: PrimitiveType> {
substitutions: &'a Substitutions<Prim>,
var_indexes: HashSet<usize>,
recursive_var: Option<usize>,
}
impl<'a, Prim: PrimitiveType> OccurrenceChecker<'a, Prim> {
fn new(
substitutions: &'a Substitutions<Prim>,
var_indexes: impl IntoIterator<Item = usize>,
) -> Self {
Self {
substitutions,
var_indexes: var_indexes.into_iter().collect(),
recursive_var: None,
}
}
}
impl<Prim: PrimitiveType> Visit<Prim> for OccurrenceChecker<'_, Prim> {
fn visit_type(&mut self, ty: &Type<Prim>) {
if self.recursive_var.is_some() {
} else {
visit::visit_type(self, ty);
}
}
fn visit_var(&mut self, var: TypeVar) {
if !var.is_free() {
return;
}
let var_idx = var.index();
if self.var_indexes.contains(&var_idx) {
self.recursive_var = Some(var_idx);
} else if let Some(ty) = self.substitutions.eqs.get(&var_idx) {
self.visit_type(ty);
}
}
}
#[derive(Debug)]
struct TypeSanitizer {
fixed_idx: usize,
}
impl TypeSanitizer {
fn new(fixed_idx: usize) -> Self {
Self { fixed_idx }
}
}
impl<Prim: PrimitiveType> VisitMut<Prim> for TypeSanitizer {
fn visit_type_mut(&mut self, ty: &mut Type<Prim>) {
match ty {
Type::Var(var) if var.index() == self.fixed_idx => {
*ty = Type::param(0);
}
_ => visit::visit_type_mut(self, ty),
}
}
}
#[derive(Debug, Clone, Copy)]
struct TypeResolver<'a, Prim: PrimitiveType> {
substitutions: &'a Substitutions<Prim>,
}
impl<Prim: PrimitiveType> VisitMut<Prim> for TypeResolver<'_, Prim> {
fn visit_type_mut(&mut self, ty: &mut Type<Prim>) {
let fast_resolved = self.substitutions.fast_resolve(ty);
if !ptr::eq(ty, fast_resolved) {
*ty = fast_resolved.clone();
}
visit::visit_type_mut(self, ty);
}
fn visit_middle_len_mut(&mut self, len: &mut TupleLen) {
*len = self.substitutions.resolve_len(*len);
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum Variance {
Co,
Contra,
}
impl ops::Not for Variance {
type Output = Self;
fn not(self) -> Self {
match self {
Self::Co => Self::Contra,
Self::Contra => Self::Co,
}
}
}
#[derive(Debug)]
struct TypeSpecifier<'a, Prim: PrimitiveType> {
substitutions: &'a mut Substitutions<Prim>,
variance: Variance,
}
impl<'a, Prim: PrimitiveType> TypeSpecifier<'a, Prim> {
fn new(substitutions: &'a mut Substitutions<Prim>) -> Self {
Self {
substitutions,
variance: Variance::Co,
}
}
}
impl<Prim: PrimitiveType> VisitMut<Prim> for TypeSpecifier<'_, Prim> {
fn visit_type_mut(&mut self, ty: &mut Type<Prim>) {
match ty {
Type::Any if self.variance == Variance::Co => {
*ty = self.substitutions.new_type_var();
}
Type::Dyn(constraints) if self.variance == Variance::Co => {
let var_idx = self.substitutions.type_var_count;
self.substitutions
.constraints
.insert(var_idx, constraints.inner.clone());
*ty = Type::free_var(var_idx);
self.substitutions.type_var_count += 1;
}
_ => visit::visit_type_mut(self, ty),
}
}
fn visit_middle_len_mut(&mut self, len: &mut TupleLen) {
if self.variance != Variance::Co {
return;
}
if let (Some(var_len @ UnknownLen::Dynamic), _) = len.components_mut() {
*var_len = self.substitutions.new_len_var();
}
}
fn visit_function_mut(&mut self, function: &mut Function<Prim>) {
self.visit_type_mut(&mut function.return_type);
let old_variance = self.variance;
self.variance = !self.variance;
self.visit_tuple_mut(&mut function.args);
self.variance = old_variance;
}
}