From a8ec38633fbf1c4c631ebf9498ad397f64364c9e Mon Sep 17 00:00:00 2001 From: Vidush Singhal Date: Mon, 2 Dec 2024 18:05:24 -0500 Subject: [PATCH] extended convertTy to handle SoA case --- .../src/Gibbon/Passes/InferLocations.hs | 332 ++++++++++-------- 1 file changed, 195 insertions(+), 137 deletions(-) diff --git a/gibbon-compiler/src/Gibbon/Passes/InferLocations.hs b/gibbon-compiler/src/Gibbon/Passes/InferLocations.hs index c37d16a4..ee4fc4c4 100644 --- a/gibbon-compiler/src/Gibbon/Passes/InferLocations.hs +++ b/gibbon-compiler/src/Gibbon/Passes/InferLocations.hs @@ -103,6 +103,7 @@ import qualified Gibbon.L1.Syntax as L1 import Gibbon.L2.Syntax as L2 hiding (extendVEnv, extendsVEnv, lookupVEnv, lookupFEnv) import Gibbon.Passes.InlineTriv (inlineTriv) import Gibbon.Passes.Flatten (flattenL1) +import Gibbon.DynFlags -------------------------------------------------------------------------------- -- Environments @@ -134,14 +135,16 @@ lookupFEnv v FullEnv{funEnv} = funEnv # v -- If we assume output regions are disjoint from input ones, then we -- can instantiate an L1 function type into a polymorphic L2 one, -- mechanically. -convertFunTy :: ([Ty1],Ty1,Bool) -> PassM (ArrowTy2 Ty2) -convertFunTy (from,to,isPar) = do - from' <- mapM convertTy from - to' <- convertTy to +convertFunTy :: DDefs1 -> ([Ty1],Ty1,Bool) -> PassM (ArrowTy2 Ty2) +convertFunTy ddefs (from,to,isPar) = do + dflags <- getDynFlags + let useSoA = gopt Opt_Packed_SoA dflags + from' <- mapM (convertTy ddefs useSoA) from + to' <- convertTy ddefs useSoA to -- For this simple version, we assume every location is in a separate region: lrm1 <- concat <$> mapM (toLRM Input) from' lrm2 <- toLRM Output to' - return $ ArrowTy2 { locVars = lrm1 ++ lrm2 + dbgTraceIt "convertFunTy: " dbgTraceIt (sdoc (from', to', lrm1, lrm2, useSoA)) dbgTraceIt "\n" return $ ArrowTy2 { locVars = lrm1 ++ lrm2 , arrIns = from' , arrEffs = S.empty , arrOut = to' @@ -153,15 +156,70 @@ convertFunTy (from,to,isPar) = do return $ LRM v (AoSR $ VarR (unwrapLocVar r)) md) (F.toList ls) -convertTy :: Ty1 -> PassM Ty2 -convertTy ty = traverse (const (freshLocVar "loc")) ty +convertTy :: DDefs1 -> Bool -> Ty1 -> PassM Ty2 +convertTy ddefs useSoA ty = case useSoA of + False -> traverse (const (freshLocVar "loc")) ty + True -> case ty of + PackedTy tycon _ -> do + dconBuff <- freshLocVar "loc" + let ddef = lookupDDef ddefs tycon + let dcons = getConOrdering ddefs tycon + locsForFields <- convertTyHelperSoAParent tycon ddefs dcons + let soaLocation = SoA (unwrapLocVar dconBuff) locsForFields + return $ PackedTy tycon soaLocation + _ -> traverse (const (freshLocVar "loc")) ty + +convertTyHelperSoAParent :: TyCon -> DDefs1 -> [DataCon] -> PassM [((DataCon, Int), Var)] +convertTyHelperSoAParent tycon ddefs dcons = do + case dcons of + [] -> return [] + d:rst -> do + out <- convertTyHelperSoAChild tycon ddefs d + outRst <- convertTyHelperSoAParent tycon ddefs rst + return $ out ++ outRst + + + + +convertTyHelperSoAChild :: TyCon -> DDefs1 -> DataCon -> PassM [((DataCon, Int), Var)] +convertTyHelperSoAChild tycon ddefs dcon = do + let fields = lookupDataCon ddefs dcon + let fields' = P.concatMap (\f -> case f of + PackedTy tycon' _ -> if tycon == tycon' + then [] + else [f] + _ -> [f] + + ) fields + let numFields = L.length fields' + let indices = [0 .. numFields] + let namesOfLocs = P.map show fields' + let zipped = zip namesOfLocs indices + out <- convertTyHelperGetLocForField dcon zipped + return out + +convertTyHelperGetLocForField :: DataCon -> [(String, Int)] -> PassM [((DataCon, Int), Var)] +convertTyHelperGetLocForField dcon zipped = do + case zipped of + [] -> return [] + (x, y):xs -> do + elem <- convertTyHelperGetLocForField' dcon y x + rst <- convertTyHelperGetLocForField dcon xs + return $ [elem] ++ rst + +convertTyHelperGetLocForField' :: DataCon -> Int -> String -> PassM ((DataCon, Int), Var) +convertTyHelperGetLocForField' dcon index nameForLoc = do + loc' <- freshLocVar $ "loc_" ++ nameForLoc + return ((dcon, index), (unwrapLocVar loc')) + + convertDDefs :: DDefs Ty1 -> PassM (DDefs Ty2) convertDDefs ddefs = traverse f ddefs where f (DDef tyargs n dcs) = do dcs' <- forM dcs $ \(dc,bnds) -> do bnds' <- forM bnds $ \(isb,ty) -> do - ty' <- convertTy ty + ty' <- convertTy ddefs False ty return (isb, ty') return (dc,bnds') return $ DDef tyargs n dcs' @@ -224,13 +282,13 @@ inferLocs initPrg = do dfs' <- lift $ lift $ convertDDefs dfs fenv <- forM fds $ \(FunDef _ _ (intys, outty) bod _meta) -> do let has_par = hasSpawns bod - lift $ lift $ convertFunTy (intys,outty,has_par) + lift $ lift $ convertFunTy dfs (intys,outty,has_par) let fe = FullEnv dfs' M.empty fenv me' <- case me of -- We ignore the type of the main expression inferred in L1.. -- Probably should add a small check here Just (me,_ty) -> do - (me',ty') <- inferExp' fe me [] NoDest + (me',ty') <- inferExp' dfs fe me [] NoDest return $ Just (me',ty') Nothing -> return Nothing fds' <- forM fds $ \(FunDef fn fa (intty,outty) fbod meta) -> do @@ -239,7 +297,7 @@ inferLocs initPrg = do boundLocs = concat $ map locsInTy (arrIns arrty ++ [arrOut arrty]) dest <- destFromType (arrOut arrty) mapM_ fixType_ (arrIns arrty) - (fbod',_) <- inferExp' fe' fbod boundLocs dest + (fbod',_) <- inferExp' dfs fe' fbod boundLocs dest return $ FunDef fn fa arrty fbod' meta return $ Prog dfs' fds' me' prg <- St.runStateT (runExceptT m) M.empty @@ -291,8 +349,8 @@ fixType_ ty = _ -> return () -- | Wrap the inferExp procedure, and consume all remaining constraints -inferExp' :: FullEnv -> Exp1 -> [LocVar] -> Dest -> TiM (L2.Exp2, L2.Ty2) -inferExp' env exp bound dest= +inferExp' :: DDefs1 -> FullEnv -> Exp1 -> [LocVar] -> Dest -> TiM (L2.Exp2, L2.Ty2) +inferExp' ddefs env exp bound dest= let -- TODO: These should not be necessary, eventually @@ -326,7 +384,7 @@ inferExp' env exp bound dest= in LetE (v',[],copyRetTy, AppE f lvs [VarE v1]) $ Ext (LetLocE lv1 (AfterVariableLE v' lv2 True) a') - in do res <- inferExp env exp dest + in do res <- inferExp ddefs env exp dest (e,ty,cs) <- bindAllLocations res e' <- finishExp e let (e'',s) = cleanExp e' @@ -336,8 +394,8 @@ inferExp' env exp bound dest= -- | We proceed in a destination-passing style given the target region -- into which we must produce the resulting value. -inferExp :: FullEnv -> Exp1 -> Dest -> TiM Result -inferExp env@FullEnv{dataDefs} ex0 dest = +inferExp :: DDefs1 -> FullEnv -> Exp1 -> Dest -> TiM Result +inferExp ddefs env@FullEnv{dataDefs} ex0 dest = let -- | Check if there are any StartRegion constraints that can be dischaged here. @@ -533,7 +591,7 @@ inferExp env@FullEnv{dataDefs} ex0 dest = let contys = lookupDataCon ddfs con newtys = L.map (\(ty,(_,lv)) -> fmap (const lv) ty) $ zip contys vars' env' = L.foldr (\(v,ty) a -> extendVEnv v ty a) env $ zip (L.map fst vars') newtys - res <- inferExp env' rhs dst + res <- inferExp ddefs env' rhs dst (rhs',ty',cs') <- bindAfterLocs (orderOfVarsOutputDataConE rhs) res -- let cs'' = removeLocs (L.map snd vars') cs' -- TODO: check constraints are correct and fail/repair if they're not!!! @@ -568,7 +626,7 @@ inferExp env@FullEnv{dataDefs} ex0 dest = ProjE i w -> do (e', ty) <- case w of VarE v -> pure (ProjE i (VarE v), let ProdTy tys = lookupVEnv v env in tys !! i) - w' -> (\(e, ProdTy bs, _) -> (ProjE i e, bs !! i)) <$> inferExp env w dest + w' -> (\(e, ProdTy bs, _) -> (ProjE i e, bs !! i)) <$> inferExp ddefs env w dest case dest of NoDest -> return (e', ty, []) TupleDest ds -> err "TODO: handle tuple of destinations for ProjE" @@ -585,24 +643,24 @@ inferExp env@FullEnv{dataDefs} ex0 dest = MkProdE ls -> case dest of - NoDest -> do results <- mapM (\e -> inferExp env e NoDest) ls + NoDest -> do results <- mapM (\e -> inferExp ddefs env e NoDest) ls let pty = case results of [(_,ty,_)] -> ty _ -> ProdTy ([b | (_,b,_) <- results]) return (MkProdE ([a | (a,_,_) <- results]), pty, concat $ [c | (_,_,c) <- results]) SingleDest d -> case ls of - [e] -> do (e',ty,les) <- inferExp env e dest + [e] -> do (e',ty,les) <- inferExp ddefs env e dest return (MkProdE [e'], ty, les) _ -> err $ "Cannot match single destination to tuple: " ++ show ex0 - TupleDest ds -> do results <- mapM (\(e,d) -> inferExp env e d) $ zip ls ds + TupleDest ds -> do results <- mapM (\(e,d) -> inferExp ddefs env e d) $ zip ls ds return (MkProdE ([a | (a,_,_) <- results]), ProdTy ([b | (_,b,_) <- results]), concat $ [c | (_,_,c) <- results]) SpawnE f _ args -> do - (ex0', ty, acs) <- inferExp env (AppE f [] args) dest + (ex0', ty, acs) <- inferExp ddefs env (AppE f [] args) dest case ex0' of AppE f' locs args' -> pure (SpawnE f' locs args', ty, acs) oth -> err $ "SpawnE: " ++ sdoc oth @@ -621,7 +679,7 @@ inferExp env@FullEnv{dataDefs} ex0 dest = -- /cc @vollmerm argTys <- mapM freshTyLocs $ arrIns arrty argDests <- mapM destFromType' argTys - (args', atys, acss) <- L.unzip3 <$> mapM (uncurry $ inferExp env) (zip args argDests) + (args', atys, acss) <- L.unzip3 <$> mapM (uncurry $ inferExp ddefs env) (zip args argDests) let acs = concat acss case dest of SingleDest d -> do @@ -642,11 +700,11 @@ inferExp env@FullEnv{dataDefs} ex0 dest = _ -> err$ "(AppE) Cannot unify NoDest with " ++ sdoc valTy ++ ". This might be caused by a main expression having a packed type." ++ sdoc ex0 TimeIt e t b -> - do (e',ty',cs') <- inferExp env e dest + do (e',ty',cs') <- inferExp ddefs env e dest return (TimeIt e' ty' b, ty', cs') WithArenaE v e -> - do (e',ty',cs') <- inferExp (extendVEnv v ArenaTy env) e dest + do (e',ty',cs') <- inferExp ddefs (extendVEnv v ArenaTy env) e dest return (WithArenaE v e', ty', cs') DataConE () k [] -> do @@ -663,14 +721,14 @@ inferExp env@FullEnv{dataDefs} ex0 dest = NoDest -> do -- CSK: Should this really be an error ? loc <- lift $ lift $ freshLocVar "datacon" - (e',ty,cs) <- inferExp env (DataConE () k ls) (SingleDest loc) + (e',ty,cs) <- inferExp ddefs env (DataConE () k ls) (SingleDest loc) fcs <- tryInRegion cs tryBindReg (e', ty, fcs) TupleDest _ds -> err $ "Expected single location destination for DataConE" ++ sdoc ex0 SingleDest d -> do locs <- sequence $ replicate (length ls) fresh mapM_ fixLoc locs -- Don't allow argument locations to freely unify - ls' <- mapM (\(e,lv) -> (inferExp env e $ SingleDest lv)) $ zip ls locs + ls' <- mapM (\(e,lv) -> (inferExp ddefs env e $ SingleDest lv)) $ zip ls locs -- let ls'' = L.map unNestLet ls' -- bnds = catMaybes $ L.map pullBnds ls' -- env' = addCopyVarToEnv ls' env @@ -745,17 +803,17 @@ inferExp env@FullEnv{dataDefs} ex0 dest = IfE a b c@ce -> do -- Here we blithely assume BoolTy because L1 typechecking has already passed: - (a',bty,acs) <- inferExp env a NoDest + (a',bty,acs) <- inferExp ddefs env a NoDest assumeEq bty BoolTy -- Here BOTH branches are unified into the destination, so -- there is no need to unify with eachother. - res <- inferExp env b dest + res <- inferExp ddefs env b dest -- bind variables after if branch -- This ensures that the location bindings are not freely floated up to the upper level expressions (b',tyb,csb) <- bindAfterLocs (removeDuplicates (orderOfVarsOutputDataConE b)) res -- Else branch - res' <- inferExp env c dest + res' <- inferExp ddefs env c dest -- bind variables after else branch -- This ensures that the location bindings are not freely floated up to the upper level expressions (c',tyc,csc) <- bindAfterLocs (removeDuplicates (orderOfVarsOutputDataConE c)) res' @@ -766,22 +824,22 @@ inferExp env@FullEnv{dataDefs} ex0 dest = case dest of SingleDest _ -> err "Cannot unify DictInsert with destination" TupleDest _ -> err "Cannot unify DictInsert with destination" - NoDest -> do (d',SymDictTy ar dty',_dcs) <- inferExp env d NoDest - (k',_,_kcs) <- inferExp env k NoDest - dty'' <- lift $ lift $ convertTy dty + NoDest -> do (d',SymDictTy ar dty',_dcs) <- inferExp ddefs env d NoDest + (k',_,_kcs) <- inferExp ddefs env k NoDest + dty'' <- lift $ lift $ convertTy ddefs False dty r <- lift $ lift $ freshRegVar loc <- lift $ lift $ freshLocVar "ins" -- _ <- fixLoc loc - (v',vty,vcs) <- inferExp env v $ SingleDest loc + (v',vty,vcs) <- inferExp ddefs env v $ SingleDest loc let cs = vcs -- (StartRegionL loc r) : vcs dummyDty <- dummyTyLocs dty' return (PrimAppE (DictInsertP dummyDty) [(VarE var),d',k',v'], SymDictTy (Just var) $ stripTyLocs dty'', cs) PrimAppE (DictLookupP dty) [d,k] -> case dest of - SingleDest loc -> do (d',SymDictTy _ _dty,_dcs) <- inferExp env d NoDest - (k',_,_kcs) <- inferExp env k NoDest - dty' <- lift $ lift $ convertTy dty + SingleDest loc -> do (d',SymDictTy _ _dty,_dcs) <- inferExp ddefs env d NoDest + (k',_,_kcs) <- inferExp ddefs env k NoDest + dty' <- lift $ lift $ convertTy ddefs False dty let loc' = locOfTy dty' _ <- fixLoc loc' let e' = PrimAppE (DictLookupP dty') [d',k'] @@ -796,15 +854,15 @@ inferExp env@FullEnv{dataDefs} ex0 dest = case dest of SingleDest _ -> err "Cannot unify DictEmpty with destination" TupleDest _ -> err "Cannot unify DictEmpty with destination" - NoDest -> do dty' <- lift $ lift $ convertTy dty + NoDest -> do dty' <- lift $ lift $ convertTy ddefs False dty return (PrimAppE (DictEmptyP dty') [(VarE var)], SymDictTy (Just var) $ stripTyLocs dty', []) PrimAppE (DictHasKeyP dty) [d,k] -> case dest of SingleDest _ -> err "Cannot unify DictEmpty with destination" TupleDest _ -> err "Cannot unify DictEmpty with destination" - NoDest -> do (d',SymDictTy _ dty',_dcs) <- inferExp env d NoDest - (k',_,_kcs) <- inferExp env k NoDest + NoDest -> do (d',SymDictTy _ dty',_dcs) <- inferExp ddefs env d NoDest + (k',_,_kcs) <- inferExp ddefs env k NoDest dummyDty <- dummyTyLocs dty' return (PrimAppE (DictHasKeyP dummyDty) [d',k'], BoolTy, []) @@ -814,11 +872,11 @@ inferExp env@FullEnv{dataDefs} ex0 dest = case dest of SingleDest d -> err $ "Cannot unify primop " ++ sdoc pr ++ " with destination " ++ sdoc d TupleDest d -> err $ "Cannot unify primop " ++ sdoc pr ++ " with destination " ++ sdoc d - NoDest -> do results <- mapM (\e -> inferExp env e NoDest) [VarE ls] + NoDest -> do results <- mapM (\e -> inferExp ddefs env e NoDest) [VarE ls] -- Assume arguments to PrimAppE are trivial -- so there's no need to deal with constraints or locations - ty <- lift $ lift $ convertTy $ primRetTy pr - pr' <- lift $ lift $ prim pr + ty <- lift $ lift $ convertTy ddefs False $ primRetTy pr + pr' <- lift $ lift $ prim ddefs pr let args = [a | (a,_,_) <- results] ++ [VarE fp] return (PrimAppE pr' args, ty, []) @@ -827,23 +885,23 @@ inferExp env@FullEnv{dataDefs} ex0 dest = SingleDest d -> err $ "Cannot unify primop " ++ sdoc pr ++ " with destination " ++ sdoc dest ++ "in " ++ sdoc ex0 TupleDest d -> case pr of - PrintInt -> inferExp env ex0 NoDest - PrintFloat -> inferExp env ex0 NoDest - PrintBool -> inferExp env ex0 NoDest - PrintSym -> inferExp env ex0 NoDest - VNthP{} -> inferExp env ex0 NoDest + PrintInt -> inferExp ddefs env ex0 NoDest + PrintFloat -> inferExp ddefs env ex0 NoDest + PrintBool -> inferExp ddefs env ex0 NoDest + PrintSym -> inferExp ddefs env ex0 NoDest + VNthP{} -> inferExp ddefs env ex0 NoDest _ -> err $ "Cannot unify primop " ++ sdoc pr ++ " with destination " ++ sdoc dest ++ "in " ++ sdoc ex0 - NoDest -> do results <- mapM (\e -> inferExp env e NoDest) es + NoDest -> do results <- mapM (\e -> inferExp ddefs env e NoDest) es -- Assume arguments to PrimAppE are trivial -- so there's no need to deal with constraints or locations - ty <- lift $ lift $ convertTy $ primRetTy pr - pr' <- lift $ lift $ prim pr + ty <- lift $ lift $ convertTy ddefs False $ primRetTy pr + pr' <- lift $ lift $ prim ddefs pr return (PrimAppE pr' [a | (a,_,_) <- results], ty, []) CaseE ex ls -> do -- Case expressions introduce fresh destinations for the scrutinee: loc <- lift $ lift $ freshLocVar "scrut" - (ex',ty2,cs) <- inferExp env ex (SingleDest loc) + (ex',ty2,cs) <- inferExp ddefs env ex (SingleDest loc) let src = locOfTy ty2 pairs <- mapM (doCase dataDefs env src dest) ls return (CaseE ex' ([a | (a,_,_) <- pairs]), @@ -863,10 +921,10 @@ inferExp env@FullEnv{dataDefs} ex0 dest = -- /cc @vollmerm argTys <- mapM freshTyLocs $ arrIns arrty argDests <- mapM destFromType' argTys - (args', atys, acss) <- L.unzip3 <$> mapM (uncurry $ inferExp env) (zip args argDests) + (args', atys, acss) <- L.unzip3 <$> mapM (uncurry $ inferExp ddefs env) (zip args argDests) let acs = concat acss tupBod <- projTups valTy (VarE vr) bod - res <- inferExp (extendVEnv vr valTy env) tupBod dest + res <- inferExp ddefs (extendVEnv vr valTy env) tupBod dest (bod'',ty'',cs'') <- handleTrailingBindLoc vr res vcs <- tryNeedRegion (locsInTy valTy) ty'' $ acs ++ cs'' fcs <- tryInRegion vcs @@ -880,7 +938,7 @@ inferExp env@FullEnv{dataDefs} ex0 dest = let _ret_ty = arrOut $ lookupFEnv f env -- if isScalarTy ret_ty || isPackedTy ret_ty -- then do - (ex0', ty, cs) <- inferExp env (LetE (vr,locs,bty,(AppE f [] args)) bod) dest + (ex0', ty, cs) <- inferExp ddefs env (LetE (vr,locs,bty,(AppE f [] args)) bod) dest -- Assume that all args are VarE's let args2 = map (\e -> case e of (VarE v) -> VarE v @@ -894,36 +952,36 @@ inferExp env@FullEnv{dataDefs} ex0 dest = pure (ex0'', ty, cs) SyncE -> do - (bod',ty,cs) <- inferExp env bod dest + (bod',ty,cs) <- inferExp ddefs env bod dest pure (LetE (vr,[],ProdTy [],SyncE) bod', ty, cs) IfE a b c -> do - (boda,tya,csa) <- inferExp env a NoDest + (boda,tya,csa) <- inferExp ddefs env a NoDest -- just assuming tyb == tyc - res <- inferExp env b NoDest + res <- inferExp ddefs env b NoDest (bodb,tyb,csb) <- bindAfterLocs (removeDuplicates (orderOfVarsOutputDataConE b)) res - res' <- inferExp env c NoDest + res' <- inferExp ddefs env c NoDest (bodc,tyc,csc) <- bindAfterLocs (removeDuplicates (orderOfVarsOutputDataConE c)) res' - (bod',ty',cs') <- inferExp (extendVEnv vr tyc env) bod dest + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr tyc env) bod dest let cs = L.nub $ csa ++ csb ++ csc ++ cs' return (L2.LetE (vr,[],tyc,L2.IfE boda bodb bodc) bod', ty', cs) LetE{} -> err $ "Expected let spine, encountered nested lets: " ++ sdoc ex0 LitE i -> do - (bod',ty',cs') <- inferExp (extendVEnv vr IntTy env) bod dest + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr IntTy env) bod dest (bod'',ty'',cs'') <- handleTrailingBindLoc vr (bod', ty', cs') fcs <- tryInRegion cs'' tryBindReg (L2.LetE (vr,[],IntTy,L2.LitE i) bod'', ty'', fcs) CharE i -> do - (bod',ty',cs') <- inferExp (extendVEnv vr CharTy env) bod dest + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr CharTy env) bod dest (bod'',ty'',cs'') <- handleTrailingBindLoc vr (bod', ty', cs') fcs <- tryInRegion cs'' tryBindReg (L2.LetE (vr,[],CharTy,L2.CharE i) bod'', ty'', fcs) FloatE i -> do - (bod',ty',cs') <- inferExp (extendVEnv vr FloatTy env) bod dest + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr FloatTy env) bod dest (bod'',ty'',cs'') <- handleTrailingBindLoc vr (bod', ty', cs') fcs <- tryInRegion cs'' tryBindReg (L2.LetE (vr,[],FloatTy,L2.FloatE i) bod'', ty'', fcs) @@ -933,7 +991,7 @@ inferExp env@FullEnv{dataDefs} ex0 dest = r <- lift $ lift $ gensym "r" loc <- lift $ lift $ freshLocVar "mmap_file" let rhs' = PrimAppE (ReadPackedFile fp tycon (Just r) (PackedTy tycon loc)) [] - (bod',ty',cs') <- inferExp (extendVEnv vr (PackedTy tycon loc) env) bod dest + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr (PackedTy tycon loc) env) bod dest (bod'',ty'',cs'') <- handleTrailingBindLoc vr (bod', ty', cs') fcs <- tryInRegion cs' tryBindReg ( Ext$ LetRegionE (MMapR r) Undefined Nothing $ Ext $ LetLocE loc (StartOfRegionLE (MMapR r)) $ @@ -942,8 +1000,8 @@ inferExp env@FullEnv{dataDefs} ex0 dest = PrimAppE (WritePackedFile fp _ty0) [VarE packd] -> do - bty' <- lift $ lift $ convertTy bty - (bod',ty',cs') <- inferExp (extendVEnv vr bty' env) bod dest + bty' <- lift $ lift $ convertTy ddefs False bty + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr bty' env) bod dest (bod'',ty'',cs'') <- handleTrailingBindLoc vr (bod', ty', cs') fcs <- tryInRegion cs'' let (PackedTy tycon loc) = lookupVEnv packd env @@ -956,43 +1014,43 @@ inferExp env@FullEnv{dataDefs} ex0 dest = PrimAppE (ReadArrayFile fp ty0) [] -> do - ty <- lift $ lift $ convertTy bty - ty0' <- lift $ lift $ convertTy ty0 - (bod',ty',cs') <- inferExp (extendVEnv vr ty env) bod dest + ty <- lift $ lift $ convertTy ddefs False bty + ty0' <- lift $ lift $ convertTy ddefs False ty0 + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr ty env) bod dest (bod'',ty'',cs''') <- handleTrailingBindLoc vr (bod', ty', cs') fcs <- tryInRegion cs''' tryBindReg (L2.LetE (vr,[],ty, L2.PrimAppE (ReadArrayFile fp ty0') []) bod'', ty'', fcs) -- Don't process the StartOf or SizeOf operation at all, just recur through it PrimAppE RequestSizeOf [(VarE v)] -> do - (bod',ty',cs') <- inferExp (extendVEnv vr CursorTy env) bod dest + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr CursorTy env) bod dest return (L2.LetE (vr,[],IntTy, L2.PrimAppE RequestSizeOf [(L2.VarE v)]) bod', ty', cs') PrimAppE (DictInsertP dty) ls -> do - (e,ty,cs) <- inferExp env (PrimAppE (DictInsertP dty) ls) NoDest - (bod',ty',cs') <- inferExp (extendVEnv vr ty env) bod dest + (e,ty,cs) <- inferExp ddefs env (PrimAppE (DictInsertP dty) ls) NoDest + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr ty env) bod dest (bod'',ty'',cs''') <- handleTrailingBindLoc vr (bod',ty', L.nub $ cs' ++ cs) fcs <- tryInRegion cs''' tryBindReg (L2.LetE (vr,[],ty,e) bod'', ty'', fcs) PrimAppE (DictLookupP dty) ls -> do loc <- lift $ lift $ freshLocVar "dict" - (e,ty,cs) <- inferExp env (PrimAppE (DictLookupP dty) ls) $ SingleDest loc - (bod',ty',cs') <- inferExp (extendVEnv vr ty env) bod dest + (e,ty,cs) <- inferExp ddefs env (PrimAppE (DictLookupP dty) ls) $ SingleDest loc + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr ty env) bod dest (bod'',ty'',cs''') <- handleTrailingBindLoc vr (bod', ty', L.nub $ cs ++ cs') fcs <- tryInRegion cs''' tryBindReg (L2.LetE (vr,[],ty,e) bod'',ty'', fcs) PrimAppE (DictEmptyP dty) ls -> do - (e,ty,cs) <- inferExp env (PrimAppE (DictEmptyP dty) ls) NoDest - (bod',ty',cs') <- inferExp (extendVEnv vr ty env) bod dest + (e,ty,cs) <- inferExp ddefs env (PrimAppE (DictEmptyP dty) ls) NoDest + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr ty env) bod dest (bod'',ty'',cs''') <- handleTrailingBindLoc vr (bod',ty',L.nub $ cs' ++ cs) fcs <- tryInRegion cs''' tryBindReg (L2.LetE (vr,[],ty,e) bod'', ty'', fcs) PrimAppE (DictHasKeyP dty) ls -> do - (e,ty,cs) <- inferExp env (PrimAppE (DictHasKeyP dty) ls) NoDest - (bod',ty',cs') <- inferExp (extendVEnv vr ty env) bod dest + (e,ty,cs) <- inferExp ddefs env (PrimAppE (DictHasKeyP dty) ls) NoDest + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr ty env) bod dest (bod'',ty'',cs''') <- handleTrailingBindLoc vr (bod',ty',L.nub $ cs' ++ cs) fcs <- tryInRegion cs''' tryBindReg (L2.LetE (vr,[],ty,e) bod'', ty'', fcs) @@ -1000,45 +1058,45 @@ inferExp env@FullEnv{dataDefs} ex0 dest = -- Special case for VSortP because we don't want to lookup fp in -- the type environment. PrimAppE p@(VSortP ty) [VarE ls, VarE fp] -> do - lsrec <- mapM (\e -> inferExp env e NoDest) [VarE ls] - ty <- lift $ lift $ convertTy bty - (bod',ty',cs') <- inferExp (extendVEnv vr ty env) bod dest + lsrec <- mapM (\e -> inferExp ddefs env e NoDest) [VarE ls] + ty <- lift $ lift $ convertTy ddefs False bty + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr ty env) bod dest let ls' = [a | (a,_,_) <- lsrec] ++ [VarE fp] cs'' = concat $ [c | (_,_,c) <- lsrec] (bod'',ty'',cs''') <- handleTrailingBindLoc vr (bod', ty', L.nub $ cs' ++ cs'') fcs <- tryInRegion cs''' - p' <- lift $ lift $ prim p + p' <- lift $ lift $ prim ddefs p tryBindReg (L2.LetE (vr,[],ty, L2.PrimAppE p' ls') bod'', ty'', fcs) PrimAppE p ls -> do - lsrec <- mapM (\e -> inferExp env e NoDest) ls - ty <- lift $ lift $ convertTy bty - (bod',ty',cs') <- inferExp (extendVEnv vr ty env) bod dest + lsrec <- mapM (\e -> inferExp ddefs env e NoDest) ls + ty <- lift $ lift $ convertTy ddefs False bty + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr ty env) bod dest let ls' = [a | (a,_,_) <- lsrec] cs'' = concat $ [c | (_,_,c) <- lsrec] (bod'',ty'',cs''') <- handleTrailingBindLoc vr (bod', ty', L.nub $ cs' ++ cs'') fcs <- tryInRegion cs''' - p' <- lift $ lift $ prim p + p' <- lift $ lift $ prim ddefs p tryBindReg (L2.LetE (vr,[],ty, L2.PrimAppE p' ls') bod'', ty'', fcs) DataConE _loc k ls -> do loc <- lift $ lift $ freshLocVar "datacon" - (rhs',rty,rcs) <- inferExp env (DataConE () k ls) $ SingleDest loc - (bod',ty',cs') <- inferExp (extendVEnv vr (PackedTy (getTyOfDataCon dataDefs k) loc) env) bod dest + (rhs',rty,rcs) <- inferExp ddefs env (DataConE () k ls) $ SingleDest loc + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr (PackedTy (getTyOfDataCon dataDefs k) loc) env) bod dest (bod'',ty'',cs'') <- handleTrailingBindLoc vr (bod', ty', L.nub $ cs' ++ rcs) fcs <- tryInRegion cs'' tryBindReg (L2.LetE (vr,[],PackedTy (getTyOfDataCon dataDefs k) loc,rhs') bod'', ty', fcs) LitSymE x -> do - (bod',ty',cs') <- inferExp (extendVEnv vr IntTy env) bod dest + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr IntTy env) bod dest (bod'',ty'',cs'') <- handleTrailingBindLoc vr (bod', ty', cs') fcs <- tryInRegion cs'' tryBindReg (L2.LetE (vr,[],SymTy,L2.LitSymE x) bod'', ty'', fcs) ProjE i arg -> do - (e,ProdTy tys,cs) <- inferExp env arg NoDest - (bod',ty',cs') <- inferExp (extendVEnv vr (tys !! i) env) bod dest + (e,ProdTy tys,cs) <- inferExp ddefs env arg NoDest + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr (tys !! i) env) bod dest (bod'',ty'',cs'') <- handleTrailingBindLoc vr (bod', ty', L.nub $ cs ++ cs') fcs <- tryInRegion cs'' tryBindReg (L2.LetE (vr,[],tys !! i,L2.ProjE i e) bod'', @@ -1046,12 +1104,12 @@ inferExp env@FullEnv{dataDefs} ex0 dest = CaseE ex ls -> do loc <- lift $ lift $ freshLocVar "scrut" - (ex',ty2,cs) <- inferExp env ex (SingleDest loc) + (ex',ty2,cs) <- inferExp ddefs env ex (SingleDest loc) let src = locOfTy ty2 - rhsTy <- lift $ lift $ convertTy bty + rhsTy <- lift $ lift $ convertTy ddefs False bty caseDest <- destFromType' rhsTy pairs <- mapM (doCase dataDefs env src caseDest) ls - (bod',ty',cs') <- inferExp (extendVEnv vr rhsTy env) bod dest + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr rhsTy env) bod dest (bod'',ty'',cs'') <- handleTrailingBindLoc vr (bod', ty', cs') fcs <- tryInRegion cs'' let ccs = L.nub $ cs ++ fcs ++ (concat $ [c | (_,_,c) <- pairs]) @@ -1066,10 +1124,10 @@ inferExp env@FullEnv{dataDefs} ex0 dest = -- there's an assumption that things in a MkProdE will always be a -- variable reference (because of ANF), and the AppE/DataConE cases -- above will do the right thing. - lsrec <- mapM (\e -> inferExp env e NoDest) ls - ty@(ProdTy tys) <- lift $ lift $ convertTy bty + lsrec <- mapM (\e -> inferExp ddefs env e NoDest) ls + ty@(ProdTy tys) <- lift $ lift $ convertTy ddefs False bty let env' = extendVEnv vr ty env - (bod',ty',cs') <- inferExp env' bod dest + (bod',ty',cs') <- inferExp ddefs env' bod dest let als = [a | (a,_,_) <- lsrec] acs = concat $ [c | (_,_,c) <- lsrec] aty = [b | (_,b,_) <- lsrec] @@ -1093,8 +1151,8 @@ inferExp env@FullEnv{dataDefs} ex0 dest = tryBindReg (L2.LetE bind bod'', ty'', fcs) WithArenaE v e -> do - (e',ty,cs) <- inferExp (extendVEnv v ArenaTy env) e NoDest - (bod',ty',cs') <- inferExp (extendVEnv vr ty env) bod dest + (e',ty,cs) <- inferExp ddefs (extendVEnv v ArenaTy env) e NoDest + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr ty env) bod dest (bod'',ty'',cs'') <- handleTrailingBindLoc vr (bod', ty', L.nub $ cs ++ cs') vcs <- tryNeedRegion (locsInTy ty) ty'' cs'' fcs <- tryInRegion vcs @@ -1106,8 +1164,8 @@ inferExp env@FullEnv{dataDefs} ex0 dest = let subdest = case bty of PackedTy _ _ -> SingleDest lv _ -> NoDest - (e',ty,cs) <- inferExp env e subdest - (bod',ty',cs') <- inferExp (extendVEnv vr ty env) bod dest + (e',ty,cs) <- inferExp ddefs env e subdest + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr ty env) bod dest (bod'',ty'',cs'') <- handleTrailingBindLoc vr (bod', ty', L.nub $ cs ++ cs') vcs <- tryNeedRegion (locsInTy ty) ty'' cs'' fcs <- tryInRegion vcs @@ -1118,11 +1176,11 @@ inferExp env@FullEnv{dataDefs} ex0 dest = FoldE{} -> err$ "FoldE unsupported" Ext (L1.AddFixed cur i) -> do - (bod',ty',cs') <- inferExp (extendVEnv vr CursorTy env) bod dest + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr CursorTy env) bod dest return (L2.LetE (vr,[],L2.CursorTy,L2.Ext (L2.AddFixed cur i)) bod', ty', cs') Ext (L1.StartOfPkdCursor cur) -> do - (bod',ty',cs') <- inferExp (extendVEnv vr CursorTy env) bod dest + (bod',ty',cs') <- inferExp ddefs (extendVEnv vr CursorTy env) bod dest return (L2.LetE (vr,[],L2.CursorTy,L2.Ext (L2.StartOfPkdCursor cur)) bod', ty', cs') Ext(BenchE{}) -> error "inferExp: BenchE not handled." @@ -1136,7 +1194,7 @@ inferExp env@FullEnv{dataDefs} ex0 dest = retty :: Ty2 retty = outTy fn_ty e' = TimeIt (AppE fn locs args) (stripTyLocs retty) b - in inferExp env e' dest + in inferExp ddefs env e' dest -- TODO: Should eventually allow src and dest regions to be the same @@ -1678,8 +1736,8 @@ assumeEq a1 a2 = else err $ "Expected these to be equal: " ++ (show a1) ++ ", " ++ (show a2) -- | Convert a prim from L1 to L2 -prim :: Prim Ty1 -> PassM (Prim Ty2) -prim p = case p of +prim :: DDefs1 -> Prim Ty1 -> PassM (Prim Ty2) +prim ddefs p = case p of AddP -> return AddP SubP -> return SubP MulP -> return MulP @@ -1724,40 +1782,40 @@ prim p = case p of PrintSym -> return PrintSym ReadInt -> return PrintInt RequestSizeOf -> return RequestSizeOf - ErrorP sty ty -> convertTy ty >>= \ty -> return (ErrorP sty ty) - DictEmptyP dty -> convertTy dty >>= return . DictEmptyP - DictInsertP dty -> convertTy dty >>= return . DictInsertP - DictLookupP dty -> convertTy dty >>= return . DictLookupP - DictHasKeyP dty -> convertTy dty >>= return . DictHasKeyP - VAllocP elty -> convertTy elty >>= return . VAllocP - VFreeP elty -> convertTy elty >>= return . VFreeP - VFree2P elty -> convertTy elty >>= return . VFree2P - VLengthP elty -> convertTy elty >>= return . VLengthP - VNthP elty -> convertTy elty >>= return . VNthP - VSliceP elty -> convertTy elty >>= return . VSliceP - InplaceVUpdateP elty -> convertTy elty >>= return . InplaceVUpdateP - VConcatP elty -> convertTy elty >>= return . VConcatP - VSortP elty -> convertTy elty >>= return . VSortP - VMergeP elty -> convertTy elty >>= return . VMergeP - PDictAllocP k v -> convertTy k >>= (\k' -> convertTy v >>= \v' -> return $ PDictAllocP k' v') - PDictInsertP k v -> convertTy k >>= (\k' -> convertTy v >>= \v' -> return $ PDictInsertP k' v') - PDictLookupP k v -> convertTy k >>= (\k' -> convertTy v >>= \v' -> return $ PDictLookupP k' v') - PDictHasKeyP k v -> convertTy k >>= (\k' -> convertTy v >>= \v' -> return $ PDictHasKeyP k' v') - PDictForkP k v -> convertTy k >>= (\k' -> convertTy v >>= \v' -> return $ PDictForkP k' v') - PDictJoinP k v -> convertTy k >>= (\k' -> convertTy v >>= \v' -> return $ PDictJoinP k' v') - LLAllocP elty -> convertTy elty >>= return . LLAllocP - LLIsEmptyP elty -> convertTy elty >>= return . LLIsEmptyP - LLConsP elty -> convertTy elty >>= return . LLConsP - LLHeadP elty -> convertTy elty >>= return . LLHeadP - LLTailP elty -> convertTy elty >>= return . LLTailP - LLFreeP elty -> convertTy elty >>= return . LLFreeP - LLFree2P elty -> convertTy elty >>= return . LLFree2P - LLCopyP elty -> convertTy elty >>= return . LLCopyP - InplaceVSortP elty -> convertTy elty >>= return . InplaceVSortP + ErrorP sty ty -> convertTy ddefs False ty >>= \ty -> return (ErrorP sty ty) + DictEmptyP dty -> convertTy ddefs False dty >>= return . DictEmptyP + DictInsertP dty -> convertTy ddefs False dty >>= return . DictInsertP + DictLookupP dty -> convertTy ddefs False dty >>= return . DictLookupP + DictHasKeyP dty -> convertTy ddefs False dty >>= return . DictHasKeyP + VAllocP elty -> convertTy ddefs False elty >>= return . VAllocP + VFreeP elty -> convertTy ddefs False elty >>= return . VFreeP + VFree2P elty -> convertTy ddefs False elty >>= return . VFree2P + VLengthP elty -> convertTy ddefs False elty >>= return . VLengthP + VNthP elty -> convertTy ddefs False elty >>= return . VNthP + VSliceP elty -> convertTy ddefs False elty >>= return . VSliceP + InplaceVUpdateP elty -> convertTy ddefs False elty >>= return . InplaceVUpdateP + VConcatP elty -> convertTy ddefs False elty >>= return . VConcatP + VSortP elty -> convertTy ddefs False elty >>= return . VSortP + VMergeP elty -> convertTy ddefs False elty >>= return . VMergeP + PDictAllocP k v -> convertTy ddefs False k >>= (\k' -> convertTy ddefs False v >>= \v' -> return $ PDictAllocP k' v') + PDictInsertP k v -> convertTy ddefs False k >>= (\k' -> convertTy ddefs False v >>= \v' -> return $ PDictInsertP k' v') + PDictLookupP k v -> convertTy ddefs False k >>= (\k' -> convertTy ddefs False v >>= \v' -> return $ PDictLookupP k' v') + PDictHasKeyP k v -> convertTy ddefs False k >>= (\k' -> convertTy ddefs False v >>= \v' -> return $ PDictHasKeyP k' v') + PDictForkP k v -> convertTy ddefs False k >>= (\k' -> convertTy ddefs False v >>= \v' -> return $ PDictForkP k' v') + PDictJoinP k v -> convertTy ddefs False k >>= (\k' -> convertTy ddefs False v >>= \v' -> return $ PDictJoinP k' v') + LLAllocP elty -> convertTy ddefs False elty >>= return . LLAllocP + LLIsEmptyP elty -> convertTy ddefs False elty >>= return . LLIsEmptyP + LLConsP elty -> convertTy ddefs False elty >>= return . LLConsP + LLHeadP elty -> convertTy ddefs False elty >>= return . LLHeadP + LLTailP elty -> convertTy ddefs False elty >>= return . LLTailP + LLFreeP elty -> convertTy ddefs False elty >>= return . LLFreeP + LLFree2P elty -> convertTy ddefs False elty >>= return . LLFree2P + LLCopyP elty -> convertTy ddefs False elty >>= return . LLCopyP + InplaceVSortP elty -> convertTy ddefs False elty >>= return . InplaceVSortP GetNumProcessors -> pure GetNumProcessors ReadPackedFile{} -> err $ "Can't handle this primop yet in InferLocations:\n"++show p ReadArrayFile{} -> err $ "Can't handle this primop yet in InferLocations:\n"++show p - WritePackedFile fp ty -> convertTy ty >>= return . (WritePackedFile fp) + WritePackedFile fp ty -> convertTy ddefs False ty >>= return . (WritePackedFile fp) SymSetEmpty{} -> return SymSetEmpty SymSetInsert{} -> return SymSetInsert SymSetContains{} -> return SymSetContains