from pathlib import Path
import os
from fastai.vision.all import *
from fastai.callback.wandb import *
from fastai.callback.progress import ShowGraphCallback
from fastai.callback.tracker import SaveModelCallback
from drone_detector.engines.fastai.data import *
from drone_detector.metrics import JaccardCoeffMulti
import wandb
'..')
sys.path.append(from src.dataloaders import SegmentationDataLoadersFix
from src.augmentations import RandomErasingSeg
Train model
1 Model training
Specify which patch locations are used for validation.
= Path('../data/processed/train')
outpath = sorted(os.listdir(outpath/'1984'/'raster_tiles'))
patches = patches[int(len(patches)*.75):]
val_files = [Path(outpath/year/'raster_tiles'/f) for year in ['1965', '1984']
fnames for f in os.listdir(outpath/year/'raster_tiles')]
Locations are the same for both years in order to prevent data leakage.
import rasterio as rio
import rasterio.merge as rio_merge
import rasterio.plot as rioplot
= []
rasters
= plt.subplots(1,4, figsize=(12,4))
fig, ax for a in ax: a.axis('off')
= []
rasters for f in [s for s in patches if s not in val_files]:
= rio.open(outpath/'1965/raster_tiles'/f)
src
rasters.append(src)
= rio_merge.merge(rasters)
mos, tfm 0].imshow(np.moveaxis(mos,0,2))
ax[0].set_title('1965 training area')
ax[
= []
rasters for f in val_files:
= rio.open(outpath/'1965/raster_tiles'/f)
src
rasters.append(src)
= rio_merge.merge(rasters)
mos, tfm 1].imshow(np.moveaxis(mos,0,2))
ax[1].set_title('1965 validation area')
ax[
= []
rasters for f in [s for s in patches if s not in val_files]:
= rio.open(outpath/'1984/raster_tiles'/f)
src
rasters.append(src)
= rio_merge.merge(rasters)
mos, tfm 2].imshow(np.moveaxis(mos,0,2))
ax[2].set_title('1984 training area')
ax[
= []
rasters for f in val_files:
= rio.open(outpath/'1984/raster_tiles'/f)
src
rasters.append(src)
= rio_merge.merge(rasters)
mos, tfm 3].imshow(np.moveaxis(mos,0,2))
ax[3].set_title('1984 validation area')
ax[
plt.tight_layout() plt.show()
We need to slighly modify fastai’s SegmentationDataLoaders
so that we can specify the tiles to use as validation data during training, as well as edit RandomErasing
so that it also erases the masks. These can be found in src
folder.
The models are trained using fastai library. During training, the images are randomly flipped horizontally, brightness and contrast and saturation are randomly adjusted, images are rotated randomly (maximum rotation 5 degrees) and random areas are erased from the images.
= SegmentationDataLoadersFix.from_label_func('../data/', bs=8,
dls = ['Background',
codes 'Fields',
'Mires',
'Roads',
'Watercourses',
'Water bodies'],
=fnames,
fnames=(partial(label_from_different_folder,
label_func='raster_tiles',
original_folder='mask_tiles')),
new_folder= [
batch_tfms *aug_transforms(max_rotate=5.,
=0.,
max_warp=1,
max_zoom='zeros',
pad_mode=[
xtra_tfms=0.3)]),
Saturation(max_lighting=2, erasing_mode='gaussian'),
RandomErasingSeg(max_count*imagenet_stats)
Normalize.from_stats(
],=0, val_fnames=val_files) num_workers
from matplotlib import colors
= colors.ListedColormap(['white', 'tab:orange', 'tab:grey', 'tab:red', 'tab:cyan', 'tab:blue'])
cmap =[0,1,2,3,4,5]
bounds
=8, cmap=cmap, alpha=1, vmin=0, vmax=5) dls.show_batch(max_n
We used U-Net (Roenneberger, Fischer and Brox, 2015) architecture with ResNet152 as the encoder and Focal Loss as the loss function. The model evaluation metrics are Dice and Jaccard coefficients for multiclass tasks. Dice coefficient is twice the area of overlap in pixels divided by the total number of pixels in the images, while Jaccard coefficient is the area of intersection divided by the area of union. Both metrics are computed for each class separately, and then averaged, excluding background class.
= unet_learner(dls, arch=resnet152, pretrained=True, n_in=3, n_out=6, blur=False, self_attention=False,
learn =[DiceMulti(), JaccardCoeffMulti()], loss_func=FocalLossFlat(axis=1)).to_fp16() metrics
learn.lr_find()
SuggestedLRs(valley=1.737800812406931e-05)
Track the training with wandb.
='historical-maps') wandb.init(project
As the task is fairly simple, we trained the model for 1 frozen epochs (only the classification layer and decoder), and 10 unfrozen epochs (all layers unfrozen).
10, freeze_epochs=1, base_lr=1e-4, cbs=[ShowGraphCallback, WandbCallback(log_preds_every_epoch=True)]) learn.fine_tune(
epoch | train_loss | valid_loss | dice_multi | jaccard_coeff_multi | time |
---|---|---|---|---|---|
0 | 0.366200 | 0.135884 | 0.579627 | 0.462131 | 00:48 |
epoch | train_loss | valid_loss | dice_multi | jaccard_coeff_multi | time |
---|---|---|---|---|---|
0 | 0.096456 | 0.061018 | 0.827349 | 0.696469 | 00:44 |
1 | 0.078594 | 0.105482 | 0.760327 | 0.599985 | 00:44 |
2 | 0.065460 | 0.035690 | 0.883655 | 0.777175 | 00:44 |
3 | 0.051111 | 0.033925 | 0.881897 | 0.776236 | 00:44 |
4 | 0.042035 | 0.035209 | 0.891968 | 0.791245 | 00:44 |
5 | 0.035553 | 0.025388 | 0.910085 | 0.820936 | 00:44 |
6 | 0.031272 | 0.031957 | 0.895384 | 0.796915 | 00:44 |
7 | 0.028636 | 0.024199 | 0.910118 | 0.821767 | 00:44 |
8 | 0.027473 | 0.023532 | 0.914583 | 0.828892 | 00:44 |
9 | 0.025727 | 0.023381 | 0.914946 | 0.829465 | 00:44 |
learn.to_fp32() wandb.finish()
=8, cmap=cmap, alpha=0.8, vmin=0, vmax=5) learn.show_results(max_n
learn.validate()
(#3) [0.023381350561976433,0.9149572189762392,0.8294852675339317]
= Path('../models')
learn.path ='resnet152_focal_cutmasks.pkl') learn.export(fname
2 Test with remaining labeled data
2.1 Example result patches
= load_learner('../models/resnet152_focal_cutmasks.pkl', cpu=False)
learn = dls learn.dls
= Path('../data/processed/val/')
valid_im_path = [Path(valid_im_path/year/'raster_tiles'/f) for year in ['1965', '1984']
val_images for f in os.listdir(valid_im_path/year/'raster_tiles')]
= learn.dls.test_dl(val_images, with_labels=True, bs=32, shuffle=True) testdl
=8, cmap=cmap, alpha=0.8, vmin=0, vmax=5) testdl.show_batch(max_n
Check the loss, dice score and jaccard score for this dataset.
=testdl) learn.validate(dl
(#3) [0.029747433960437775,0.9198446067228588,0.8382089817195901]
And then some example results.
=testdl, max_n=2, cmap=cmap, alpha=.95)
learn.show_results(dl= plt.gcf()
f = colors.Normalize(vmin=0,vmax=6)
norm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm #cbar = fig.colorbar(sm, ax=f.axes, ticks=np.arange(0.5,6.5), aspect=75)
#cbar.ax.set_yticklabels(['Background', 'Fields', 'Mires', 'Roads', 'Watercourses', 'Water bodies'])
'../data/figures/poster_preds.png', dpi=300, bbox_inches='tight')
plt.savefig( plt.tight_layout()
2.2 Classwise metrics
These results are before any morphological post-processings, and before merging the results into large tiles.
from sklearn.metrics import precision_score, recall_score
def dice(targs, preds, cls_id):
= (torch.where(targs==cls_id, 1, 0)*torch.where(preds==cls_id, 1, 0)).float().sum()
inter = (torch.where(targs==cls_id, 1, 0)+torch.where(preds==cls_id, 1, 0)).float().sum()
union return 2 * inter/union if union > 0 else None
def jaccard(targs, preds, cls_id):
= (torch.where(targs==cls_id, 1, 0)*torch.where(preds==cls_id, 1, 0)).float().sum()
inter = (torch.where(targs==cls_id, 1, 0)+torch.where(preds==cls_id, 1, 0)).float().sum()
union return inter/(union-inter) if union > 0 else None
def pre(targs, preds, cls_id):
return precision_score(torch.where(targs==cls_id, 1, 0).flatten(), torch.where(preds==cls_id, 1, 0).flatten())
def rec(targs, preds, cls_id):
return recall_score(torch.where(targs==cls_id, 1, 0).flatten(), torch.where(preds==cls_id, 1, 0).flatten())
2.2.1 Validation set
= learn.get_preds(with_decoded=True) _,targs, preds
2.2.1.1 Fields
Dice score for fields:
1) dice(targs, preds,
TensorBase(0.9731)
Jaccard score for fields:
1) jaccard(targs, preds,
TensorBase(0.9477)
Precision for fields:
1) pre(targs, preds,
0.972258257302713
Recall for fields:
1) rec(targs, preds,
0.9740398934335991
2.2.1.2 Mires
Dice score for mires:
2) dice(targs, preds,
TensorBase(0.8912)
Jaccard score for mires:
2) jaccard(targs, preds,
TensorBase(0.8037)
Precision for mires:
2) pre(targs, preds,
0.884052339005425
Recall for mires:
2) rec(targs, preds,
0.8984412879455308
2.2.1.3 Roads
Dice score for roads:
3) dice(targs, preds,
TensorBase(0.8770)
Jaccard score for roads:
3) jaccard(targs, preds,
TensorBase(0.7809)
Precision for roads:
3) pre(targs, preds,
0.8530348089242903
Recall for mires:
3) rec(targs, preds,
0.90231700895208
2.2.1.4 Watercourses
Dice score for watercourses:
4) dice(targs, preds,
TensorBase(0.7721)
Jaccard score for watercourses:
4) jaccard(targs, preds,
TensorBase(0.6288)
Precision for watercourses:
4) pre(targs, preds,
0.7471939707544204
Recall for watercourses
4) rec(targs, preds,
0.7987637892901098
2.2.1.5 Water bodies
Dice score for water bodies:
5) dice(targs, preds,
TensorBase(0.9931)
Jaccard score for water bodies:
5) jaccard(targs, preds,
TensorBase(0.9863)
Precision for water bodies
5) pre(targs, preds,
0.9910614109293016
Recall for water bodies:
5) rec(targs, preds,
0.9951029661699824
2.2.2 Test set
= learn.get_preds(dl=testdl, with_decoded=True) _, targs, preds
2.2.2.1 Fields
Dice score for fields:
1) dice(targs, preds,
TensorBase(0.9697)
Jaccard score for fields:
1) jaccard(targs, preds,
TensorBase(0.9412)
Precision for fields:
1) pre(targs, preds,
0.9618274055387686
Recall for fields:
1) rec(targs, preds,
0.9777684946199127
2.2.2.2 Mires
Dice score for mires:
2) dice(targs, preds,
TensorBase(0.8981)
Jaccard score for mires:
2) jaccard(targs, preds,
TensorBase(0.8151)
Precision for mires:
2) pre(targs, preds,
0.8723801784942353
Recall for mires:
2) rec(targs, preds,
0.9254702282614122
2.2.2.3 Roads
Dice score for roads:
3) dice(targs, preds,
TensorBase(0.8816)
Jaccard score for roads:
3) jaccard(targs, preds,
TensorBase(0.7882)
Precision for roads:
3) pre(targs, preds,
0.8464330652120925
Recall for mires:
3) rec(targs, preds,
0.9197872447043832
2.2.2.4 Watercourses
Dice score for watercourses:
4) dice(targs, preds,
TensorBase(0.8009)
Jaccard score for watercourses:
4) jaccard(targs, preds,
TensorBase(0.6679)
Precision for watercourses:
4) pre(targs, preds,
0.783207010142257
Recall for watercourses
4) rec(targs, preds,
0.8193276521039975
2.2.2.5 Water bodies
Dice score for water bodies:
5) dice(targs, preds,
TensorBase(0.9892)
Jaccard score for water bodies:
5) jaccard(targs, preds,
TensorBase(0.9786)
Precision for water bodies
5) pre(targs, preds,
0.9865743571649199
Recall for water bodies:
5) rec(targs, preds,
0.9917838442196202