Source code for yt_xarray.accessor.accessor

from typing import Any, List, Union

import numpy as np
import xarray as xr
import yt
from numpy.typing import ArrayLike, NDArray
from unyt import unyt_quantity
from yt.data_objects.static_output import Dataset as ytDataset

from yt_xarray.accessor import _xr_to_yt
from yt_xarray.accessor._readers import _get_xarray_reader
from yt_xarray.accessor._xr_to_yt import _load_full_field_from_xr
from yt_xarray.utilities._grid_decomposition import ChunkInfo
from yt_xarray.utilities.logging import ytxr_log


[docs] @xr.register_dataset_accessor("yt") class YtAccessor: def __init__(self, xarray_obj): self._obj = xarray_obj self._bbox_cart = {} self._bbox = {} self._active_yt_ds = None @property def _yt_ds(self): # a property for storing the last-loaded yt dataset. return self._active_yt_ds @_yt_ds.setter def _yt_ds(self, yt_ds): self._active_yt_ds = yt_ds
[docs] def load_grid( self, fields: Union[str, List[str]] | None = None, geometry: str | None = None, use_callable: bool = True, sel_dict: dict[str, Any] | None = None, sel_dict_type: str = "isel", chunksizes: Union[int, ArrayLike] | None = None, **kwargs, ): """ Initializes a yt gridded dataset for the supplied fields. Parameters ---------- fields : str, list[str] list of fields to include. If None, will try to use all fields geometry : str the geometry to pass to yt.load_uniform grid. If not provided, will attempt to infer. use_callable : bool if True (default), then the yt dataset will utilize links to the open xarray Dataset handle to avoid copying data. sel_dict: dict an optional selection dictionary to apply to the fields before yt dataset initialization sel_dict_type: str either "isel" (default) or "sel" to indicate index or value selection for sel_dict. chunksizes: int or ArrayLike if set, will decompose the grid into multiple grids with grid dimensions of chunksizes. Can be a single integer (same chunksize in each dimensions) or an ArrayLike object of the same length as the number of dimensions. kwargs : any additional keyword arguments to pass to yt.load_uniform_grid Returns ------- yt StreamDataset """ if fields is None: # might as well try! fields = list(self._obj.data_vars) if isinstance(fields, str): fields = [ fields, ] sel_info = _xr_to_yt.Selection( self._obj, fields=fields, sel_dict=sel_dict, sel_dict_type=sel_dict_type, ) if geometry is None: geometry = self.geometry geometry = _xr_to_yt._determine_yt_geomtype(geometry, sel_info.selected_coords) if geometry is None: raise ValueError( "Cannot determine yt geometry type, please provide" "geometry = 'geographic', 'internal_geographic' or 'cartesian'" ) length_unit: str | float | None if "length_unit" in kwargs: length_unit = kwargs.pop("length_unit") else: length_unit = self._infer_length_unit() if length_unit is None: raise ValueError( "cannot determine length_unit, please provide as" " a keyword argument." ) axis_order = sel_info.yt_coord_names geom = (geometry, axis_order) simtime = sel_info.selected_time if isinstance(sel_info.selected_time, np.datetime64): # float(simtime) will be nanoseconds before/after 1970-01-01 # would be nice to have yt ingest a np datetime, especially cause # this will be converted to a float, so the roundtrip will not # match exactly. simtime = unyt_quantity(int(simtime), "ns") kwargs.update({"sim_time": simtime}) if chunksizes is None: ds_yt = _load_single_grid( self._obj, sel_info, geom, use_callable, fields, length_unit, **kwargs ) elif sel_info.grid_type == _xr_to_yt._GridType.STRETCHED: raise NotImplementedError( "Stretched grids cannot set the chunksizes argument." ) else: ds_yt = _load_chunked_grid( self._obj, sel_info, geom, use_callable, fields, length_unit, chunksizes, **kwargs, ) self._yt_ds = ds_yt return ds_yt
def _infer_length_unit(self) -> str | float | None: if self.geometry == "geodetic": return 1.0 elif hasattr(self._obj, "geospatial_vertical_units"): # some netcdf conventions have this! return str(self._obj.geospatial_vertical_units) return None _geometry = None
[docs] def set_geometry(self, geometry: str): """ Set the geometry for the dataset. Parameters ---------- geometry: str the dataset's coordinate type. See yt_xarray.valid_geometries for possible values. """ self._geometry = _xr_to_yt._validate_geometry(geometry)
@property def geometry(self) -> str: """the dataset geometry""" if self._geometry is None: self._geometry = self._infer_geometry() return self._geometry def _infer_geometry(self) -> str: # try to infer if we have a geodetic dataset. the differentiation # between internal and not used by yt (internal_geographic vs geographic) # is not applied here, but is when converting to a yt dataset (see # _xr_to_yt._determine_yt_geomtype). Default is to assume cartesian. geodetic_names = ( _xr_to_yt._coord_aliases["latitude"] + _xr_to_yt._coord_aliases["longitude"] ) ctype = "cartesian" for coord in list(self._obj.coords): if coord.lower() in geodetic_names: ctype = "geodetic" ytxr_log.info( f"Inferred geometry type is {ctype}. To override, use ds.yt.set_geometry" ) return ctype @property def _coord_list(self): # a list of all dataset coordinates. Note that dataset fields # may use a different ordering!!! return list(self._obj.coords.keys())
[docs] def get_bbox( self, field: str, sel_dict: dict[str, Any] | None = None, sel_dict_type: str = "isel", ) -> NDArray: """ return the bounding box array for a field, with possible selections Parameters ---------- field the field to check the bounding box for sel_dict: dict an optional selection dictionary to apply to the fields before yt dataset initialization sel_dict_type: str either "isel" (default) or "sel" to indicate index or value selection for sel_dict. Returns ------- np.ndarray an array with shape (3, 2) with the min, max values for each dimension of the coordinates of a field. """ sel_info = _xr_to_yt.Selection( self._obj, fields=[field], sel_dict=sel_dict, sel_dict_type=sel_dict_type ) return sel_info.selected_bbox
[docs] def SlicePlot(self, normal, field, **im_kwargs): """ Wrapper of ``yt.SlicePlot``. For off-axis plots, first construct a yt dataset object with ``ds.yt.load_grid`` and then use ``yt.SlicePlot`` Parameters ---------- normal: str or int The normal to the slice. field: str The field to plot im_kwargs any additional keyword arguments accepted by ``yt.SlicePlot`` Returns ------- yt PlotWindow """ return _yt_2D_plot(yt.SlicePlot, self._obj, normal, field, **im_kwargs)
[docs] def ProjectionPlot(self, normal, field, **im_kwargs): """ Wrapper of ``yt.ProjectionPlot``. For off-axis plots, first construct a yt dataset object with ``ds.yt.load_grid`` and then use ``yt.ProjectionPlot`` Parameters ---------- normal: str or int or 3-element tuple The normal to the slice. field: str The field to plot im_kwargs any of the keyword arguments accepted by ``yt.ProjectionPlot`` Returns ------- yt PlotWindow """ return _yt_2D_plot(yt.ProjectionPlot, self._obj, normal, field, **im_kwargs)
[docs] def PhasePlot( self, x_field: str, y_field: str, z_fields: Union[str, List[str]], weight_field: str | None = None, x_bins: int = 128, y_bins: int = 128, accumulation: bool = False, fractional: bool = False, fontsize: int | float = 18, figure_size: int | float = 8.0, shading: str = "nearest", ): """ Construct a `yt.PhasePlot`. Parameters ---------- x_field : str The x binning field for the profile. y_field : str The y binning field for the profile. z_fields : str or list The field or fields to be profiled. weight_field : str The weight field for calculating weighted averages. If None, the profile values are the sum of the field values within the bin. Otherwise, the values are a weighted average. Default : ("gas", "mass") x_bins : int The number of bins in x field for the profile. Default: 128. y_bins : int The number of bins in y field for the profile. Default: 128. accumulation : bool or list of bools If True, the profile values for a bin n are the cumulative sum of all the values from bin 0 to n. If -True, the sum is reversed so that the value for bin n is the cumulative sum from bin N (total bins) to n. A list of values can be given to control the summation in each dimension independently. Default: False. fractional : If True the profile values are divided by the sum of all the profile data such that the profile represents a probability distribution function. fontsize : int Font size for all text in the plot. Default: 18. figure_size : int Size in inches of the image. Default: 8 (8x8) shading : str This argument is directly passed down to matplotlib.axes.Axes.pcolormesh see https://matplotlib.org/3.3.1/gallery/images_contours_and_fields/pcolormesh_grids.html#sphx-glr-gallery-images-contours-and-fields-pcolormesh-grids-py # noqa Default: 'nearest' """ if isinstance(z_fields, str): z_fields = [ z_fields, ] fields_needed = list(set([x_field, y_field] + z_fields)) ds = _get_default_ds(self._obj, fields_needed) pp_x_field = ("stream", x_field) pp_y_field = ("stream", y_field) pp_zfields = [("stream", z_f) for z_f in z_fields] return yt.PhasePlot( ds, pp_x_field, pp_y_field, pp_zfields, weight_field=weight_field, x_bins=x_bins, y_bins=y_bins, accumulation=accumulation, fractional=fractional, fontsize=fontsize, figure_size=figure_size, shading=shading, )
[docs] def ProfilePlot( self, x_field, y_fields, weight_field=None, n_bins=64, accumulation=False, fractional=False, label=None, plot_spec=None, x_log=True, y_log=True, ): """ Construct a `yt.ProfilePlot`. Parameters ---------- x_field : str The binning field for the profile. y_fields : str or list The field or fields to be profiled. weight_field : str The weight field for calculating weighted averages. If None, the profile values are the sum of the field values within the bin. Otherwise, the values are a weighted average. Default : None n_bins : int The number of bins in the profile. Default: 64. accumulation : bool If True, the profile values for a bin N are the cumulative sum of all the values from bin 0 to N. Default: False. fractional : If True the profile values are divided by the sum of all the profile data such that the profile represents a probability distribution function. label : str or list of strings If a string, the label to be put on the line plotted. If a list, this should be a list of labels for each profile to be overplotted. Default: None. plot_spec : dict or list of dicts A dictionary or list of dictionaries containing plot keyword arguments. For example, dict(color="red", linestyle=":"). Default: None. x_log : bool Whether the x_axis should be plotted with a logarithmic scaling (True), or linear scaling (False). Default: True. y_log : dict or bool A dictionary containing field:boolean pairs, setting the logarithmic property for that field. May be overridden after instantiation using set_log A single boolean can be passed to signify all fields should use logarithmic (True) or linear scaling (False). Default: True. Returns ------- """ fields_needed = list(set([x_field, y_fields])) if weight_field is not None and weight_field not in fields_needed: fields_needed.append(weight_field) ds = _get_default_ds(self._obj, fields_needed) ad = ds.all_data() return yt.ProfilePlot( ad, x_field, y_fields, weight_field=weight_field, n_bins=n_bins, accumulation=accumulation, fractional=fractional, label=label, plot_spec=plot_spec, x_log=x_log, y_log=y_log, )
def _load_single_grid( ds_xr, sel_info, geom, use_callable, fields, length_unit, **kwargs ) -> ytDataset: geometry = geom[0] interp_required, data_shp, bbox = sel_info.interp_validation(geometry) g_dict = sel_info.grid_dict.copy() g_dict["dimensions"] = data_shp g_dict["left_edge"] = bbox[:, 0] g_dict["right_edge"] = bbox[:, 1] if sel_info.ndims == 2: axis_order = geom[1] axis_order = _xr_to_yt._add_3rd_axis_name(geom[0], axis_order) geom = (geom[0], axis_order) data_shp = data_shp + (1,) bbox = np.vstack([bbox, [-0.5, 0.5]]) data = {} if use_callable: reader = _get_xarray_reader(ds_xr, sel_info, interp_required=interp_required) for field in fields: units = sel_info.units[field] if use_callable: data[field] = (reader, units) else: vals = _load_full_field_from_xr( ds_xr, field, sel_info, interp_required=interp_required ) data[field] = (vals, units) if sel_info.ndims == 2: g_dict["left_edge"] = np.append(g_dict["left_edge"], -0.5) g_dict["right_edge"] = np.append(g_dict["right_edge"], 0.5) g_dict["dimensions"] += (1,) if sel_info.grid_type == _xr_to_yt._GridType.STRETCHED: return yt.load_uniform_grid( data, data_shp, geometry=geom[0], bbox=bbox, length_unit=length_unit, cell_widths=sel_info.cell_widths, axis_order=geom[1], **kwargs, ) else: data.update(g_dict) grid_data = [ data, ] return yt.load_amr_grids( grid_data, data_shp, geometry=geom[0], bbox=bbox, length_unit=length_unit, axis_order=geom[1], **kwargs, ) def _load_chunked_grid( ds_xr, sel_info, geom, use_callable, fields, length_unit, chunksizes, **kwargs ) -> ytDataset: if isinstance(chunksizes, int): chunksizes = np.array((chunksizes,) * sel_info.ndims) elif len(chunksizes) != sel_info.ndims: raise ValueError( f"The number of elements in chunksizes {len(chunksizes)} " f"must match the dimensionality {sel_info.ndims}" ) else: chunksizes = np.asarray(chunksizes, dtype=int) if sel_info.ndims != 3: raise NotImplementedError( "Can only load a chunked grid with 3D fields at present." ) geometry = geom[0] # get the global shape and bounding box interp_required, data_shp, bbox = sel_info.interp_validation(geometry) # note: if interp_required, data_shp is number of cells # otherwise it is number of nodes (which are treated as new cell centers). # the bbox will already account for this as well. # initialize the global starting index si = np.array([0, 0, 0], dtype=int) si = sel_info.starting_indices + si # do some grid/chunk counting chnkinfo = ChunkInfo(data_shp, chunksizes, starting_index_offset=si) ytxr_log.info(f"Constructing a yt chunked grid with {chnkinfo.n_tots} chunks.") # select field for grabbing coordinate arrays -- fields should all be # verified by now fld = fields[0] cnames = sel_info.selected_coords if interp_required is False: dxyz = np.array([cell_wids[0] for cell_wids in sel_info.cell_widths]) # build arrays of the left_edges, right_edges and dimensions. these will # be organized by dimension first (all of the left edges in x, all # the left edges in y, etc.) left_edges = [] right_edges = [] subgrid_sizes = [] subgrid_start = [] subgrid_end = [] for idim in range(sel_info.ndims): si_0 = chnkinfo.si[idim] ei_0 = chnkinfo.ei[idim] c = cnames[idim] rev_ax = sel_info.reverse_axis[idim] if rev_ax is False: le_0 = ds_xr[fld].coords[c].isel({c: si_0}).values if interp_required is False: # move the edges so the node is now a cell center le_0 = le_0 - dxyz[idim] / 2.0 # bbox value below already accounts for interp_required max_val = bbox[idim, 1] re_0 = np.concatenate([le_0[1:], [max_val]]) else: re_0 = ds_xr[fld].coords[c].isel({c: si_0[::-1]}).values if interp_required is False: # move the edges so the node is now a cell center re_0 = re_0 - dxyz[idim] / 2.0 min_val = bbox[idim, 0] le_0 = np.concatenate([[min_val], re_0[:-1]]) # sizes also already account for interp_required subgrid_size = chnkinfo.sizes[idim] left_edges.append(le_0) right_edges.append(re_0) subgrid_sizes.append(subgrid_size) subgrid_start.append(si_0) subgrid_end.append(ei_0) # these arrays are ordered by dimension. e.g., left_edges[0] will be the # all first dimension left edges left_edges = np.meshgrid(*left_edges, indexing="ij") right_edges = np.meshgrid(*right_edges, indexing="ij") subgrid_sizes = np.meshgrid(*subgrid_sizes, indexing="ij") subgrid_start = np.meshgrid(*subgrid_start, indexing="ij") subgrid_end = np.meshgrid(*subgrid_end, indexing="ij") # re-organize by grid number so that, e.g., the left_edges are the usual # left_edges (left_edges[0] is the min x, y, z of grid 0) left_edges = np.column_stack([le.ravel() for le in left_edges]) right_edges = np.column_stack([re.ravel() for re in right_edges]) dimensions = np.column_stack([sz.ravel() for sz in subgrid_sizes]) subgrid_start = np.column_stack([sz.ravel() for sz in subgrid_start]) subgrid_end = np.column_stack([sz.ravel() for sz in subgrid_end]) # now ready to build the list of grids if use_callable: reader = _get_xarray_reader(ds_xr, sel_info, interp_required=interp_required) grid_data = [] n_grids = len(left_edges) if use_callable is False: full_field_vals = {} for field in fields: vals = _load_full_field_from_xr( ds_xr, field, sel_info, interp_required=interp_required ) full_field_vals[field] = vals for igrid in range(n_grids): gdict = { "left_edge": left_edges[igrid], "right_edge": right_edges[igrid], "dimensions": dimensions[igrid], "level": 0, } for field in fields: units = sel_info.units[field] if use_callable: gdict[field] = (reader, units) else: si = subgrid_start[igrid] ei = subgrid_end[igrid] gridvals = full_field_vals[field][ si[0] : ei[0], si[1] : ei[1], si[2] : ei[2] ] gdict[field] = (gridvals, units) grid_data.append(gdict) return yt.load_amr_grids( grid_data, data_shp, geometry=geom[0], bbox=bbox, length_unit=length_unit, axis_order=geom[1], **kwargs, ) def _get_default_ds(ds_xr: xr.Dataset, field): geom = ds_xr.yt.geometry # will trigger inference here # other load_grid options # use_callable: bool = True, # sel_dict: Optional[dict] = None, dont need this one # sel_dict_type: Optional[str] = "isel", dont need this one # chunksizes: Optional[int] = None, # if the grid were cached this might be easier... ds = ds_xr.yt.load_grid(fields=field, geometry=geom) return ds def _yt_2D_plot(yt_function, ds_xr: xr.Dataset, normal, field, **im_kwargs): ds = _get_default_ds(ds_xr, field) # normal = validate_normal(normal) return yt_function(ds, normal, ("stream", field), **im_kwargs)