Skip to content

FEAT - Add chromatic psf and the corresponding tests (Issue #251)#253

Draft
MaxRonce wants to merge 5 commits into
GalSim-developers:mainfrom
MaxRonce:chromatic_psf
Draft

FEAT - Add chromatic psf and the corresponding tests (Issue #251)#253
MaxRonce wants to merge 5 commits into
GalSim-developers:mainfrom
MaxRonce:chromatic_psf

Conversation

@MaxRonce
Copy link
Copy Markdown

@MaxRonce MaxRonce commented May 22, 2026

Should go with the issue #251

python -m pytest -c /home/maxime/src/galsim/JAX-GalSim_chromaticPSF/pyproject.toml tests/GalSim/tests/test_sed.py tests/GalSim/tests/test_bandpass.py tests/GalSim/tests/test_chromatic.py -q

Tests seams to pass and results are similar to original galsim

image

Element by element testing and benchmark was done using this script :

bench_dsps_chromatic.py

image

The 10e-5 error on the bottom left plot seams to be only numerical bias due to the NWAVES for the fft and not an actuel implementation bias

Some class diagram before and after :

Regular Galsim

image

Jax Galsim

image workflow

@MaxRonce
Copy link
Copy Markdown
Author

Know issue

ChromaticSum incorrectly treated as separable when summing separable components with different spatial profiles

ChromaticSum currently marks itself as separable when all child components are separable:

self._separable = all(o._separable for o in self.obj_list)

This is not generally correct. A sum of separable chromatic objects is not necessarily separable.

Ex : physical bulge + disk model:

disk(x, y) * SED_disk(lambda) + bulge(x, y) * SED_bulge(lambda)

is not separable in general, even though each component is individually separable. It is separable only in special cases, e.g. if the spatial profiles are identical or if the SEDs are identical up to a scalar factor.

Observed behavior

When building a model like:

disk = disk_profile * sed_disk
bulge = bulge_profile * sed_bulge
source = disk + bulge
model = ChromaticConvolution([source, psf])
image = model.drawImage(bandpass, ...)

source becomes a ChromaticSum.

Because both disk and bulge are separable, ChromaticSum sets:

self._separable = True

Then ChromaticConvolution.drawImage() takes the optimized separable path and calls:

o._sed_value(wave)

on the ChromaticSum.

But ChromaticSum does not implement _sed_value(), so the base implementation raises:

NotImplementedError

Expected behavior

A sum of separable chromatic components should only be marked separable if the sum itself can be written as:

spatial_profile(x, y) * spectral_weight(lambda)

A generic bulge + disk model cannot be represented this way.

Fixes ??

Always treat ChromaticSum as non-separable:

class ChromaticSum(ChromaticObject):
    _separable = False

or inside __init__:

self._separable = False

This is simple and robust, but may lose optimization opportunities for rare truly separable sums.

More precise fix

Only mark ChromaticSum as separable when the sum can actually be factored into:

g(x, y) * h(lambda)

That would require implementing compatible _sed_value() and _static_spatial_profile() behavior only for cases where the children share an equivalent spatial profile or equivalent SED structure.

However, for a generic bulge + disk model with distinct spatial profiles and distinct SEDs, ChromaticSum should remain non-separable.

Current fix

For testing purpose I used this solution :

(disk + bulge) (*) psf = disk (*) psf + bulge (*) psf

In code:

disk_image = ChromaticConvolution([disk, psf]).drawImage(bandpass, ...)
bulge_image = ChromaticConvolution([bulge, psf]).drawImage(bandpass, ...)
image = disk_image + bulge_image

This is mathematically correct, but it would be preferable for ChromaticSum / ChromaticConvolution to handle the generic case properly

@MaxRonce MaxRonce changed the title FEAT - Add chromatic psf and the corresponding tests FEAT - Add chromatic psf and the corresponding tests (Issue #251) May 22, 2026
Comment thread docs/api/chromatic.rst
Comment on lines +1 to +2
Chromatic Profiles
==================
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Chromatic Profiles
==================
Wavelength-dependent Profiles
=============================

Let's use this title to match the upstream documentation. The galsim docs further use separate pages for the SEDS, Bandpasses, and Chromatic objects, but we can skip that bit.

Comment thread jax_galsim/bandpass.py


@register_pytree_node_class
class Bandpass:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not repeat doc strings from upstream and instead use the implements decorator, with any caveats put into the lax_description keyword.

Comment thread jax_galsim/bandpass.py

if self._wave.ndim != 1 or len(self._wave) < 2:
raise ValueError("wave must be a 1-D array with at least 2 elements.")
if len(self._throughput) != len(self._wave):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if len(self._throughput) != len(self._wave):
if self._throughput.shape != self._wave.shape:

It may be more robust to check the shape directly?

Comment thread jax_galsim/bandpass.py Outdated
Comment on lines +60 to +61
self._blue_limit = float(blue_limit) if blue_limit is not None else float(self._wave[0])
self._red_limit = float(red_limit) if red_limit is not None else float(self._wave[-1])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a utility function cast_to_float in jax_galsim.core.utils. I'd use that here.

Comment thread jax_galsim/bandpass.py
Comment on lines +63 to +72
# Precompute effective wavelength at construction time so it is a
# concrete Python float and can be used as a static value under JIT.
_w = jnp.linspace(self._blue_limit, self._red_limit, 512)
_t = jnp.interp(_w, self._wave, self._throughput)
_norm = jnp.trapezoid(_t, _w)
self._effective_wavelength_val = (
float(jnp.trapezoid(_w * _t, _w) / _norm)
if float(_norm) > 0
else float(0.5 * (self._blue_limit + self._red_limit))
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here self._throughput could be a traced array in which case this precompute step won't work.

Comment thread jax_galsim/bandpass.py
# Precompute effective wavelength at construction time so it is a
# concrete Python float and can be used as a static value under JIT.
_w = jnp.linspace(self._blue_limit, self._red_limit, 512)
_t = jnp.interp(_w, self._wave, self._throughput)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This interpolation is linear. There is a slightly better akima interpolant that is fast to compute on-the-fly. See jax_galsim.core.interpolate.

Comment thread jax_galsim/sed.py
Comment on lines +238 to +239
and jnp.array_equal(self._flux, other._flux)
and jnp.array_equal(self._redshift, other._redshift)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
and jnp.array_equal(self._flux, other._flux)
and jnp.array_equal(self._redshift, other._redshift)
& jnp.array_equal(self._flux, other._flux)
& jnp.array_equal(self._redshift, other._redshift)

Copy link
Copy Markdown
Collaborator

@beckermr beckermr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a first pass at review. There are a few big structural things that need addressing before we can review it again.

  1. All classes, methods, and functions need to be wrapped with the @implements decorator. Any changes specific to JAX need to go in the lax_description keyword.
  2. Please add the new objects to the API tests in test_api.py. From my read of the code, many of those API tests will fail.
  3. The Bandpass class cannot cache the effective wavelength on init as a python float since the throughput could be traced.
  4. The JAX tests for Bandpass don't appear the test any gradients or JIT with respect to throughput and instead always use a tophat. I think this is possibly why the init stuff from 3 works in the tests, but won't work in general.
  5. If an array/computation is meant to only be used at compile time, then we should use numpy and wrap it in an ensure_compile_time_eval block.

@beckermr
Copy link
Copy Markdown
Collaborator

pre-commit.ci autofix

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file needs to be renamed to test_chromatic_jax.py. Otherwise, pytest barfs.

Comment on lines +20 to +21
# Enable float64 for accuracy
jax.config.update("jax_enable_x64", True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be removed. We set float64 or float32 globally for different testing regimes.

Suggested change
# Enable float64 for accuracy
jax.config.update("jax_enable_x64", True)

Comment on lines +23 to +24
import jax_galsim as jgal
from jax_galsim.chromatic import ChromaticAtmosphere, ChromaticConvolution
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use the jgs idiom for importing jax_galsim.

Suggested change
import jax_galsim as jgal
from jax_galsim.chromatic import ChromaticAtmosphere, ChromaticConvolution
import jax_galsim as jgs
from jax_galsim.chromatic import ChromaticAtmosphere, ChromaticConvolution

Comment on lines +117 to +118
- "could not convert string to float"
- "is not a valid JAX array type"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These errors should not be ignored. Instead, we need to adjust the JAX code and/or test suite to account for them.

@codspeed-hq
Copy link
Copy Markdown

codspeed-hq Bot commented May 28, 2026

Merging this PR will improve performance by 34.71%

⚠️ Unknown Walltime execution environment detected

Using the Walltime instrument on standard Hosted Runners will lead to inconsistent data.

For the most accurate results, we recommend using CodSpeed Macro Runners: bare-metal machines fine-tuned for performance measurement consistency.

⚠️ Different runtime environments detected

Some benchmarks with significant performance changes were compared across different runtime environments,
which may affect the accuracy of the results.

Open the report in CodSpeed to investigate

⚡ 1 improved benchmark
✅ 35 untouched benchmarks

Performance Changes

Mode Benchmark BASE HEAD Efficiency
Simulation test_benchmark_invert_ab_noraise[run] 1,084.2 µs 804.9 µs +34.71%

Tip

Curious why this is faster? Comment @codspeedbot explain why this is faster on this PR, or directly use the CodSpeed MCP with your agent.


Comparing MaxRonce:chromatic_psf (539826b) with main (4c1bf06)

Open in CodSpeed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants