Example Usage

Here some example code can be found to see how to use the GOOSE Dataset in a custom application. Everything can be tested in the given Jupyter Notebook link

Pytorch Dataset

The GOOSE Dataset is divided into three subcategories: train, test and validation. The first step is to read the data from the dataset folder.

In the following example this is achieved in two main steps:

  1. Parse images from root folder into (three) python dictionaries with images paths and information.
  2. Create Pytorch Dataset objects to load the images and use them to train models or inference.

Additionally, the dataset also has a mapping CSV File which contains information about the classes such as label id, class name or whether the class has instances or not (thing or stuff).

Parsing and Reading Data

def __check_labels(img_path: str, lbl_path: str) -> bool:
    '''
    Check if pair of labels and images exist. Filter non-existing pairs.
    '''
    name = os.path.basename(img_path)
    name, ext = name.split('.')
    name = name.split('_')[:-2]
    name = '_'.join(name)

    names = []
    for l in ['color', 'instanceids', 'labelids']:
        # Check if label exists
        lbl_name = name + '_' + l + '.' + ext
        if not os.path.exists(os.path.join(lbl_path, lbl_name)):
            return False, None
        names.append(lbl_name)

    return True, names

def __goose_datadict_folder(img_path: str, lbl_path: str):
    '''
    Create a data Dictionary with image paths
    '''
    subfolders = glob.glob(os.path.join(img_path, '*/'), recursive = False)
    subfolders = [f.split('/')[-2] for f in subfolders]

    valid_imgs = []
    valid_lbls = []
    valid_insta= []
    valid_color= []

    datadict = []

    for s in tqdm.tqdm(subfolders):
        imgs_p = os.path.join(img_path, s)
        lbls_p = os.path.join(lbl_path, s)
        imgs = glob.glob(os.path.join(imgs_p, '*.png'))
        for i in imgs:
            valid, lbl_names = __check_labels(i, lbls_p)
            if not valid:
                continue

            valid_imgs.append(i)
            valid_color.append(os.path.join(lbls_p, lbl_names[0]))
            valid_insta.append(os.path.join(lbls_p, lbl_names[1]))
            valid_lbls.append(os.path.join(lbls_p,  lbl_names[2]))

    for i,m,p,c in zip(valid_imgs, valid_lbls, valid_insta, valid_color):
        datadict.append({
                'img_path': i,
                'semantic_path': m,
                'instance_path':p,
                'color_path': c,
            })   

    return datadict

def goose_create_dataDict(src_path: str, mapping_csv_name: str = 'goose_label_mapping.csv') -> Dict:
    '''
    Parameters:

        src_path            :   path to dataset

    Returns:

        datadict_train      : dict with the dataset train images information

        datadict_val        : dict with the dataset validation images information

        datadict_test       : dict with the dataset test images information
    '''
    if mapping_csv_name is not None:
        mapping_path = os.path.join(src_path, mapping_csv_name)
        mapping = []
        with open(mapping_path, newline='') as f:
            reader = csv.DictReader(f)
            for r in reader:
                mapping.append(r)
    else:
        mapping = None

    img_path = os.path.join(src_path, 'images')
    lbl_path = os.path.join(src_path, 'labels')

    datadicts = []
    for c in ['test', 'train', 'val']:
        print("### " + c.capitalize() + " Data ###")
        datadicts.append(
            __goose_datadict_folder(
                os.path.join(img_path, c),
                os.path.join(lbl_path, c)
                )
            )

    test,train,val = datadicts

    return test,train,val, mapping

Dataset Module

This Dataset class is specific for semantic segmentation and performs and square crop and resize of the images. It can be used as any other Dataset object in Pytorch to train a model.

class GOOSE_SemanticDataset(Dataset):
    """
    Example Pytorch Dataset Module for semantic tasks with GOOSE.
    """

    def __init__(self, dataset_dict: List[Dict], crop: bool = True, resize_size: Iterable[int] = None):
        '''
        Parameters:
            dataset_dict  [Iter]    : List of  Dicts with the images information generated by *goose_create_dataDict*

            crop          [Bool]    : Whether to make a square crop of the images or not

            resize_size   [Iter]    : List with the target resize size of the images (After the crop if crop == True)
        '''
        self.dataset_dict   = dataset_dict
        self.transforms     = transforms.Compose([
            transforms.ToTensor(),
            ])
        self.resize_size    = resize_size
        self.crop           = crop

    def preprocess(self, image):
        if image is None:
            return None

        if self.crop:
            # Square-Crop in the center
            s = min([image.width , image.height])
            image = transforms.CenterCrop((s,s)).forward(image)

        if self.resize_size is not None:
            # Resize to given size
            image = image.resize(self.resize_size, resample=Image.NEAREST)

        return image


    def __getitem__(self, i):
        '''
        Parameter:
            i   [int]                   : Index of the image to get

        Returns:
            image_tensor [torch.Tensor] : 3 x H x W Tensor

            label_tensor [torch.Tensor] : H x W Tensor as semantic map
        '''
        image = Image.open(self.dataset_dict[i]['img_path']).convert('RGB')
        label = Image.open(self.dataset_dict[i]['semantic_path']).convert('L')

        image = self.preprocess(image)
        label = self.preprocess(label)

        image_tensor = self.transforms(image)
        label_tensor = torch.from_numpy(np.array(label)).long()

        return image_tensor, label_tensor

    def __len__(self):
        return len(self.dataset_dict)

