-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 8ae4df3
Showing
23 changed files
with
2,538 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.ipynb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import xarray as xr | ||
import numpy as np | ||
import os | ||
import pandas as pd | ||
|
||
|
||
def load_data(datatype): | ||
# 使用os获取文件夹下的所有文件名称并使用xarray进行读取并合并为一个DataArray | ||
dir_path = os.path.join(os.path.dirname(os.getcwd()), 'data', datatype) | ||
filenames = os.listdir(dir_path) | ||
chlo_mean = [] | ||
|
||
for filename in filenames: | ||
with xr.open_dataset(os.path.join(dir_path, filename)) as daily: | ||
print(filename) | ||
daily = daily.expand_dims("time") | ||
daily.coords["time"] = pd.to_datetime([filename[4:12]]) | ||
chlo_mean.append(daily) | ||
|
||
merged = xr.concat(chlo_mean, dim="time") | ||
return merged | ||
|
||
|
||
def merge(datatype): | ||
merged = load_data(datatype) | ||
merged = merged.isel(lat=slice(1, None), lon=slice(1, None)) | ||
dir_path = os.path.join(os.path.dirname(os.getcwd()), 'data') | ||
filename = datatype + ".nc" | ||
merged.to_netcdf(os.path.join(dir_path, filename)) | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import os | ||
import random | ||
import shutil | ||
from shutil import copy2 | ||
from shutil import copytree | ||
|
||
|
||
def data_combine(data_folder, data_root): | ||
image_names = os.listdir(data_folder) | ||
combined_dir = os.path.join(data_root, 'combined') | ||
os.mkdir(combined_dir, mode=777) | ||
image_num = len(image_names) | ||
for i in range(image_num - 2): | ||
image_name = image_names[i + 1] | ||
sample_date = image_name[4:12] | ||
dst = os.path.join(combined_dir, sample_date) | ||
os.mkdir(dst, mode=777) | ||
|
||
copy2(os.path.join(data_folder, image_names[i]), dst) | ||
copy2(os.path.join(data_folder, image_names[i + 1]), dst) | ||
copy2(os.path.join(data_folder, image_names[i + 2]), dst) | ||
os.rename(os.path.join(dst, image_names[i]), os.path.join(dst, 'prev.nc')) | ||
os.rename(os.path.join(dst, image_names[i + 1]), os.path.join(dst, 'cur.nc')) | ||
os.rename(os.path.join(dst, image_names[i + 2]), os.path.join(dst, 'next.nc')) | ||
|
||
|
||
|
||
def data_split(data_folder, data_root, train_scales = 0.8,val_scales = 0.1,test_scales = 0.1): | ||
sample_names = os.listdir(data_folder) | ||
|
||
train_folder = os.path.join(data_root, 'train') #分割后的训练数据集路径 | ||
val_folder = os.path.join(data_root, 'val') | ||
test_folder = os.path.join(data_root, 'test') | ||
|
||
# os.mkdir(train_folder, mode=777) | ||
# os.mkdir(val_folder, mode=777) | ||
# os.mkdir(test_folder, mode=777) | ||
|
||
sample_num = len(sample_names) | ||
index_list = list(range(sample_num)) | ||
random.shuffle(index_list) | ||
|
||
train_stop_flag = sample_num * train_scales | ||
val_stop_flag = sample_num * (train_scales + val_scales) | ||
|
||
train_num = 0 | ||
val_num = 0 | ||
test_num = 0 | ||
|
||
for i in range(sample_num): | ||
if i <= train_stop_flag: | ||
copytree(os.path.join(data_folder, sample_names[index_list[i]]), os.path.join(train_folder, sample_names[index_list[i]])) | ||
train_num += 1 | ||
elif i <= val_stop_flag: | ||
copytree(os.path.join(data_folder, sample_names[index_list[i]]), os.path.join(val_folder, sample_names[index_list[i]])) | ||
val_num += 1 | ||
else: | ||
copytree(os.path.join(data_folder, sample_names[index_list[i]]), os.path.join(test_folder, sample_names[index_list[i]])) | ||
test_num += 1 | ||
|
||
print('训练集', train_num) | ||
print('验证集', val_num) | ||
print('测试集', test_num) | ||
|
||
|
||
if __name__ == '__main__': | ||
# data_root = os.path.join(os.path.dirname(os.getcwd()), 'data', 'ECS') | ||
# data_folder = os.path.join(data_root, 'raw') # 数据源文件地址 | ||
# data_combine(data_folder, data_root) | ||
data_root = os.path.join(os.path.dirname(os.getcwd()), 'data', 'ECS') | ||
data_folder = os.path.join(data_root, 'combined') # 数据源文件地址 | ||
data_split(data_folder, data_root) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import torch | ||
from torch.utils.data import Dataset | ||
import xarray as xr | ||
import numpy as np | ||
import utils | ||
import math | ||
import os | ||
import pandas as pd | ||
|
||
|
||
class DINCAEDataset(Dataset): | ||
def __init__(self, data_dir=None, transform=None, target_transform=None, use_elevation=True): | ||
""" | ||
@param data_dir: 数据集目录名 | ||
根据目录名获取相对路径,访问对应文件夹 | ||
""" | ||
self.transform = transform | ||
self.target_transform = target_transform | ||
self.data_dir = data_dir | ||
self.sample_dirs = os.listdir(data_dir) | ||
with xr.open_dataset(os.path.join(os.path.dirname(data_dir), 'bathymetry.nc')) as bathymetry: | ||
self.elevation = bathymetry.elevation.to_numpy() | ||
self.elevation = np.float_power(abs(self.elevation), 1.0/5)*np.sign(self.elevation)/5 | ||
self.use_elevation = use_elevation | ||
|
||
def __getitem__(self, item): | ||
""" | ||
@param item: | ||
@return: 返回第item个input和第item个mask | ||
""" | ||
date = self.sample_dirs[item] | ||
|
||
sample_dir = os.path.join(self.data_dir, date) | ||
with xr.open_dataset(os.path.join(sample_dir, 'prev.nc')) as prev: | ||
prev = prev | ||
with xr.open_dataset(os.path.join(sample_dir, 'cur.nc')) as cur: | ||
cur = cur | ||
with xr.open_dataset(os.path.join(sample_dir, 'next.nc')) as next: | ||
next = next | ||
|
||
channel = 9 | ||
input_np = np.zeros((channel, cur.CHL1_mean.shape[0], cur.CHL1_mean.shape[1])) | ||
missing_mask = np.zeros((1, cur.CHL1_mean.shape[0], cur.CHL1_mean.shape[1])) | ||
lat = cur.lat | ||
lon = cur.lon | ||
# 设置输入input_np, 为1-3层为chlo | ||
input_np[0] = prev.CHL1_mean | ||
input_np[1] = cur.CHL1_mean | ||
input_np[2] = next.CHL1_mean | ||
if self.use_elevation: | ||
input_np[3] = np.round(cur.CHL1_flags[0] % 16 >= 8) | ||
input_np[8] = self.elevation | ||
input_np[4] = np.expand_dims((lat - 22.0)/(32.0 - 22.0) * 2 - 1, 1) | ||
input_np[5] = np.expand_dims((lon - 118.0)/(126.0 - 118.0) * 2 - 1, 0) | ||
input_np[6] = np.sin(pd.to_datetime(date).dayofyear / 366 * math.pi * 2) | ||
input_np[7] = np.cos(pd.to_datetime(date).dayofyear / 366 * math.pi * 2) | ||
|
||
# 数据的mask | ||
missing_mask[0] = 1 - cur.CHL1_mean.to_masked_array().mask | ||
# target是当天的数据 | ||
target = input_np[1:2] | ||
|
||
# 数据的transform | ||
missing_mask_ts = torch.from_numpy(missing_mask) | ||
input_ts = torch.from_numpy(input_np) | ||
target_ts = torch.from_numpy(target) | ||
|
||
if self.transform is not None: | ||
input_ts = torch.cat([self.transform(input_ts[0:3]), input_ts[3:]], 0) | ||
|
||
if self.target_transform is not None: | ||
target_ts = self.target_transform(target_ts) | ||
|
||
return input_ts, missing_mask_ts, target_ts | ||
|
||
def __len__(self): | ||
return len(self.sample_dirs) | ||
|
||
|
||
|
||
|
||
|
||
class MaskDataset(Dataset): | ||
def __init__(self, mask_dir): | ||
""" | ||
@param mask_dir: 数据集目录名 | ||
根据目录名获取相对路径,访问对应文件夹 | ||
""" | ||
self.mask_dir = mask_dir | ||
self.mask_names = os.listdir(mask_dir) | ||
|
||
def __getitem__(self, item): | ||
""" | ||
@param item: | ||
@return: 返回第item个mask | ||
""" | ||
mask_path = os.path.join(self.mask_dir, self.mask_names[item]) | ||
with xr.open_dataset(mask_path) as mask: | ||
mask = mask | ||
return torch.from_numpy(np.expand_dims(1 - mask.CHL1_mean.to_masked_array().mask, 0)) | ||
|
||
def __len__(self): | ||
return len(self.mask_names) |
Oops, something went wrong.