{ "cells": [ { "cell_type": "markdown", "id": "4be05348-4be8-406d-97fd-49f43c1c6aa0", "metadata": {}, "source": [ "# Custom interpolation functions in embedded datasets\n", "\n", "The cartesian yt dataset returned by `yt_xarray.transformations.build_interpolated_cartesian_ds` consists of a cartesian grid (or set of refined grids) that will sample the underyling xarray dataset as needed. In the final step of this process, the native coordinates of the cartesian grid points are calculated (using the supplied `Transformer` object -- see previous notebooks) and then the underlying xarray dataset is sampled at those native coordinates. Because the exact values of requested native coordinates is unlikely to exist, an interpolation is required and there are a number of ways to control the interpolation:\n", "\n", "The keyword argument, `interp_method`, can be set to `'nearest'` (the default) or `'interpolate'`. If `'nearest'`, then the nearest value to the request point is used and if `'interpolate'`, a linear interpolation will be used. In both cases, underlying xarray methods are used. For `'nearest'`, `xr.DataArray.data.sel(interp_dict, method=\"nearest\")` and for `'interpolate'`, `xr.DataArray.interp` is used (which relies on the scipy N-D linear interpolation method). \n", "\n", "Additionally, you can supply your own interpolation function to `build_interpolated_cartesian_ds`. The function will be called with ``interp_func(data=data_array, coords=eval_coords)``, where\n", "``data_array`` is an xarray ``DataArray`` and ``eval_coords`` is a list of 1d\n", "np.ndarray ordered by the `native_coords` of the supplied transformer. The function must return an `np.ndarray` of the same shape as the ``eval_coords``. \n", "\n", "As an example, we'll build a custom interpolation function similar to the implementation when you set `interp_method='nearest'`. \n", "\n", "Let's first import all the modules we need then initialize a transformer:" ] }, { "cell_type": "code", "execution_count": 1, "id": "9e578ab2-9afc-4ae3-a1af-e80ad70b44ee", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('altitude', 'latitude', 'longitude')" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np \n", "import numpy.typing as npt\n", "import xarray as xr\n", "import yt_xarray\n", "from yt_xarray.transformations import GeocentricCartesian, build_interpolated_cartesian_ds\n", "from yt_xarray.sample_data import load_random_xr_data\n", "from typing import List \n", "import yt\n", "\n", "gc = GeocentricCartesian(radial_type='altitude', r_o=6371.)\n", "gc.native_coords" ] }, { "cell_type": "markdown", "id": "75d10c1d-7d48-4e0f-9954-105b8c1a8cf6", "metadata": {}, "source": [ "The transformer `native_coords` tuple above indicates the order in which coordinate arrays will be supplied. \n", "\n", "To write an interpolation function that samples an `xarray.DataArray` at its nearest points:" ] }, { "cell_type": "code", "execution_count": 2, "id": "d147e64a-155a-42b1-89af-ef5a97efcb8c", "metadata": {}, "outputs": [], "source": [ "def my_interp(data: xr.DataArray = None, \n", " coords: List[npt.NDArray] = None) -> npt.NDArray:\n", " print(\"hello from my_interp!\")\n", " c0 = coords[0] # altitude \n", " c1 = coords[1] # latitude\n", " c2 = coords[2] # longitude\n", " \n", " interp_dict = {\n", " 'altitude': xr.DataArray(c0, dims='points'), \n", " 'latitude': xr.DataArray(c1, dims='points'), \n", " 'longitude': xr.DataArray(c2, dims='points'),\n", " } \n", " vals = data.sel(interp_dict, method='nearest')\n", " return vals.values" ] }, { "cell_type": "markdown", "id": "b15d6b9f-1ffa-44c6-bdd8-0f7ba99b4565", "metadata": {}, "source": [ "Now let's get a sample dataset to use:" ] }, { "cell_type": "code", "execution_count": 3, "id": "b7ad9c13-f74e-46e7-9eff-2e8b4ff3b1c9", "metadata": {}, "outputs": [], "source": [ "ds = load_random_xr_data({'field1':('altitude', 'latitude', 'longitude')},\n", " {'altitude': (0, 2000, 10), \n", " 'latitude': (0, 20, 15), \n", " 'longitude': (10, 20, 12)}, \n", " length_unit='m')" ] }, { "cell_type": "code", "execution_count": 4, "id": "1035014e-bc33-4354-8690-940097be0756", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
<xarray.Dataset>\n",
"Dimensions: (altitude: 10, latitude: 15, longitude: 12)\n",
"Coordinates:\n",
" * altitude (altitude) float64 0.0 222.2 444.4 ... 1.556e+03 1.778e+03 2e+03\n",
" * latitude (latitude) float64 0.0 1.429 2.857 4.286 ... 17.14 18.57 20.0\n",
" * longitude (longitude) float64 10.0 10.91 11.82 12.73 ... 18.18 19.09 20.0\n",
"Data variables:\n",
" field1 (altitude, latitude, longitude) float64 0.9765 0.4326 ... 0.01994\n",
"Attributes:\n",
" geospatial_vertical_units: m