Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace pyrsistent.PMap, immutables.Map with immutabledict #884

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"pyopencl": ("https://documen.tician.de/pyopencl", None),
"cgen": ("https://documen.tician.de/cgen", None),
"pymbolic": ("https://documen.tician.de/pymbolic", None),
"pyrsistent": ("https://pyrsistent.readthedocs.io/en/latest/", None),
"immutabledict": ("https://immutabledict.corenting.fr/", None),
}

# Some modules need to import things just so that sphinx can resolve symbols in
Expand Down
8 changes: 4 additions & 4 deletions loopy/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
Union,
)

from immutables import Map
from immutabledict import immutabledict

from loopy.codegen.result import CodeGenerationResult
from loopy.library.reduction import ReductionOpFunction
Expand Down Expand Up @@ -207,7 +207,7 @@ class CodeGenerationState:
seen_functions: Set[SeenFunction]
seen_atomic_dtypes: Set[LoopyType]

var_subst_map: Map[str, Expression]
var_subst_map: immutabledict[str, Expression]
allow_complex: bool
callables_table: CallablesTable
is_entrypoint: bool
Expand Down Expand Up @@ -418,7 +418,7 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
seen_dtypes=seen_dtypes,
seen_functions=seen_functions,
seen_atomic_dtypes=seen_atomic_dtypes,
var_subst_map=Map(),
var_subst_map=immutabledict(),
allow_complex=allow_complex,
var_name_generator=kernel.get_var_name_generator(),
is_generating_device_code=False,
Expand Down Expand Up @@ -519,7 +519,7 @@ def diverge_callee_entrypoints(program):

new_callables[name] = clbl

return program.copy(callables_table=Map(new_callables))
return program.copy(callables_table=immutabledict(new_callables))


@dataclass(frozen=True)
Expand Down
4 changes: 2 additions & 2 deletions loopy/frontend/fortran/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from warnings import warn

import numpy as np
from immutables import Map
from immutabledict import immutabledict

import islpy as isl
from islpy import dim_type
Expand Down Expand Up @@ -331,7 +331,7 @@ def specialize_fortran_division(t_unit):

new_callables[name] = clbl

return t_unit.copy(callables_table=Map(new_callables))
return t_unit.copy(callables_table=immutabledict(new_callables))

# }}}

Expand Down
4 changes: 2 additions & 2 deletions loopy/kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from warnings import warn

import numpy as np
from immutables import Map
from immutabledict import immutabledict

import islpy as isl
from islpy import dim_type
Expand Down Expand Up @@ -178,7 +178,7 @@ class LoopKernel(Taggable):
Callable[["LoopKernel", str], Optional[Tuple[LoopyType, str]]]] = ()
linearization: Optional[Sequence[ScheduleItem]] = None
iname_slab_increments: Mapping[InameStr, Tuple[int, int]] = field(
default_factory=Map)
default_factory=immutabledict)
"""
A mapping from inames to (lower_incr,
upper_incr) tuples that will be separated out in the execution to generate
Expand Down
5 changes: 2 additions & 3 deletions loopy/kernel/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
THE SOFTWARE.
"""


from collections.abc import Mapping
from dataclasses import dataclass, replace
from enum import IntEnum
from sys import intern
Expand All @@ -43,7 +43,6 @@

import numpy # FIXME: imported as numpy to allow sphinx to resolve things
import numpy as np
from immutables import Map

from pymbolic import ArithmeticExpression, Variable
from pytools import ImmutableRecord
Expand Down Expand Up @@ -434,7 +433,7 @@ class _ArraySeparationInfo:
should be used to realize this array.
"""
sep_axis_indices_set: FrozenSet[int]
subarray_names: Map[Tuple[int, ...], str]
subarray_names: Mapping[Tuple[int, ...], str]


class ArrayArg(ArrayBase, KernelArgument):
Expand Down
4 changes: 2 additions & 2 deletions loopy/kernel/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2084,7 +2084,7 @@ def get_call_graph(t_unit, only_kernel_callables=False):

:arg t_unit: An instance of :class:`TranslationUnit`.
"""
from pyrsistent import pmap
from immutabledict import immutabledict

