Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 191 additions & 27 deletions meshmode/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@
_PytestPytatoPyOpenCLArrayContextFactory,
register_pytest_array_context_factory,
)
from meshmode.transform_metadata import (
FaceMassOperatorTag,
MassInverseOperatorTag,
TensorProductDOFAxisTag,
TensorProductMassInverseOperatorTag,
TensorProductOperatorAxisTag,
TensorProductOperatorTag
)
from loopy.translation_unit import for_each_kernel

from loopy.tools import memoize_on_disk
Expand Down Expand Up @@ -901,6 +909,8 @@
# transforming it.
orig_knl = knl

import time

knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationFaceAxisTag,
"iface",
False,
Expand All @@ -915,6 +925,7 @@
"idof",
False,
orig_knl)

knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDimAxisTag,
"idim",
False,
Expand All @@ -924,15 +935,29 @@
"iface",
True,
orig_knl)

knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDOFAxisTag,
"idof",
True,
orig_knl)

knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDimAxisTag,
"idim",
True,
orig_knl)

knl = _fuse_loops_over_a_discr_entity(knl,
TensorProductDOFAxisTag,
"idof_tp",
False,
orig_knl)

knl = _fuse_loops_over_a_discr_entity(knl,
TensorProductDOFAxisTag,
"idof_tp",
True,
orig_knl)

return knl


Expand Down Expand Up @@ -1028,7 +1053,10 @@
for iname in kernel.all_inames()
if (kernel
.inames[iname]
.tags_of_type(DiscretizationDOFAxisTag))
.tags_of_type(DiscretizationDOFAxisTag) or
kernel
.inames[iname]
.tags_of_type(TensorProductDOFAxisTag))
}
iface_inames = {iname
for iname in kernel.all_inames()
Expand Down Expand Up @@ -1058,6 +1086,9 @@
raise NotImplementedError(f"The <iel> loop {insn.within_inames}"
" does not appear as a singly nested"
" loop.")

# {{{ <iel, idof> loop (simplicial)

elif ((len(insn.within_inames) == 2)
and (len(insn.within_inames & iel_inames) == 1)
and (len(insn.within_inames & idof_inames) == 1)):
Expand All @@ -1068,9 +1099,41 @@
for dof_insn in kernel.iname_to_insns()[idof]):
pass
else:
for dof_insn in kernel.iname_to_insns()[idof]:
if iel not in kernel.id_to_insn[dof_insn].within_inames:
print(f"_get_iel_to_idofs: {str(kernel.id_to_insn[dof_insn])=}")
raise NotImplementedError("The <iel,idof> loop "
f"'{insn.within_inames}' has the idof-loop"
" that's not nested within the iel-loop.")
# }}}

# {{{ <iel, idof, ...> loop (tensor product)

elif ((len(insn.within_inames) > 2)
and (len(insn.within_inames & iel_inames) == 1)
and (len(insn.within_inames & idof_inames) > 1)):

iel, = insn.within_inames & iel_inames
for idof in insn.within_inames & idof_inames:
iel_to_idofs[iel].add(idof)

if all((iel in kernel.id_to_insn[dof_insn].within_inames)
for dof_insn in kernel.iname_to_insns()[idof]):
pass
else:
for dof_insn in kernel.iname_to_insns()[idof]:
if iel not in kernel.id_to_insn[dof_insn].within_inames:
print("_get_iel_to_idofs: "
f"{str(kernel.id_to_insn[dof_insn])=}")
raise NotImplementedError("The <iel,idof> loop "
f"'{insn.within_inames}' has the "
"idof-loop that's not nested "
"within the iel-loop.")

# }}}

# {{{ <iel, idof, iface> loop

