fix hra test; move function from datastack to utils to avoid circular import

This commit is contained in:
Emily Soth 2025-02-06 17:44:35 -08:00
parent 1142d32bd7
commit b1d0043eb4
5 changed files with 54 additions and 54 deletions

View File

@ -96,53 +96,6 @@ def _tarfile_safe_extract(archive_path, dest_dir_path):
safe_extract(tar, dest_dir_path)
def _copy_spatial_files(spatial_filepath, target_dir):
"""Copy spatial files to a new directory.
Args:
spatial_filepath (str): The filepath to a GDAL-supported file.
target_dir (str): The directory where all component files of
``spatial_filepath`` should be copied. If this directory does not
exist, it will be created.
Returns:
filepath (str): The path to a representative file copied into the
``target_dir``. If possible, this will match the basename of
``spatial_filepath``, so if someone provides an ESRI Shapefile called
``my_vector.shp``, the return value will be ``os.path.join(target_dir,
my_vector.shp)``.
"""
LOGGER.info(f'Copying {spatial_filepath} --> {target_dir}')
if not os.path.exists(target_dir):
os.makedirs(target_dir)
source_basename = os.path.basename(spatial_filepath)
return_filepath = None
spatial_file = gdal.OpenEx(spatial_filepath)
for member_file in spatial_file.GetFileList():
# ArcGIS Binary/Grid format includes the directory in the file listing.
# The parent directory isn't strictly needed, so we can just skip it.
if os.path.isdir(member_file):
continue
target_basename = os.path.basename(member_file)
target_filepath = os.path.join(target_dir, target_basename)
if source_basename == target_basename:
return_filepath = target_filepath
shutil.copyfile(member_file, target_filepath)
spatial_file = None
# I can't conceive of a case where the basename of the source file does not
# match any of the member file basenames, but just in case there's a
# weird GDAL driver that does this, it seems reasonable to fall back to
# whichever of the member files was most recent.
if not return_filepath:
return_filepath = target_filepath
return return_filepath
def format_args_dict(args_dict, model_id):
"""Nicely format an arguments dictionary for writing to a stream.
@ -375,7 +328,7 @@ def build_datastack_archive(args, model_id, datastack_path):
target_dir = os.path.join(
contained_files_dir,
f'{row_index}_{basename}')
target_filepath = _copy_spatial_files(
target_filepath = utils.copy_spatial_files(
source_filepath, target_dir)
target_filepath = os.path.relpath(
target_filepath, data_dir)
@ -427,7 +380,7 @@ def build_datastack_archive(args, model_id, datastack_path):
# Create a directory with a readable name, something like
# "aoi_path_vector" or "lulc_cur_path_raster".
spatial_dir = os.path.join(data_dir, f'{key}_{input_type}')
target_arg_value = _copy_spatial_files(
target_arg_value = utils.copy_spatial_files(
source_path, spatial_dir)
files_found[source_path] = target_arg_value

View File

@ -16,13 +16,13 @@ from osgeo import gdal
from osgeo import ogr
from osgeo import osr
from . import datastack
from . import gettext
from . import spec_utils
from . import utils
from . import validation
from .unit_registry import u
LOGGER = logging.getLogger(__name__)
# RESILIENCE stressor shorthand to use when parsing tables
@ -2445,7 +2445,7 @@ def _override_datastack_archive_criteria_table_path(
os.path.splitext(os.path.basename(value))[0])
LOGGER.info(f"Copying spatial file {value} --> "
f"{dir_for_this_spatial_data}")
new_path = datastack._copy_spatial_files(
new_path = utils.copy_spatial_files(
value, dir_for_this_spatial_data)
criteria_table_array[row, col] = new_path
known_files[value] = new_path

View File

@ -793,3 +793,50 @@ def matches_format_string(test_string, format_string):
if re.fullmatch(pattern, test_string):
return True
return False
def copy_spatial_files(spatial_filepath, target_dir):
"""Copy spatial files to a new directory.
Args:
spatial_filepath (str): The filepath to a GDAL-supported file.
target_dir (str): The directory where all component files of
``spatial_filepath`` should be copied. If this directory does not
exist, it will be created.
Returns:
filepath (str): The path to a representative file copied into the
``target_dir``. If possible, this will match the basename of
``spatial_filepath``, so if someone provides an ESRI Shapefile called
``my_vector.shp``, the return value will be ``os.path.join(target_dir,
my_vector.shp)``.
"""
LOGGER.info(f'Copying {spatial_filepath} --> {target_dir}')
if not os.path.exists(target_dir):
os.makedirs(target_dir)
source_basename = os.path.basename(spatial_filepath)
return_filepath = None
spatial_file = gdal.OpenEx(spatial_filepath)
for member_file in spatial_file.GetFileList():
# ArcGIS Binary/Grid format includes the directory in the file listing.
# The parent directory isn't strictly needed, so we can just skip it.
if os.path.isdir(member_file):
continue
target_basename = os.path.basename(member_file)
target_filepath = os.path.join(target_dir, target_basename)
if source_basename == target_basename:
return_filepath = target_filepath
shutil.copyfile(member_file, target_filepath)
spatial_file = None
# I can't conceive of a case where the basename of the source file does not
# match any of the member file basenames, but just in case there's a
# weird GDAL driver that does this, it seems reasonable to fall back to
# whichever of the member files was most recent.
if not return_filepath:
return_filepath = target_filepath
return return_filepath

View File

@ -1161,8 +1161,8 @@ class HRAModelTests(unittest.TestCase):
def test_model(self):
"""HRA: end-to-end test of the model, including datastack."""
from natcap.invest import datastack
from natcap.invest import hra
from natcap.invest import datastack
args = {
'workspace_dir': os.path.join(self.workspace_dir, 'workspace'),
@ -1263,7 +1263,7 @@ class HRAModelTests(unittest.TestCase):
archive_path = os.path.join(self.workspace_dir, 'datstack.tar.gz')
datastack.build_datastack_archive(
args, 'natcap.invest.hra', archive_path)
args, 'habitat_risk_assessment', archive_path)
unarchived_path = os.path.join(self.workspace_dir, 'unarchived_data')
unarchived_args = datastack.extract_datastack_archive(

View File

@ -83,7 +83,7 @@ class EndpointFunctionTests(unittest.TestCase):
response_data = json.loads(response.get_data(as_text=True))
self.assertEqual(
set(response_data),
{'type', 'args', 'model_id', 'model_title', 'invest_version'})
{'type', 'args', 'model_id', 'invest_version'})
def test_write_parameter_set_file(self):
"""UI server: write_parameter_set_file endpoint."""