from loopy.kernel import KernelState

Expand All @@ -2111,7 +2111,7 @@ def get_call_graph(t_unit, only_kernel_callables=False):
call_graph[name] = clbl.get_called_callables(t_unit.callables_table,
recursive=False)

return pmap(call_graph)
return immutabledict(call_graph)

# }}}

Expand Down
12 changes: 5 additions & 7 deletions loopy/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from functools import partial

import numpy as np
from immutables import Map
from immutabledict import immutabledict

from pytools import ProcessLogger

Expand Down Expand Up @@ -191,7 +191,7 @@ def make_arrays_for_sep_arrays(kernel: LoopKernel) -> LoopKernel:

sep_info = _ArraySeparationInfo(
sep_axis_indices_set=sep_axis_indices_set,
subarray_names=Map({
subarray_names=immutabledict({
ind: vng(f"{arg.name}_s{'_'.join(str(i) for i in ind)}")
for ind in np.ndindex(*cast(List[int], sep_shape))}))

Expand Down Expand Up @@ -599,8 +599,6 @@ def map_call_with_kwargs(self, expr):
raise NotImplementedError

def __call__(self, expr, kernel, insn, assignees=None):
import immutables

from loopy.kernel.data import InstructionBase
from loopy.symbolic import ExpansionState, UncachedIdentityMapper
assert insn is None or isinstance(insn, InstructionBase)
Expand All @@ -610,7 +608,7 @@ def __call__(self, expr, kernel, insn, assignees=None):
kernel=kernel,
instruction=insn,
stack=(),
arg_context=immutables.Map()), assignees=assignees)
arg_context=immutabledict()), assignees=assignees)

def map_kernel(self, kernel):

Expand Down Expand Up @@ -744,7 +742,7 @@ def filter_reachable_callables(t_unit):
t_unit.entrypoints)
new_callables = {name: clbl for name, clbl in t_unit.callables_table.items()
if name in (reachable_function_ids | t_unit.entrypoints)}
return t_unit.copy(callables_table=Map(new_callables))
return t_unit.copy(callables_table=immutabledict(new_callables))


def _preprocess_single_kernel(kernel: LoopKernel, is_entrypoint: bool) -> LoopKernel:
Expand Down Expand Up @@ -869,7 +867,7 @@ def preprocess_program(t_unit: TranslationUnit) -> TranslationUnit:

new_callables[func_id] = in_knl_callable

t_unit = t_unit.copy(callables_table=Map(new_callables))
t_unit = t_unit.copy(callables_table=immutabledict(new_callables))

# }}}

Expand Down
4 changes: 2 additions & 2 deletions loopy/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
TypeVar,
)

from immutables import Map
from immutabledict import immutabledict

import islpy as isl
from pytools import ImmutableRecord, MinRecursionLimit, ProcessLogger
Expand Down Expand Up @@ -2482,7 +2482,7 @@ def linearize(t_unit: TranslationUnit) -> TranslationUnit:
else:
raise NotImplementedError(type(clbl))

return t_unit.copy(callables_table=Map(new_callables))
return t_unit.copy(callables_table=immutabledict(new_callables))


# vim: foldmethod=marker
4 changes: 2 additions & 2 deletions loopy/schedule/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from functools import cached_property, reduce
from typing import AbstractSet, Dict, FrozenSet, List, Sequence, Set, Tuple

from immutables import Map
from immutabledict import immutabledict
from typing_extensions import TypeAlias