elif ((len(insn.within_inames) > 2)
and (len(insn.within_inames & iel_inames) == 1)
and (len(insn.within_inames & idof_inames) == 1)
Expand All @@ -1086,7 +1149,11 @@
else:
raise NotImplementedError("Could not fit into <iel,idof,iface>"
" loop nest pattern.")

# }}}

else:
print(f"_get_iel_to_idofs: {str(insn)=}")
raise NotImplementedError(f"Cannot fit loop nest '{insn.within_inames}'"
" into known set of loop-nest patterns.")

Expand Down Expand Up @@ -1126,8 +1193,10 @@


def _prepare_kernel_for_parallelization(kernel):
from meshmode.transform_metadata import TensorProductDOFAxisTag
discr_tag_to_prefix = {DiscretizationElementAxisTag: "iel",
DiscretizationDOFAxisTag: "idof",
TensorProductDOFAxisTag: "idof_tp",
DiscretizationDimAxisTag: "idim",
DiscretizationAmbientDimAxisTag: "idim",
DiscretizationTopologicalDimAxisTag: "idim",
Expand All @@ -1136,7 +1205,7 @@
import loopy as lp
from loopy.match import ObjTagged

# A mapping from inames that the instruction accesss to

Check warning on line 1208 in meshmode/array_context.py

View workflow job for this annotation

GitHub Actions / Typos

"accesss" should be "access".
# the instructions ids within that iname.
ensm_buckets = {}
vng = kernel.get_var_name_generator()
Expand Down Expand Up @@ -1252,11 +1321,13 @@


from pytools.persistent_dict import WriteOncePersistentDict
from pytato.analysis import PytatoKeyBuilder
from pytato.analysis import PytatoKeyBuilder, get_num_nodes

class FusionContractorArrayContext(
SingleGridWorkBalancingPytatoArrayContext):

t_units = []

def __init__(
self, queue: "cl.CommandQueue", allocator=None, *,
use_memory_pool: Optional[bool] = None,
Expand All @@ -1282,6 +1353,8 @@
def transform_dag(self, dag):
import pytato as pt

initial_node_count = get_num_nodes(dag)

# {{{ Remove FEMEinsumTags that might have been propagated

# TODO: Is this too hacky?
Expand Down Expand Up @@ -1355,6 +1428,28 @@
with ProcessLogger(logger, "transform_dag.deduplicate_data_wrappers"):
dag = pt.transform.deduplicate_data_wrappers(dag)

# {{{ freeze and thaw tensor product operators

# FIXME: this is a hack
def thaw_freeze_tp_operators(expr):
if isinstance(expr, pt.Einsum) and \
expr.tags_of_type(TensorProductOperatorTag):
ref_mass_inv, stiff_t = expr.args
data = self.to_numpy(ref_mass_inv) @ self.to_numpy(stiff_t)
axis_tags = (TensorProductOperatorAxisTag(),)
return (self.from_numpy(data).copy(
axes=(
pt.Axis(tags=frozenset(axis_tags)),
pt.Axis(tags=frozenset(axis_tags))
)
).tagged(TensorProductOperatorTag())
.tagged(pt.tags.PrefixNamed("diff_op")))
return expr

dag = pt.transform.map_and_copy(dag, thaw_freeze_tp_operators)

# }}}

# {{{ get rid of copies for different views of a cl-array

def eliminate_reshapes_of_data_wrappers(ary):
Expand All @@ -1379,12 +1474,24 @@
expr,
"ifj,fej,fej->ei")):
mat, jac, vec = expr.args
return (pt.einsum("ifj,fej,fej->ei",
mat,
jac,
vec.tagged(pt.tags.ImplStored()))
.tagged((pt.tags.ImplStored(),
pt.tags.PrefixNamed("face_mass"))))
if mat.tags_of_type(FaceMassOperatorTag):
return (pt.einsum("ifj,fej,fej->ei",
mat,
jac,
vec.tagged(pt.tags.ImplStored()))
.tagged((pt.tags.ImplStored(),
pt.tags.PrefixNamed("face_mass_result"))))
elif (isinstance(expr, pt.Einsum)
and pt.analysis.is_einsum_similar_to_subscript(
expr,
"ifj,fej->ei")):
mat, vec = expr.args
if mat.tags_of_type(FaceMassOperatorTag):
return (pt.einsum("ifj,fej->ei",
mat,
vec.tagged(pt.tags.ImplStored()))
.tagged((pt.tags.ImplStored(),
pt.tags.PrefixNamed("face_mass_result"))))
else:
return expr

Expand All @@ -1398,17 +1505,41 @@
# {{{ materialize inverse mass inputs

def materialize_inverse_mass_inputs(expr):
def is_tp_einsum(expr):
if pt.analysis.is_einsum_similar_to_subscript(
expr, "il,eljk->eijk"):
return True
elif pt.analysis.is_einsum_similar_to_subscript(
expr, "jl,eilk->eijk"):
return True
elif pt.analysis.is_einsum_similar_to_subscript(
expr, "kl,eijl->eijk"):
return True
return False

if (isinstance(expr, pt.Einsum)
and pt.analysis.is_einsum_similar_to_subscript(
expr,
"ei,ij,ej->ei")):
arg1, arg2, arg3 = expr.args
if not arg3.tags_of_type(pt.tags.PrefixNamed):
arg3 = arg3.tagged(pt.tags.PrefixNamed("mass_inv_inp"))
if not arg3.tags_of_type(pt.tags.ImplStored):
arg3 = arg3.tagged(pt.tags.ImplStored())

return expr.copy(args=(arg1, arg2, arg3))
"ij,ej->ei")):
mat, vec = expr.args
if mat.tags_of_type(MassInverseOperatorTag):
if not vec.tags_of_type(pt.tags.PrefixNamed):
vec = vec.tagged(pt.tags.PrefixNamed("input_vec"))
if not vec.tags_of_type(pt.tags.ImplStored):
vec = vec.tagged(pt.tags.ImplStored())

