diff --git a/src/radclss/config/output_config.py b/src/radclss/config/output_config.py index ec1a9fd..0274df9 100644 --- a/src/radclss/config/output_config.py +++ b/src/radclss/config/output_config.py @@ -21,6 +21,10 @@ OUTPUT_STATION_ATTRS = dict( long_name="Bankhead National Forest AMF-3 In-Situ Ground Observation Station Identifers" ) +OUTPUT_STATION_NAME_ATTRS = dict( + long_name="Bankhead National Forest AMF-3 In-Situ Ground Observation Station Names" +) + OUTPUT_LAT_ATTRS = dict( long_name="Latitude of BNF AMF-3 Ground Observation Site", units="Degrees North" ) @@ -179,6 +183,7 @@ def get_output_config(): "gate_time_attrs": OUTPUT_GATE_TIME_ATTRS, "time_offset_attrs": OUTPUT_TIME_OFFSET_ATTRS, "station_attrs": OUTPUT_STATION_ATTRS, + "station_name_attrs": OUTPUT_STATION_NAME_ATTRS, "lat_attrs": OUTPUT_LAT_ATTRS, "lon_attrs": OUTPUT_LON_ATTRS, "alt_attrs": OUTPUT_ALT_ATTRS, diff --git a/src/radclss/core/radclss_core.py b/src/radclss/core/radclss_core.py index 808eaf0..97f6cad 100644 --- a/src/radclss/core/radclss_core.py +++ b/src/radclss/core/radclss_core.py @@ -25,7 +25,7 @@ def radclss( dod_version="1.0", discard_var={}, verbose=False, - base_station="M1", + base_station=0, current_client=None, nexrad=True, nexrad_site=None, @@ -73,8 +73,8 @@ def radclss( Dictionary containing variables to drop from each datastream. Default is {}. verbose : bool, optional Option to print additional information during processing. Default is False. - base_station : str, optional - The base station name to use for time variables. Default is "M1". + base_station : int, optional + The base station index to use for time variables. Default is 0. current_client : dask.distributed.Client, optional Option to supply an existing Dask client for parallel processing. Set to None to use the current active client. Default is None. @@ -303,7 +303,7 @@ def _get_nexrad_wrapper(time_str): output_config["site"], input_site_dict, nexrad_radar=nexrad_site, - height_bins=height_bins + height_bins=height_bins, ) results = current_client.map(_get_nexrad_wrapper, time_list) @@ -340,7 +340,7 @@ def _get_nexrad_wrapper(time_str): output_config["site"], input_site_dict, nexrad_radar=nexrad_site, - height_bins=height_bins + height_bins=height_bins, ) successful_count = 0 @@ -436,7 +436,7 @@ def _get_nexrad_wrapper(time_str): print("=" * 80) print(f" Time coordinate method: {time_coords}") - ds_concat[k] = ds_concat[k].drop_duplicates('time') + ds_concat[k] = ds_concat[k].drop_duplicates("time") if "radar" in time_coords: if verbose: print(f" Reindexing all datasets to {time_coords} time coordinates") @@ -541,7 +541,8 @@ def _get_nexrad_wrapper(time_str): print(f" Time arrays from {k}:") print(ds_concat[k]["base_time"]) ds_concat[k] = ds_concat[k].drop(["time_offset", "base_time"]) - nexrad_columns = nexrad_columns.drop(["time_offset", "base_time"]) + if nexrad_columns is not None: + nexrad_columns = nexrad_columns.drop(["time_offset", "base_time"]) first_key = list(ds_concat.keys())[0] for k in list(ds_concat.keys())[1:]: for var in ds_concat[k].data_vars: @@ -550,12 +551,13 @@ def _get_nexrad_wrapper(time_str): print(f"Dropping {var} from {k}") ds_concat[k] = ds_concat[k].drop(var) - for var in nexrad_columns.data_vars: - for k in ds_concat.keys(): - if var in ds_concat[k].data_vars: - if verbose: - print(f"Dropping {var} from nexrad_columns") - nexrad_columns = nexrad_columns.drop(var) + if nexrad_columns is not None: + for var in nexrad_columns.data_vars: + for k in ds_concat.keys(): + if var in ds_concat[k].data_vars: + if verbose: + print(f"Dropping {var} from nexrad_columns") + nexrad_columns = nexrad_columns.drop(var) ds_concat = xr.merge([x for x in ds_concat.values()]) if verbose: @@ -601,8 +603,14 @@ def _get_nexrad_wrapper(time_str): # Calculate time as seconds since base_time ds["time_offset"] = ds["time"] - ds["station"] = ds_concat["station"] - ds["height"] = ds_concat["height"] + ds["station"] = ("station", ds_concat["station"].values) + # The DOD may define station_name as a char-array data var (station, string_length_N). + # Replace it with a 1-D string coord populated from the source dataset. + if "station_name" in ds.variables: + ds = ds.drop_vars("station_name") + ds = ds.assign_coords(station_name=("station", ds_concat["station_name"].values)) + ds.coords["station_name"].attrs.update(output_config["station_name_attrs"]) + ds["height"] = ("height", ds_concat["height"].values) ds["lat"][:] = ds_concat.isel(time=0)["lat"][:] ds["lon"][:] = ds_concat.isel(time=0)["lon"][:] ds["alt"][:] = ds_concat.isel(time=0)["alt"][:] @@ -708,7 +716,6 @@ def _get_nexrad_wrapper(time_str): else: instrument = k site = base_station - site = site.upper() if instrument == "kazr2": _instrument_tasks.append( diff --git a/src/radclss/util/column_utils.py b/src/radclss/util/column_utils.py index 36cbe83..7635159 100644 --- a/src/radclss/util/column_utils.py +++ b/src/radclss/util/column_utils.py @@ -18,6 +18,29 @@ _nexrad_cache = {} +def _station_index(ds, station): + """Return integer index along the ``station`` dim for a name or int. + + ``station`` may be a single value or array-like of names/ints. Names are + looked up against ``ds["station_name"]`` if present, otherwise against + ``ds["station"]`` for back-compat with datasets that stored string names + directly on the ``station`` coord. Ints are returned as-is. + """ + if np.isscalar(station): + if isinstance(station, (int, np.integer)): + return int(station) + if "station_name" in ds.variables: + names = ds["station_name"].values + else: + names = ds["station"].values + matches = np.where(names == station)[0] + if len(matches) == 0: + raise ValueError(f"Station '{station}' not found in dataset") + return int(matches[0]) + + return np.array([_station_index(ds, s) for s in station], dtype=int) + + def _read_sonde_cached(path, exclude): """Return a fully-loaded sonde dataset, reading from disk only on first call.""" if path not in _sonde_cache: @@ -157,7 +180,7 @@ def _vpt_to_column_timeseries(radar, height_bins): data_vars[key] = xr.DataArray(arr, dims=["time", "height"], attrs=attrs) ds = xr.Dataset(data_vars, coords={"height": zgate, "time": abs_times}) - #ds = ds.dropna("height") + # ds = ds.dropna("height") valid_h = np.isfinite(ds["height"]) print(ds["reflectivity"], np.sum(np.isnan(ds["reflectivity"].values))) if int(valid_h.sum()) > 0: @@ -231,8 +254,7 @@ def _vpt_nan_fill(radar, height_bins): attrs=attrs, ) - ds = xr.Dataset(data_vars, - coords={"height": height_bins, "time": abs_times}) + ds = xr.Dataset(data_vars, coords={"height": height_bins, "time": abs_times}) ds = ds.drop_duplicates("time", keep="first") dedup_times_s = ds["time"].values.astype("datetime64[s]") @@ -365,9 +387,7 @@ def get_nexrad_column( .interp(height=height_bins) ) else: - target_height = xr.DataArray( - height_bins, dims="height", name="height" - ) + target_height = xr.DataArray(height_bins, dims="height", name="height") da = da.reindex(height=target_height) # Add the latitude and longitude of the extracted column @@ -691,7 +711,6 @@ def _prepare_match( "Invalid resample method. Please choose 'mean', 'sum', or 'skip'." ) - matched = matched.assign_coords(coords=dict(station=site)) matched = matched.expand_dims("station") for attr in ("lat", "lon", "alt"): @@ -707,10 +726,11 @@ def _prepare_match( def _apply_match(column, site, matched): """Merge a prepared match result into ``column`` in-place and return it.""" + site_idx = _station_index(column, site) for k in matched.data_vars: if k in column.data_vars: - column[k].sel(station=site)[:] = matched.sel(station=site)[k][:].astype( - column[k].dtype + column[k].isel(station=site_idx)[:] = ( + matched[k].isel(station=0)[:].astype(column[k].dtype) ) if "_FillValue" in column[k].attrs: if isinstance(column[k].attrs["_FillValue"], str): @@ -806,7 +826,8 @@ def match_datasets_act( def _add_station_vars(ds, sites, site_alt): - ds["station"] = sites + ds["station"] = ("station", np.arange(len(sites), dtype=int)) + ds = ds.assign_coords(station_name=("station", np.asarray(sites))) # Assign the Main and Supplemental Site altitudes ds = ds.assign(alt=("station", site_alt)) # Add attributes for Time, Latitude, Longitude, and Sites @@ -815,6 +836,7 @@ def _add_station_vars(ds, sites, site_alt): ds.gate_time.attrs.update(output_config["gate_time_attrs"]) ds.time_offset.attrs.update(output_config["time_offset_attrs"]) ds.station.attrs.update(output_config["station_attrs"]) + ds.coords["station_name"].attrs.update(output_config["station_name_attrs"]) ds.lat.attrs.update(output_config["lat_attrs"]) ds.lon.attrs.update(output_config["lon_attrs"]) ds.alt.attrs.update(output_config["alt_attrs"]) diff --git a/src/radclss/vis/quicklooks.py b/src/radclss/vis/quicklooks.py index 3e7e8df..8a9a50b 100644 --- a/src/radclss/vis/quicklooks.py +++ b/src/radclss/vis/quicklooks.py @@ -9,6 +9,24 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable +def _resolve_station(ds, station): + """Resolve a station identifier (str name or int index) to its int index. + + Falls back to ``ds["station"]`` for legacy datasets that stored string + names directly on the ``station`` coord. + """ + if isinstance(station, (int, np.integer)): + return int(station) + if "station_name" in ds.variables: + names = ds["station_name"].values + else: + names = ds["station"].values + matches = np.where(names == station)[0] + if len(matches) == 0: + raise ValueError(f"Station '{station}' not found in dataset") + return int(matches[0]) + + def create_radclss_columns( radclss, field="corrected_reflectivity", @@ -95,14 +113,19 @@ def create_radclss_columns( col = i % 2 if len(axarr.shape) == 1: axarr = np.expand_dims(axarr, axis=0) - ds[field].sel(station=station).sel( + station_idx = _resolve_station(ds, station) + ds[field].isel(station=station_idx).sel( time=slice( radar_time.strftime("%Y-%m-%dT00:00:00"), final_time.strftime("%Y-%m-%dT00:00:00"), ) ).plot(y="height", ax=axarr[row, col], vmin=vmin, vmax=vmax, **kwargs) long_name = ds[field].attrs.get("long_name", field) - axarr[row, col].set_title(f"{station} {long_name}") + if "station_name" in ds.variables: + label = str(ds["station_name"].values[station_idx]) + else: + label = str(ds["station"].values[station_idx]) + axarr[row, col].set_title(f"{label} {long_name}") if isinstance(radclss, str): ds.close() @@ -193,6 +216,12 @@ def create_radclss_rainfall_timeseries( print("\n") return + dis_idx = _resolve_station(ds, dis_site) + if "station_name" in ds.variables: + dis_name = str(ds["station_name"].values[dis_idx]) + else: + dis_name = str(ds["station"].values[dis_idx]) + # Define the time of the radar file we are plotting against radar_time = datetime.datetime.strptime( np.datetime_as_string(ds["time"].data[0], unit="s"), "%Y-%m-%dT%H:%M:%S" @@ -205,12 +234,12 @@ def create_radclss_rainfall_timeseries( # Top right hand subplot - Radar TimeSeries ax2 = fig.add_subplot(311) - ds[field].sel(station=dis_site).plot( + ds[field].isel(station=dis_idx).plot( x="time", ax=ax2, cmap=cmap, vmin=vmin, vmax=vmax ) ax2.set_title( - "Extracted Radar Columns and In-Situ Sensors (RadCLss), BNF Site: " + dis_site + "Extracted Radar Columns and In-Situ Sensors (RadCLss), BNF Site: " + dis_name ) ax2.set_ylabel("Height [m]") ax2.set_xlabel("Time [UTC]") @@ -222,15 +251,15 @@ def create_radclss_rainfall_timeseries( ax3 = fig.add_subplot(312) # CMAC derived rain rate - ds["rain_rate_A"].sel(station=dis_site).sel(height=rheight, method="nearest").plot( + ds["rain_rate_A"].isel(station=dis_idx).sel(height=rheight, method="nearest").plot( x="time", ax=ax3, label="CMAC" ) # Pluvio2 Weighing Bucket Rain Gauge - ds["intensity_rtnrt"].sel(station=dis_site).plot(x="time", ax=ax3, label="PLUVIO2") + ds["intensity_rtnrt"].isel(station=dis_idx).plot(x="time", ax=ax3, label="PLUVIO2") # LDQUANTS derived rain rate - ds["ldquants_rain_rate"].sel(station=dis_site).plot( + ds["ldquants_rain_rate"].isel(station=dis_idx).plot( x="time", ax=ax3, label="LDQUANTS" ) @@ -258,22 +287,22 @@ def create_radclss_rainfall_timeseries( # CMAC Accumulated Rain Rates radar_accum = act.utils.accumulate_precip( - ds.sel(station=dis_site).sel(height=rheight, method="nearest"), "rain_rate_A" + ds.isel(station=dis_idx).sel(height=rheight, method="nearest"), "rain_rate_A" ).compute() # CMAC Accumulated Rain Rates radar_accum["rain_rate_A_accumulated"].plot(x="time", ax=ax4, label="CMAC") # PLUVIO2 Accumulation - if dis_site == "M1": + if dis_name == "M1": gauge_precip_accum = act.utils.accumulate_precip( - ds.sel(station=dis_site), "intensity_rtnrt" + ds.isel(station=dis_idx), "intensity_rtnrt" ).intensity_rtnrt_accumulated.compute() gauge_precip_accum.plot(x="time", ax=ax4, label="PLUVIO2") # LDQUANTS Accumulation - if dis_site == "M1" or dis_site == "S30": + if dis_name == "M1" or dis_name == "S30": ld_precip_accum = act.utils.accumulate_precip( - ds.sel(station=dis_site), "ldquants_rain_rate" + ds.isel(station=dis_idx), "ldquants_rain_rate" ).ldquants_rain_rate_accumulated.compute() ld_precip_accum.plot(x="time", ax=ax4, label="LDQUANTS") @@ -308,9 +337,9 @@ def create_radclss_rainfall_timeseries( # Clean up this function ax = np.array([ax2, ax3, ax4]) del radar_accum - if dis_site == "M1" or dis_site == "S30": + if dis_name == "M1" or dis_name == "S30": del ld_precip_accum - if dis_site == "M1": + if dis_name == "M1": del gauge_precip_accum if isinstance(radclss, str): ds.close() diff --git a/tests/test_radclss.py b/tests/test_radclss.py index 6de8b2c..df5f489 100644 --- a/tests/test_radclss.py +++ b/tests/test_radclss.py @@ -9,6 +9,12 @@ from distributed import Client, LocalCluster +def _by_name(ds, name): + """Select the slice of ``ds`` whose ``station_name`` equals ``name``.""" + idx = list(ds["station_name"].values).index(name) + return ds.isel(station=idx) + + def test_radclss_serial(): test_data_path = arm_test_data.DATASETS.abspath @@ -88,103 +94,68 @@ def test_radclss_serial(): assert my_columns.dims["time"] == 6 assert my_columns.dims["height"] == 32 assert my_columns.dims["station"] == 6 - assert np.array_equal( - my_columns["station"].values, ["M1", "S4", "S20", "S30", "S40", "S13"] - ) + assert np.array_equal(my_columns["station"].values, np.arange(6)) + assert list(my_columns["station_name"].values) == [ + "M1", + "S4", + "S20", + "S30", + "S40", + "S13", + ] # Radar and sonde data check - for site in my_columns["station"].values: - missing_value = ( - my_columns["csapr2_reflectivity"] - .sel(station=site) - .attrs.get("missing_value", None) - ) - assert not ( - my_columns["csapr2_reflectivity"].sel(station=site) == missing_value - ).all() + for name in my_columns["station_name"].values: + col = _by_name(my_columns, name) + missing_value = col["csapr2_reflectivity"].attrs.get("missing_value", None) + assert not (col["csapr2_reflectivity"] == missing_value).all() - for site in ["M1", "S4", "S20", "S30", "S40"]: + for name in ["M1", "S4", "S20", "S30", "S40"]: + col = _by_name(my_columns, name) # Sonde data - missing_value = ( - my_columns["sonde_u_wind"] - .sel(station=site) - .attrs.get("missing_value", None) - ) - assert not (my_columns["sonde_u_wind"].sel(station=site) == missing_value).all() - missing_value = ( - my_columns["sonde_v_wind"] - .sel(station=site) - .attrs.get("missing_value", None) - ) - assert not (my_columns["sonde_v_wind"].sel(station=site) == missing_value).all() - missing_value = ( - my_columns["sonde_temp"].sel(station=site).attrs.get("missing_value", None) - ) - assert not (my_columns["sonde_temp"].sel(station=site) == missing_value).all() - missing_value = ( - my_columns["sonde_rh"].sel(station=site).attrs.get("missing_value", None) - ) - assert not (my_columns["sonde_rh"].sel(station=site) == missing_value).all() - missing_value = ( - my_columns["sonde_bar_pres"] - .sel(station=site) - .attrs.get("missing_value", None) - ) - assert not ( - my_columns["sonde_bar_pres"].sel(station=site) == missing_value - ).all() + missing_value = col["sonde_u_wind"].attrs.get("missing_value", None) + assert not (col["sonde_u_wind"] == missing_value).all() + missing_value = col["sonde_v_wind"].attrs.get("missing_value", None) + assert not (col["sonde_v_wind"] == missing_value).all() + missing_value = col["sonde_temp"].attrs.get("missing_value", None) + assert not (col["sonde_temp"] == missing_value).all() + missing_value = col["sonde_rh"].attrs.get("missing_value", None) + assert not (col["sonde_rh"] == missing_value).all() + missing_value = col["sonde_bar_pres"].attrs.get("missing_value", None) + assert not (col["sonde_bar_pres"] == missing_value).all() # Met data check - for site in ["M1", "S20", "S30", "S40"]: - missing_value = ( - my_columns["temp_mean"].sel(station=site).attrs.get("_FillValue", None) - ) - assert not (my_columns["temp_mean"].sel(station=site) == missing_value).all() + for name in ["M1", "S20", "S30", "S40"]: + col = _by_name(my_columns, name) + missing_value = col["temp_mean"].attrs.get("_FillValue", None) + assert not (col["temp_mean"] == missing_value).all() - for site in ["S4"]: - missing_value = ( - my_columns["temp_mean"].sel(station=site).attrs.get("_FillValue", None) - ) - assert (my_columns["temp_mean"].sel(station=site) == missing_value).all() + for name in ["S4"]: + col = _by_name(my_columns, name) + missing_value = col["temp_mean"].attrs.get("_FillValue", None) + assert (col["temp_mean"] == missing_value).all() # Pluvio data check - missing_value = ( - my_columns["accum_nrt"].sel(station="M1").attrs.get("_FillValue", None) - ) - assert not (my_columns["accum_nrt"].sel(station="M1") == missing_value).all() - missing_value = ( - my_columns["bucket_nrt"].sel(station="M1").attrs.get("_FillValue", None) - ) - assert not (my_columns["bucket_nrt"].sel(station="M1") == missing_value).all() - - for site in ["S20", "S30", "S40", "S13", "S4"]: - missing_value = ( - my_columns["accum_nrt"].sel(station=site).attrs.get("_FillValue", None) - ) - assert (my_columns["accum_nrt"].sel(station=site) == missing_value).all() - missing_value = ( - my_columns["bucket_nrt"].sel(station=site).attrs.get("_FillValue", None) - ) - assert (my_columns["bucket_nrt"].sel(station=site) == missing_value).all() + m1 = _by_name(my_columns, "M1") + missing_value = m1["accum_nrt"].attrs.get("_FillValue", None) + assert not (m1["accum_nrt"] == missing_value).all() + missing_value = m1["bucket_nrt"].attrs.get("_FillValue", None) + assert not (m1["bucket_nrt"] == missing_value).all() + + for name in ["S20", "S30", "S40", "S13", "S4"]: + col = _by_name(my_columns, name) + missing_value = col["accum_nrt"].attrs.get("_FillValue", None) + assert (col["accum_nrt"] == missing_value).all() + missing_value = col["bucket_nrt"].attrs.get("_FillValue", None) + assert (col["bucket_nrt"] == missing_value).all() # LD data check - for site in ["M1", "S30"]: - missing_value = ( - my_columns["ldquants_rain_rate"] - .sel(station=site) - .attrs.get("missing_value", None) - ) - assert not ( - my_columns["ldquants_rain_rate"].sel(station=site) == missing_value - ).all() - missing_value = ( - my_columns["ldquants_med_diameter"] - .sel(station=site) - .attrs.get("missing_value", None) - ) - assert not ( - my_columns["ldquants_med_diameter"].sel(station=site) == missing_value - ).all() + for name in ["M1", "S30"]: + col = _by_name(my_columns, name) + missing_value = col["ldquants_rain_rate"].attrs.get("missing_value", None) + assert not (col["ldquants_rain_rate"] == missing_value).all() + missing_value = col["ldquants_med_diameter"].attrs.get("missing_value", None) + assert not (col["ldquants_med_diameter"] == missing_value).all() def test_radclss_parallel(): @@ -267,103 +238,68 @@ def test_radclss_parallel(): assert my_columns.dims["time"] == 6 assert my_columns.dims["height"] == 32 assert my_columns.dims["station"] == 6 - assert np.array_equal( - my_columns["station"].values, ["M1", "S4", "S20", "S30", "S40", "S13"] - ) + assert np.array_equal(my_columns["station"].values, np.arange(6)) + assert list(my_columns["station_name"].values) == [ + "M1", + "S4", + "S20", + "S30", + "S40", + "S13", + ] # Radar and sonde data check - for site in my_columns["station"].values: - missing_value = ( - my_columns["csapr2_reflectivity"] - .sel(station=site) - .attrs.get("missing_value", None) - ) - assert not ( - my_columns["csapr2_reflectivity"].sel(station=site) == missing_value - ).all() + for name in my_columns["station_name"].values: + col = _by_name(my_columns, name) + missing_value = col["csapr2_reflectivity"].attrs.get("missing_value", None) + assert not (col["csapr2_reflectivity"] == missing_value).all() - for site in ["M1", "S4", "S20", "S30", "S40"]: + for name in ["M1", "S4", "S20", "S30", "S40"]: + col = _by_name(my_columns, name) # Sonde data - missing_value = ( - my_columns["sonde_u_wind"] - .sel(station=site) - .attrs.get("missing_value", None) - ) - assert not (my_columns["sonde_u_wind"].sel(station=site) == missing_value).all() - missing_value = ( - my_columns["sonde_v_wind"] - .sel(station=site) - .attrs.get("missing_value", None) - ) - assert not (my_columns["sonde_v_wind"].sel(station=site) == missing_value).all() - missing_value = ( - my_columns["sonde_temp"].sel(station=site).attrs.get("missing_value", None) - ) - assert not (my_columns["sonde_temp"].sel(station=site) == missing_value).all() - missing_value = ( - my_columns["sonde_rh"].sel(station=site).attrs.get("missing_value", None) - ) - assert not (my_columns["sonde_rh"].sel(station=site) == missing_value).all() - missing_value = ( - my_columns["sonde_bar_pres"] - .sel(station=site) - .attrs.get("missing_value", None) - ) - assert not ( - my_columns["sonde_bar_pres"].sel(station=site) == missing_value - ).all() + missing_value = col["sonde_u_wind"].attrs.get("missing_value", None) + assert not (col["sonde_u_wind"] == missing_value).all() + missing_value = col["sonde_v_wind"].attrs.get("missing_value", None) + assert not (col["sonde_v_wind"] == missing_value).all() + missing_value = col["sonde_temp"].attrs.get("missing_value", None) + assert not (col["sonde_temp"] == missing_value).all() + missing_value = col["sonde_rh"].attrs.get("missing_value", None) + assert not (col["sonde_rh"] == missing_value).all() + missing_value = col["sonde_bar_pres"].attrs.get("missing_value", None) + assert not (col["sonde_bar_pres"] == missing_value).all() # Met data check - for site in ["M1", "S20", "S30", "S40"]: - missing_value = ( - my_columns["temp_mean"].sel(station=site).attrs.get("_FillValue", None) - ) - assert not (my_columns["temp_mean"].sel(station=site) == missing_value).all() + for name in ["M1", "S20", "S30", "S40"]: + col = _by_name(my_columns, name) + missing_value = col["temp_mean"].attrs.get("_FillValue", None) + assert not (col["temp_mean"] == missing_value).all() - for site in ["S4"]: - missing_value = ( - my_columns["temp_mean"].sel(station=site).attrs.get("_FillValue", None) - ) - assert (my_columns["temp_mean"].sel(station=site) == missing_value).all() + for name in ["S4"]: + col = _by_name(my_columns, name) + missing_value = col["temp_mean"].attrs.get("_FillValue", None) + assert (col["temp_mean"] == missing_value).all() # Pluvio data check - missing_value = ( - my_columns["accum_nrt"].sel(station="M1").attrs.get("_FillValue", None) - ) - assert not (my_columns["accum_nrt"].sel(station="M1") == missing_value).all() - missing_value = ( - my_columns["bucket_nrt"].sel(station="M1").attrs.get("_FillValue", None) - ) - assert not (my_columns["bucket_nrt"].sel(station="M1") == missing_value).all() - - for site in ["S20", "S30", "S40", "S13", "S4"]: - missing_value = ( - my_columns["accum_nrt"].sel(station=site).attrs.get("_FillValue", None) - ) - assert (my_columns["accum_nrt"].sel(station=site) == missing_value).all() - missing_value = ( - my_columns["bucket_nrt"].sel(station=site).attrs.get("_FillValue", None) - ) - assert (my_columns["bucket_nrt"].sel(station=site) == missing_value).all() + m1 = _by_name(my_columns, "M1") + missing_value = m1["accum_nrt"].attrs.get("_FillValue", None) + assert not (m1["accum_nrt"] == missing_value).all() + missing_value = m1["bucket_nrt"].attrs.get("_FillValue", None) + assert not (m1["bucket_nrt"] == missing_value).all() + + for name in ["S20", "S30", "S40", "S13", "S4"]: + col = _by_name(my_columns, name) + missing_value = col["accum_nrt"].attrs.get("_FillValue", None) + assert (col["accum_nrt"] == missing_value).all() + missing_value = col["bucket_nrt"].attrs.get("_FillValue", None) + assert (col["bucket_nrt"] == missing_value).all() # LD data check - for site in ["M1", "S30"]: - missing_value = ( - my_columns["ldquants_rain_rate"] - .sel(station=site) - .attrs.get("missing_value", None) - ) - assert not ( - my_columns["ldquants_rain_rate"].sel(station=site) == missing_value - ).all() - missing_value = ( - my_columns["ldquants_med_diameter"] - .sel(station=site) - .attrs.get("missing_value", None) - ) - assert not ( - my_columns["ldquants_med_diameter"].sel(station=site) == missing_value - ).all() + for name in ["M1", "S30"]: + col = _by_name(my_columns, name) + missing_value = col["ldquants_rain_rate"].attrs.get("missing_value", None) + assert not (col["ldquants_rain_rate"] == missing_value).all() + missing_value = col["ldquants_med_diameter"].attrs.get("missing_value", None) + assert not (col["ldquants_med_diameter"] == missing_value).all() def test_subset_points(): @@ -400,7 +336,8 @@ def test_subset_points(): "S30": (34.38501, -86.92757, 183), } subset_ds = radclss.util.subset_points(radar_files[0], input_site_dict, sonde=None) - assert set(subset_ds["station"].values) == {"M1", "S30"} + assert set(subset_ds["station_name"].values) == {"M1", "S30"} + assert np.array_equal(subset_ds["station"].values, np.arange(2)) assert "reflectivity" in subset_ds.data_vars assert subset_ds.dims["station"] == 2 assert np.array_equal(subset_ds["height"].values, np.arange(500, 8500, 250)) @@ -415,7 +352,7 @@ def test_subset_points(): subset_ds = radclss.util.subset_points( radar_files[0], input_site_dict, sonde=sonde_files ) - assert set(subset_ds["station"].values) == {"M1", "S30"} + assert set(subset_ds["station_name"].values) == {"M1", "S30"} assert np.array_equal(subset_ds["height"].values, np.arange(500, 8500, 250)) assert "reflectivity" in subset_ds.data_vars assert "sonde_u_wind" in subset_ds.data_vars @@ -532,9 +469,15 @@ def test_radclss_with_kasacr(): # Basic structure checks assert isinstance(my_columns, xr.Dataset) assert my_columns.dims["station"] == 6 - assert np.array_equal( - my_columns["station"].values, ["M1", "S4", "S20", "S30", "S40", "S13"] - ) + assert np.array_equal(my_columns["station"].values, np.arange(6)) + assert list(my_columns["station_name"].values) == [ + "M1", + "S4", + "S20", + "S30", + "S40", + "S13", + ] # Check that CSAPR2 data exists assert ( @@ -552,7 +495,7 @@ def test_radclss_with_kasacr(): assert len(kasacr_vars) > 0, "Expected KASACR variables in dataset" # Check that KASACR reflectivity data exists - for site in my_columns["station"].values: + for name in my_columns["station_name"].values: # Find reflectivity variable for KASACR kasacr_refl_vars = [ var @@ -561,15 +504,12 @@ def test_radclss_with_kasacr(): ] if len(kasacr_refl_vars) > 0: kasacr_refl = kasacr_refl_vars[0] - missing_value = ( - my_columns[kasacr_refl] - .sel(station=site) - .attrs.get("missing_value", None) - ) + col = _by_name(my_columns, name) + missing_value = col[kasacr_refl].attrs.get("missing_value", None) # At least some data should be non-missing assert not ( - my_columns[kasacr_refl].sel(station=site) == missing_value - ).all(), f"All KASACR data is missing for station {site}" + col[kasacr_refl] == missing_value + ).all(), f"All KASACR data is missing for station {name}" def test_radclss_with_kazr(): @@ -677,9 +617,15 @@ def test_radclss_with_kazr(): # Basic structure checks assert isinstance(my_columns, xr.Dataset) assert my_columns.dims["station"] == 6 - assert np.array_equal( - my_columns["station"].values, ["M1", "S4", "S20", "S30", "S40", "S13"] - ) + assert np.array_equal(my_columns["station"].values, np.arange(6)) + assert list(my_columns["station_name"].values) == [ + "M1", + "S4", + "S20", + "S30", + "S40", + "S13", + ] # Check that CSAPR2 data exists assert ( @@ -697,7 +643,7 @@ def test_radclss_with_kazr(): assert len(kazr_vars) > 0, "Expected KAZR variables in dataset" # Check that KAZR reflectivity data exists - for site in my_columns["station"].values: + for name in my_columns["station_name"].values: # Find reflectivity variable for KAZR kazr_refl_vars = [ var @@ -706,15 +652,12 @@ def test_radclss_with_kazr(): ] if len(kazr_refl_vars) > 0: kazr_refl = kazr_refl_vars[0] - missing_value = ( - my_columns[kazr_refl] - .sel(station=site) - .attrs.get("missing_value", None) - ) + col = _by_name(my_columns, name) + missing_value = col[kazr_refl].attrs.get("missing_value", None) # At least some data should be non-missing assert not ( - my_columns[kazr_refl].sel(station=site) == missing_value - ).all(), f"All KAZR data is missing for station {site}" + col[kazr_refl] == missing_value + ).all(), f"All KAZR data is missing for station {name}" def test_radclss_parallel_with_nexrad(): @@ -773,9 +716,15 @@ def test_radclss_parallel_with_nexrad(): # Basic structure checks assert isinstance(my_columns, xr.Dataset) assert my_columns.dims["station"] == 6 - assert np.array_equal( - my_columns["station"].values, ["M1", "S4", "S20", "S30", "S40", "S13"] - ) + assert np.array_equal(my_columns["station"].values, np.arange(6)) + assert list(my_columns["station_name"].values) == [ + "M1", + "S4", + "S20", + "S30", + "S40", + "S13", + ] # Check that CSAPR2 data exists csapr2_vars = [var for var in my_columns.data_vars if "csapr2" in var.lower()] @@ -790,7 +739,7 @@ def test_radclss_parallel_with_nexrad(): assert len(nexrad_vars) > 0, "Expected NEXRAD variables in dataset" # Check that both radar systems have reflectivity data - for site in my_columns["station"].values: + for name in my_columns["station_name"].values: # Find CSAPR2 reflectivity csapr2_refl_vars = [ var @@ -799,15 +748,12 @@ def test_radclss_parallel_with_nexrad(): ] if len(csapr2_refl_vars) > 0: csapr2_refl = csapr2_refl_vars[0] - missing_value = ( - my_columns[csapr2_refl] - .sel(station=site) - .attrs.get("missing_value", None) - ) + col = _by_name(my_columns, name) + missing_value = col[csapr2_refl].attrs.get("missing_value", None) # At least some data should be non-missing assert not ( - my_columns[csapr2_refl].sel(station=site) == missing_value - ).all(), f"All CSAPR2 data is missing for station {site}" + col[csapr2_refl] == missing_value + ).all(), f"All CSAPR2 data is missing for station {name}" def test_radclss_parallel_with_kasacr(): @@ -917,9 +863,15 @@ def test_radclss_parallel_with_kasacr(): # Basic structure checks assert isinstance(my_columns, xr.Dataset) assert my_columns.dims["station"] == 6 - assert np.array_equal( - my_columns["station"].values, ["M1", "S4", "S20", "S30", "S40", "S13"] - ) + assert np.array_equal(my_columns["station"].values, np.arange(6)) + assert list(my_columns["station_name"].values) == [ + "M1", + "S4", + "S20", + "S30", + "S40", + "S13", + ] # Check that CSAPR2 data exists assert ( @@ -934,7 +886,7 @@ def test_radclss_parallel_with_kasacr(): assert len(kasacr_vars) > 0, "Expected KASACR variables in dataset" # Check that KASACR reflectivity data exists - for site in my_columns["station"].values: + for name in my_columns["station_name"].values: # Find reflectivity variable for KASACR kasacr_refl_vars = [ var @@ -943,15 +895,12 @@ def test_radclss_parallel_with_kasacr(): ] if len(kasacr_refl_vars) > 0: kasacr_refl = kasacr_refl_vars[0] - missing_value = ( - my_columns[kasacr_refl] - .sel(station=site) - .attrs.get("missing_value", None) - ) + col = _by_name(my_columns, name) + missing_value = col[kasacr_refl].attrs.get("missing_value", None) # At least some data should be non-missing assert not ( - my_columns[kasacr_refl].sel(station=site) == missing_value - ).all(), f"All KASACR data is missing for station {site}" + col[kasacr_refl] == missing_value + ).all(), f"All KASACR data is missing for station {name}" def test_match_datasets_act():