diff --git a/speasy/webservices/amda/_impl.py b/speasy/webservices/amda/_impl.py index 69173093..700c020f 100644 --- a/speasy/webservices/amda/_impl.py +++ b/speasy/webservices/amda/_impl.py @@ -123,10 +123,18 @@ def dl_parameter_chunk(self, start_time: datetime, stop_time: datetime, paramete return None def dl_parameter(self, start_time: datetime, stop_time: datetime, parameter_id: str, - extra_http_headers: Dict or None = None, output_format: str = 'ASCII', **kwargs) -> Optional[ + extra_http_headers: Dict or None = None, output_format: str = 'ASCII', restricted_period=False, + **kwargs) -> Optional[ SpeasyVariable]: dt = timedelta(days=amda_cfg.max_chunk_size_days()) - + if restricted_period: + try: + username, password = _get_credentials() + kwargs['userID'] = username + kwargs['password'] = password + except MissingCredentials: + raise MissingCredentials( + "Restricted period requested but no credentials provided, please add your AMDA credentials.") if stop_time - start_time > dt: var = None curr_t = start_time diff --git a/speasy/webservices/amda/ws.py b/speasy/webservices/amda/ws.py index 4b082b30..eac5d786 100644 --- a/speasy/webservices/amda/ws.py +++ b/speasy/webservices/amda/ws.py @@ -139,6 +139,13 @@ def is_user_timetable(self, timetable_id: str or TimetableIndex): def is_user_parameter(self, parameter_id: str or ParameterIndex): return _is_user_prod(parameter_id, self.flat_inventory.parameters) + def has_time_restriction(self, product_id: str or SpeasyIndex, start_time: str or datetime, + stop_time: str or datetime): + dataset = self._find_parent_dataset(product_id) + if dataset and hasattr(dataset, 'timeRestriction'): + return dataset.timeRestriction.intersect(DateTimeRange(start_time, stop_time)) + return False + def product_version(self, parameter_id: str or ParameterIndex): """Get date of last modification of dataset or parameter. @@ -351,13 +358,28 @@ def get_user_catalog(self, catalog_id: str or CatalogIndex) -> Optional[Catalog] catalog_id = to_xmlid(catalog_id) return self._impl.dl_user_catalog(catalog_id=catalog_id) - @AllowedKwargs(PROXY_ALLOWED_KWARGS + CACHE_ALLOWED_KWARGS + GET_DATA_ALLOWED_KWARGS + ['output_format']) - @ParameterRangeCheck() - @Cacheable(prefix="amda", version=product_version, fragment_hours=lambda x: 12, entry_name=_amda_cache_entry_name) - @Proxyfiable(GetProduct, get_parameter_args) def get_parameter(self, product, start_time, stop_time, extra_http_headers: Dict or None = None, output_format: str or None = None, **kwargs) -> Optional[ SpeasyVariable]: + if self.has_time_restriction(product, start_time, stop_time): + kwargs['disable_proxy'] = True + kwargs['restricted_period'] = True + return self._get_parameter(product, start_time, stop_time, extra_http_headers=extra_http_headers, + output_format=output_format or amda_cfg.output_format(), **kwargs) + else: + return self._get_parameter(product, start_time, stop_time, extra_http_headers=extra_http_headers, + output_format=output_format or amda_cfg.output_format(), **kwargs) + + @AllowedKwargs( + PROXY_ALLOWED_KWARGS + CACHE_ALLOWED_KWARGS + GET_DATA_ALLOWED_KWARGS + ['output_format', 'restricted_period']) + @ParameterRangeCheck() + @Cacheable(prefix="amda", version=product_version, fragment_hours=lambda x: 12, entry_name=_amda_cache_entry_name) + @Proxyfiable(GetProduct, get_parameter_args) + def _get_parameter(self, product, start_time, stop_time, + extra_http_headers: Dict or None = None, output_format: str or None = None, + restricted_period=False, **kwargs) -> \ + Optional[ + SpeasyVariable]: """Get parameter data. Parameters @@ -393,7 +415,8 @@ def get_parameter(self, product, start_time, stop_time, log.debug(f'Get data: product = {product}, data start time = {start_time}, data stop time = {stop_time}') return self._impl.dl_parameter(start_time=start_time, stop_time=stop_time, parameter_id=product, extra_http_headers=extra_http_headers, - output_format=output_format or amda_cfg.output_format()) + output_format=output_format, + restricted_period=restricted_period) def get_dataset(self, dataset_id: str or DatasetIndex, start: str or datetime, stop: str or datetime, **kwargs) -> Dataset or None: @@ -672,6 +695,13 @@ def _find_parent_dataset(self, product_id: str or DatasetIndex or ParameterIndex if product_id in dataset: return to_xmlid(dataset) + def _time_restriction_range(self, product_id: str or DatasetIndex or ParameterIndex or ComponentIndex) -> Optional[ + DateTimeRange]: + dataset = self._find_parent_dataset(product_id) + if dataset and hasattr(dataset, 'timeRestriction'): + return DateTimeRange(dataset.timeRestriction, dataset.stop_date) + return None + def product_type(self, product_id: str or SpeasyIndex) -> ProductType: """Returns product type for any known ADMA product from its index or ID. diff --git a/tests/test_amda_parameter.py b/tests/test_amda_parameter.py index 773dd578..af63af1e 100644 --- a/tests/test_amda_parameter.py +++ b/tests/test_amda_parameter.py @@ -5,7 +5,7 @@ import unittest from ddt import data, ddt -from datetime import datetime +from datetime import datetime, timedelta import numpy as np @@ -63,6 +63,21 @@ def test_dataset_items_datatype(self): for item in self.dataset: self.assertTrue(isinstance(self.dataset[item], spz.SpeasyVariable)) + def test_restricted_time_range(self): + from speasy.webservices.amda._impl import credential_are_valid + if credential_are_valid(): + self.skipTest("Should only run when credentials are not valid") + dataset = None + for dataset in spz.inventories.flat_inventories.amda.datasets.values(): + if hasattr(dataset, 'timeRestriction'): + break + if dataset is not None: + from speasy.webservices.amda.exceptions import MissingCredentials + from speasy.core import make_utc_datetime + with self.assertRaises(MissingCredentials): + spz.amda.get_dataset(dataset, make_utc_datetime(dataset.timeRestriction), + make_utc_datetime(dataset.timeRestriction) + timedelta(minutes=1)) + @ddt class AMDAParametersPlots(unittest.TestCase):