遥感影像三波段获取其中一个波段代码
from osgeo import gdal
from osgeo import ogr, osr
import os, sys,shutil
import numpy as np
import datetime
class Three2one:
def __init__(self):
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
gdal.SetConfigOption("SHAPE_ENCODING", "CP936")
ogr.RegisterAll()
gdal.AllRegister()
return
def read_raster(self, RasterPath):
''' :param RasterPath: input the raster :return: data--- the values of the raster SpacialRef -- include the geotransform, projection and the NoDate of the raster '''
dataset = gdal.Open(RasterPath, gdal.GA_ReadOnly)
if not dataset:
print('打开文件失败')
XSize = dataset.RasterXSize
YSize = dataset.RasterYSize
band_num = dataset.RasterCount
datatype = dataset.GetRasterBand(1).DataType
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
if band_num == 1:
band = dataset.GetRasterBand(1)
data_type = band.DataType
data = band.ReadAsArray(0, 0, XSize, YSize)
NoDate = band.GetNoDataValue()
else:
data = []
for i in range(band_num):
band = dataset.GetRasterBand(i + 1)
dt = band.ReadAsArray(0, 0, XSize, YSize)
data.append(list(dt))
if i == 0:
NoDate = band.GetNoDataValue()
data = np.array(data)
SpacialRef = [geotransform, projection, NoDate]
dataset = None
return data, SpacialRef
def Creat_raster(self, RasterCreatPath, ArrayDate, SpacialRef, ctmap=None, DriverName="GTiff"):
if 'int8' in ArrayDate.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in ArrayDate.dtype.name:
datatype = gdal.GDT_UInt16
elif 'int' in ArrayDate.dtype.name:
datatype = gdal.GDT_Int32
else:
datatype = gdal.GDT_Float32
if len(ArrayDate.shape) == 3:
Bandnum, YSize, XSize = ArrayDate.shape
else:
Bandnum = 1
[YSize, XSize] = ArrayDate.shape
geotransform = SpacialRef[0]
projection = SpacialRef[1]
NoData_value = SpacialRef[2]
if NoData_value is None:
NoData_value = float('nan')
driver = gdal.GetDriverByName(DriverName)
dataset = driver.Create(RasterCreatPath, XSize, YSize, Bandnum, datatype)
if ctmap is not None:
color = gdal.ColorTable()
for i in range(len(ctmap)):
color.SetColorEntry(ctmap[i][0], (ctmap[i][1], ctmap[i][2], ctmap[i][3], ctmap[i][4]))
if DriverName == "GTiff":
if geotransform is not None:
dataset.SetGeoTransform(geotransform)
if projection is not None:
dataset.SetProjection(projection)
if len(ArrayDate.shape) == 3:
for i in range(Bandnum):
dataset.GetRasterBand(i + 1).WriteArray(ArrayDate[i])
dataset.GetRasterBand(i + 1).SetNoDataValue(NoData_value)
dataset.GetRasterBand(i + 1).ComputeStatistics(True)
print('波段', i, '写入完成')
dataset = None
else:
band = dataset.GetRasterBand(1)
band.WriteArray(ArrayDate)
band.SetNoDataValue(NoData_value)
if DriverName == "GTiff":
band.ComputeStatistics(True)
if ctmap is not None:
band.SetRasterColorTable(color)
dataset = None
return
def run(self, InputRaster, OutputRaster):
data, SpacialRef = self.read_raster(InputRaster)
dd_out = data[0]
self.Creat_raster(OutputRaster, dd_out.astype('uint8'), SpacialRef)
print('successful')
return
def move_index(self, height, width, block, overlap_rate):
slide_window_size = block
whsize = height if height <= width else width
if slide_window_size > whsize:
slide_window_size = whsize
overlap_pixel = int(slide_window_size * (1 - overlap_rate))
if height - slide_window_size < 0:
y_idx = [0]
nYBK = [height]
else:
y_idx = [x for x in range(0, height - slide_window_size + 1, overlap_pixel)]
nYBK = [slide_window_size for x in range(0, height - slide_window_size + 1, overlap_pixel)]
if y_idx[-1] + slide_window_size > height:
if overlap_rate == 0.0:
y_idx[-1] = y_idx[-2] + slide_window_size
nYBK[-1] = height - y_idx[-2]
else:
y_idx[-1] = height - slide_window_size
else:
if overlap_rate == 0.0:
y_idx.append(y_idx[-1] + slide_window_size)
nYBK.append(height - y_idx[-1])
else:
y_idx.append(height - slide_window_size)
nYBK.append(slide_window_size)
if width - slide_window_size < 0:
x_idx = [0]
nXBK = [width]
else:
x_idx = [y for y in range(0, width - slide_window_size + 1, overlap_pixel)]
nXBK = [slide_window_size for y in range(0, width - slide_window_size + 1, overlap_pixel)]
if x_idx[-1] + slide_window_size > width:
if overlap_rate == 0.0:
x_idx[-1] = x_idx[-2] + slide_window_size
nXBK[-1] = width - x_idx[-2]
else:
x_idx[-1] = width - slide_window_size
else:
if overlap_rate == 0.0:
x_idx.append(x_idx[-1] + slide_window_size)
nXBK.append(width - x_idx[-1])
else:
x_idx.append(width - slide_window_size)
nXBK.append(slide_window_size)
return x_idx, y_idx,nXBK, nYBK
def run_block(self, InputRaster, OutputRaster):
dataset = gdal.Open(InputRaster, gdal.GA_ReadOnly)
if not dataset:
print('打开文件失败')
XSize = dataset.RasterXSize
YSize = dataset.RasterYSize
band_num = dataset.RasterCount
datatype = dataset.GetRasterBand(1).DataType
nodata = dataset.GetRasterBand(1).GetNoDataValue()
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
driver = gdal.GetDriverByName("GTiff")
dataset_new = driver.Create(OutputRaster, XSize, YSize, 1, datatype)
if geotransform is not None:
dataset_new.SetGeoTransform(geotransform)
if projection is not None:
dataset_new.SetProjection(projection)
block = 1024
overlap_rate = 0.0
x_idx, y_idx, nXBK, nYBK = self.move_index(YSize, XSize, block, overlap_rate)
band = dataset.GetRasterBand(1)
band_out = dataset_new.GetRasterBand(1)
for iy, y_start in enumerate(y_idx):
for ix, x_start in enumerate(x_idx):
data = band.ReadAsArray(x_start, y_start, nXBK[ix], nYBK[iy])
band_out.WriteArray(data, x_start, y_start)
print('第%d行处理完成。'%y_start)
if nodata is not None:
band_out.SetNoDataValue(nodata)
dataset_new = None
dataset = None
return
if __name__ == '__main__':
InputRaster = r'./变化pred.tif'
OutputRaster = r'./变化pred_single.tif'
Three2one().run_block(InputRaster,OutputRaster)
文章评论