diff --git a/pytential/symbolic/mappers.py b/pytential/symbolic/mappers.py index 43dbb85a4..8e021209e 100644 --- a/pytential/symbolic/mappers.py +++ b/pytential/symbolic/mappers.py @@ -366,10 +366,15 @@ def map_common_subexpression(self, expr: p.CommonSubexpression): if child is expr.child: return expr + if isinstance(expr, (pp.InterpolationUnit, pp.DistributeInterpolation)): + return type(expr)( + child, + expr.prefix, + expr.scope, + **expr.get_extra_properties()) + return pp.cse( - child, - expr.prefix, - expr.scope) + child, expr.prefix, expr.scope) # }}} @@ -758,10 +763,29 @@ class EarlyInterpolationAdder( """ from_dd: DOFDescriptor to_dd: DOFDescriptor + variable_from_dd: DOFDescriptor | None = None @override def map_variable(self, expr: p.Variable): - return pp.interpolate(expr, self.from_dd, self.to_dd) + from_dd = self.from_dd + if self.variable_from_dd is not None: + from_dd = self.variable_from_dd + return pp.interpolate(expr, from_dd, self.to_dd) + + @override + def map_subscript(self, expr: p.Subscript): + if isinstance(expr.aggregate, p.Variable): + from_dd = self.from_dd + if self.variable_from_dd is not None: + from_dd = self.variable_from_dd + return pp.interpolate(expr, from_dd, self.to_dd) + + return super().map_subscript(expr) + + def map_q_weight(self, expr: pp.QWeight): + raise ValueError( + "EarlyInterpolationAdder must not interpolate a QWeight: " + "quadrature weights are intrinsic to their discretization stage") @override def map_call(self, @@ -793,6 +817,277 @@ def map_common_subexpression_uncached(self, **expr.get_extra_properties()) +class _QWeightCollector(Collector[pp.QWeight]): + @override + def map_q_weight(self, expr: pp.QWeight): + return {expr} + + +class _VariableCollector(Collector[p.Variable]): + @override + def map_variable(self, expr: p.Variable): + return {expr} + + @override + def map_call(self, expr: p.Call): + result: set[p.Variable] = set() + for child in expr.parameters: + result |= self.rec(child) + return result + + @override + def map_call_with_kwargs(self, expr: p.CallWithKwargs): + result: set[p.Variable] = set() + for child in expr.parameters: + result |= self.rec(child) + for child in expr.kw_parameters.values(): + result |= self.rec(child) + return result + + +class _DensityInterpolationUnitCollector(Collector[p.CommonSubexpression]): + @override + def map_common_subexpression(self, expr: p.CommonSubexpression): + if isinstance(expr, pp.InterpolationUnit): + return {expr} + + return self.rec(expr.child) + + +class _DistributedInterpolationCollector(Collector[pp.DistributeInterpolation]): + @override + def map_common_subexpression(self, expr: p.CommonSubexpression): + if isinstance(expr, pp.DistributeInterpolation): + return {expr} + + return self.rec(expr.child) + + +class _DensityInterpolationSplitter: + def __init__(self, + preprocessor: InterpolationPreprocessor, + geometry_from_dd: DOFDescriptor, + variable_from_dd: DOFDescriptor, + to_dd: DOFDescriptor, + ) -> None: + self.preprocessor = preprocessor + self.variable_from_dd = variable_from_dd + self.to_dd = to_dd + self.geometry_adder = EarlyInterpolationAdder( + geometry_from_dd, to_dd, variable_from_dd=variable_from_dd) + self.qweight_collector = _QWeightCollector() + self.variable_collector = _VariableCollector() + self.interpolation_unit_collector = _DensityInterpolationUnitCollector() + + # Marked densities may contain both geometry quantities and density + # unknowns. For example, normal*u/sqrt_jac_q_weight should become + # normal formed from geometry_from_dd and interpolated to to_dd, multiplied + # by Interp(u/sqrt_jac_q_weight) from variable_from_dd to to_dd. The + # QWeight-scaled unknown is kept atomic. + + def __call__(self, expr: ArithmeticExpression) -> ArithmeticExpression: + return self.rec_arith(expr) + + def _contains_qweight(self, expr: ArithmeticExpression) -> bool: + return bool(self.qweight_collector(expr)) + + def _contains_variable(self, expr: ArithmeticExpression) -> bool: + return bool(self.variable_collector(expr)) + + def _contains_interpolation_unit(self, expr: ArithmeticExpression) -> bool: + return bool(self.interpolation_unit_collector(expr)) + + def _is_interpolation_unit(self, expr: Expression) -> bool: + return isinstance(expr, pp.InterpolationUnit) + + def _is_distributed_interpolation(self, expr: Expression) -> bool: + return isinstance(expr, pp.DistributeInterpolation) + + def _is_pure_geometry(self, expr: ArithmeticExpression) -> bool: + return not ( + self._contains_qweight(expr) + or self._contains_variable(expr) + or self._contains_interpolation_unit(expr)) + + def _is_variable_leaf(self, expr: ArithmeticExpression) -> bool: + return ( + isinstance(expr, p.Variable) + or ( + isinstance(expr, p.Subscript) + and isinstance(expr.aggregate, p.Variable))) + + def _factorize(self, expr: ArithmeticExpression) -> ArithmeticExpression: + return self.geometry_adder.rec_arith( + self.preprocessor.rec_arith( + self.preprocessor.tagger.rec_arith(expr))) + + def _interpolate_as_unit(self, + expr: ArithmeticExpression, + ) -> ArithmeticExpression: + return pp.interpolate( + self.preprocessor.rec_arith(expr), + self.variable_from_dd, + self.to_dd) + + def _distribute_interpolation( + self, + expr: pp.DistributeInterpolation, + ) -> ArithmeticExpression: + return self.rec_arith(expr.child) + + def _flatten_product(self, + expr: ArithmeticExpression, + ) -> list[ArithmeticExpression]: + if isinstance(expr, p.Product): + result = [] + for child in expr.children: + result.extend(self._flatten_product(child)) + return result + + return [expr] + + def _flatten_distributed_product(self, + expr: ArithmeticExpression, + ) -> list[ArithmeticExpression]: + result = [] + for factor in self._flatten_product(expr): + if self._is_distributed_interpolation(factor): + result.extend(self._flatten_distributed_product(factor.child)) + else: + result.append(factor) + + return result + + def _make_product(self, + factors: Iterable[ArithmeticExpression], + ) -> ArithmeticExpression: + factors = tuple(factors) + if not factors: + return 1 + if len(factors) == 1: + return factors[0] + + return p.Product(factors) + + def _partition_factors(self, + factors: Iterable[ArithmeticExpression], + ) -> tuple[list[ArithmeticExpression], list[ArithmeticExpression]]: + geometry_factors = [] + residual_factors = [] + + for factor in factors: + if self._is_pure_geometry(factor): + geometry_factors.append(self._factorize(factor)) + else: + residual_factors.append(factor) + + return geometry_factors, residual_factors + + def rec_arith(self, expr: ArithmeticExpression) -> ArithmeticExpression: + if self._is_interpolation_unit(expr): + return self._interpolate_as_unit(expr) + if self._is_distributed_interpolation(expr): + return self._distribute_interpolation(expr) + + if isinstance(expr, p.CommonSubexpression): + result = self.rec_arith(expr.child) + if result is expr.child: + return expr + + return type(expr)( + result, + expr.prefix, + expr.scope, + **expr.get_extra_properties()) + + if isinstance(expr, p.Sum): + return self.map_sum(expr) + + if isinstance(expr, p.Product): + return self.map_product(expr) + + if isinstance(expr, p.Quotient): + return self.map_quotient(expr) + + if self._contains_qweight(expr) or self._contains_interpolation_unit(expr): + return self._interpolate_as_unit(expr) + + if self._contains_variable(expr) and not self._is_variable_leaf(expr): + return self._interpolate_as_unit(expr) + + return self._factorize(expr) + + def map_product(self, expr: p.Product) -> ArithmeticExpression: + factors = self._flatten_distributed_product(expr) + geometry_factors, residual_factors = self._partition_factors(factors) + + if self._contains_qweight(expr): + result_factors = list(geometry_factors) + if residual_factors: + result_factors.append( + self._interpolate_as_unit( + self._make_product(residual_factors))) + + return self._make_product(result_factors) + + result_factors = list(geometry_factors) + if len(residual_factors) == 1: + result_factors.append(self.rec_arith(residual_factors[0])) + elif residual_factors: + result_factors.append( + self._interpolate_as_unit( + self._make_product(residual_factors))) + + return self._make_product(result_factors) + + def map_sum(self, expr: p.Sum) -> ArithmeticExpression: + children = tuple(self.rec_arith(child) for child in expr.children) + if all( + child is orig_child + for child, orig_child in zip(children, expr.children, strict=True)): + return expr + + from pymbolic.primitives import flattened_sum + return flattened_sum(children) + + def map_quotient(self, expr: p.Quotient) -> ArithmeticExpression: + if self._contains_qweight(expr): + if self._contains_qweight(expr.denominator): + numerator_factors = self._flatten_distributed_product(expr.numerator) + geometry_factors, residual_numerator_factors = ( + self._partition_factors(numerator_factors)) + + residual = p.Quotient( + self._make_product(residual_numerator_factors), + expr.denominator) + return self._make_product([ + *geometry_factors, + self._interpolate_as_unit(residual)]) + + if self._is_pure_geometry(expr.denominator): + return p.Quotient( + self._interpolate_as_unit(expr.numerator), + self._factorize(expr.denominator)) + + return self._interpolate_as_unit(expr) + + if self._is_pure_geometry(expr.denominator): + return p.Quotient( + self.rec_arith(expr.numerator), + self._factorize(expr.denominator)) + + numerator_factors = self._flatten_distributed_product(expr.numerator) + geometry_factors, residual_numerator_factors = ( + self._partition_factors(numerator_factors)) + + residual = p.Quotient( + self._make_product(residual_numerator_factors), + expr.denominator) + return self._make_product([ + *geometry_factors, + self._interpolate_as_unit(residual)]) + + class InterpolationPreprocessor(IdentityMapper): """Handle expressions that require upsampling or downsampling by inserting a :class:`~pytential.symbolic.primitives.Interpolation`. This is used to @@ -804,6 +1099,12 @@ class InterpolationPreprocessor(IdentityMapper): :attr:`~pytential.symbolic.dof_desc.QBX_SOURCE_QUAD_STAGE2`, if a stage is not already assigned to the source descriptor. + Unmarked layer-potential densities are interpolated as whole units. Only + densities marked with :func:`pytential.symbolic.primitives.geo_density` or + :func:`pytential.symbolic.primitives.distribute_interpolation` are split. + These markers are only honored in layer-potential densities, not in kernel + arguments. + .. attribute:: from_discr_stage .. automethod:: __init__ """ @@ -852,15 +1153,27 @@ def map_int_g(self, expr: pp.IntG): if not isinstance(lpot_source, QBXLayerPotentialSource): return expr - from_dd = expr.source.to_stage1() - to_dd = from_dd.to_quad_stage2() - interp_adder = EarlyInterpolationAdder(from_dd, to_dd) - densities = tuple( - interp_adder.rec_arith(self.rec_arith(density)) - for density in expr.densities) + variable_from_dd = expr.source.to_stage1() + to_dd = variable_from_dd.to_quad_stage2() + + # stage1 density discretization for geometry can give wrong values for + # quantities that depend on the stage2 element parameterization, such as + # area_element. + geometry_from_dd = variable_from_dd.copy(discr_stage=self.from_discr_stage) + + density_splitter = _DensityInterpolationSplitter( + self, geometry_from_dd, variable_from_dd, to_dd) + distributed_interpolation_collector = _DistributedInterpolationCollector() + + def process_density(density: ArithmeticExpression) -> ArithmeticExpression: + if distributed_interpolation_collector(density): + return density_splitter(density) + return pp.interpolate( + self.rec_arith(density), variable_from_dd, to_dd) + + densities = tuple(process_density(density) for density in expr.densities) - from_dd = from_dd.copy(discr_stage=self.from_discr_stage) - interp_adder = EarlyInterpolationAdder(from_dd, to_dd) + interp_adder = EarlyInterpolationAdder(geometry_from_dd, to_dd) kernel_arguments = constantdict({ name: componentwise( lambda aexpr: interp_adder.rec_arith( diff --git a/pytential/symbolic/primitives.py b/pytential/symbolic/primitives.py index ccc4ec2d4..0cc932916 100644 --- a/pytential/symbolic/primitives.py +++ b/pytential/symbolic/primitives.py @@ -51,10 +51,11 @@ NablaComponent, ) from pymbolic.primitives import ( # noqa: N813 + CommonSubexpression, Variable as var, cse_scope as cse_scope_base, expr_dataclass, - make_common_subexpression as cse, + make_common_subexpression as _cse, make_sym_vector, ) from pymbolic.typing import ArithmeticExpression @@ -92,7 +93,7 @@ import modepy as mp from pymbolic.mapper.stringifier import StringifyMapper - from pymbolic.primitives import CommonSubexpression, Quotient + from pymbolic.primitives import Quotient __doc__ = """ @@ -231,6 +232,7 @@ .. autofunction:: area_element .. autofunction:: sqrt_jac_q_weight .. autofunction:: normal +.. autofunction:: geo_density .. autofunction:: mean_curvature .. autofunction:: first_fundamental_form .. autofunction:: second_fundamental_form @@ -306,6 +308,22 @@ Operators ^^^^^^^^^ +.. autofunction:: cse + +.. autoclass:: InterpolationUnit + :show-inheritance: + :undoc-members: + :members: mapper_method + +.. autofunction:: interpolation_unit + +.. autoclass:: DistributeInterpolation + :show-inheritance: + :undoc-members: + :members: mapper_method + +.. autofunction:: distribute_interpolation + .. autoclass:: Interpolation :show-inheritance: :undoc-members: @@ -404,7 +422,7 @@ "IsShapeClass", "QWeight", "nodes", "parametrization_derivative", "parametrization_derivative_matrix", "pseudoscalar", "area_element", - "sqrt_jac_q_weight", "normal", "mean_curvature", + "sqrt_jac_q_weight", "normal", "geo_density", "mean_curvature", "first_fundamental_form", "second_fundamental_form", "shape_operator", "expansion_radii", "expansion_centers", "h_max", "weights_and_area_elements", @@ -413,7 +431,8 @@ "ElementwiseMax", "integral", "Ones", "ones_vec", "area", "mean", "IterativeInverse", - "Interpolation", "interpolate", + "Interpolation", "interpolate", "InterpolationUnit", "interpolation_unit", + "DistributeInterpolation", "distribute_interpolation", "Derivative", @@ -527,6 +546,87 @@ class ErrorExpression(ExpressionNode): """The error message to raise when this expression is encountered.""" +@expr_dataclass() +class InterpolationUnit(CommonSubexpression): + """A common subexpression whose value should be interpolated as a unit. + + This is used for quantities whose expanded symbolic form is not a valid + place to insert interpolation. For example, the result of + :func:`tangential_to_xyz` is Cartesian, while its inputs can be coefficients + in an element-local tangential basis. + """ + + mapper_method = "map_common_subexpression" + + +@for_each_expression +def interpolation_unit( + operand: ArithmeticExpression, + prefix: str | None = "interp_unit", + scope: str = cse_scope.EVALUATION, + ) -> ArithmeticExpression: + """Mark *operand* as a quantity that should be interpolated as a whole.""" + + return InterpolationUnit(operand, prefix, scope) + + +@expr_dataclass() +class DistributeInterpolation(CommonSubexpression): + """A common subexpression whose source interpolation may be distributed. + + This is an opt-in marker for densities such as ``normal*h`` where the + geometry factors should be formed using the geometry-stage element + parameterization and the unknown variables are interpolated from the source + discretization. + + """ + + mapper_method = "map_common_subexpression" + + +@for_each_expression +def distribute_interpolation( + operand: ArithmeticExpression, + prefix: str | None = "distribute_interp", + scope: str = cse_scope.EVALUATION, + ) -> ArithmeticExpression: + """Mark *operand* so source-to-quad interpolation may be pushed into it.""" + + return DistributeInterpolation(operand, prefix, scope) + + +def cse( + expr: Operand, + prefix: str | None = None, + scope: str | None = None, + *, + wrap_vars: bool = True, + ) -> Operand: + """Wrap *expr* in a common subexpression. + + This is the usual :func:`pymbolic.primitives.make_common_subexpression`, + except that top-level rewrapping of an :class:`InterpolationUnit` or + :class:`DistributeInterpolation` updates its name hint while preserving the + marker. + """ + + if isinstance(expr, (InterpolationUnit, DistributeInterpolation)): + if prefix is None: + prefix = expr.prefix + if scope is None: + scope = expr.scope + return cast("Operand", type(expr)( + expr.child, + prefix, + scope, + **expr.get_extra_properties())) + + if scope is None: + scope = cse_scope.EVALUATION + + return cast("Operand", _cse(expr, prefix, scope, wrap_vars=wrap_vars)) + + def make_sym_mv(name: str, num_components: int) -> MultiVector[ArithmeticExpression]: return MultiVector(make_sym_vector(name, num_components)) @@ -894,6 +994,23 @@ def normal( scope=cse_scope.DISCRETIZATION) +def geo_density( + geometry: Operand, + density: Operand, + prefix: str | None = "geo_density", + ) -> Operand: + """Return ``geometry * density`` with source interpolation distributed. + + This is a helper for densities whose semantic form is a geometric + factor multiplying a density unknown. It should not be used for + coordinate transforms such as :func:`tangential_to_xyz`, whose inputs are + local coordinate components rather than independent scalar densities. + """ + + return cast("Operand", distribute_interpolation( + geometry * density, prefix)) + + def mean_curvature( ambient_dim: int, dim: int | None = None, @@ -2498,7 +2615,8 @@ def tangential_to_xyz( tonb = tangential_onb(ambient_dim, dofdesc=dofdesc) result = sum(tonb[:, i] * tangential_vec[i] for i in range(ambient_dim - 1)) - return cast("ObjectArray1D[ArithmeticExpression]", result) + return cast("ObjectArray1D[ArithmeticExpression]", + interpolation_unit(result, "tangential_to_xyz")) def project_to_tangential(