Skip to content

Commit

Permalink
Update raster_tile.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dphogan authored Aug 21, 2020
1 parent 6446beb commit b896216
Showing 1 changed file with 0 additions and 43 deletions.
43 changes: 0 additions & 43 deletions solaris/tile/raster_tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,49 +458,6 @@ def save_tile(self, tile_data, mask, profile, dest_fname_base=None):
# os.path.join(self.dest_dir, dest_fname))
# os.remove(os.path.join(self.dest_dir, 'tmp.tif'))

def fill_all_nodata(self, nodata_fill):
"""
Fills all tile nodata values with a fill value.
The standard workflow is to run this function only after generating label masks and using the original output
from the raster tiler to filter out label pixels that overlap nodata pixels in a tile. For example,
solaris.vector.mask.instance_mask will filter out nodata pixels from a label mask if a reference_im is provided,
and after this step nodata pixels may be filled by calling this method.
nodata_fill : int, float, or str, optional
Default is to not fill any nodata values. Otherwise, pixels outside of the aoi_boundary and pixels inside
the aoi_boundary with the nodata value will be filled. "mean" will fill pixels with the channel-wise mean.
Providing an int or float will fill pixels in all channels with the provided value.
Returns: list
The fill values, in case the mean of the src image should be used for normalization later.
"""
src = _check_rasterio_im_load(self.src_name)
if nodata_fill == "mean":
arr = src.read()
arr_nan = np.where(arr!=src.nodata, arr, np.nan)
fill_values = np.nanmean(arr_nan, axis=tuple(range(1, arr_nan.ndim)))
print('Fill values set to {}'.format(fill_values))
elif isinstance(nodata_fill, (float, int)):
fill_values = src.meta['count'] * [nodata_fill]
print('Fill values set to {}'.format(fill_values))
else:
raise TypeError('nodata_fill must be "mean", int, or float. {} was supplied.'.format(nodata_fill))
src.close()
for tile_path in self.tile_paths:
tile_src = rasterio.open(tile_path, "r+")
tile_data = tile_src.read()
for i in np.arange(tile_data.shape[0]):
tile_data[i,...][tile_data[i,...] == tile_src.nodata] = fill_values[i] # set fill value for each band
if tile_src.meta['count'] == 1:
tile_src.write(tile_data[0, :, :], 1)
else:
for band in range(1, tile_src.meta['count'] + 1):
# base-1 vs. base-0 indexing...bleh
tile_src.write(tile_data[band-1, :, :], band)
tile_src.close()
return fill_values

def fill_all_nodata(self, nodata_fill):
"""
Fills all tile nodata values with a fill value.
Expand Down

0 comments on commit b896216

Please sign in to comment.