import rasterio as rio
import geopandas as gpd
from pathlib import Path
import rasterio.plot as rioplot
import matplotlib.pyplot as plt
Unet workflow
Example workflow for training a semantic segmentation model with
fastai
= Path('workflow_examples/')
path_to_data = path_to_data/'104_28_Hiidenportti_Chunk1_orto.tif'
train_raster = path_to_data/'104_28_Hiidenportti_Chunk1_orto.geojson'
train_shp = path_to_data/'104_42_Hiidenportti_Chunk5_orto.tif'
test_raster = path_to_data/'104_42_Hiidenportti_Chunk5_orto.geojson' test_shp
Example data is RGB UAV imagery from Hiidenportti, and the task is to detect and segment different deadwood types. The reference data are annotated as polygons, and target column is layer
.
Training area looks like this.
= plt.subplots(1,2, dpi=150, figsize=(10,3))
fig, axs with rio.open(train_raster) as src:
=axs[0])
rioplot.show(src, ax= gpd.read_file(train_shp)
train_gdf ='layer', ax=axs[1], cmap='seismic')
train_gdf.plot(column'Train area')
plt.suptitle(
plt.tight_layout() plt.show()
And test area looks like this.
= plt.subplots(1,2, dpi=150, figsize=(5,3))
fig, axs with rio.open(test_raster) as src:
=axs[0])
rioplot.show(src, ax= gpd.read_file(test_shp)
test_gdf ='layer', ax=axs[1], cmap='seismic')
test_gdf.plot(column'Test area')
plt.suptitle(
plt.tight_layout() plt.show()
Install required dependencies
Simplest way to install fastai
is
conda install -c fastchan fastai
Other ways to install can be found here.
Create a raster dataset
In this example, the data are split into 224x224 pixel tiles with no overlap.
CLI
geo2ml_create_raster_dataset \
\
example_data/workflow_examples/104_28_Hiidenportti_Chunk1_orto.tif \
example_data/workflow_examples/104_28_Hiidenportti_Chunk1_orto.geojson \
example_data/workflow_examples/unet/train --gridsize_x 224 --gridsize_y 224
--target_column layer
geo2ml_create_raster_dataset \
\
example_data/workflow_examples/104_42_Hiidenportti_Chunk5_orto.tif \
example_data/workflow_examples/104_42_Hiidenportti_Chunk5_orto.geojson layer \
example_data/workflow_examples/unet/test --gridsize_x 224 --gridsize_y 224 --target_column layer
Python
from geo2ml.scripts.data import create_raster_dataset
= path_to_data/'unet'
outpath
=train_raster, mask_path=train_shp, outpath=outpath/'train',
create_raster_dataset(raster_path=False, target_column='layer', gridsize_x=224, gridsize_y=224)
save_grid
=test_raster, mask_path=test_shp, outpath=outpath/'test',
create_raster_dataset(raster_path=False, target_column='layer', gridsize_x=224, gridsize_y=224) save_grid
Dataset structure
Above creates the dataset to path_to_data/'unet'
, so that it contains folders train
and test
. Both of these contain
- folder
images
, which contains the tiled raster patches - folder
mask_images
, which contain the rasterized masks corresponding to files inimages
label_map.txt
, which has the mapping between integer values and classes
Train the model
from fastai.vision.all import *
= [Path(outpath/'train/images'/f) for f in os.listdir(outpath/'train/mask_images')]
fnames
def label_from_different_folder(fn, original_folder, new_folder):
return str(fn).replace(original_folder, new_folder)
= SegmentationDataLoaders.from_label_func(outpath/'train/images', bs=8,
dls =['Background', 'Standing', 'Fallen'],
codes=fnames,
fnames=partial(label_from_different_folder,
label_func='images',
original_folder='mask_images'),
new_folder= [
batch_tfms *aug_transforms(max_rotate=0., max_warp=0.),
*imagenet_stats)
Normalize.from_stats( ])
Fine-tune a Unet that uses resnet34 as the encoder.
= unet_learner(dls, resnet34)
learn 6) learn.fine_tune(
/home/mayrajeo/miniconda3/envs/point-eo-dev/lib/python3.11/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
/home/mayrajeo/miniconda3/envs/point-eo-dev/lib/python3.11/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet34_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet34_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /home/mayrajeo/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|███████████████████████████████████████████████████████████████| 83.3M/83.3M [00:35<00:00, 2.48MB/s]
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 0.102724 | 0.063255 | 00:31 |
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 0.060898 | 0.049642 | 00:32 |
1 | 0.049917 | 0.035729 | 00:32 |
2 | 0.038387 | 0.032115 | 00:32 |
3 | 0.030265 | 0.028458 | 00:32 |
4 | 0.025367 | 0.025913 | 00:32 |
5 | 0.022732 | 0.025986 | 00:32 |
See the results.
=4) learn.show_results(max_n