return expr.copy(args=(mat, vec))

elif (isinstance(expr, pt.Einsum) and is_tp_einsum(expr)):
mat, vec = expr.args
if mat.tags_of_type(TensorProductMassInverseOperatorTag):
if not vec.tags_of_type(pt.tags.PrefixNamed):
vec = vec.tagged(pt.tags.PrefixNamed("input_vec_tp"))
if not vec.tags_of_type(pt.tags.ImplStored):
vec = vec.tagged(pt.tags.ImplStored())

return expr.copy(args=(mat, vec))

else:
return expr

Expand Down Expand Up @@ -1638,6 +1769,15 @@

# }}}

final_node_count = get_num_nodes(dag)
with ProcessLogger(logger, "final node count"):
logger.info(
"Final DAG size: %d nodes, started with %d nodes, %s %d nodes",
final_node_count, initial_node_count,
("added" if initial_node_count < final_node_count else
"removed"), abs(initial_node_count - final_node_count)
)

return dag

def transform_loopy_program(self, t_unit):
Expand All @@ -1660,6 +1800,7 @@
# from loopy.transform.instruction import simplify_indices
# t_unit = simplify_indices(t_unit)

knl = t_unit.default_entrypoint

logger.info(f"Transforming kernel '{knl.name}' with {len(knl.instructions)} statements.")

Expand Down Expand Up @@ -1820,7 +1961,8 @@
t_unit = t_unit.with_kernel(knl)
del knl

if False and t_unit.default_entrypoint.tags_of_type(FromArrayContextCompile):
if False and t_unit.default_entrypoint.tags_of_type(
FromArrayContextCompile):
# FIXME: Enable this branch, WIP for now and hence disabled it.
from loopy.match import ObjTagged
import feinsum as fnsm
Expand Down Expand Up @@ -1863,17 +2005,40 @@
knl = t_unit.default_entrypoint
for iel, idofs in sorted(iel_to_idofs.items()):
if idofs:
nunit_dofs = {knl.get_constant_iname_length(idof)
for idof in idofs}
idof, = idofs
if len(idofs) == 1:
nunit_dofs = {
knl.get_constant_iname_length(idof)
for idof in idofs
}
l_one, l_zero = _get_group_size_for_dof_array_loop(
nunit_dofs)

idof, = idofs

knl = lp.split_iname(knl, iel, l_one,
inner_tag="l.1",
outer_tag="g.0")
knl = lp.split_iname(knl, idof, l_zero,
inner_tag="for",
outer_tag="l.0")

l_one_size, l_zero_size = _get_group_size_for_dof_array_loop(
nunit_dofs)
else:
def idof_tp_sort_key(idof):
tag, = knl.inames[idof].tags_of_type(
TensorProductDOFAxisTag
)
return tag.iaxis

inames_to_tags = {iel: "g.0"}
inames_to_tags.update({
idof: f"l.{i}"
for i, idof in enumerate(
sorted(idofs, key=idof_tp_sort_key)[:-1]
)
})

knl = lp.tag_inames(knl, inames_to_tags)

knl = lp.split_iname(knl, iel, l_one_size,
inner_tag="l.1", outer_tag="g.0")
knl = lp.split_iname(knl, idof, l_zero_size,
inner_tag="l.0", outer_tag="unr")
else:
knl = lp.split_iname(knl, iel, 32,
outer_tag="g.0", inner_tag="l.0")
Expand All @@ -1886,5 +2051,4 @@

return t_unit


# vim: foldmethod=marker
Loading
Loading