基于 CNN 的 AI 分类模型开发¶
本案例主要介绍如何快速利用 AIE Python SDK 创建机器学习建模流程。我们主要使用到 Python SDK的Machine Learning Proxy 模块(下文简称 AieMlProxy )。该模块涵盖了一系列用户与训练集群之间的交互接口,包括:鉴权、数据加载、训练任务提交、任务状态和日志查看、模型推理等。
导入 AIE Python SDK 包并初始化¶
import aie
aie.Authenticate()
aie.Initialize()
创建工作目录¶
PACKAGE_PATH = "cnn_clf_demo"
!ls -l
!mkdir {PACKAGE_PATH}
!mkdir {PACKAGE_PATH}/data
!touch {PACKAGE_PATH}/__init__.py
!ls -l {PACKAGE_PATH}
数据集加载¶
AI Earth 平台目前主要存储和管理三类数据,分别是
- 影像类(即栅格数据,包括 Image 和 ImageCollection)
- 矢量类(包括 Feature 和 FeatureCollection )
- 数据集(除影像、矢量之外的非时空类数据,包括用户上传、代码生成的 csv、txt、json、zip 等格式的文件)
其中,数据集 又分为公开数据集和私有数据集。公开数据集采用 SpatioTemporal Asset Catalog ( STAC )进行管理,数据集合或单项数据均有各自全局唯一的 STAC ID ;私有数据集为用户自行上传的数据集。
公开数据集¶
公开数据集包含 CV 和遥感领域常见的 benchmark ,通过 MlProxy 模块提供的 STAC 接口来查询和获取。
# 导入AieMlProxy模块
from aie.client.mlproxy import MlProxy
# 列出所有公开数据集
MlProxy.list_stac_datasets()
# 通过dataset id获取数据集描述信息,以CIFAR-10数据集为例
stac_desc = MlProxy.get_stac_dataset("AIE_PUBLIC_DATA_CIFAR10_DATASET_V10_20220627")
print(stac_desc)
# 获取数据集split
TRAIN_PATH = stac_desc.get('train_path')
VALID_PATH = stac_desc.get('valid_path')
print(VALID_PATH)
print(TRAIN_PATH)
私有数据集¶
在本项目页面左侧,依次点击 数据 → 项目数据 → 导入数据 → 自主上传数据 以导入特定的单景影像到项目中。 私有数据集导入以后,默认会挂载到 /home/data 目录中(可使用终端命令行查看)
配置文件¶
%%writefile {PACKAGE_PATH}/config.py
OSS_HOST = "oss-cn-hangzhou-internal.aliyuncs.com"
OSS_WORK_DIR = "pai/cnn_clf_demo"
OSS_CHECKPOINT_DIR = "pai/cnn_clf_demo/checkpoint"
OPEN_DATA_BUCKET = "aie-sample-data"
OPEN_DATA_ENDPOINT = "http://oss-cn-hangzhou-internal.aliyuncs.com"
# Ouput info
OUTPUT_MODEL_FILE_NAME = "cnnDemoModelBest.pth"
# Local tmp dir on PAI Cluster (to save dataset)
PAI_LOCAL_TMP_DIR = "./tmp/"
# Hyperparams
BATCH_SIZE = 256
NUM_LABELS = 10
NUM_EPOCHES = 1
STAC_TEST_PATH = ""
STAC_TRAIN_MAPPING_PATH = ""
STAC_CLASS_DICT_PATH = ""
# 追加公开数据集split的路径到配置文件
!echo 'STAC_TRAIN_PATH = '\"{TRAIN_PATH}\" >> {PACKAGE_PATH}/config.py
!echo 'STAC_VALID_PATH = '\"{VALID_PATH}\" >> {PACKAGE_PATH}/config.py
文章评论