oiv_ld_phenotyping / src /com_augmentations.py
treizh's picture
Upload folder using huggingface_hub
fc262e7 verified
raw
history blame contribute delete
No virus
7.93 kB
from pprint import pprint
import numpy as np
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform
import torch
from torch.utils.data import Dataset
import com_image as ci
import com_plot as cp
class FixPatchBrightness(ImageOnlyTransform):
def __init__(
self,
brightness_target=115,
brightness_thresholds=(115, 130),
always_apply: bool = False,
p: float = 0.5,
):
super().__init__(always_apply, p)
self.brightness_target = brightness_target
self.brightness_thresholds = brightness_thresholds
def apply(self, img, brightness_target=None, brightness_thresholds=None, **params):
brightness_target = (
self.brightness_target if brightness_target is None else brightness_target
)
brightness_thresholds = (
self.brightness_thresholds
if brightness_thresholds is None
else brightness_thresholds
)
r, g, b = cv2.split(img)
avg_bright = np.sqrt(
0.241 * np.power(r.astype(float), 2)
+ 0.691 * np.power(g.astype(float), 2)
+ 0.068 * np.power(b.astype(float), 2)
).mean()
tmin, tmax = min(*brightness_thresholds), max(*brightness_thresholds)
if avg_bright < tmin or avg_bright > tmax:
if avg_bright > brightness_target:
gamma = brightness_target / avg_bright
if gamma != 1:
inv_gamma = 1.0 / gamma
table = np.array(
[((i / 255.0) ** inv_gamma) * 255 for i in np.arange(0, 256)]
).astype("uint8")
return cv2.LUT(src=img, lut=table)
else:
return img
else:
return cv2.convertScaleAbs(
src=img,
alpha=(brightness_target + avg_bright) / (2 * avg_bright),
beta=(brightness_target - avg_bright) / 2,
)
else:
return img
def build_albumentations(
image_size: int,
gamma=(60, 180),
brightness_limit=0.15,
contrast_limit=0.25,
crop=None,
center_crop: int = -1,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
brightness_target=None,
brightness_thresholds=None,
affine_transforms={"H": 0.3, "V": 0.3, "R": 0.3, "T": 0.3},
):
albs_ = {"resize": [A.Resize(height=image_size, width=image_size, p=1)]}
if brightness_target is not None and brightness_thresholds is not None:
albs_ = albs_ | {
"fix_brightness": [
FixPatchBrightness(
brightness_target=brightness_target,
brightness_thresholds=brightness_thresholds,
p=1,
)
]
}
if crop is not None:
if isinstance(crop, int):
albs_ = albs_ | {
"crop_and_pad": [
A.RandomCrop(height=crop, width=crop, p=0.5),
A.PadIfNeeded(min_height=image_size, min_width=image_size, p=1),
]
}
elif isinstance(crop, dict):
crop_val = crop["value"]
crop_p = crop["p"]
albs_ = albs_ | {
"crop_and_pad": [
A.PadIfNeeded(min_height=crop_val, min_width=crop_val, p=1),
A.RandomCrop(height=crop_val, width=crop_val, p=crop_p),
A.PadIfNeeded(min_height=image_size, min_width=image_size, p=1),
]
}
if center_crop > -1:
albs_ = albs_ | {
"center_crop": [
A.PadIfNeeded(min_height=center_crop, min_width=center_crop, p=1),
A.CenterCrop(height=center_crop, width=center_crop, p=1),
]
}
affine = []
for k, v in affine_transforms.items():
if k == "H":
affine.append(A.HorizontalFlip(p=v))
elif k == "V":
affine.append(A.VerticalFlip(p=v))
elif k == "R":
affine.append(A.RandomRotate90(p=v))
elif k == "T":
affine.append(A.Transpose(p=v))
albs_ = albs_ | {"affine": affine}
color = []
if brightness_limit is not None and contrast_limit is not None:
color.append(
A.RandomBrightnessContrast(
brightness_limit=brightness_limit,
contrast_limit=contrast_limit,
p=0.5,
)
)
if gamma is not None:
color.append(A.RandomGamma(gamma_limit=gamma, p=0.5))
albs_ = albs_ | {"color": color}
return albs_ | {
"to_tensor": [A.Normalize(mean=mean, std=std, p=1), ToTensorV2()],
"un_normalize": [
A.Normalize(
mean=[-m / s for m, s in zip(mean, std)],
std=[1.0 / s for s in std],
always_apply=True,
max_pixel_value=1.0,
),
],
}
def get_augmentations(
image_size: int = 224,
gamma=(60, 180),
brightness_limit=0.15,
contrast_limit=0.25,
crop=180,
center_crop: int = -1,
kinds: list = ["resize", "to_tensor"],
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
brightness_target=None,
brightness_thresholds=None,
affine_transforms={"H": 0.3, "V": 0.3, "R": 0.3, "T": 0.3},
):
if "train" in kinds:
kinds.insert(kinds.index("train"), "affine")
kinds.insert(kinds.index("train"), "color")
kinds.remove("train")
td_ = build_albumentations(
image_size := image_size,
gamma=gamma,
brightness_limit=brightness_limit,
contrast_limit=contrast_limit,
crop=crop,
center_crop=center_crop,
mean=mean,
std=std,
brightness_target=brightness_target,
brightness_thresholds=brightness_thresholds,
affine_transforms=affine_transforms,
)
augs = []
for k in kinds:
if k:
augs += td_[k] # .append(*[a for a in td_[k]])
return A.Compose(augs)
class MlcPatches(Dataset):
def __init__(self, dataframe, transform, path_to_images) -> None:
super().__init__()
self.dataframe = dataframe
self.transform = transform
self.path_to_images = path_to_images
def __len__(self):
return self.dataframe.shape[0]
def __getitem__(self, index):
img = self.transform(image=self.get_image(index=index))["image"]
return {"image": img, "labels": torch.tensor([1])}
def get_image(self, index):
return ci.load_image(
file_name=self.dataframe.file_name.to_list()[index],
path_to_images=self.path_to_images,
)
def test_augmentations(
df,
image_size,
path_to_images,
columns: list = [],
kinds: list = ["resize", "to_tensor"],
rows: int = 2,
cols: int = 4,
**aug_params,
):
sample = df.sample(n=1)
src_dataset = MlcPatches(
dataframe=sample,
transform=get_augmentations(
image_size=image_size, kinds=["resize", "to_tensor"], **aug_params
),
path_to_images=path_to_images,
)
test_dataset = MlcPatches(
dataframe=sample,
transform=get_augmentations(image_size=image_size, kinds=kinds, **aug_params),
path_to_images=path_to_images,
)
pprint(sample[[c for c in ["file_name"] + columns if c in sample]])
cp.tensor_image_to_grid(
images=[(src_dataset[0]["image"], "source")]
+ [(test_dataset[0]["image"], "augmented") for i in range(rows * cols)],
transform=get_augmentations(
image_size=image_size, kinds=(["un_normalize"]), **aug_params
),
row_count=rows,
col_count=cols,
figsize=(cols * 4, rows * 4),
)