Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions src/vip_hci/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
from . import preproc
from . import config
from . import fits
from . import invprob
from . import psfsub
from . import fm
from . import metrics
from . import stats
from . import var
from . import objects
from .vip_ds9 import *
import importlib as _importlib

_submodules = [
"config",
"fits",
"fm",
"greedy",
"invprob",
"metrics",
"objects",
"preproc",
"psfsub",
"stats",
"var",
"vip_ds9",
]


def __getattr__(name: str):
if name == '__version__':
if name in _submodules:
return _importlib.import_module(f".{name}", __name__)
if name == "__version__":
from importlib.metadata import version
return version('vip_hci')
return version("vip_hci")
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


def __dir__():
return _submodules + ["__version__"]
55 changes: 49 additions & 6 deletions src/vip_hci/preproc/rescaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
"scale_fft",
]

import numpy as np
import warnings
from multiprocessing import cpu_count

import numpy as np

try:
import cv2
Expand All @@ -27,6 +29,7 @@

from scipy.ndimage import geometric_transform, zoom
from scipy.optimize import minimize
from ..config.utils_conf import pool_map, iterable
from ..var import frame_center, get_square, cube_filter_highpass
from .subsampling import cube_collapse
from .recentering import frame_shift
Expand Down Expand Up @@ -329,6 +332,7 @@ def cube_rescaling_wavelengths(
interpolation="lanczos4",
collapse="median",
pad_mode="reflect",
nproc=1,
):
"""
Scale/Descale a cube by scal_list, with padding. Can deal with NaN values.
Expand Down Expand Up @@ -404,6 +408,9 @@ def cube_rescaling_wavelengths(
pads with the wrap of the vector along the axis. The first
values are used to pad the end and the end values are used to
pad the beginning
nproc : int, optional
Number of processes to use for parallel frame rescaling. Default is 1.
Set to None to use half the available CPUs.

Returns
-------
Expand Down Expand Up @@ -446,7 +453,7 @@ def cube_rescaling_wavelengths(

# (de)scale the cube, so that a planet would now move radially
cube = cube_rescaling(big_cube, scal_list, ref_xy=(cx, cy), imlib=imlib,
interpolation=interpolation)
interpolation=interpolation, nproc=nproc)
frame = cube_collapse(cube, collapse)

if inverse and max_sc > 1:
Expand Down Expand Up @@ -488,6 +495,14 @@ def _scale_func(output_coords, ref_xy=0, scaling=1.0, scale_y=None,
)


def _frame_rescaling_mp(frame, scale, ref_xy, imlib, interpolation,
scale_y, scale_x):
"""Multiprocessing helper for cube_rescaling."""
return frame_rescaling(frame, ref_xy=ref_xy, scale=scale, imlib=imlib,
interpolation=interpolation, scale_y=scale_y,
scale_x=scale_x)


def frame_rescaling(
array,
ref_xy=None,
Expand Down Expand Up @@ -675,6 +690,7 @@ def cube_rescaling(
interpolation="lanczos4",
scaling_y=None,
scaling_x=None,
nproc=1,
):
"""
Rescale a cube by factors from ``scaling_list`` wrt a position.
Expand All @@ -699,22 +715,27 @@ def cube_rescaling(
scaling_x : 1D-array or list
Scaling factor only for x axis. If provided, it takes priority on
scaling_list.
nproc : int, optional
Number of processes to use for parallel frame rescaling. Default is 1.
Set to None to use half the available CPUs.

Returns
-------
array_sc : numpy ndarray
Resulting cube with rescaled frames.

"""

if array.ndim != 3:
raise TypeError("Input array is not a cube or 3d array")

array_sc = []
if scaling_list is None:
scaling_list = [None] * array.shape[0]
for i in range(array.shape[0]):
array_sc.append(

if nproc is None:
nproc = cpu_count() // 2

if nproc == 1:
array_sc = [
frame_rescaling(
array[i],
ref_xy=ref_xy,
Expand All @@ -724,7 +745,22 @@ def cube_rescaling(
scale_y=scaling_y,
scale_x=scaling_x,
)
for i in range(array.shape[0])
]
else:
array_sc = pool_map(
nproc,
_frame_rescaling_mp,
iterable(array),
iterable(scaling_list),
ref_xy,
imlib,
interpolation,
scaling_y,
scaling_x,
reuse_pool=True, # keep pool open because we are often iterating over an ADI cube
)

return np.array(array_sc)


Expand Down Expand Up @@ -1179,3 +1215,10 @@ def scale_fft(array, scale, ori_dim=False):
array_resc = scaled

return array_resc







9 changes: 6 additions & 3 deletions src/vip_hci/psfsub/pca_fullfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,7 +1087,7 @@ def _adimsdi_singlepca(
print("Rescaling the spectral channels to align the speckles")
for i in Progressbar(range(n), verbose=verbose):
cube_resc = scwave(cube[:, i, :, :], scale_list, imlib=imlib2,
interpolation=interpolation)[0]
interpolation=interpolation, nproc=nproc)[0]
if crop_ifs:
cube_resc = cube_crop_frames(cube_resc, size=y_in, verbose=False)
big_cube.append(cube_resc)
Expand All @@ -1104,7 +1104,7 @@ def _adimsdi_singlepca(
print(msg)
for i in Progressbar(range(nr), verbose=verbose):
cube_resc = scwave(cube_ref[:, i, :, :], scale_list, imlib=imlib2,
interpolation=interpolation)[0]
interpolation=interpolation, nproc=nproc)[0]
if crop_ifs:
cube_resc = cube_crop_frames(cube_resc, size=y_in,
verbose=False)
Expand Down Expand Up @@ -1181,6 +1181,7 @@ def _adimsdi_singlepca(
imlib=imlib2,
interpolation=interpolation,
collapse=collapse_ifs,
nproc=nproc
)
cube_desc_residuals[:, i] = res_i[0]
resadi_cube[i] = res_i[1]
Expand Down Expand Up @@ -1490,7 +1491,8 @@ def _adimsdi_doublepca_ifs(
frame_i = cube_collapse(multispec_fr[idx_ini:idx_fin])
else:
cube_resc = scwave(
multispec_fr, scale_list, imlib=imlib, interpolation=interpolation
multispec_fr, scale_list, imlib=imlib, interpolation=interpolation,
nproc=1, # already inside a pool_map worker so we don't want nested multiprocessing
)[0]

if mask_rdi is None:
Expand Down Expand Up @@ -1523,6 +1525,7 @@ def _adimsdi_doublepca_ifs(
imlib=imlib,
interpolation=interpolation,
collapse=collapse,
nproc=1, # already inside a pool_map worker so we don't want nested multiprocessing
)
if mask_center_px:
frame_i = mask_circle(frame_i, mask_center_px)
Expand Down
Loading