Using fastai for segmentation

from pathlib import Path
from drone_detector.processing.tiling import *
import os
from fastai.vision.all import *
from drone_detector.engines.fastai.data import *
/opt/conda/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
outpath = Path('../data/historic_map/processed/raster_tiles/')

fnames = [Path(outpath/f) for f in os.listdir(outpath)]

dls = SegmentationDataLoaders.from_label_func('../data/historic_map/', bs=16,
                                              codes=['Marshes'],
                                              fnames=fnames,
                                              label_func=partial(label_from_different_folder,
                                                                 original_folder='raster_tiles',
                                                                 new_folder='mask_tiles'),
                                              batch_tfms = [
                                                  *aug_transforms(max_rotate=0., max_warp=0.),
                                                  Normalize.from_stats(*imagenet_stats)
                                              ])
/opt/conda/lib/python3.9/site-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  ret = func(*args, **kwargs)

label_from_different_folder is a helper located in drone_detector.engines.fastai.data. That module also contains helpers to use with for instance multispectral images or time series of images.

dls.show_batch(max_n=16)

Train basic U-Net, using pretrained Resnet50 as the encoder. to_fp16() tells our model to use half precision training, thus using less memory. Loss function is FocalLossFlat, and for segmentation we need to specify axis=1. Metrics are Dice and JaccardCoeff, fairly standard segmentation metrics.

learn = unet_learner(dls, resnet50, pretrained=True, n_in=3, n_out=2,
                     metrics=[Dice(), JaccardCoeff()], loss_func=FocalLossFlat(axis=1)
                    ).to_fp16()

Search for a suitable learning rate.

learn.lr_find()
SuggestedLRs(valley=3.630780702224001e-05)

Train the model for 2 epochs with encoder layers frozen and 10 epochs with all layers unfrozen.

from fastai.callback.progress import ShowGraphCallback
learn.fine_tune(10, freeze_epochs=2, base_lr=3e-5, cbs=ShowGraphCallback)
epoch train_loss valid_loss dice jaccard_coeff time
0 0.086749 0.066611 0.050679 0.025998 00:13
1 0.070000 0.036666 0.721976 0.564916 00:13

epoch train_loss valid_loss dice jaccard_coeff time
0 0.031274 0.025833 0.810985 0.682064 00:13
1 0.030273 0.021505 0.839143 0.722865 00:13
2 0.029289 0.019079 0.852387 0.742747 00:13
3 0.027417 0.048365 0.708193 0.548218 00:13
4 0.030767 0.020879 0.826265 0.703961 00:13
5 0.027386 0.018578 0.867091 0.765366 00:13
6 0.025325 0.017101 0.863993 0.760552 00:13
7 0.022994 0.016151 0.879036 0.784179 00:13
8 0.021048 0.015779 0.878420 0.783198 00:13
9 0.019957 0.015841 0.876572 0.780265 00:13

Return to full precision.

learn.to_fp32()
<fastai.learner.Learner>

Check results.

learn.show_results(max_n=8)

preds = learn.get_preds(with_input=False, with_decoded=False)

Export the model to use later.

learn.path = Path('../data/historic_map/models')
learn.export('resnet50_focalloss_swamps.pkl')

Some helper functions for inference, such as removing all resizing transforms.

def label_func(fn):
    return str(fn).replace('raster_tiles', 'mask_tiles')

@patch 
def remove(self:Pipeline, t):
    for i,o in enumerate(self.fs):
        if isinstance(o, t.__class__): self.fs.pop(i)
            
@patch
def set_base_transforms(self:DataLoader):
    attrs = ['after_item', 'after_batch']
    for i, attr in enumerate(attrs):
        tfms = getattr(self, attr)
        for j, o in enumerate(tfms):
            if hasattr(o, 'size'):
                tfms.remove(o)
            setattr(self, attr, tfms)

Load learners and remove all resizing transforms. If you run out of memory just restart the kernel.

testlearn = load_learner('../data/historic_map/models/resnet50_focalloss_swamps.pkl', cpu=False)
testlearn.dls.valid.set_base_transforms()

The model is tested with 3 different map patches from different areas and sizes. Two of the images are from 1965 and two from 1984. Image sizes vary between 600x600 and 1500x1500 pixels.

import PIL
def unet_predict(fn):
    image = np.array(PIL.Image.open(fn))
    mask = testlearn.predict(PILImage.create(image))[0].numpy()
    img = image
    img[:,:,0][mask==0] = 0
    img[:,:,1][mask==0] = 0
    img[:,:,2][mask==0] = 0
    img = PIL.Image.fromarray(img.astype(np.uint8))
    return img
test_images = [f'../data/historic_map/test_patches/{f}' for f in os.listdir('../data/historic_map/test_patches/')]

First result.

patch_pred = unet_predict(test_images[0])

fig, axs = plt.subplots(1,2, figsize=(10,5),dpi=300)
for a in axs:
    a.set_yticks([])
    a.set_xticks([])
axs[0].imshow(PIL.Image.open(test_images[0]))
axs[1].imshow(patch_pred)
axs[0].set_title(test_images[0].split('/')[-1])
axs[1].set_title('Predicted marshes')
plt.show()

Second result

patch_pred = unet_predict(test_images[1])

fig, axs = plt.subplots(1,2, figsize=(10,5),dpi=300)
for a in axs:
    a.set_yticks([])
    a.set_xticks([])
axs[0].imshow(PIL.Image.open(test_images[1]))
axs[1].imshow(patch_pred)
axs[0].set_title(test_images[1].split('/')[-1])
axs[1].set_title('Predicted marshes')
plt.show()

Third result

patch_pred = unet_predict(test_images[3])

fig, axs = plt.subplots(1,2, figsize=(10,5),dpi=300)
for a in axs:
    a.set_yticks([])
    a.set_xticks([])
axs[0].imshow(PIL.Image.open(test_images[3]))
axs[1].imshow(patch_pred)
axs[0].set_title(test_images[3].split('/')[-1])
axs[1].set_title('Predicted marshes')
plt.show()