Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 325 additions & 12 deletions pytential/symbolic/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,15 @@ def map_common_subexpression(self, expr: p.CommonSubexpression):
if child is expr.child:
return expr

if isinstance(expr, (pp.InterpolationUnit, pp.DistributeInterpolation)):
return type(expr)(
child,
expr.prefix,
expr.scope,
**expr.get_extra_properties())

return pp.cse(
child,
expr.prefix,
expr.scope)
child, expr.prefix, expr.scope)

# }}}

Expand Down Expand Up @@ -758,10 +763,29 @@ class EarlyInterpolationAdder(
"""
from_dd: DOFDescriptor
to_dd: DOFDescriptor
variable_from_dd: DOFDescriptor | None = None

@override
def map_variable(self, expr: p.Variable):
return pp.interpolate(expr, self.from_dd, self.to_dd)
from_dd = self.from_dd
if self.variable_from_dd is not None:
from_dd = self.variable_from_dd
return pp.interpolate(expr, from_dd, self.to_dd)

@override
def map_subscript(self, expr: p.Subscript):
if isinstance(expr.aggregate, p.Variable):
from_dd = self.from_dd
if self.variable_from_dd is not None:
from_dd = self.variable_from_dd
return pp.interpolate(expr, from_dd, self.to_dd)

return super().map_subscript(expr)

def map_q_weight(self, expr: pp.QWeight):
raise ValueError(
"EarlyInterpolationAdder must not interpolate a QWeight: "
"quadrature weights are intrinsic to their discretization stage")

@override
def map_call(self,
Expand Down Expand Up @@ -793,6 +817,277 @@ def map_common_subexpression_uncached(self,
**expr.get_extra_properties())


class _QWeightCollector(Collector[pp.QWeight]):
@override
def map_q_weight(self, expr: pp.QWeight):
return {expr}


class _VariableCollector(Collector[p.Variable]):
@override
def map_variable(self, expr: p.Variable):
return {expr}

@override
def map_call(self, expr: p.Call):
result: set[p.Variable] = set()
for child in expr.parameters:
result |= self.rec(child)
return result

@override
def map_call_with_kwargs(self, expr: p.CallWithKwargs):
result: set[p.Variable] = set()
for child in expr.parameters:
result |= self.rec(child)
for child in expr.kw_parameters.values():
result |= self.rec(child)
return result


class _DensityInterpolationUnitCollector(Collector[p.CommonSubexpression]):
@override
def map_common_subexpression(self, expr: p.CommonSubexpression):
if isinstance(expr, pp.InterpolationUnit):
return {expr}

return self.rec(expr.child)


class _DistributedInterpolationCollector(Collector[pp.DistributeInterpolation]):
@override
def map_common_subexpression(self, expr: p.CommonSubexpression):
if isinstance(expr, pp.DistributeInterpolation):
return {expr}

return self.rec(expr.child)


class _DensityInterpolationSplitter:
def __init__(self,
preprocessor: InterpolationPreprocessor,
geometry_from_dd: DOFDescriptor,
variable_from_dd: DOFDescriptor,
to_dd: DOFDescriptor,
) -> None:
self.preprocessor = preprocessor
self.variable_from_dd = variable_from_dd
self.to_dd = to_dd
self.geometry_adder = EarlyInterpolationAdder(
geometry_from_dd, to_dd, variable_from_dd=variable_from_dd)
self.qweight_collector = _QWeightCollector()
self.variable_collector = _VariableCollector()
self.interpolation_unit_collector = _DensityInterpolationUnitCollector()

# Marked densities may contain both geometry quantities and density
# unknowns. For example, normal*u/sqrt_jac_q_weight should become
# normal formed from geometry_from_dd and interpolated to to_dd, multiplied
# by Interp(u/sqrt_jac_q_weight) from variable_from_dd to to_dd. The
# QWeight-scaled unknown is kept atomic.

def __call__(self, expr: ArithmeticExpression) -> ArithmeticExpression:
return self.rec_arith(expr)

def _contains_qweight(self, expr: ArithmeticExpression) -> bool:
return bool(self.qweight_collector(expr))

def _contains_variable(self, expr: ArithmeticExpression) -> bool:
return bool(self.variable_collector(expr))

def _contains_interpolation_unit(self, expr: ArithmeticExpression) -> bool:
return bool(self.interpolation_unit_collector(expr))

def _is_interpolation_unit(self, expr: Expression) -> bool:
return isinstance(expr, pp.InterpolationUnit)

def _is_distributed_interpolation(self, expr: Expression) -> bool:
return isinstance(expr, pp.DistributeInterpolation)

def _is_pure_geometry(self, expr: ArithmeticExpression) -> bool:
return not (
self._contains_qweight(expr)
or self._contains_variable(expr)
or self._contains_interpolation_unit(expr))

def _is_variable_leaf(self, expr: ArithmeticExpression) -> bool:
return (
isinstance(expr, p.Variable)
or (
isinstance(expr, p.Subscript)
and isinstance(expr.aggregate, p.Variable)))

def _factorize(self, expr: ArithmeticExpression) -> ArithmeticExpression:
return self.geometry_adder.rec_arith(
self.preprocessor.rec_arith(
self.preprocessor.tagger.rec_arith(expr)))

def _interpolate_as_unit(self,
expr: ArithmeticExpression,
) -> ArithmeticExpression:
return pp.interpolate(
self.preprocessor.rec_arith(expr),
self.variable_from_dd,
self.to_dd)

def _distribute_interpolation(
self,
expr: pp.DistributeInterpolation,
) -> ArithmeticExpression:
return self.rec_arith(expr.child)

def _flatten_product(self,
expr: ArithmeticExpression,
) -> list[ArithmeticExpression]:
if isinstance(expr, p.Product):
result = []
for child in expr.children:
result.extend(self._flatten_product(child))
return result

return [expr]

def _flatten_distributed_product(self,
expr: ArithmeticExpression,
) -> list[ArithmeticExpression]:
result = []
for factor in self._flatten_product(expr):
if self._is_distributed_interpolation(factor):
result.extend(self._flatten_distributed_product(factor.child))
else:
result.append(factor)

return result

def _make_product(self,
factors: Iterable[ArithmeticExpression],
) -> ArithmeticExpression:
factors = tuple(factors)
if not factors:
return 1
if len(factors) == 1:
return factors[0]

return p.Product(factors)

def _partition_factors(self,
factors: Iterable[ArithmeticExpression],
) -> tuple[list[ArithmeticExpression], list[ArithmeticExpression]]:
geometry_factors = []
residual_factors = []

for factor in factors:
if self._is_pure_geometry(factor):
geometry_factors.append(self._factorize(factor))
else:
residual_factors.append(factor)

return geometry_factors, residual_factors

def rec_arith(self, expr: ArithmeticExpression) -> ArithmeticExpression:
if self._is_interpolation_unit(expr):
return self._interpolate_as_unit(expr)
if self._is_distributed_interpolation(expr):
return self._distribute_interpolation(expr)

if isinstance(expr, p.CommonSubexpression):
result = self.rec_arith(expr.child)
if result is expr.child:
return expr

return type(expr)(
result,
expr.prefix,
expr.scope,
**expr.get_extra_properties())

if isinstance(expr, p.Sum):
return self.map_sum(expr)

if isinstance(expr, p.Product):
return self.map_product(expr)

if isinstance(expr, p.Quotient):
return self.map_quotient(expr)

if self._contains_qweight(expr) or self._contains_interpolation_unit(expr):
return self._interpolate_as_unit(expr)

if self._contains_variable(expr) and not self._is_variable_leaf(expr):
return self._interpolate_as_unit(expr)

return self._factorize(expr)

def map_product(self, expr: p.Product) -> ArithmeticExpression:
factors = self._flatten_distributed_product(expr)
geometry_factors, residual_factors = self._partition_factors(factors)

if self._contains_qweight(expr):
result_factors = list(geometry_factors)
if residual_factors:
result_factors.append(
self._interpolate_as_unit(
self._make_product(residual_factors)))

return self._make_product(result_factors)

result_factors = list(geometry_factors)
if len(residual_factors) == 1:
result_factors.append(self.rec_arith(residual_factors[0]))
elif residual_factors:
result_factors.append(
self._interpolate_as_unit(
self._make_product(residual_factors)))

return self._make_product(result_factors)

def map_sum(self, expr: p.Sum) -> ArithmeticExpression:
children = tuple(self.rec_arith(child) for child in expr.children)
if all(
child is orig_child
for child, orig_child in zip(children, expr.children, strict=True)):
return expr

from pymbolic.primitives import flattened_sum
return flattened_sum(children)

def map_quotient(self, expr: p.Quotient) -> ArithmeticExpression:
if self._contains_qweight(expr):
if self._contains_qweight(expr.denominator):
numerator_factors = self._flatten_distributed_product(expr.numerator)
geometry_factors, residual_numerator_factors = (
self._partition_factors(numerator_factors))

residual = p.Quotient(
self._make_product(residual_numerator_factors),
expr.denominator)
return self._make_product([
*geometry_factors,
self._interpolate_as_unit(residual)])

if self._is_pure_geometry(expr.denominator):
return p.Quotient(
self._interpolate_as_unit(expr.numerator),
self._factorize(expr.denominator))

return self._interpolate_as_unit(expr)

if self._is_pure_geometry(expr.denominator):
return p.Quotient(
self.rec_arith(expr.numerator),
self._factorize(expr.denominator))

numerator_factors = self._flatten_distributed_product(expr.numerator)
geometry_factors, residual_numerator_factors = (
self._partition_factors(numerator_factors))

residual = p.Quotient(
self._make_product(residual_numerator_factors),
expr.denominator)
return self._make_product([
*geometry_factors,
self._interpolate_as_unit(residual)])


class InterpolationPreprocessor(IdentityMapper):
"""Handle expressions that require upsampling or downsampling by inserting
a :class:`~pytential.symbolic.primitives.Interpolation`. This is used to
Expand All @@ -804,6 +1099,12 @@ class InterpolationPreprocessor(IdentityMapper):
:attr:`~pytential.symbolic.dof_desc.QBX_SOURCE_QUAD_STAGE2`, if a
stage is not already assigned to the source descriptor.

Unmarked layer-potential densities are interpolated as whole units. Only
densities marked with :func:`pytential.symbolic.primitives.geo_density` or
:func:`pytential.symbolic.primitives.distribute_interpolation` are split.
These markers are only honored in layer-potential densities, not in kernel
arguments.

.. attribute:: from_discr_stage
.. automethod:: __init__
"""
Expand Down Expand Up @@ -852,15 +1153,27 @@ def map_int_g(self, expr: pp.IntG):
if not isinstance(lpot_source, QBXLayerPotentialSource):
return expr

from_dd = expr.source.to_stage1()
to_dd = from_dd.to_quad_stage2()
interp_adder = EarlyInterpolationAdder(from_dd, to_dd)
densities = tuple(
interp_adder.rec_arith(self.rec_arith(density))
for density in expr.densities)
variable_from_dd = expr.source.to_stage1()
to_dd = variable_from_dd.to_quad_stage2()

# stage1 density discretization for geometry can give wrong values for
# quantities that depend on the stage2 element parameterization, such as
# area_element.
geometry_from_dd = variable_from_dd.copy(discr_stage=self.from_discr_stage)

density_splitter = _DensityInterpolationSplitter(
self, geometry_from_dd, variable_from_dd, to_dd)
distributed_interpolation_collector = _DistributedInterpolationCollector()

def process_density(density: ArithmeticExpression) -> ArithmeticExpression:
if distributed_interpolation_collector(density):
return density_splitter(density)
return pp.interpolate(
self.rec_arith(density), variable_from_dd, to_dd)

densities = tuple(process_density(density) for density in expr.densities)

from_dd = from_dd.copy(discr_stage=self.from_discr_stage)
interp_adder = EarlyInterpolationAdder(from_dd, to_dd)
interp_adder = EarlyInterpolationAdder(geometry_from_dd, to_dd)
kernel_arguments = constantdict({
name: componentwise(
lambda aexpr: interp_adder.rec_arith(
Expand Down
Loading
Loading