FEAT - Add chromatic psf and the corresponding tests (Issue #251)#253
FEAT - Add chromatic psf and the corresponding tests (Issue #251)#253MaxRonce wants to merge 5 commits into
Conversation
Know issue
|
| Chromatic Profiles | ||
| ================== |
There was a problem hiding this comment.
| Chromatic Profiles | |
| ================== | |
| Wavelength-dependent Profiles | |
| ============================= |
Let's use this title to match the upstream documentation. The galsim docs further use separate pages for the SEDS, Bandpasses, and Chromatic objects, but we can skip that bit.
|
|
||
|
|
||
| @register_pytree_node_class | ||
| class Bandpass: |
There was a problem hiding this comment.
Let's not repeat doc strings from upstream and instead use the implements decorator, with any caveats put into the lax_description keyword.
|
|
||
| if self._wave.ndim != 1 or len(self._wave) < 2: | ||
| raise ValueError("wave must be a 1-D array with at least 2 elements.") | ||
| if len(self._throughput) != len(self._wave): |
There was a problem hiding this comment.
| if len(self._throughput) != len(self._wave): | |
| if self._throughput.shape != self._wave.shape: |
It may be more robust to check the shape directly?
| self._blue_limit = float(blue_limit) if blue_limit is not None else float(self._wave[0]) | ||
| self._red_limit = float(red_limit) if red_limit is not None else float(self._wave[-1]) |
There was a problem hiding this comment.
There is a utility function cast_to_float in jax_galsim.core.utils. I'd use that here.
| # Precompute effective wavelength at construction time so it is a | ||
| # concrete Python float and can be used as a static value under JIT. | ||
| _w = jnp.linspace(self._blue_limit, self._red_limit, 512) | ||
| _t = jnp.interp(_w, self._wave, self._throughput) | ||
| _norm = jnp.trapezoid(_t, _w) | ||
| self._effective_wavelength_val = ( | ||
| float(jnp.trapezoid(_w * _t, _w) / _norm) | ||
| if float(_norm) > 0 | ||
| else float(0.5 * (self._blue_limit + self._red_limit)) | ||
| ) |
There was a problem hiding this comment.
Here self._throughput could be a traced array in which case this precompute step won't work.
| # Precompute effective wavelength at construction time so it is a | ||
| # concrete Python float and can be used as a static value under JIT. | ||
| _w = jnp.linspace(self._blue_limit, self._red_limit, 512) | ||
| _t = jnp.interp(_w, self._wave, self._throughput) |
There was a problem hiding this comment.
This interpolation is linear. There is a slightly better akima interpolant that is fast to compute on-the-fly. See jax_galsim.core.interpolate.
| and jnp.array_equal(self._flux, other._flux) | ||
| and jnp.array_equal(self._redshift, other._redshift) |
There was a problem hiding this comment.
| and jnp.array_equal(self._flux, other._flux) | |
| and jnp.array_equal(self._redshift, other._redshift) | |
| & jnp.array_equal(self._flux, other._flux) | |
| & jnp.array_equal(self._redshift, other._redshift) |
beckermr
left a comment
There was a problem hiding this comment.
I took a first pass at review. There are a few big structural things that need addressing before we can review it again.
- All classes, methods, and functions need to be wrapped with the
@implementsdecorator. Any changes specific to JAX need to go in thelax_descriptionkeyword. - Please add the new objects to the API tests in
test_api.py. From my read of the code, many of those API tests will fail. - The Bandpass class cannot cache the effective wavelength on init as a python float since the throughput could be traced.
- The JAX tests for Bandpass don't appear the test any gradients or JIT with respect to throughput and instead always use a tophat. I think this is possibly why the init stuff from 3 works in the tests, but won't work in general.
- If an array/computation is meant to only be used at compile time, then we should use
numpyand wrap it in anensure_compile_time_evalblock.
|
pre-commit.ci autofix |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
This file needs to be renamed to test_chromatic_jax.py. Otherwise, pytest barfs.
| # Enable float64 for accuracy | ||
| jax.config.update("jax_enable_x64", True) |
There was a problem hiding this comment.
This needs to be removed. We set float64 or float32 globally for different testing regimes.
| # Enable float64 for accuracy | |
| jax.config.update("jax_enable_x64", True) |
| import jax_galsim as jgal | ||
| from jax_galsim.chromatic import ChromaticAtmosphere, ChromaticConvolution |
There was a problem hiding this comment.
We use the jgs idiom for importing jax_galsim.
| import jax_galsim as jgal | |
| from jax_galsim.chromatic import ChromaticAtmosphere, ChromaticConvolution | |
| import jax_galsim as jgs | |
| from jax_galsim.chromatic import ChromaticAtmosphere, ChromaticConvolution |
| - "could not convert string to float" | ||
| - "is not a valid JAX array type" |
There was a problem hiding this comment.
These errors should not be ignored. Instead, we need to adjust the JAX code and/or test suite to account for them.
Merging this PR will improve performance by 34.71%
|
| Mode | Benchmark | BASE |
HEAD |
Efficiency | |
|---|---|---|---|---|---|
| ⚡ | Simulation | test_benchmark_invert_ab_noraise[run] |
1,084.2 µs | 804.9 µs | +34.71% |
Tip
Curious why this is faster? Comment @codspeedbot explain why this is faster on this PR, or directly use the CodSpeed MCP with your agent.
Comparing MaxRonce:chromatic_psf (539826b) with main (4c1bf06)
Should go with the issue #251
python -m pytest -c /home/maxime/src/galsim/JAX-GalSim_chromaticPSF/pyproject.toml tests/GalSim/tests/test_sed.py tests/GalSim/tests/test_bandpass.py tests/GalSim/tests/test_chromatic.py -q
Tests seams to pass and results are similar to original galsim
Element by element testing and benchmark was done using this script :
bench_dsps_chromatic.py
The 10e-5 error on the bottom left plot seams to be only numerical bias due to the NWAVES for the fft and not an actuel implementation bias
Some class diagram before and after :
Regular Galsim
Jax Galsim