Training

For this example we used SuperGradients to ease the training process. But the workflow would be very similar with any other custom model or framework.

Firstly, the images are parsed into the dictionaries and the Datadicts are created with their information. These are then passed into a Dataloader object in order to use them with the SuperGradients Trainer.

from torch.utils.data import DataLoader
import super_gradients as sg
from super_gradients.training.metrics.segmentation_metrics import IoU

## Load the data
#
PATH = '/path/to/goose'
test_dict, train_dict, val_dict, mapping_dict = goose_create_dataDict(PATH)

train_dataset = GOOSE_SemanticDataset(train_dict, crop=True, resize_size=(768,768))
val_dataset   = GOOSE_SemanticDataset(val_dict, crop=True, resize_size=(768,768))

## Set-up for training
#

# Create output directory
EXPERIMENT_NAME = "GOOSE_train"
WS_PATH = os.getcwd()
CHECKPOINT_DIR = os.path.join(WS_PATH, 'output', 'ckpts')

if not os.path.isdir(CHECKPOINT_DIR):
    os.makedirs(CHECKPOINT_DIR)

# Params
BATCH_SIZE      = 5
N_EPOCHS        = 10

# Dataloaders
train_dataloader    = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=5, drop_last=True)
val_dataloader      = DataLoader(val_dataset,  batch_size=BATCH_SIZE, shuffle=True, num_workers=5, drop_last=True)

Then the Trainer is configured and the model is loaded. In this case, the pre-trained weights with the Cityscapes datasets are loaded.

# Trainer Set-up
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
sg.setup_device(device=device)
trainer = sg.Trainer(experiment_name=EXPERIMENT_NAME, ckpt_root_dir=CHECKPOINT_DIR)

## Load Model
#
model = sg.training.models.get(model_name=Models.DDRNET_39,
                    num_classes=64,
                    pretrained_weights='cityscapes')
model.eval()

# Set-up Training params
lr_updates = [int(.3 * N_EPOCHS), int(.6 * N_EPOCHS), int(.9 * N_EPOCHS)]
train_params = {
            "max_epochs": N_EPOCHS,
            "lr_mode":"step",
            "lr_updates": lr_updates,
            "lr_decay_factor": 0.1,
            "initial_lr": 0.005,
            "optimizer": 'sgd',
            "loss": 'cross_entropy',
            "average_best_models": False,
            "greater_metric_to_watch_is_better": True,
            "loss_logging_items_names": ["loss"],
            "drop_last": True,
            }

train_params["train_metrics_list"] = [IoU(num_classes=64)]
train_params["valid_metrics_list"] = [IoU(num_classes=64)]
train_params["metric_to_watch"]    = "IoU"
Lastly the training is started.

## Train
#
trainer.train(model=model, training_params=train_params, train_loader=train_dataloader, valid_loader=val_dataloader)

Test Inference

This example shows how to simply run inference with a pre-trained checkpoint using Pytorch and represent the results in a semantic map.

Load Model

from matplotlib import pyplot as plt
import matplotlib
import matplotlib.patches as mpatches


def run_inference(img, model):
    '''
    Run inference and return semantic mask
    '''
    if len(img.shape) != 4:
        img = torch.unsqueeze(img, 0)
    mask = model(img)
    masks = torch.sigmoid(mask).squeeze()
    label = torch.max(masks, 0)[1]

    return label

## Import Model
#
model = model = sg.training.models.get(model_name=Models.DDRNET_39,
                        num_classes=64,
                        checkpoint_path="path/to/checkpoint.pth")
model.eval()

Run Inference on Images

This iterates through the images in a dataset object and runs inference on them. The input image, output of net and ground truth are displayed with the corresponding class ids.

## Iterate through images and run inference on them
#
N_SAMPLES = 10
viridis = matplotlib.colormaps['viridis'].resampled(64)
for idx in np.random.randint(0, len(test_dataset), min(N_SAMPLES, len(test_dataset))):
        # Configure plot
        plt.figure()
        f, axarr = plt.subplots(1,3)
        f.subplots_adjust(hspace=10.0, right=1.5)

        axarr[0].set_xlabel("RGB")
        axarr[1].set_xlabel("Predicted")
        axarr[2].set_xlabel("Ground Truth")

        # Get images
        img, label = test_dataset[idx]
        mask = run_inference(img, model)

        imgs = [np.transpose(img, (1, 2, 0)), np.asarray(mask), np.asarray(label)]

        for i in range(len(axarr)):
            if i != 0:
                im = axarr[i].imshow(imgs[i], cmap = viridis)
                im.set_clim(0, 64)

                # Legend
                handles = []
                for i_c,c in enumerate(np.unique(imgs[i])):
                    segment_id = i_c
                    segment_label = c
                    label = f"{segment_label}"
                    color = viridis(segment_label / 63)
                    handles.append(mpatches.Patch(color=color, label=label))
                axarr[i].legend(handles=handles, bbox_to_anchor=(1.0, 1.00))
            else:
                im = axarr[i].imshow(imgs[i])

        plt.show()
The results should look similar to this:

results