import islpy as isl
Expand Down Expand Up @@ -1050,7 +1050,7 @@ def _get_iname_to_tree_node_id_from_partial_loop_nest_tree(
for iname in node:
iname_to_tree_node_id[iname] = node

return Map(iname_to_tree_node_id)
return immutabledict(iname_to_tree_node_id)


def get_loop_tree(kernel: LoopKernel) -> LoopTree:
Expand Down
16 changes: 8 additions & 8 deletions loopy/schedule/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,11 @@ def add_node(self, node: NodeT, parent: NodeT) -> Tree[NodeT]:

siblings = self._parent_to_children[parent]

_parent_to_children_mut = self._parent_to_children.mutate()
_parent_to_children_mut[parent] = (*siblings, node)
_parent_to_children_mut[node] = ()
parent_to_children_mut = self._parent_to_children.mutate()
parent_to_children_mut[parent] = (*siblings, node)
parent_to_children_mut[node] = ()

return Tree(_parent_to_children_mut.finish(),
return Tree(parent_to_children_mut.finish(),
self._child_to_parent.set(node, parent))

def replace_node(self, node: NodeT, new_node: NodeT) -> Tree[NodeT]:
Expand Down Expand Up @@ -223,11 +223,11 @@ def move_node(self, node: NodeT, new_parent: NodeT | None) -> Tree[NodeT]:
parents_new_children = tuple(frozenset(siblings) - frozenset([node]))
new_parents_children = (*self.children(new_parent), node)

_parent_to_children_mut = self._parent_to_children.mutate()
_parent_to_children_mut[parent] = parents_new_children
_parent_to_children_mut[new_parent] = new_parents_children
parent_to_children_mut = self._parent_to_children.mutate()
parent_to_children_mut[parent] = parents_new_children
parent_to_children_mut[new_parent] = new_parents_children

return Tree(_parent_to_children_mut.finish(),
return Tree(parent_to_children_mut.finish(),
self._child_to_parent.set(node, new_parent))

def __str__(self) -> str:
Expand Down
12 changes: 6 additions & 6 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
)
from warnings import warn

import immutables
import numpy as np
from immutabledict import immutabledict

import islpy as isl
import pymbolic.primitives # FIXME: also import by full name to allow sphinx to resolve
Expand Down Expand Up @@ -1036,12 +1036,12 @@ class ExpansionState(ImmutableRecord):
a dict representing current argument values
"""
def __init__(self, kernel, instruction, stack, arg_context):
if not isinstance(arg_context, immutables.Map):
if not isinstance(arg_context, immutabledict):
warn(f"Got a {type(arg_context)} for arg_context,"
" expected `immutables.Map`. This is deprecated"
" expected `immutabledict`. This is deprecated"
" and will result in an error from 2023.",
DeprecationWarning, stacklevel=2)
arg_context = immutables.Map(arg_context)
arg_context = immutabledict(arg_context)
super().__init__(kernel=kernel,
instruction=instruction,
stack=stack,
Expand Down Expand Up @@ -1274,7 +1274,7 @@ def make_new_arg_context(

from pymbolic.mapper.substitutor import make_subst_func
arg_subst_map = SubstitutionMapper(make_subst_func(arg_context))
return immutables.Map({
return immutabledict({
formal_arg_name: arg_subst_map(arg_value)
for formal_arg_name, arg_value in zip(arg_names, arguments)})

Expand Down Expand Up @@ -1317,7 +1317,7 @@ def __call__(self, expr, kernel, insn):
kernel=kernel,
instruction=insn,
stack=(),
arg_context=immutables.Map()))
arg_context=immutabledict()))

def map_instruction(self, kernel, insn):
return insn
Expand Down
4 changes: 2 additions & 2 deletions loopy/target/c/c_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
import logging
import os
import tempfile
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Callable, ClassVar, Optional, Sequence, Tuple, Union

import numpy as np
from codepy.jit import compile_from_string
from codepy.toolchain import GCCToolchain, ToolchainGuessError, guess_toolchain
from immutables import Map

from pytools import memoize_method
from pytools.codegen import CodeGenerator, Indentation
Expand Down Expand Up @@ -493,7 +493,7 @@ def get_wrapper_generator(self):

@memoize_method
def translation_unit_info(self,
arg_to_dtype: Optional[Map[str, LoopyType]] = None) -> _KernelInfo:
arg_to_dtype: Optional[Mapping[str, LoopyType]] = None) -> _KernelInfo:
t_unit = self.get_typed_and_scheduled_translation_unit(arg_to_dtype)

from loopy.codegen import generate_code_v2
Expand Down
28 changes: 14 additions & 14 deletions loopy/target/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
cast,
)

from immutables import Map
from immutabledict import immutabledict

from pymbolic import Variable, var
from pytools.codegen import CodeGenerator, Indentation
Expand Down Expand Up @@ -817,7 +817,7 @@ def check_for_required_array_arguments(self, input_args):
"your argument.")

def get_typed_and_scheduled_translation_unit_uncached(
self, arg_to_dtype: Optional[Map[str, LoopyType]]
self, arg_to_dtype: Optional[Mapping[str, LoopyType]]
) -> TranslationUnit:
t_unit = self.t_unit

Expand All @@ -827,15 +827,15 @@ def get_typed_and_scheduled_translation_unit_uncached(
# FIXME: This is not so nice. This transfers types from the
# subarrays of sep-tagged arrays to the 'main' array, because
# type inference fails otherwise.
with arg_to_dtype.mutate() as mm:
for name, sep_info in self.sep_info.items():
if entry_knl.arg_dict[name].dtype is None:
for sep_name in sep_info.subarray_names.values():
if sep_name in arg_to_dtype:
mm.set(name, arg_to_dtype[sep_name])
del mm[sep_name]
mm = dict(arg_to_dtype)
for name, sep_info in self.sep_info.items():
if entry_knl.arg_dict[name].dtype is None:
for sep_name in sep_info.subarray_names.values():
if sep_name in arg_to_dtype:
mm[name] = arg_to_dtype[sep_name]
del mm[sep_name]

arg_to_dtype = mm.finish()
arg_to_dtype = immutabledict(mm)

from loopy.kernel.tools import add_dtypes
t_unit = t_unit.with_kernel(add_dtypes(entry_knl, arg_to_dtype))
Expand All @@ -854,7 +854,7 @@ def get_typed_and_scheduled_translation_unit_uncached(
return t_unit

def get_typed_and_scheduled_translation_unit(
self, arg_to_dtype: Optional[Map[str, LoopyType]]
self, arg_to_dtype: Optional[Mapping[str, LoopyType]]
) -> TranslationUnit:
from loopy import CACHING_ENABLED

Expand All @@ -876,7 +876,7 @@ def get_typed_and_scheduled_translation_unit(

return t_unit

def arg_to_dtype(self, kwargs) -> Optional[Map[str, LoopyType]]:
def arg_to_dtype(self, kwargs) -> Optional[immutabledict[str, LoopyType]]:
if not self.has_runtime_typed_args:
return None

Expand All @@ -893,7 +893,7 @@ def arg_to_dtype(self, kwargs) -> Optional[Map[str, LoopyType]]:
else:
arg_to_dtype[arg_name] = NumpyType(dtype)

return Map(arg_to_dtype)
return immutabledict(arg_to_dtype)

# {{{ debugging aids

Expand All @@ -904,7 +904,7 @@ def get_highlighted_code(self, entrypoint, arg_to_dtype=None, code=None):

def get_code(
self, entrypoint: str,
arg_to_dtype: Optional[Map[str, LoopyType]] = None) -> str:
arg_to_dtype: Optional[Mapping[str, LoopyType]] = None) -> str:
kernel = self.get_typed_and_scheduled_translation_unit(arg_to_dtype)

from loopy.codegen import generate_code_v2
Expand Down
Loading
Loading