treizh commited on
Commit
fc262e7
1 Parent(s): ac0c4c7

Upload folder using huggingface_hub

Browse files
src/com_augmentations.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pprint import pprint
2
+
3
+ import numpy as np
4
+
5
+ import cv2
6
+
7
+ import albumentations as A
8
+ from albumentations.pytorch import ToTensorV2
9
+ from albumentations import ImageOnlyTransform
10
+
11
+ import torch
12
+ from torch.utils.data import Dataset
13
+
14
+ import com_image as ci
15
+ import com_plot as cp
16
+
17
+
18
+ class FixPatchBrightness(ImageOnlyTransform):
19
+ def __init__(
20
+ self,
21
+ brightness_target=115,
22
+ brightness_thresholds=(115, 130),
23
+ always_apply: bool = False,
24
+ p: float = 0.5,
25
+ ):
26
+ super().__init__(always_apply, p)
27
+ self.brightness_target = brightness_target
28
+ self.brightness_thresholds = brightness_thresholds
29
+
30
+ def apply(self, img, brightness_target=None, brightness_thresholds=None, **params):
31
+ brightness_target = (
32
+ self.brightness_target if brightness_target is None else brightness_target
33
+ )
34
+ brightness_thresholds = (
35
+ self.brightness_thresholds
36
+ if brightness_thresholds is None
37
+ else brightness_thresholds
38
+ )
39
+
40
+ r, g, b = cv2.split(img)
41
+ avg_bright = np.sqrt(
42
+ 0.241 * np.power(r.astype(float), 2)
43
+ + 0.691 * np.power(g.astype(float), 2)
44
+ + 0.068 * np.power(b.astype(float), 2)
45
+ ).mean()
46
+
47
+ tmin, tmax = min(*brightness_thresholds), max(*brightness_thresholds)
48
+
49
+ if avg_bright < tmin or avg_bright > tmax:
50
+ if avg_bright > brightness_target:
51
+ gamma = brightness_target / avg_bright
52
+ if gamma != 1:
53
+ inv_gamma = 1.0 / gamma
54
+ table = np.array(
55
+ [((i / 255.0) ** inv_gamma) * 255 for i in np.arange(0, 256)]
56
+ ).astype("uint8")
57
+ return cv2.LUT(src=img, lut=table)
58
+ else:
59
+ return img
60
+ else:
61
+ return cv2.convertScaleAbs(
62
+ src=img,
63
+ alpha=(brightness_target + avg_bright) / (2 * avg_bright),
64
+ beta=(brightness_target - avg_bright) / 2,
65
+ )
66
+ else:
67
+ return img
68
+
69
+
70
+ def build_albumentations(
71
+ image_size: int,
72
+ gamma=(60, 180),
73
+ brightness_limit=0.15,
74
+ contrast_limit=0.25,
75
+ crop=None,
76
+ center_crop: int = -1,
77
+ mean=(0.485, 0.456, 0.406),
78
+ std=(0.229, 0.224, 0.225),
79
+ brightness_target=None,
80
+ brightness_thresholds=None,
81
+ affine_transforms={"H": 0.3, "V": 0.3, "R": 0.3, "T": 0.3},
82
+ ):
83
+ albs_ = {"resize": [A.Resize(height=image_size, width=image_size, p=1)]}
84
+
85
+ if brightness_target is not None and brightness_thresholds is not None:
86
+ albs_ = albs_ | {
87
+ "fix_brightness": [
88
+ FixPatchBrightness(
89
+ brightness_target=brightness_target,
90
+ brightness_thresholds=brightness_thresholds,
91
+ p=1,
92
+ )
93
+ ]
94
+ }
95
+
96
+ if crop is not None:
97
+ if isinstance(crop, int):
98
+ albs_ = albs_ | {
99
+ "crop_and_pad": [
100
+ A.RandomCrop(height=crop, width=crop, p=0.5),
101
+ A.PadIfNeeded(min_height=image_size, min_width=image_size, p=1),
102
+ ]
103
+ }
104
+ elif isinstance(crop, dict):
105
+ crop_val = crop["value"]
106
+ crop_p = crop["p"]
107
+ albs_ = albs_ | {
108
+ "crop_and_pad": [
109
+ A.PadIfNeeded(min_height=crop_val, min_width=crop_val, p=1),
110
+ A.RandomCrop(height=crop_val, width=crop_val, p=crop_p),
111
+ A.PadIfNeeded(min_height=image_size, min_width=image_size, p=1),
112
+ ]
113
+ }
114
+
115
+ if center_crop > -1:
116
+ albs_ = albs_ | {
117
+ "center_crop": [
118
+ A.PadIfNeeded(min_height=center_crop, min_width=center_crop, p=1),
119
+ A.CenterCrop(height=center_crop, width=center_crop, p=1),
120
+ ]
121
+ }
122
+
123
+ affine = []
124
+ for k, v in affine_transforms.items():
125
+ if k == "H":
126
+ affine.append(A.HorizontalFlip(p=v))
127
+ elif k == "V":
128
+ affine.append(A.VerticalFlip(p=v))
129
+ elif k == "R":
130
+ affine.append(A.RandomRotate90(p=v))
131
+ elif k == "T":
132
+ affine.append(A.Transpose(p=v))
133
+ albs_ = albs_ | {"affine": affine}
134
+
135
+ color = []
136
+ if brightness_limit is not None and contrast_limit is not None:
137
+ color.append(
138
+ A.RandomBrightnessContrast(
139
+ brightness_limit=brightness_limit,
140
+ contrast_limit=contrast_limit,
141
+ p=0.5,
142
+ )
143
+ )
144
+ if gamma is not None:
145
+ color.append(A.RandomGamma(gamma_limit=gamma, p=0.5))
146
+
147
+ albs_ = albs_ | {"color": color}
148
+
149
+ return albs_ | {
150
+ "to_tensor": [A.Normalize(mean=mean, std=std, p=1), ToTensorV2()],
151
+ "un_normalize": [
152
+ A.Normalize(
153
+ mean=[-m / s for m, s in zip(mean, std)],
154
+ std=[1.0 / s for s in std],
155
+ always_apply=True,
156
+ max_pixel_value=1.0,
157
+ ),
158
+ ],
159
+ }
160
+
161
+
162
+ def get_augmentations(
163
+ image_size: int = 224,
164
+ gamma=(60, 180),
165
+ brightness_limit=0.15,
166
+ contrast_limit=0.25,
167
+ crop=180,
168
+ center_crop: int = -1,
169
+ kinds: list = ["resize", "to_tensor"],
170
+ mean=(0.485, 0.456, 0.406),
171
+ std=(0.229, 0.224, 0.225),
172
+ brightness_target=None,
173
+ brightness_thresholds=None,
174
+ affine_transforms={"H": 0.3, "V": 0.3, "R": 0.3, "T": 0.3},
175
+ ):
176
+ if "train" in kinds:
177
+ kinds.insert(kinds.index("train"), "affine")
178
+ kinds.insert(kinds.index("train"), "color")
179
+ kinds.remove("train")
180
+ td_ = build_albumentations(
181
+ image_size := image_size,
182
+ gamma=gamma,
183
+ brightness_limit=brightness_limit,
184
+ contrast_limit=contrast_limit,
185
+ crop=crop,
186
+ center_crop=center_crop,
187
+ mean=mean,
188
+ std=std,
189
+ brightness_target=brightness_target,
190
+ brightness_thresholds=brightness_thresholds,
191
+ affine_transforms=affine_transforms,
192
+ )
193
+ augs = []
194
+ for k in kinds:
195
+ if k:
196
+ augs += td_[k] # .append(*[a for a in td_[k]])
197
+ return A.Compose(augs)
198
+
199
+
200
+ class MlcPatches(Dataset):
201
+ def __init__(self, dataframe, transform, path_to_images) -> None:
202
+ super().__init__()
203
+ self.dataframe = dataframe
204
+ self.transform = transform
205
+ self.path_to_images = path_to_images
206
+
207
+ def __len__(self):
208
+ return self.dataframe.shape[0]
209
+
210
+ def __getitem__(self, index):
211
+ img = self.transform(image=self.get_image(index=index))["image"]
212
+ return {"image": img, "labels": torch.tensor([1])}
213
+
214
+ def get_image(self, index):
215
+ return ci.load_image(
216
+ file_name=self.dataframe.file_name.to_list()[index],
217
+ path_to_images=self.path_to_images,
218
+ )
219
+
220
+
221
+ def test_augmentations(
222
+ df,
223
+ image_size,
224
+ path_to_images,
225
+ columns: list = [],
226
+ kinds: list = ["resize", "to_tensor"],
227
+ rows: int = 2,
228
+ cols: int = 4,
229
+ **aug_params,
230
+ ):
231
+ sample = df.sample(n=1)
232
+ src_dataset = MlcPatches(
233
+ dataframe=sample,
234
+ transform=get_augmentations(
235
+ image_size=image_size, kinds=["resize", "to_tensor"], **aug_params
236
+ ),
237
+ path_to_images=path_to_images,
238
+ )
239
+
240
+ test_dataset = MlcPatches(
241
+ dataframe=sample,
242
+ transform=get_augmentations(image_size=image_size, kinds=kinds, **aug_params),
243
+ path_to_images=path_to_images,
244
+ )
245
+ pprint(sample[[c for c in ["file_name"] + columns if c in sample]])
246
+ cp.tensor_image_to_grid(
247
+ images=[(src_dataset[0]["image"], "source")]
248
+ + [(test_dataset[0]["image"], "augmented") for i in range(rows * cols)],
249
+ transform=get_augmentations(
250
+ image_size=image_size, kinds=(["un_normalize"]), **aug_params
251
+ ),
252
+ row_count=rows,
253
+ col_count=cols,
254
+ figsize=(cols * 4, rows * 4),
255
+ )
src/com_const.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ path_to_here = Path(__file__).resolve().parent
4
+ path_to_root = path_to_here.parent
5
+
6
+ path_to_data = path_to_root.joinpath("data")
7
+
8
+ path_to_images = path_to_root.joinpath("images")
9
+ path_to_plates = path_to_images.joinpath("plates")
10
+ path_to_leaf_discs = path_to_images.joinpath("leaf_discs")
11
+ path_to_leaf_patches = path_to_images.joinpath("leaf_patches")
12
+
13
+ path_to_checkpoints = path_to_root.joinpath("checkpoints")
14
+ path_to_chk_detector = path_to_checkpoints.joinpath("leaf_disc_detector")
15
+ path_to_chk_oiv = path_to_checkpoints.joinpath("oiv_scorer")
16
+
17
+ path_to_src = path_to_root.joinpath("src")
src/com_func.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+
4
+ def ensure_folder(forced_path, return_string: bool = True):
5
+ path = forced_path.parent
6
+ if path.is_dir() is False:
7
+ path.mkdir(parents=True, exist_ok=True)
8
+ return str(forced_path) if return_string is True else forced_path
9
+
10
+
11
+ def read_dataframe(path, sep=";") -> pd.DataFrame:
12
+ try:
13
+ return pd.read_csv(filepath_or_buffer=str(path), sep=sep)
14
+ except:
15
+ return None
16
+
17
+
18
+ def write_dataframe(df: pd.DataFrame, path, sep=";") -> pd.DataFrame:
19
+ df.to_csv(path_or_buf=ensure_folder(path, return_string=True), sep=sep, index=False)
20
+ return df
src/com_image.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Any, Union
3
+
4
+ import numpy as np
5
+
6
+ import cv2
7
+
8
+ from PIL import Image, ImageEnhance
9
+
10
+
11
+ def load_image(file_name, path_to_images=None, rgb: bool = True):
12
+ path = (
13
+ file_name
14
+ if isinstance(file_name, Path) is True
15
+ else path_to_images.joinpath(file_name)
16
+ )
17
+
18
+ try:
19
+ img = cv2.imread(str(path))
20
+ if rgb is True:
21
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
22
+ except Exception as e:
23
+ print(file_name)
24
+ return img
25
+
26
+
27
+ def to_pil(image):
28
+ return Image.fromarray(image)
29
+
30
+
31
+ def to_cv2(image):
32
+ return np.array(image)
33
+
34
+
35
+ def enhance_pil_image(
36
+ image, color=1, brightness=1, contrast=1, sharpness=1
37
+ ) -> Image.Image:
38
+ image = ImageEnhance.Sharpness(
39
+ image=ImageEnhance.Brightness(
40
+ image=ImageEnhance.Contrast(
41
+ image=ImageEnhance.Color(
42
+ image=(
43
+ image
44
+ if isinstance(image, Image.Image) is True
45
+ else to_pil(image=image)
46
+ )
47
+ ).enhance(color)
48
+ ).enhance(contrast)
49
+ ).enhance(brightness)
50
+ ).enhance(sharpness)
51
+ return image
52
+
53
+
54
+ def ensure_odd(
55
+ i: int,
56
+ min_val: Union[None, int] = None,
57
+ max_val: Union[None, int] = None,
58
+ ) -> int:
59
+ """Transforms an odd number into pair number by adding one
60
+ Arguments:
61
+ i {int} -- number
62
+ Returns:
63
+ int -- Odd number
64
+ """
65
+ if (i > 0) and (i % 2 == 0):
66
+ i += 1
67
+ if min_val is not None:
68
+ return max(i, min_val)
69
+ if max_val is not None:
70
+ return min(i, max_val)
71
+ return i
72
+
73
+
74
+ def get_morphology_kernel(size: int, shape: int):
75
+ """Builds morphology kernel
76
+ :param size: kernel size, must be odd number
77
+ :param shape: select shape of kernel
78
+ :return: Morphology kernel
79
+ """
80
+ size = ensure_odd(size)
81
+ return cv2.getStructuringElement(shape, (size, size))
82
+
83
+
84
+ def close(
85
+ image: Any,
86
+ kernel_size: int = 3,
87
+ kernel_shape: int = cv2.MORPH_ELLIPSE,
88
+ rois: tuple = (),
89
+ proc_times: int = 1,
90
+ ):
91
+ """Morphology - Close wrapper
92
+ Arguments:
93
+ image {numpy array} -- Source image
94
+ kernel_size {int} -- kernel size
95
+ kernel_shape {int} -- cv2 constant
96
+ roi -- Region of Interest
97
+ proc_times {int} -- iterations
98
+ Returns:
99
+ numpy array -- closed image
100
+ """
101
+ morph_kernel = get_morphology_kernel(kernel_size, kernel_shape)
102
+ if rois:
103
+ result = image.copy()
104
+ for roi in rois:
105
+ r = roi.as_rect()
106
+ result[r.top : r.bottom, r.left : r.right] = cv2.morphologyEx(
107
+ result[r.top : r.bottom, r.left : r.right],
108
+ cv2.MORPH_CLOSE,
109
+ morph_kernel,
110
+ iterations=proc_times,
111
+ )
112
+ else:
113
+ result = cv2.morphologyEx(
114
+ image, cv2.MORPH_CLOSE, morph_kernel, iterations=proc_times
115
+ )
116
+ return result
117
+
118
+
119
+ def get_concat_h_multi_resize(im_list, resample=Image.Resampling.BICUBIC):
120
+ min_height = min(im.height for im in im_list)
121
+ im_list_resize = [
122
+ im.resize(
123
+ (int(im.width * min_height / im.height), min_height), resample=resample
124
+ )
125
+ for im in im_list
126
+ ]
127
+ total_width = sum(im.width for im in im_list_resize)
128
+ dst = Image.new("RGB", (total_width, min_height))
129
+ pos_x = 0
130
+ for im in im_list_resize:
131
+ dst.paste(im, (pos_x, 0))
132
+ pos_x += im.width
133
+ return dst
134
+
135
+
136
+ def get_concat_v_multi_resize(im_list, resample=Image.Resampling.BICUBIC):
137
+ min_width = min(im.width for im in im_list)
138
+ im_list_resize = [
139
+ im.resize((min_width, int(im.height * min_width / im.width)), resample=resample)
140
+ for im in im_list
141
+ ]
142
+ total_height = sum(im.height for im in im_list_resize)
143
+ dst = Image.new("RGB", (min_width, total_height))
144
+ pos_y = 0
145
+ for im in im_list_resize:
146
+ dst.paste(im, (0, pos_y))
147
+ pos_y += im.height
148
+ return dst
149
+
150
+
151
+ def get_concat_tile_resize(im_list_2d, resample=Image.Resampling.BICUBIC):
152
+ im_list_v = [
153
+ get_concat_h_multi_resize(im_list_h, resample=resample)
154
+ for im_list_h in im_list_2d
155
+ ]
156
+ return get_concat_v_multi_resize(im_list_v, resample=resample)
157
+
158
+
159
+ def get_tiles(img_list, row_count, resample=Image.Resampling.BICUBIC):
160
+ if isinstance(img_list, np.ndarray) is False:
161
+ img_list = np.asarray(img_list, dtype="object")
162
+ return get_concat_tile_resize(np.split(img_list, row_count), resample)
src/com_plot.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+
3
+
4
+ def _update_axis(
5
+ axis, image, title=None, fontsize=18, remove_axis=True, title_loc="center"
6
+ ):
7
+ axis.imshow(image, origin="upper")
8
+ if title is not None:
9
+ axis.set_title(title, fontsize=fontsize, loc=title_loc)
10
+ if remove_axis is True:
11
+ axis.set_axis_off()
12
+
13
+
14
+ def tensor_image_to_grid(
15
+ images: list,
16
+ transform,
17
+ row_count,
18
+ col_count=None,
19
+ figsize=(20, 20),
20
+ fontsize=None,
21
+ ):
22
+ def splt_image_title(image):
23
+ if isinstance(image, tuple):
24
+ return image[0], image[1]
25
+ else:
26
+ return image, None
27
+
28
+ def torch_to_image(t):
29
+ return transform(image=t.permute(1, 2, 0).numpy())["image"]
30
+
31
+ col_count = row_count if col_count is None else col_count
32
+ if len(images) == 1:
33
+ img, title = splt_image_title(images[0])
34
+ plt.imshow(torch_to_image(img))
35
+ plt.title = title
36
+ plt.tight_layout()
37
+ plt.axis("off")
38
+ else:
39
+ _, axii = plt.subplots(row_count, col_count, figsize=figsize)
40
+ for ax, image in zip(axii.reshape(-1), images):
41
+ try:
42
+ img, title = splt_image_title(image)
43
+ _update_axis(
44
+ axis=ax,
45
+ image=torch_to_image(img),
46
+ remove_axis=True,
47
+ title=title,
48
+ fontsize=figsize[0] if fontsize is None else fontsize,
49
+ )
50
+ except:
51
+ pass
52
+
53
+ plt.tight_layout()
54
+ plt.show()
src/leaf_patch_annotation.ipynb ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## 202311 Dataset Annotation"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "## Imports"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "%load_ext autoreload\n",
24
+ "%autoreload 2"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "from pathlib import Path\n",
34
+ "import warnings\n",
35
+ "from datetime import datetime as dt\n",
36
+ "import inspect\n",
37
+ "\n",
38
+ "import pandas as pd\n",
39
+ "import numpy as np\n",
40
+ "\n",
41
+ "import altair as alt\n",
42
+ "import plotly.express as px\n",
43
+ "\n",
44
+ "from PIL import Image, ImageEnhance\n",
45
+ "\n",
46
+ "from siuba import _ as s\n",
47
+ "from siuba import filter as sfilter\n",
48
+ "from siuba import mutate\n",
49
+ "\n",
50
+ "import panel as pn\n",
51
+ "\n",
52
+ "import com_const as cc\n",
53
+ "import com_func as cf"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "markdown",
58
+ "metadata": {},
59
+ "source": [
60
+ "## Setup"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "warnings.simplefilter(action=\"ignore\", category=UserWarning)\n",
70
+ "warnings.simplefilter(action=\"ignore\", category=FutureWarning)"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "metadata": {},
77
+ "outputs": [],
78
+ "source": [
79
+ "pd.set_option(\"display.max_colwidth\", 500)\n",
80
+ "pd.set_option(\"display.max_columns\", 500)\n",
81
+ "pd.set_option(\"display.width\", 1000)\n",
82
+ "pd.set_option(\"display.max_rows\", 16)"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": [
91
+ "alt.data_transformers.disable_max_rows();"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "pn.extension(\n",
101
+ " \"plotly\", \"terminal\", \"tabulator\", \"vega\", notifications=True, console_output=\"disable\"\n",
102
+ ")"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "template = pn.template.BootstrapTemplate(title=\"OIV Annotation Tool\")"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "markdown",
116
+ "metadata": {},
117
+ "source": [
118
+ "## Gobals"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "current_row = None"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": [
136
+ "image_quality_options = {\n",
137
+ " \"?\":np.nan,\n",
138
+ " \"good\":\"good_images\",\n",
139
+ " \"crop\":\"crop_images\",\n",
140
+ " \"missing\":\"missing_images\",\n",
141
+ " \"dark\":\"dark_images\",\n",
142
+ " \"blur\":\"blur_images\",\n",
143
+ " \"color\":\"color_images\",\n",
144
+ " \"water\":\"water_images\",\n",
145
+ "}"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "markdown",
150
+ "metadata": {},
151
+ "source": [
152
+ "## Load Data"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": null,
158
+ "metadata": {},
159
+ "outputs": [],
160
+ "source": [
161
+ "df = cf.read_dataframe(path=cc.path_to_data.joinpath(\"oiv_annotation.csv\")).sort_values(\n",
162
+ " [\"experiment\", \"inoc\", \"dpi\", \"plaque\", \"row\", \"col\"]\n",
163
+ ")\n",
164
+ "if \"seen_at\" not in df:\n",
165
+ " df = df >> mutate(seen_at=np.nan)\n",
166
+ "df.seen_at = pd.to_datetime(df.seen_at)\n",
167
+ "df = df.set_index(\"file_name\")\n",
168
+ "df"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "markdown",
173
+ "metadata": {},
174
+ "source": [
175
+ "## Functions"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": null,
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "def update_image(image_name:str, color, brightness, contrast, sharpness):\n",
185
+ " image_path = cc.path_to_leaf_patches.joinpath(image_name)\n",
186
+ " if image_path.is_file() is False:\n",
187
+ " fig = px.imshow(\n",
188
+ " np.array(\n",
189
+ " [\n",
190
+ " [[255, 0, 255], [255, 0, 255], [255, 0, 255]],\n",
191
+ " [[255, 0, 255], [255, 0, 255], [255, 0, 255]],\n",
192
+ " [[255, 0, 255], [255, 0, 255], [255, 0, 255]],\n",
193
+ " ],\n",
194
+ " dtype=np.uint8,\n",
195
+ " )\n",
196
+ " )\n",
197
+ " else:\n",
198
+ " image = Image.open(image_path)\n",
199
+ " image = ImageEnhance.Color(image=image).enhance(color)\n",
200
+ " image = ImageEnhance.Contrast(image=image).enhance(contrast)\n",
201
+ " image = ImageEnhance.Brightness(image=image).enhance(brightness)\n",
202
+ " image = ImageEnhance.Sharpness(image=image).enhance(sharpness)\n",
203
+ " fig = px.imshow(image)\n",
204
+ " fig.update_layout(coloraxis_showscale=False)\n",
205
+ " fig.update_xaxes(showticklabels=False)\n",
206
+ " fig.update_yaxes(showticklabels=False)\n",
207
+ " fig.update_layout(margin=dict(l=0, r=0, t=0, b=0))\n",
208
+ " return fig"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "metadata": {},
215
+ "outputs": [],
216
+ "source": [
217
+ "def plot_classes(df_: pd.DataFrame, var: str):\n",
218
+ " d = pd.DataFrame(\n",
219
+ " data={\n",
220
+ " var: df_[var]\n",
221
+ " .fillna(\"?\")\n",
222
+ " .astype(str)\n",
223
+ " .str.replace(\".0\", \"\")\n",
224
+ " .str.replace(\"images\", \"\")\n",
225
+ " }\n",
226
+ " )\n",
227
+ " bars = (\n",
228
+ " alt.Chart(d)\n",
229
+ " .mark_bar()\n",
230
+ " .encode(\n",
231
+ " y=alt.Y(var, title=None),\n",
232
+ " x=alt.X(\"count()\", axis=None),\n",
233
+ " color=alt.Color(var, legend=None),\n",
234
+ " )\n",
235
+ " )\n",
236
+ " text = bars.mark_text(align=\"center\", dy=0, dx=12).encode(\n",
237
+ " y=alt.Y(var, title=None),\n",
238
+ " x=alt.X(\"count()\", axis=None),\n",
239
+ " color=alt.Color(var, legend=None),\n",
240
+ " text=\"count()\",\n",
241
+ " )\n",
242
+ "\n",
243
+ " return (bars + text).configure_view(stroke=None).configure_axis(grid=False)"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "markdown",
248
+ "metadata": {},
249
+ "source": [
250
+ "## Widgets"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": null,
256
+ "metadata": {},
257
+ "outputs": [],
258
+ "source": [
259
+ "img_current = pn.pane.Plotly(height=750, align=(\"center\", \"center\"))\n",
260
+ "mkd_current = pn.pane.Markdown(sizing_mode=\"scale_width\", align=\"center\")\n",
261
+ "\n",
262
+ "sl_contrast = pn.widgets.EditableFloatSlider(\n",
263
+ " name=\"Contrast\", start=0.0, end=7.5, value=1.5, step=0.1, sizing_mode=\"scale_width\"\n",
264
+ ")\n",
265
+ "sl_color = pn.widgets.EditableFloatSlider(\n",
266
+ " name=\"Color\", start=0.0, end=5.0, value=1.0, step=0.1, sizing_mode=\"scale_width\"\n",
267
+ ")\n",
268
+ "sl_brightness = pn.widgets.EditableFloatSlider(\n",
269
+ " name=\"Brightness\",\n",
270
+ " start=0.0,\n",
271
+ " end=5.0,\n",
272
+ " value=1.0,\n",
273
+ " step=0.1,\n",
274
+ " sizing_mode=\"scale_width\",\n",
275
+ ")\n",
276
+ "sl_sharpness = pn.widgets.EditableFloatSlider(\n",
277
+ " name=\"Sharpness\", start=0.0, end=2.0, value=1.5, step=0.1, sizing_mode=\"scale_width\"\n",
278
+ ")\n",
279
+ "\n",
280
+ "c_image_processing = pn.Card(\n",
281
+ " pn.Column(sl_brightness, sl_color, sl_contrast, sl_sharpness),\n",
282
+ " title=\"Image Processing Options\",\n",
283
+ " sizing_mode=\"scale_width\",\n",
284
+ ")\n",
285
+ "\n",
286
+ "pg_progress = pn.widgets.Tqdm(name=\"Progress\", align=\"center\", max=len(df))\n",
287
+ "\n",
288
+ "rgb_oiv = pn.widgets.RadioButtonGroup(\n",
289
+ " name=\"OIV\",\n",
290
+ " options=[\"?\", 1, 3, 5, 7, 9],\n",
291
+ " button_style=\"outline\",\n",
292
+ " button_type=\"success\",\n",
293
+ ")\n",
294
+ "\n",
295
+ "rgb_source = pn.widgets.RadioButtonGroup(\n",
296
+ " name=\"Image quality\",\n",
297
+ " options=list(image_quality_options.keys()),\n",
298
+ " button_style=\"outline\",\n",
299
+ " button_type=\"success\",\n",
300
+ " value=\"?\",\n",
301
+ ")\n",
302
+ "\n",
303
+ "sel_def_img_quality = pn.widgets.Select(\n",
304
+ " name=\"Default Image Quality\", options=list(image_quality_options.keys())\n",
305
+ ")\n",
306
+ "\n",
307
+ "mc_filter_quality = pn.widgets.MultiChoice(\n",
308
+ " name=\"Allow qualities\",\n",
309
+ " options=list(image_quality_options.values()),\n",
310
+ " value=list(image_quality_options.values()),\n",
311
+ ")\n",
312
+ "\n",
313
+ "rgb_target = pn.widgets.RadioButtonGroup(\n",
314
+ " name=\"Annotation target\",\n",
315
+ " options=[\"All\", \"OIV\", \"Image quality\"],\n",
316
+ " button_style=\"outline\",\n",
317
+ " button_type=\"success\",\n",
318
+ " value=\"All\",\n",
319
+ ")\n",
320
+ "\n",
321
+ "c_anno_options = pn.Card(\n",
322
+ " pn.Column(\n",
323
+ " pn.Row(pn.pane.Markdown(\"**Annotate**\"), rgb_target),\n",
324
+ " sel_def_img_quality,\n",
325
+ " mc_filter_quality,\n",
326
+ " ),\n",
327
+ " title=\"Annotation Options\",\n",
328
+ " sizing_mode=\"scale_width\",\n",
329
+ ")\n",
330
+ "\n",
331
+ "pn_hist_oiv = pn.pane.Vega()\n",
332
+ "pn_hist_source = pn.pane.Vega()\n",
333
+ "\n",
334
+ "c_hists = pn.Card(\n",
335
+ " pn.Column(\n",
336
+ " pn.pane.Markdown(\"### OIV\"),\n",
337
+ " pn_hist_oiv,\n",
338
+ " pn.pane.Markdown(\"### Image Quality\"),\n",
339
+ " pn_hist_source,\n",
340
+ " ),\n",
341
+ " title=\"Annotation Overview\",\n",
342
+ " sizing_mode=\"scale_width\",\n",
343
+ ")\n",
344
+ "\n",
345
+ "sw_ui_state = pn.widgets.Switch(name=\"active\", value=False)\n",
346
+ "alt_ui_state = pn.pane.Alert(\"Annotations will be stored\", alert_type=\"primary\")\n",
347
+ "\n",
348
+ "pn_ui_state = pn.Row(sw_ui_state, alt_ui_state)\n",
349
+ "\n",
350
+ "\n",
351
+ "bt_next = pn.widgets.Button(name=\"Next\", button_type=\"primary\")\n",
352
+ "bt_previous = pn.widgets.Button(name=\"Previous\", button_type=\"primary\")\n",
353
+ "\n",
354
+ "ui_annotation = pn.GridSpec(sizing_mode=\"scale_width\", align=\"center\", max_height=120)\n",
355
+ "\n",
356
+ "ui_annotation[1, 0] = bt_previous\n",
357
+ "ui_annotation[0, 1:5] = rgb_source\n",
358
+ "ui_annotation[1, 1:5] = rgb_oiv\n",
359
+ "ui_annotation[1, 5] = bt_next"
360
+ ]
361
+ },
362
+ {
363
+ "cell_type": "markdown",
364
+ "metadata": {},
365
+ "source": [
366
+ "## Callbacks"
367
+ ]
368
+ },
369
+ {
370
+ "cell_type": "code",
371
+ "execution_count": null,
372
+ "metadata": {},
373
+ "outputs": [],
374
+ "source": [
375
+ "@pn.depends(\n",
376
+ " sl_color.param.value,\n",
377
+ " sl_contrast.param.value,\n",
378
+ " sl_brightness.param.value,\n",
379
+ " sl_sharpness.param.value,\n",
380
+ " watch=True,\n",
381
+ ")\n",
382
+ "def on_preprocess_changed(color, contrast, brightness, sharpeness):\n",
383
+ " img_current.object = update_image(\n",
384
+ " image_name=current_row.file_name,\n",
385
+ " color=color,\n",
386
+ " brightness=brightness,\n",
387
+ " contrast=contrast,\n",
388
+ " sharpness=sharpeness,\n",
389
+ " )\n",
390
+ "\n",
391
+ "\n",
392
+ "def update_ui_state(ui_state: bool):\n",
393
+ " if ui_state is True:\n",
394
+ " alt_ui_state.object = \"Annotations will be stored\"\n",
395
+ " alt_ui_state.alert_type = \"primary\"\n",
396
+ " else:\n",
397
+ " alt_ui_state.object = \"Annotations will be discarded\"\n",
398
+ " alt_ui_state.alert_type = \"danger\"\n",
399
+ "\n",
400
+ "\n",
401
+ "@pn.depends(sw_ui_state, watch=True)\n",
402
+ "def on_ui_State_changed(new_state: bool):\n",
403
+ " update_ui_state(new_state)\n",
404
+ "\n",
405
+ "\n",
406
+ "def select_next(event):\n",
407
+ " global current_row\n",
408
+ " global df\n",
409
+ " now = dt.now()\n",
410
+ " if current_row is not None and (event is None or event.obj.name == \"Next\"):\n",
411
+ " if rgb_target.value in [\"All\", \"OIV\"] and rgb_oiv.value != \"?\":\n",
412
+ " df.at[current_row.file_name, \"oiv\"] = int(rgb_oiv.value)\n",
413
+ " df.at[current_row.file_name, \"oiv_annotated_at\"] = now\n",
414
+ "\n",
415
+ " if rgb_target.value in [\"All\", \"Image quality\"] and rgb_source.value != \"?\":\n",
416
+ " df.at[current_row.file_name, \"source_annotated_at\"] = now\n",
417
+ " df.at[current_row.file_name, \"source\"] = image_quality_options[\n",
418
+ " rgb_source.value\n",
419
+ " ]\n",
420
+ " cf.write_dataframe(\n",
421
+ " df=df.reset_index(),\n",
422
+ " path=cc.path_to_data.joinpath(\n",
423
+ " \"oiv_annotation.csv\" if sw_ui_state.value is True else \"oiv_annotation_test.csv\"\n",
424
+ " ),\n",
425
+ " )\n",
426
+ " df.at[current_row.file_name, \"seen_at\"] = now\n",
427
+ "\n",
428
+ " df_cr = df >> sfilter(s.source.isin(mc_filter_quality.value))\n",
429
+ "\n",
430
+ " if rgb_target.value == \"All\":\n",
431
+ " df_cr = df_cr >> sfilter(s.oiv.isna() | s.source.isna())\n",
432
+ " elif rgb_target.value == \"OIV\":\n",
433
+ " df_cr = df_cr >> sfilter(s.oiv.isna())\n",
434
+ " if rgb_target.value == \"Image quality\":\n",
435
+ " df_cr = df_cr >> sfilter(s.source.isna())\n",
436
+ " remaining = len(df_cr)\n",
437
+ " if event is None or event.obj.name == \"Next\":\n",
438
+ " df_cr = df_cr.reset_index()\n",
439
+ " current_row = df_cr.sample(n=1).iloc[0] if len(df_cr) > 0 else None\n",
440
+ " elif event.obj.name == \"Previous\":\n",
441
+ " current_row = (\n",
442
+ " (df.reset_index() >> sfilter(~s.seen_at.isna()))\n",
443
+ " .sort_values(\"seen_at\", ascending=False)\n",
444
+ " .reset_index(drop=True)\n",
445
+ " .iloc[0]\n",
446
+ " )\n",
447
+ " df.at[current_row.file_name, \"seen_at\"] = None\n",
448
+ "\n",
449
+ " if current_row is not None:\n",
450
+ " rgb_source.value = (\n",
451
+ " sel_def_img_quality.value\n",
452
+ " if pd.isnull(current_row.source)\n",
453
+ " else {v: k for k, v in image_quality_options.items()}[current_row.source]\n",
454
+ " )\n",
455
+ " rgb_oiv.value = (\n",
456
+ " current_row.oiv if current_row.oiv in [1, 3, 5, 7, 9] else \"?\"\n",
457
+ " )\n",
458
+ "\n",
459
+ " pg_progress.value = len(df) - remaining\n",
460
+ " file_name = current_row.file_name if current_row is not None else \"\"\n",
461
+ " img_current.object = update_image(\n",
462
+ " image_name=file_name,\n",
463
+ " color=sl_color.value,\n",
464
+ " brightness=sl_brightness.value,\n",
465
+ " contrast=sl_contrast.value,\n",
466
+ " sharpness=sl_sharpness.value,\n",
467
+ " )\n",
468
+ " mkd_current.object = f\"## {file_name}\"\n",
469
+ " df_unf = df >> sfilter(s.source.isin(mc_filter_quality.value))\n",
470
+ " pn_hist_source.object = plot_classes(df_unf, \"source\")\n",
471
+ " pn_hist_oiv.object = plot_classes(df_unf, \"oiv\")\n",
472
+ "\n",
473
+ "\n",
474
+ "@pn.depends(rgb_target, watch=True)\n",
475
+ "def on_target_changed(target):\n",
476
+ " rgb_oiv.disabled = target == \"Image quality\"\n",
477
+ " rgb_source.disabled = target == \"OIV\"\n",
478
+ "\n",
479
+ "\n",
480
+ "# @pn.depends(rgb_oiv, watch=True)\n",
481
+ "# def on_oiv_changed(_):\n",
482
+ "# select_next(None)\n",
483
+ "\n",
484
+ "\n",
485
+ "bt_next.on_click(select_next)\n",
486
+ "bt_previous.on_click(select_next)\n",
487
+ "\n",
488
+ "update_ui_state(sw_ui_state.value)\n",
489
+ "select_next(None)"
490
+ ]
491
+ },
492
+ {
493
+ "cell_type": "markdown",
494
+ "metadata": {},
495
+ "source": [
496
+ "## UI"
497
+ ]
498
+ },
499
+ {
500
+ "cell_type": "code",
501
+ "execution_count": null,
502
+ "metadata": {},
503
+ "outputs": [],
504
+ "source": [
505
+ "template.sidebar.append(pn_ui_state)\n",
506
+ "template.sidebar.append(c_image_processing)\n",
507
+ "template.sidebar.append(c_anno_options)\n",
508
+ "\n",
509
+ "template.main.append(\n",
510
+ " pn.Row(\n",
511
+ " pn.Column(\n",
512
+ " # mkd_current,\n",
513
+ " img_current,\n",
514
+ " ui_annotation,\n",
515
+ " ),\n",
516
+ " pn.Column(c_hists, pg_progress),\n",
517
+ " )\n",
518
+ ")\n",
519
+ "\n",
520
+ "template.servable()"
521
+ ]
522
+ },
523
+ {
524
+ "cell_type": "markdown",
525
+ "metadata": {},
526
+ "source": [
527
+ "# Please launch with command \"panel serve leaf_patch_annotation.ipynb --show --dev\" from the \"src\" folder"
528
+ ]
529
+ }
530
+ ],
531
+ "metadata": {
532
+ "kernelspec": {
533
+ "display_name": "env",
534
+ "language": "python",
535
+ "name": "python3"
536
+ },
537
+ "language_info": {
538
+ "codemirror_mode": {
539
+ "name": "ipython",
540
+ "version": 3
541
+ },
542
+ "file_extension": ".py",
543
+ "mimetype": "text/x-python",
544
+ "name": "python",
545
+ "nbconvert_exporter": "python",
546
+ "pygments_lexer": "ipython3",
547
+ "version": "3.9.2"
548
+ }
549
+ },
550
+ "nbformat": 4,
551
+ "nbformat_minor": 2
552
+ }
src/leaf_patch_extractor.ipynb ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Extract Leaf Patches From Plates"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "## Imports"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "%load_ext autoreload\n",
24
+ "%autoreload 2"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "from datetime import datetime as dt\n",
34
+ "import warnings\n",
35
+ "import random\n",
36
+ "\n",
37
+ "from tqdm import tqdm\n",
38
+ "\n",
39
+ "import cv2\n",
40
+ "\n",
41
+ "import pandas as pd\n",
42
+ "\n",
43
+ "from siuba import _ as s\n",
44
+ "from siuba import filter as sfilter\n",
45
+ "from siuba import mutate, select, if_else\n",
46
+ "\n",
47
+ "import panel as pn\n",
48
+ "\n",
49
+ "import torch\n",
50
+ "\n",
51
+ "from pytorch_lightning.callbacks import (\n",
52
+ " RichProgressBar,\n",
53
+ " ModelCheckpoint,\n",
54
+ " LearningRateMonitor,\n",
55
+ ")\n",
56
+ "from pytorch_lightning import Trainer\n",
57
+ "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
58
+ "from pytorch_lightning.loggers import TensorBoardLogger\n",
59
+ "\n",
60
+ "\n",
61
+ "import com_const as cc\n",
62
+ "import com_image as ci\n",
63
+ "import com_func as cf\n",
64
+ "import leaf_patch_extractor_model as lpem"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "markdown",
69
+ "metadata": {},
70
+ "source": [
71
+ "## Setup"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "warnings.simplefilter(action=\"ignore\", category=UserWarning)\n",
81
+ "warnings.simplefilter(action=\"ignore\", category=FutureWarning)"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": null,
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": [
90
+ "pd.set_option(\"display.max_colwidth\", 500)\n",
91
+ "pd.set_option(\"display.max_columns\", 500)\n",
92
+ "pd.set_option(\"display.width\", 1000)\n",
93
+ "pd.set_option(\"display.max_rows\", 16)"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "metadata": {},
100
+ "outputs": [],
101
+ "source": [
102
+ "pn.extension(notifications=True, console_output=\"disable\")"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "markdown",
107
+ "metadata": {},
108
+ "source": [
109
+ "## Train Disc Detector"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "markdown",
114
+ "metadata": {},
115
+ "source": [
116
+ "### Load Datasets"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": null,
122
+ "metadata": {},
123
+ "outputs": [],
124
+ "source": [
125
+ "train, val, test = [\n",
126
+ " cf.read_dataframe(cc.path_to_data.joinpath(f\"ldd_{d}.csv\"))\n",
127
+ " for d in [\"train\", \"val\", \"test\"]\n",
128
+ "]\n",
129
+ "\n",
130
+ "print(len(train), len(test), len(val))"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "markdown",
135
+ "metadata": {},
136
+ "source": [
137
+ "### Test Augmentations"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": null,
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "# aug_ = lpem.get_augmentations(image_size=10, kinds=[\"resize\", \"train\"])\n",
147
+ "\n",
148
+ "# test_aug_dataset = lpem.LeafDiskDetectorDataset(csv=train, transform=aug_)\n",
149
+ "\n",
150
+ "# file_name = train.sample(n=1).plate_name.to_list()[0]\n",
151
+ "\n",
152
+ "# print(aug_[0].width, aug_[0].height)\n",
153
+ "\n",
154
+ "# lpem.make_patches_grid(\n",
155
+ "# images=[\n",
156
+ "# test_aug_dataset.draw_image_with_boxes(plate_name=file_name) for _ in range(12)\n",
157
+ "# ],\n",
158
+ "# row_count=3,\n",
159
+ "# col_count=4,\n",
160
+ "# figsize=(12, 6),\n",
161
+ "# )"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "markdown",
166
+ "metadata": {},
167
+ "source": [
168
+ "### Train"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": [
177
+ "# model = lpem.LeafDiskDetector(\n",
178
+ "# batch_size=15,\n",
179
+ "# learning_rate=7.0e-05,\n",
180
+ "# image_factor=10,\n",
181
+ "# max_epochs=1000,\n",
182
+ "# train_data=train,\n",
183
+ "# val_data=val,\n",
184
+ "# test_data=test,\n",
185
+ "# augmentations_kinds=[\"resize\", \"train\", \"to_tensor\"],\n",
186
+ "# augmentations_params={\"gamma\": (60, 180)},\n",
187
+ "# num_workers=2,\n",
188
+ "# accumulate_grad_batches=5,\n",
189
+ "# scheduler=\"steplr\",\n",
190
+ "# scheduler_params={\"step_size\": 10, \"gamma\": 0.80},\n",
191
+ "# )\n",
192
+ "\n",
193
+ "# model.eval()\n",
194
+ "# len(model(torch.rand(2, 3, 128, 128)))\n",
195
+ "\n",
196
+ "# model.hr_desc()"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": null,
202
+ "metadata": {},
203
+ "outputs": [],
204
+ "source": [
205
+ "# trainer = Trainer(\n",
206
+ "# default_root_dir=str(cc.path_to_chk_detector),\n",
207
+ "# logger=TensorBoardLogger(\n",
208
+ "# save_dir=str(cc.path_to_chk_detector),\n",
209
+ "# version=model.model_name + \"_\" + dt.now().strftime(\"%Y%m%d_%H%M%S\"),\n",
210
+ "# name=\"lightning_logs\",\n",
211
+ "# ),\n",
212
+ "# accelerator=\"gpu\",\n",
213
+ "# max_epochs=model.max_epochs,\n",
214
+ "# log_every_n_steps=5,\n",
215
+ "# callbacks=[\n",
216
+ "# RichProgressBar(),\n",
217
+ "# EarlyStopping(monitor=\"val_loss\", mode=\"min\", patience=10, min_delta=0.0005),\n",
218
+ "# ModelCheckpoint(\n",
219
+ "# save_top_k=1,\n",
220
+ "# monitor=\"val_loss\",\n",
221
+ "# auto_insert_metric_name=True,\n",
222
+ "# filename=model.model_name\n",
223
+ "# + \"-{val_loss:.3f}-{epoch}-{train_loss:.3f}-{step}\",\n",
224
+ "# ),\n",
225
+ "# LearningRateMonitor(logging_interval=\"epoch\"),\n",
226
+ "# ],\n",
227
+ "# accumulate_grad_batches=model.accumulate_grad_batches,\n",
228
+ "# )\n",
229
+ "\n",
230
+ "# trainer.fit(model)"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "markdown",
235
+ "metadata": {},
236
+ "source": [
237
+ "## Extract Patches"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "markdown",
242
+ "metadata": {},
243
+ "source": [
244
+ "### Load Model"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "code",
249
+ "execution_count": null,
250
+ "metadata": {},
251
+ "outputs": [],
252
+ "source": [
253
+ "ld_model: lpem.LeafDiskDetector = lpem.LeafDiskDetector.load_from_checkpoint(\n",
254
+ " cc.path_to_chk_detector.joinpath(\"leaf_disc_detector.ckpt\")\n",
255
+ ")\n",
256
+ "ld_model.hr_desc()"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "markdown",
261
+ "metadata": {},
262
+ "source": [
263
+ "### Predict All Bounding Boxes"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "code",
268
+ "execution_count": null,
269
+ "metadata": {},
270
+ "outputs": [],
271
+ "source": [
272
+ "bb_predictions_path = cc.path_to_data.joinpath(\"train_ld_bounding_boxes.csv\")\n",
273
+ "\n",
274
+ "bb_predictions = (\n",
275
+ " cf.read_dataframe(bb_predictions_path)\n",
276
+ " if bb_predictions_path.is_file() is True\n",
277
+ " else pd.DataFrame()\n",
278
+ ")\n",
279
+ "\n",
280
+ "bb_predictions"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "code",
285
+ "execution_count": null,
286
+ "metadata": {},
287
+ "outputs": [],
288
+ "source": [
289
+ "plates = list(cc.path_to_plates.rglob(\"*.JPG\"))\n",
290
+ "len(plates)"
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": null,
296
+ "metadata": {},
297
+ "outputs": [],
298
+ "source": [
299
+ "errors = []\n",
300
+ "handled_plates = bb_predictions.file_name.unique()\n",
301
+ "\n",
302
+ "for plate in tqdm(plates):\n",
303
+ " if \"file_name\" in bb_predictions and plate.name in handled_plates:\n",
304
+ " continue\n",
305
+ " try:\n",
306
+ " current_data = ld_model.index_plate(plate) >> mutate(\n",
307
+ " disc_name=s.file_name.str.replace(\" \", \"\").replace(\".JPG\", \"\")\n",
308
+ " + \"_\"\n",
309
+ " + s.row.astype(str)\n",
310
+ " + \"_\"\n",
311
+ " + s.col.astype(str)\n",
312
+ " + \".png\"\n",
313
+ " )\n",
314
+ " bb_predictions = pd.concat([bb_predictions, current_data])\n",
315
+ " except:\n",
316
+ " errors.append(plate)\n",
317
+ "\n",
318
+ "print(errors)\n",
319
+ "cf.write_dataframe(\n",
320
+ " bb_predictions.sort_values([\"file_name\", \"col\", \"row\"]).reset_index(drop=True)\n",
321
+ " >> mutate(disc_name=s.disc_name.str.replace(\".JPG\", \"\")),\n",
322
+ " bb_predictions_path,\n",
323
+ ")\n",
324
+ "\n",
325
+ "bb_predictions = cf.read_dataframe(bb_predictions_path)\n",
326
+ "bb_predictions"
327
+ ]
328
+ },
329
+ {
330
+ "cell_type": "code",
331
+ "execution_count": null,
332
+ "metadata": {},
333
+ "outputs": [],
334
+ "source": [
335
+ "selected_image = random.choice(plates)\n",
336
+ "bboxes = bb_predictions >> sfilter(s.file_name == selected_image.name)\n",
337
+ "pn.Column(\n",
338
+ " pn.pane.Markdown(f\"### {selected_image.name}\"),\n",
339
+ " pn.pane.DataFrame(bboxes),\n",
340
+ " pn.pane.Image(\n",
341
+ " ci.to_pil(\n",
342
+ " lpem.print_boxes(\n",
343
+ " image_name=selected_image,\n",
344
+ " boxes=bboxes,\n",
345
+ " draw_first_line=True,\n",
346
+ " return_plot=False,\n",
347
+ " ) #\n",
348
+ " ),\n",
349
+ " sizing_mode=\"scale_width\",\n",
350
+ " ),\n",
351
+ ")"
352
+ ]
353
+ },
354
+ {
355
+ "cell_type": "markdown",
356
+ "metadata": {},
357
+ "source": [
358
+ "### Extract Needed Patches"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "markdown",
363
+ "metadata": {},
364
+ "source": [
365
+ "#### Model Training"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "code",
370
+ "execution_count": null,
371
+ "metadata": {},
372
+ "outputs": [],
373
+ "source": [
374
+ "df_model_training = pd.concat(\n",
375
+ " [\n",
376
+ " cf.read_dataframe(cc.path_to_data.joinpath(f\"oiv_{d}.csv\"))\n",
377
+ " for d in [\"train\", \"val\", \"test\"]\n",
378
+ " ]\n",
379
+ ").sort_values([\"file_name\"]).reset_index(drop=True)\n",
380
+ "df_model_training"
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "execution_count": null,
386
+ "metadata": {},
387
+ "outputs": [],
388
+ "source": [
389
+ "err = {}\n",
390
+ "\n",
391
+ "for file_name in tqdm(df_model_training.file_name):\n",
392
+ " row = (bb_predictions >> sfilter(s.disc_name == file_name)).reset_index(drop=True)\n",
393
+ " lpem.handle_bbox(\n",
394
+ " row.iloc[0],\n",
395
+ " add_process_image=True,\n",
396
+ " paths=dict(\n",
397
+ " segmented_leaf_disc=cc.path_to_leaf_discs,\n",
398
+ " leaf_disc_patch=cc.path_to_leaf_patches,\n",
399
+ " plates=cc.path_to_plates,\n",
400
+ " ),\n",
401
+ " errors=err,\n",
402
+ " )\n",
403
+ "err"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "markdown",
408
+ "metadata": {},
409
+ "source": [
410
+ "#### Genotype differenciation"
411
+ ]
412
+ },
413
+ {
414
+ "cell_type": "code",
415
+ "execution_count": null,
416
+ "metadata": {},
417
+ "outputs": [],
418
+ "source": [
419
+ "df_gd = cf.read_dataframe(\n",
420
+ " cc.path_to_data.joinpath(\"genotype_differenciation_dataset.csv\")\n",
421
+ ")\n",
422
+ "df_gd"
423
+ ]
424
+ },
425
+ {
426
+ "cell_type": "code",
427
+ "execution_count": null,
428
+ "metadata": {},
429
+ "outputs": [],
430
+ "source": [
431
+ "err = {}\n",
432
+ "\n",
433
+ "for file_name in tqdm(df_gd.file_name):\n",
434
+ " row = (bb_predictions >> sfilter(s.disc_name == file_name)).reset_index(drop=True)\n",
435
+ " lpem.handle_bbox(\n",
436
+ " row.iloc[0],\n",
437
+ " add_process_image=True,\n",
438
+ " paths=dict(\n",
439
+ " segmented_leaf_disc=cc.path_to_leaf_discs,\n",
440
+ " leaf_disc_patch=cc.path_to_leaf_patches,\n",
441
+ " plates=cc.path_to_plates,\n",
442
+ " ),\n",
443
+ " errors=err,\n",
444
+ " )\n",
445
+ "err"
446
+ ]
447
+ }
448
+ ],
449
+ "metadata": {
450
+ "kernelspec": {
451
+ "display_name": "env",
452
+ "language": "python",
453
+ "name": "python3"
454
+ },
455
+ "language_info": {
456
+ "codemirror_mode": {
457
+ "name": "ipython",
458
+ "version": 3
459
+ },
460
+ "file_extension": ".py",
461
+ "mimetype": "text/x-python",
462
+ "name": "python",
463
+ "nbconvert_exporter": "python",
464
+ "pygments_lexer": "ipython3",
465
+ "version": "3.9.2"
466
+ }
467
+ },
468
+ "nbformat": 4,
469
+ "nbformat_minor": 2
470
+ }
src/leaf_patch_extractor_model.py ADDED
@@ -0,0 +1,1292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import math
3
+
4
+ from rich.console import Console
5
+ from rich.table import Table
6
+ from rich.pretty import Pretty
7
+
8
+ import numpy as np
9
+
10
+ import pandas as pd
11
+
12
+ import cv2
13
+
14
+ from sklearn.cluster import MeanShift
15
+
16
+ from skimage.transform import hough_circle, hough_circle_peaks
17
+
18
+
19
+ import torch
20
+ from torch.utils.data import Dataset, DataLoader
21
+ from torchvision import transforms
22
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
23
+
24
+ from torchvision.models.detection import (
25
+ fasterrcnn_resnet50_fpn_v2,
26
+ FasterRCNN_ResNet50_FPN_V2_Weights,
27
+ )
28
+
29
+ import pytorch_lightning as pl
30
+ from pytorch_lightning.callbacks import RichProgressBar
31
+ from pytorch_lightning import Trainer
32
+
33
+ import albumentations as A
34
+ from albumentations.pytorch.transforms import ToTensorV2
35
+
36
+ import matplotlib.pyplot as plt
37
+
38
+ import com_const as cc
39
+ import com_image as ci
40
+
41
+ g_device = (
42
+ "mps"
43
+ if torch.backends.mps.is_built() is True
44
+ else "cuda" if torch.backends.cuda.is_built() else "cpu"
45
+ )
46
+
47
+
48
+ def load_tray_image(image_name):
49
+ return ci.load_image(
50
+ file_name=image_name, path_to_images=cc.path_to_plates, rgb=True
51
+ )
52
+
53
+
54
+ def build_albumentations(
55
+ image_size: int = 10,
56
+ gamma=(60, 180),
57
+ mean=(0.485, 0.456, 0.406),
58
+ std=(0.229, 0.224, 0.225),
59
+ ):
60
+ return {
61
+ "resize": [
62
+ A.Resize(height=image_size * 32 * 2, width=image_size * 32 * 3, p=1)
63
+ ],
64
+ "train": [
65
+ A.HorizontalFlip(p=0.3),
66
+ A.RandomBrightnessContrast(
67
+ brightness_limit=0.25, contrast_limit=0.25, p=0.5
68
+ ),
69
+ A.RandomGamma(gamma_limit=gamma, p=0.5),
70
+ ],
71
+ "to_tensor": [A.Normalize(mean=mean, std=std, p=1), ToTensorV2()],
72
+ "un_normalize": [
73
+ A.Normalize(
74
+ mean=[-m / s for m, s in zip(mean, std)],
75
+ std=[1.0 / s for s in std],
76
+ always_apply=True,
77
+ max_pixel_value=1.0,
78
+ ),
79
+ ],
80
+ }
81
+
82
+
83
+ def get_augmentations(
84
+ image_size: int = 10,
85
+ gamma=(60, 180),
86
+ kinds: list = ["resize", "to_tensor"],
87
+ mean=(0.485, 0.456, 0.406),
88
+ std=(0.229, 0.224, 0.225),
89
+ inferrence: bool = False,
90
+ ):
91
+ td_ = build_albumentations(
92
+ image_size=image_size,
93
+ gamma=gamma,
94
+ mean=mean,
95
+ std=std,
96
+ )
97
+ augs = []
98
+ for k in kinds:
99
+ augs += td_[k]
100
+ if inferrence is True:
101
+ return A.Compose(augs)
102
+ else:
103
+ return A.Compose(
104
+ augs,
105
+ bbox_params={"format": "pascal_voc", "label_fields": ["labels"]},
106
+ )
107
+
108
+
109
+ def safe_row_col(row, col):
110
+ """Ensures that row is a string and col is an integer
111
+ Args:
112
+ row (int or str): row output must be string
113
+ col (int or str): col output must be int
114
+ """
115
+ if row is not None and col is not None:
116
+ if isinstance(col, str):
117
+ row, col = col, row
118
+ return row, col
119
+
120
+
121
+ def _update_axis(axis, image, title=None, fontsize=10, remove_axis=True):
122
+ axis.imshow(image, origin="upper")
123
+ if title is not None:
124
+ axis.set_title(title, fontsize=fontsize)
125
+
126
+
127
+ def make_patches_grid(images, row_count, col_count=None, figsize=(20, 20)):
128
+ col_count = row_count if col_count is None else col_count
129
+ _, axii = plt.subplots(row_count, col_count, figsize=figsize)
130
+ for ax, image in zip(axii.reshape(-1), images):
131
+ if isinstance(image, tuple):
132
+ title = image[1]
133
+ image = image[0]
134
+ else:
135
+ title = None
136
+ try:
137
+ _update_axis(axis=ax, image=image, remove_axis=True, title=title)
138
+ except:
139
+ pass
140
+ ax.set_axis_off()
141
+
142
+ plt.tight_layout()
143
+ plt.show()
144
+
145
+
146
+ def print_boxes(
147
+ image_name,
148
+ boxes,
149
+ highlight=(None, None),
150
+ draw_first_line: bool = False,
151
+ return_plot: bool = True,
152
+ ):
153
+ r, c = safe_row_col(*highlight)
154
+ image = load_tray_image(image_name=image_name)
155
+
156
+ fnt = cv2.FONT_HERSHEY_SIMPLEX
157
+ fnt_scale = 3
158
+ fnt_thickness = 8
159
+
160
+ column_colors = {
161
+ 1: (255, 0, 0),
162
+ 2: (0, 0, 255),
163
+ 3: (255, 255, 0),
164
+ 4: (0, 255, 255),
165
+ }
166
+
167
+ for box in boxes[["x1", "y1", "x2", "y2", "cx", "cy", "row", "col"]].values:
168
+ color = (
169
+ (255, 0, 255)
170
+ if c == box[7] and r == box[6]
171
+ else column_colors.get(box[7], (255, 255, 244))
172
+ )
173
+ thickness = 20 if c == box[7] and r == box[6] else 10
174
+ image = cv2.rectangle(
175
+ image,
176
+ (int(box[0]), int(box[1])),
177
+ (int(box[2]), int(box[3])),
178
+ color,
179
+ thickness,
180
+ )
181
+ label = str(box[6]).upper() + str(int(box[7]))
182
+ (w, h), _ = cv2.getTextSize(label, fnt, fnt_scale, fnt_thickness)
183
+ x, y = (int(box[0]), int(box[1]) - fnt_thickness)
184
+ image = cv2.rectangle(
185
+ image,
186
+ (x - fnt_thickness, y - h - fnt_thickness),
187
+ (x + fnt_thickness + w, y + fnt_thickness),
188
+ color,
189
+ -1,
190
+ )
191
+ image = cv2.putText(
192
+ image,
193
+ label,
194
+ (x + fnt_thickness, y),
195
+ fnt,
196
+ fnt_scale,
197
+ (0, 0, 0),
198
+ fnt_thickness,
199
+ )
200
+
201
+ if draw_first_line is True:
202
+ line = get_first_vert_line(image_name=image_name)
203
+ if line is not None:
204
+ x1, y1, x2, y2 = line
205
+ cv2.line(
206
+ image,
207
+ [
208
+ int(i)
209
+ for i in (np.array([x2, y2]) - np.array([x1, y1])) * 10
210
+ + np.array([x1, y1])
211
+ ],
212
+ [
213
+ int(i)
214
+ for i in (np.array([x1, y1]) - np.array([x2, y2])) * 10
215
+ + np.array([x2, y2])
216
+ ],
217
+ (255, 0, 255),
218
+ 20,
219
+ lineType=8,
220
+ )
221
+
222
+ if return_plot is True:
223
+ plt.figure(figsize=(10, 10))
224
+ plt.imshow(image)
225
+ plt.tight_layout()
226
+ plt.axis("off")
227
+ plt.show()
228
+ else:
229
+ return image
230
+
231
+
232
+ def crop_to_vert(image):
233
+ return image[0 : image.shape[1] // 2, 0 : image.shape[0] // 3]
234
+
235
+
236
+ def get_first_vert_line(image_name, min_angle=80, max_angle=100):
237
+ r, *_ = cv2.split(load_tray_image(image_name))
238
+
239
+ red_crop = cv2.normalize(
240
+ crop_to_vert(r),
241
+ None,
242
+ alpha=0,
243
+ beta=200,
244
+ norm_type=cv2.NORM_MINMAX,
245
+ )
246
+
247
+ lines = cv2.HoughLinesP(
248
+ image=ci.close(
249
+ cv2.Canny(red_crop, 50, 200, None, 3),
250
+ kernel_size=5,
251
+ proc_times=5,
252
+ ),
253
+ rho=1,
254
+ theta=np.pi / 180,
255
+ threshold=50,
256
+ minLineLength=red_crop.shape[0] // 5,
257
+ maxLineGap=20,
258
+ )
259
+ if lines is not None:
260
+ min_x = red_crop.shape[0]
261
+ sel_line = None
262
+ for _, line in enumerate(lines):
263
+ x1, y1, x2, y2 = line[0]
264
+ min_angle, max_angle = min(min_angle, max_angle), max(min_angle, max_angle)
265
+ line_angle = math.atan2(y2 - y1, x2 - x1) * 180 / math.pi * -1
266
+ if min_angle <= abs(line_angle) <= max_angle and min(x1, x2) < min_x:
267
+ min_x = min(x1, x2)
268
+ sel_line = (x1, y1, x2, y2)
269
+
270
+ if sel_line is not None:
271
+ return sel_line
272
+ else:
273
+ return None
274
+
275
+
276
+ def draw_first_line(image_name, dot_size=10, crop_canvas: bool = False):
277
+ line = get_first_vert_line(image_name=image_name)
278
+ if line is None:
279
+ return canvas
280
+ x1, y1, x2, y2 = line
281
+ canvas = load_tray_image(image_name)
282
+ if crop_canvas is True:
283
+ canvas = crop_to_vert(canvas)
284
+ cv2.circle(canvas, (x1, y1), dot_size, (255, 0, 0))
285
+ cv2.circle(canvas, (x2, y2), dot_size, (0, 255, 0))
286
+ cv2.line(canvas, (x1, y1), (x2, y2), (0, 0, 255), 10)
287
+ return canvas
288
+
289
+
290
+ def get_bbox(image_name, bboxes, row, col):
291
+ if isinstance(bboxes, pd.Series):
292
+ return bboxes
293
+ else:
294
+ row, col = safe_row_col(row, col)
295
+ return bboxes[
296
+ (
297
+ bboxes.file_name
298
+ == (image_name.name if isinstance(image_name, Path) else image_name)
299
+ )
300
+ & (bboxes.row == row)
301
+ & (bboxes.col == col)
302
+ ].iloc[0]
303
+
304
+
305
+ def get_hough_leaf_disc_circle(
306
+ image_name,
307
+ bboxes,
308
+ row=-1,
309
+ col=-1,
310
+ padding: int = 10,
311
+ allow_move: bool = False,
312
+ ):
313
+ padded_leaf_disk = get_leaf_disk_wbb(
314
+ image_name=image_name,
315
+ bboxes=bboxes,
316
+ row=row,
317
+ col=col,
318
+ padding=padding,
319
+ )
320
+ *_, b = cv2.split(padded_leaf_disk)
321
+
322
+ min_t, max_t = 100, 200
323
+ rb = cv2.Canny(
324
+ cv2.normalize(
325
+ b,
326
+ None,
327
+ alpha=0,
328
+ beta=200,
329
+ norm_type=cv2.NORM_MINMAX,
330
+ ),
331
+ min_t,
332
+ max_t,
333
+ None,
334
+ 3,
335
+ )
336
+
337
+ bbox = get_bbox(image_name=image_name, bboxes=bboxes, row=row, col=col)
338
+ hough_radii = np.arange(bbox.max_size // 2 - 10, bbox.max_size // 2 + 10, 10)
339
+ hough_res = hough_circle(rb, hough_radii)
340
+
341
+ # Select the most prominent n circles
342
+ _, cx, cy, radii = hough_circle_peaks(
343
+ hough_res,
344
+ hough_radii,
345
+ min_xdistance=10,
346
+ min_ydistance=10,
347
+ total_num_peaks=1,
348
+ )
349
+
350
+ cx = cx[0]
351
+ cy = cy[0]
352
+ r = radii[0]
353
+
354
+ if allow_move is True:
355
+ h, w, c = padded_leaf_disk.shape
356
+ if cx - r < 0:
357
+ cx += abs(r - cx)
358
+ if cx + r > w:
359
+ cx -= abs(r - cx)
360
+ if cy - r < 0:
361
+ cy += abs(cy - r)
362
+ if cy + r > h:
363
+ cy -= abs(cy - r)
364
+
365
+ return dict(cx=cx, cy=cy, r=radii)
366
+
367
+
368
+ def get_hough_leaf_disk_patch(
369
+ image_name,
370
+ bboxes,
371
+ patch_size=-1,
372
+ row=-1,
373
+ col=-1,
374
+ padding: int = 10,
375
+ radius_crop=0,
376
+ disc=None,
377
+ allow_move: bool = False,
378
+ image_folder=None,
379
+ ):
380
+ if patch_size > 0:
381
+ try:
382
+ bbox = get_bbox(image_name, bboxes, row, col)
383
+ cx = int(bbox.cx)
384
+ cy = int(bbox.cy)
385
+ except:
386
+ return None
387
+ patch_size = patch_size // 2
388
+
389
+ return A.crop(
390
+ load_tray_image(image_name, image_folder=image_folder),
391
+ cx - patch_size,
392
+ cy - patch_size,
393
+ cx + patch_size,
394
+ cy + patch_size,
395
+ )
396
+ else:
397
+ if disc is None:
398
+ disc = get_hough_leaf_disc_circle(
399
+ image_name=image_name,
400
+ bboxes=bboxes,
401
+ row=row,
402
+ col=col,
403
+ padding=padding,
404
+ allow_move=allow_move,
405
+ )
406
+
407
+ r = int((disc["r"] - radius_crop) / math.sqrt(2))
408
+ cx = int(disc["cx"])
409
+ cy = int(disc["cy"])
410
+
411
+ left = cx - r
412
+ top = cy - r
413
+ right = cx + r
414
+ bottom = cy + r
415
+
416
+ return get_leaf_disk_wbb(
417
+ image_name=image_name,
418
+ bboxes=bboxes,
419
+ row=row,
420
+ col=col,
421
+ padding=padding,
422
+ )[top:bottom, left:right]
423
+
424
+
425
+ def get_hough_segment_disk(
426
+ image_name,
427
+ bboxes,
428
+ row=-1,
429
+ col=-1,
430
+ padding: int = 10,
431
+ radius_crop=0,
432
+ disc=None,
433
+ allow_move: bool = False,
434
+ ):
435
+ if disc is None:
436
+ disc = get_hough_leaf_disc_circle(
437
+ image_name=image_name,
438
+ bboxes=bboxes,
439
+ row=row,
440
+ col=col,
441
+ padding=padding,
442
+ allow_move=allow_move,
443
+ )
444
+
445
+ padded_leaf_disk = get_leaf_disk_wbb(
446
+ image_name=image_name,
447
+ bboxes=bboxes,
448
+ row=row,
449
+ col=col,
450
+ padding=padding,
451
+ )
452
+ r = int(disc["r"] - radius_crop)
453
+ rc = int((disc["r"] - radius_crop) / math.sqrt(2))
454
+ cx = int(disc["cx"])
455
+ cy = int(disc["cy"])
456
+ left = cx - r
457
+ top = cy - r
458
+ right = cx + r
459
+ bottom = cy + r
460
+
461
+ return cv2.bitwise_and(
462
+ padded_leaf_disk,
463
+ padded_leaf_disk,
464
+ mask=cv2.circle(np.zeros_like(padded_leaf_disk[:, :, 0]), (cx, cy), r, 255, -1),
465
+ )[top:bottom, left:right]
466
+
467
+
468
+ def draw_hough_bb_to_patch_process(
469
+ image_name,
470
+ bboxes,
471
+ row=-1,
472
+ col=-1,
473
+ padding: int = 10,
474
+ radius_crop=0,
475
+ disc=None,
476
+ allow_move: bool = False,
477
+ ):
478
+ if disc is None:
479
+ disc = get_hough_leaf_disc_circle(
480
+ image_name=image_name,
481
+ bboxes=bboxes,
482
+ row=row,
483
+ col=col,
484
+ padding=padding,
485
+ allow_move=allow_move,
486
+ )
487
+
488
+ padded_leaf_disk = get_leaf_disk_wbb(
489
+ image_name=image_name,
490
+ bboxes=bboxes,
491
+ row=row,
492
+ col=col,
493
+ padding=padding,
494
+ )
495
+ r = int(disc["r"] - radius_crop)
496
+ rc = int((disc["r"] - radius_crop) / math.sqrt(2))
497
+ cx = int(disc["cx"])
498
+ cy = int(disc["cy"])
499
+ left = cx - r
500
+ top = cy - r
501
+ right = cx + r
502
+ bottom = cy + r
503
+
504
+ return cv2.circle(
505
+ cv2.circle(
506
+ cv2.rectangle(
507
+ cv2.rectangle(
508
+ padded_leaf_disk,
509
+ (cx - rc, cy - rc),
510
+ (cx + rc, cy + rc),
511
+ (0, 255, 0),
512
+ 5,
513
+ ),
514
+ (left, top),
515
+ (right, bottom),
516
+ (255, 0, 155),
517
+ 5,
518
+ ),
519
+ (cx, cy),
520
+ 10,
521
+ (255, 0, 155),
522
+ -1,
523
+ ),
524
+ (cx, cy),
525
+ r,
526
+ (255, 0, 155),
527
+ 5,
528
+ )
529
+
530
+
531
+ def get_leaf_disk_wbb(image_name, bboxes, row=-1, col=-1, image_path: Path = None):
532
+ try:
533
+ bbox = get_bbox(image_name, bboxes, row, col)
534
+ return load_tray_image(image_name if image_path is None else image_path)[
535
+ int(bbox.y1) : int(bbox.y2), int(bbox.x1) : int(bbox.x2)
536
+ ]
537
+ except:
538
+ return None
539
+
540
+
541
+ def get_fast_leaf_disc_circle(
542
+ image_name, bboxes, row=-1, col=-1, percent_radius: float = 1.0
543
+ ):
544
+ bbox = get_bbox(image_name=image_name, bboxes=bboxes, row=row, col=col)
545
+ return int(bbox.cx), int(bbox.cy), int((bbox.max_size / 2) * percent_radius)
546
+
547
+
548
+ def get_fast_segment_disk(
549
+ image_name,
550
+ bboxes,
551
+ row=-1,
552
+ col=-1,
553
+ percent_radius: float = 1.0,
554
+ image_path: Path = None,
555
+ ):
556
+ cx, cy, r = get_fast_leaf_disc_circle(
557
+ image_name=image_name,
558
+ bboxes=bboxes,
559
+ row=row,
560
+ col=col,
561
+ percent_radius=percent_radius,
562
+ )
563
+ src_image = load_tray_image(image_name if image_path is None else image_path)
564
+ left = cx - r
565
+ top = cy - r
566
+ right = cx + r
567
+ bottom = cy + r
568
+
569
+ return cv2.bitwise_and(
570
+ src_image,
571
+ src_image,
572
+ mask=cv2.circle(np.zeros_like(src_image[:, :, 0]), (cx, cy), r, 255, -1),
573
+ )[top:bottom, left:right]
574
+
575
+
576
+ def get_fast_leaf_disk_patch(
577
+ image_name,
578
+ bboxes,
579
+ row=-1,
580
+ col=-1,
581
+ percent_radius: float = 1.0,
582
+ image_path: Path = None,
583
+ ):
584
+ cx, cy, r = get_fast_leaf_disc_circle(
585
+ image_name=image_name,
586
+ bboxes=bboxes,
587
+ row=row,
588
+ col=col,
589
+ percent_radius=percent_radius,
590
+ )
591
+ r = int(r / math.sqrt(2))
592
+ left = cx - r
593
+ top = cy - r
594
+ right = cx + r
595
+ bottom = cy + r
596
+
597
+ return load_tray_image(image_name if image_path is None else image_path)[
598
+ top:bottom, left:right
599
+ ]
600
+
601
+
602
+ def draw_fast_bb_to_patch_process(
603
+ image_name,
604
+ bboxes,
605
+ row=-1,
606
+ col=-1,
607
+ percent_radius: float = 1.0,
608
+ image_path: Path = None,
609
+ add_center: bool = True,
610
+ ):
611
+ cx, cy, r = get_fast_leaf_disc_circle(
612
+ image_name=image_name,
613
+ bboxes=bboxes,
614
+ row=row,
615
+ col=col,
616
+ percent_radius=percent_radius,
617
+ )
618
+ bbox = get_bbox(image_name=image_name, bboxes=bboxes, row=row, col=col)
619
+ image = load_tray_image(image_name if image_path is None else image_path)
620
+ rc = int(r / math.sqrt(2))
621
+
622
+ cv2.circle(image, (cx, cy), r, color=(255, 0, 155), thickness=5)
623
+ if add_center is True:
624
+ cv2.circle(image, (cx, cy), 10, color=(255, 0, 155), thickness=-1)
625
+ cv2.rectangle(image, (cx - rc, cy - rc), (cx + rc, cy + rc), (0, 255, 0), 5)
626
+
627
+ return image[int(bbox.y1) : int(bbox.y2), int(bbox.x1) : int(bbox.x2)]
628
+
629
+
630
+ class LeafDiskDetectorDataset(Dataset):
631
+ def __init__(
632
+ self,
633
+ csv,
634
+ transform=None,
635
+ yxyx: bool = False,
636
+ return_id: bool = False,
637
+ bboxes: bool = True,
638
+ ):
639
+ self.boxes = csv.copy()
640
+ self.images = list(self.boxes.plate_name.unique())
641
+ self.transforms = transform
642
+ if transform is not None:
643
+ self.width, self.height = transform[0].width, transform[0].height
644
+ else:
645
+ self.width, self.height = 0, 0
646
+ self.yxyx = yxyx
647
+ self.return_id = return_id
648
+ self.bboxes = bboxes
649
+
650
+ def __len__(self):
651
+ return len(self.images)
652
+
653
+ def load_boxes(self, idx):
654
+ if "x" in self.boxes.columns:
655
+ boxes = self.boxes[self.boxes.plate_name == self.images[idx]].dropna()
656
+ size = boxes.shape[0]
657
+ return (
658
+ (size, boxes[["x1", "y1", "x2", "y2"]].values) if size > 0 else (0, [])
659
+ )
660
+ return 0, []
661
+
662
+ def load_tray_image(self, idx):
663
+ return load_tray_image(self.images[idx])
664
+
665
+ def get_by_sample_name(self, plate_name):
666
+ return self[self.images.index(plate_name)]
667
+
668
+ def get_image_by_name(self, plate_name):
669
+ return load_tray_image(plate_name)
670
+
671
+ def draw_image_with_boxes(self, plate_name):
672
+ image, labels, *_ = self[self.images.index(plate_name)]
673
+ boxes = labels[self.get_boxes_key()]
674
+ for box in boxes:
675
+ box_indexes = [1, 0, 3, 2] if self.yxyx is True else [0, 1, 2, 3]
676
+ image = cv2.rectangle(
677
+ image,
678
+ # Boxes are in yxyx format
679
+ (int(box[box_indexes[0]]), int(box[box_indexes[1]])),
680
+ (int(box[box_indexes[2]]), int(box[box_indexes[3]])),
681
+ (255, 0, 0),
682
+ 2,
683
+ )
684
+ return image
685
+
686
+ def get_boxes_key(self):
687
+ return "bboxes" if self.bboxes is True else "boxes"
688
+
689
+ def __getitem__(self, index):
690
+ num_box, boxes = self.load_boxes(
691
+ index
692
+ ) # return list of [xmin, ymin, xmax, ymax]
693
+ img = self.load_tray_image(index) # return an image
694
+
695
+ if num_box > 0:
696
+ boxes = torch.as_tensor(boxes, dtype=torch.float32)
697
+ else:
698
+ # negative example, ref: https://github.com/pytorch/vision/issues/2144
699
+ boxes = torch.zeros((0, 4), dtype=torch.float32)
700
+
701
+ image_id = torch.tensor([index])
702
+ labels = torch.ones((num_box,), dtype=torch.int64)
703
+ target = {
704
+ self.get_boxes_key(): boxes,
705
+ "labels": labels,
706
+ "image_id": image_id,
707
+ "area": torch.as_tensor(
708
+ (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
709
+ dtype=torch.float32,
710
+ ),
711
+ "iscrowd": torch.zeros((num_box,), dtype=torch.int64),
712
+ "img_size": torch.tensor([self.height, self.width]),
713
+ "img_scale": torch.tensor([1.0]),
714
+ }
715
+
716
+ if self.transforms is not None:
717
+ sample = {
718
+ "image": img,
719
+ "bboxes": target[self.get_boxes_key()],
720
+ "labels": labels,
721
+ }
722
+ sample = self.transforms(**sample)
723
+ img = sample["image"]
724
+ if num_box > 0:
725
+ # Convert to ndarray to allow slicing
726
+ boxes = np.array(sample["bboxes"])
727
+ # Convert to yxyx
728
+ if self.yxyx is True:
729
+ boxes[:, [0, 1, 2, 3]] = boxes[:, [1, 0, 3, 2]]
730
+ # Convert to tensor
731
+ target[self.get_boxes_key()] = torch.as_tensor(
732
+ boxes, dtype=torch.float32
733
+ )
734
+ else:
735
+ target[self.get_boxes_key()] = torch.zeros((0, 4), dtype=torch.float32)
736
+ else:
737
+ img = transforms.ToTensor()(img)
738
+ if self.return_id is True:
739
+ return img, target, image_id
740
+ else:
741
+ return img, target
742
+
743
+
744
+ def collate_fn(batch):
745
+ images, targets = tuple(zip(*batch))
746
+ images = torch.stack(images)
747
+ images = images.float()
748
+
749
+ boxes = [target["boxes"].float() for target in targets]
750
+ labels = [target["labels"].float() for target in targets]
751
+
752
+ return images, targets
753
+
754
+
755
+ def find_best_lr(model, default_root_dir=cc.path_to_chk_detector):
756
+ # run learning rate finder, results override hparams.learning_rate
757
+ trainer = Trainer(
758
+ default_root_dir=default_root_dir,
759
+ auto_lr_find=True,
760
+ accelerator="gpu",
761
+ callbacks=[RichProgressBar()],
762
+ )
763
+
764
+ # call tune to find the lr
765
+ trainer.tune(model)
766
+
767
+ return model.learning_rate
768
+
769
+
770
+ class LeafDiskDetector(pl.LightningModule):
771
+ def __init__(
772
+ self,
773
+ batch_size: int,
774
+ learning_rate: float,
775
+ max_epochs: int,
776
+ image_factor: int,
777
+ train_data: pd.DataFrame,
778
+ val_data: pd.DataFrame,
779
+ test_data: pd.DataFrame,
780
+ augmentations_kinds: list = ["resize", "train", "to_tensor"],
781
+ augmentations_params: dict = {"gamma": (60, 180)},
782
+ num_workers: int = 0,
783
+ accumulate_grad_batches: int = 3,
784
+ selected_device: str = g_device,
785
+ optimizer: str = "adam",
786
+ scheduler: str = None,
787
+ scheduler_params: dict = {},
788
+ ):
789
+ super().__init__()
790
+
791
+ self.model_name = "ldd"
792
+
793
+ # Hyperparameters
794
+ self.batch_size = batch_size
795
+ self.selected_device = selected_device
796
+ self.learning_rate = learning_rate
797
+ self.num_workers = num_workers
798
+ self.max_epochs = max_epochs
799
+ self.accumulate_grad_batches = accumulate_grad_batches
800
+
801
+ # dataframes
802
+ self.train_data = train_data
803
+ self.val_data = val_data
804
+ self.test_data = test_data
805
+
806
+ # Optimizer
807
+ self.optimizer = optimizer
808
+ self.scheduler = scheduler
809
+ self.scheduler_params = scheduler_params
810
+
811
+ # albumentations
812
+ self.image_factor = image_factor
813
+ self.augmentations_kinds = augmentations_kinds
814
+ self.augmentations_params = augmentations_params
815
+
816
+ self.train_augmentations = get_augmentations(
817
+ image_size=self.image_factor,
818
+ kinds=self.augmentations_kinds,
819
+ **self.augmentations_params,
820
+ )
821
+
822
+ self.val_augmentations = get_augmentations(
823
+ image_size=self.image_factor,
824
+ kinds=["resize", "to_tensor"],
825
+ **self.augmentations_params,
826
+ )
827
+
828
+ # Model
829
+ self.encoder = fasterrcnn_resnet50_fpn_v2(
830
+ weights=FasterRCNN_ResNet50_FPN_V2_Weights
831
+ )
832
+ num_classes = 2 # 1 class (wheat) + background
833
+ # get number of input features for the classifier
834
+ in_features = self.encoder.roi_heads.box_predictor.cls_score.in_features
835
+ # replace the pre-trained head with a new one
836
+ self.encoder.roi_heads.box_predictor = FastRCNNPredictor(
837
+ in_features, num_classes
838
+ )
839
+
840
+ self.save_hyperparameters()
841
+
842
+ def hr_desc(self):
843
+ table = Table(title=f"{self.model_name} params & values")
844
+ table.add_column("Param", justify="right", style="bold", no_wrap=True)
845
+ table.add_column("Value")
846
+
847
+ def add_pairs(table_, attributes: list) -> None:
848
+ for a in attributes:
849
+ try:
850
+ table_.add_row(a, Pretty(getattr(self, a)))
851
+ except:
852
+ pass
853
+
854
+ add_pairs(
855
+ table,
856
+ ["model_name", "batch_size", "num_workers", "accumulate_grad_batches"],
857
+ )
858
+ table.add_row("image_width", Pretty(self.train_augmentations[0].width))
859
+ table.add_row("image_height", Pretty(self.train_augmentations[0].height))
860
+ add_pairs(
861
+ table,
862
+ ["image_factor", "augmentations_kinds", "augmentations_params"],
863
+ )
864
+
865
+ add_pairs(
866
+ table,
867
+ ["learning_rate", "optimizer", "scheduler", "scheduler_params"],
868
+ )
869
+
870
+ for name, df in zip(
871
+ ["train", "val", "test"],
872
+ [self.train_data, self.val_data, self.test_data],
873
+ ):
874
+ table.add_row(
875
+ name,
876
+ Pretty(
877
+ f"shape: {str(df.shape)}, images: {len(df.plate_name.unique())}"
878
+ ),
879
+ )
880
+
881
+ Console().print(table)
882
+
883
+ def configure_optimizers(self):
884
+ # Optimizer
885
+ if self.optimizer == "adam":
886
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
887
+ elif self.optimizer == "sgd":
888
+ optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
889
+ else:
890
+ optimizer = None
891
+
892
+ # Scheduler
893
+ if self.scheduler == "cycliclr":
894
+ scheduler = torch.optim.lr_scheduler.CyclicLR(
895
+ optimizer,
896
+ base_lr=self.learning_rate,
897
+ max_lr=0.01,
898
+ step_size_up=100,
899
+ mode=self.scheduler_mode,
900
+ )
901
+ elif self.scheduler == "steplr":
902
+ self.scheduler_params["optimizer"] = optimizer
903
+ scheduler = torch.optim.lr_scheduler.StepLR(**self.scheduler_params)
904
+ self.scheduler_params.pop("optimizer")
905
+ elif self.scheduler == "plateau":
906
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
907
+ optimizer,
908
+ mode="min",
909
+ factor=0.2,
910
+ patience=10,
911
+ min_lr=1e-6,
912
+ )
913
+ scheduler = {"scheduler": scheduler, "monitor": "val_loss"}
914
+ else:
915
+ scheduler = None
916
+ if scheduler is None:
917
+ return optimizer
918
+ else:
919
+ return [optimizer], [scheduler]
920
+
921
+ def train_dataloader(self):
922
+ return DataLoader(
923
+ LeafDiskDetectorDataset(
924
+ csv=self.train_data,
925
+ transform=self.train_augmentations,
926
+ bboxes=False,
927
+ ),
928
+ batch_size=self.batch_size,
929
+ shuffle=True,
930
+ num_workers=self.num_workers,
931
+ collate_fn=collate_fn,
932
+ pin_memory=True,
933
+ )
934
+
935
+ def val_dataloader(self):
936
+ return DataLoader(
937
+ LeafDiskDetectorDataset(
938
+ csv=self.train_data,
939
+ transform=self.val_augmentations,
940
+ bboxes=False,
941
+ ),
942
+ batch_size=self.batch_size,
943
+ num_workers=self.num_workers,
944
+ collate_fn=collate_fn,
945
+ pin_memory=True,
946
+ )
947
+
948
+ def test_dataloader(self):
949
+ return DataLoader(
950
+ LeafDiskDetectorDataset(
951
+ csv=self.train_data,
952
+ transform=self.val_augmentations,
953
+ bboxes=False,
954
+ ),
955
+ batch_size=self.batch_size,
956
+ num_workers=self.num_workers,
957
+ collate_fn=collate_fn,
958
+ pin_memory=True,
959
+ )
960
+
961
+ def forward(self, x):
962
+ return self.encoder(x)
963
+
964
+ def step_(self, batch, batch_index):
965
+ x, y = batch
966
+ self.train()
967
+ loss_dict = self.encoder(x, y)
968
+ return sum(loss for loss in loss_dict.values())
969
+
970
+ def training_step(self, batch, batch_idx):
971
+ loss = self.step_(batch=batch, batch_index=batch_idx)
972
+ self.log(
973
+ "train_loss", loss, on_step=True, prog_bar=True, batch_size=self.batch_size
974
+ )
975
+ return loss
976
+
977
+ def validation_step(self, batch, batch_idx):
978
+ loss = self.step_(batch=batch, batch_index=batch_idx)
979
+ self.log(
980
+ "val_loss",
981
+ loss,
982
+ on_epoch=True,
983
+ on_step=False,
984
+ prog_bar=True,
985
+ batch_size=self.batch_size,
986
+ )
987
+ self.log("train_loss", loss)
988
+ return loss
989
+
990
+ def test_step(self, batch, batch_idx):
991
+ loss = self.step_(
992
+ batch=batch, batch_index=batch_idx, batch_size=self.batch_size
993
+ )
994
+ self.log("test_loss", loss)
995
+ return loss
996
+
997
+ def prepare_bboxes(
998
+ self,
999
+ image_name,
1000
+ score_threshold=0.90,
1001
+ ar_threshold=1.5,
1002
+ size_threshold=0.30,
1003
+ ):
1004
+ augs = get_augmentations(
1005
+ image_size=self.image_factor,
1006
+ kinds=["resize", "to_tensor"],
1007
+ inferrence=True,
1008
+ **self.augmentations_params,
1009
+ )
1010
+ image = load_tray_image(image_name=image_name)
1011
+
1012
+ self.to(g_device)
1013
+ self.eval()
1014
+ predictions = self(augs(image=image)["image"].to(g_device).unsqueeze(0))
1015
+
1016
+ boxes = predictions[0]["boxes"].detach().to("cpu").numpy()
1017
+ scores = predictions[0]["scores"].detach().to("cpu").numpy()
1018
+
1019
+ filtered_predictions = [
1020
+ [box[i] for i in range(4)]
1021
+ for box, score in zip(boxes, scores)
1022
+ if score > score_threshold
1023
+ ]
1024
+
1025
+ restore_size = A.Compose(
1026
+ [A.Resize(width=image.shape[1], height=image.shape[0])],
1027
+ # [A.Resize(width=5000, height=5000)],
1028
+ bbox_params={"format": "pascal_voc", "label_fields": ["labels"]},
1029
+ )
1030
+
1031
+ sample = {
1032
+ "image": image,
1033
+ "bboxes": filtered_predictions,
1034
+ "labels": [1 for _ in range(len(filtered_predictions))],
1035
+ }
1036
+ sample = restore_size(**sample)
1037
+
1038
+ resized_predictions = sample["bboxes"]
1039
+
1040
+ from siuba import _, filter, mutate
1041
+
1042
+ boxes = (
1043
+ pd.DataFrame(data=resized_predictions, columns=["x1", "y1", "x2", "y2"])
1044
+ >> mutate(
1045
+ x1=_.x1 * image.shape[1] / augs[0].width,
1046
+ y1=_.y1 * image.shape[0] / augs[0].height,
1047
+ x2=_.x2 * image.shape[1] / augs[0].width,
1048
+ y2=_.y2 * image.shape[0] / augs[0].height,
1049
+ )
1050
+ >> mutate(width=_.x2 - _.x1, height=_.y2 - _.y1)
1051
+ >> mutate(cx=(_.x1 + _.x2) / 2, cy=(_.y1 + _.y2) / 2)
1052
+ >> mutate(area=_.width * _.height)
1053
+ >> mutate(ar=_.width / _.height)
1054
+ )
1055
+ boxes.insert(
1056
+ 0,
1057
+ "file_name",
1058
+ image_name.name if isinstance(image_name, Path) else image_name,
1059
+ )
1060
+ boxes["max_size"] = boxes[["width", "height"]].max(axis=1)
1061
+
1062
+ ar_boxes = (
1063
+ boxes
1064
+ >> filter(_.width / _.height < ar_threshold)
1065
+ >> filter(_.height / _.width < ar_threshold)
1066
+ )
1067
+
1068
+ return ar_boxes[ar_boxes.area > ar_boxes.area.max() * size_threshold]
1069
+
1070
+ @staticmethod
1071
+ def init_cols(bboxes):
1072
+ bboxes = bboxes.copy()
1073
+
1074
+ # Handle columns
1075
+ X = np.reshape(bboxes.cx.to_list(), (-1, 1))
1076
+ ms = MeanShift(bandwidth=100, bin_seeding=True)
1077
+ ms.fit(X)
1078
+ cols = ms.predict(X)
1079
+ bboxes["col"] = cols
1080
+
1081
+ bboxes = bboxes.sort_values("cx")
1082
+ bboxes["mean_cx"] = (
1083
+ bboxes.groupby("col").transform("mean", numeric_only=True).cx
1084
+ )
1085
+ bboxes = bboxes.sort_values("mean_cx")
1086
+ for i, val in enumerate(bboxes.mean_cx.unique()):
1087
+ bboxes.loc[bboxes["mean_cx"] == val, "col"] = i
1088
+
1089
+ # Handle Rows
1090
+ bboxes = bboxes.sort_values("cy")
1091
+ X = np.reshape(bboxes.cy.to_list(), (-1, 1))
1092
+ ms = MeanShift(bandwidth=100, bin_seeding=True)
1093
+ ms.fit(X)
1094
+ rows = ms.predict(X)
1095
+ bboxes["row"] = rows
1096
+
1097
+ bboxes = bboxes.sort_values("cy")
1098
+ bboxes["mean_cy"] = (
1099
+ bboxes.groupby("row").transform("mean", numeric_only=True).cy
1100
+ )
1101
+ bboxes = bboxes.sort_values("mean_cy")
1102
+ for i, val in zip(["a", "b", "c"], bboxes.mean_cy.unique()):
1103
+ bboxes.loc[bboxes["mean_cy"] == val, "row"] = i
1104
+
1105
+ bboxes = bboxes.sort_values("cx")
1106
+
1107
+ return bboxes
1108
+
1109
+ @staticmethod
1110
+ def finalize_indexing(bboxes, image_name):
1111
+ bboxes = bboxes.copy()
1112
+ bboxes = bboxes.sort_values("cx")
1113
+ labels_unique = bboxes.col.unique()
1114
+ labels = bboxes.col.to_numpy()
1115
+ if len(labels_unique) < 4:
1116
+ inc_labels = [[i, 0] for i in range(len(labels_unique))]
1117
+ max_width = bboxes.max_size.max()
1118
+
1119
+ # Handle left-most label
1120
+ # We remove half of max width to take into account trails margins
1121
+ left_most_line = get_first_vert_line(image_name=image_name)
1122
+ if left_most_line is not None:
1123
+ left_most_point = bboxes.x1.min() - min(
1124
+ left_most_line[0], left_most_line[1]
1125
+ )
1126
+ else:
1127
+ left_most_point = bboxes.x1.min() - (max_width / 2)
1128
+ i = 1
1129
+ while left_most_point > i * 1.1 * max_width:
1130
+ inc_labels[0][1] += 1
1131
+ i += 1
1132
+
1133
+ # Handle the next labels
1134
+ prev_min_min = bboxes[bboxes.col == 0].x2.max()
1135
+
1136
+ for label in labels_unique[1:]:
1137
+ current_label_contours = bboxes[bboxes.col == label]
1138
+ max_width = current_label_contours.max_size.max()
1139
+ min_left = current_label_contours.x1.min()
1140
+ i = 1
1141
+ while min_left - prev_min_min > i * 1.1 * max_width:
1142
+ inc_labels[label][1] += 1
1143
+ i += 1
1144
+ prev_min_min = min_left + max_width
1145
+
1146
+ for pos, inc in reversed(inc_labels):
1147
+ labels[labels >= pos] += inc
1148
+
1149
+ bboxes["col"] = labels
1150
+
1151
+ labels_unique = np.unique(labels)
1152
+
1153
+ bboxes["col"] += 1
1154
+
1155
+ return bboxes.sort_values(["row", "col"])
1156
+
1157
+ def index_plate(
1158
+ self,
1159
+ image_name,
1160
+ score_threshold=0.90,
1161
+ ar_threshold=1.5,
1162
+ size_threshold=0.50,
1163
+ ):
1164
+ bboxes = self.prepare_bboxes(
1165
+ image_name=image_name,
1166
+ score_threshold=score_threshold,
1167
+ ar_threshold=ar_threshold,
1168
+ size_threshold=size_threshold,
1169
+ )
1170
+ if bboxes.shape[0] == 0:
1171
+ return bboxes
1172
+
1173
+ bboxes = self.init_cols(bboxes=bboxes)
1174
+ bboxes = self.finalize_indexing(bboxes=bboxes, image_name=image_name)
1175
+
1176
+ return bboxes
1177
+
1178
+
1179
+ def test_augmentations(
1180
+ df,
1181
+ image_size,
1182
+ kinds: list = ["resize", "train"],
1183
+ row_count=2,
1184
+ col_count=4,
1185
+ **aug_params,
1186
+ ):
1187
+ src_dataset = LeafDiskDetectorDataset(
1188
+ csv=df,
1189
+ transform=get_augmentations(
1190
+ image_size=image_size, kinds=["resize"], **aug_params
1191
+ ),
1192
+ )
1193
+
1194
+ test_dataset = LeafDiskDetectorDataset(
1195
+ csv=df,
1196
+ transform=get_augmentations(image_size=image_size, kinds=kinds, **aug_params),
1197
+ )
1198
+
1199
+ image_name = df.sample(n=1).iloc[0].plate_name
1200
+
1201
+ images = [(src_dataset.draw_image_with_boxes(plate_name=image_name), "Source")] + [
1202
+ (test_dataset.draw_image_with_boxes(plate_name=image_name), "Augmented")
1203
+ for i in range(row_count * col_count - 1)
1204
+ ]
1205
+
1206
+ make_patches_grid(
1207
+ images=images,
1208
+ row_count=row_count,
1209
+ col_count=col_count,
1210
+ figsize=(col_count * 4, row_count * 3),
1211
+ )
1212
+
1213
+
1214
+ def get_file_path_from_row(row, path_to_patches: Path):
1215
+ return path_to_patches.joinpath(row.file_name)
1216
+
1217
+
1218
+ def get_fast_images(
1219
+ row, path_to_patches, percent_radius: float = 1.0, add_process_image: bool = False
1220
+ ):
1221
+ d = {}
1222
+ try:
1223
+ d["leaf_disc_box"] = get_leaf_disk_wbb(
1224
+ row.file_name, row, image_path=get_file_path_from_row(row, path_to_patches)
1225
+ )
1226
+ except:
1227
+ pass
1228
+ try:
1229
+ d["segmented_leaf_disc"] = get_fast_segment_disk(
1230
+ image_name=row.file_name,
1231
+ bboxes=row,
1232
+ percent_radius=percent_radius,
1233
+ image_path=get_file_path_from_row(row, path_to_patches),
1234
+ )
1235
+ except:
1236
+ pass
1237
+ try:
1238
+ d["leaf_disc_patch"] = get_fast_leaf_disk_patch(
1239
+ image_name=row.file_name,
1240
+ bboxes=row,
1241
+ percent_radius=percent_radius,
1242
+ image_path=get_file_path_from_row(row, path_to_patches),
1243
+ )
1244
+ except:
1245
+ pass
1246
+ if add_process_image is True:
1247
+ try:
1248
+ d["process_image"] = draw_fast_bb_to_patch_process(
1249
+ image_name=row.file_name,
1250
+ bboxes=row,
1251
+ percent_radius=percent_radius,
1252
+ image_path=get_file_path_from_row(row, path_to_patches),
1253
+ )
1254
+ except:
1255
+ pass
1256
+
1257
+ return d
1258
+
1259
+
1260
+ def save_images(row: pd.Series, images_data: dict, errors: dict, paths: dict):
1261
+ fn = f"{Path(row.file_name).stem}_{row.row}_{int(row.col)}.png"
1262
+ for k, image in images_data.items():
1263
+ if k not in paths:
1264
+ continue
1265
+ path_to_image = paths[k].joinpath(fn)
1266
+ if image is not None:
1267
+ if path_to_image.is_file() is False:
1268
+ cv2.imwrite(str(path_to_image), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
1269
+ elif errors is not None:
1270
+ errors[k].append(row.file_name)
1271
+ else:
1272
+ pass
1273
+
1274
+
1275
+ def handle_bbox(
1276
+ row: pd.Series,
1277
+ paths: dict,
1278
+ errors: dict = None,
1279
+ percent_radius: float = 1.0,
1280
+ add_process_image: bool = False,
1281
+ ):
1282
+ save_images(
1283
+ row=row,
1284
+ images_data=get_fast_images(
1285
+ row=row,
1286
+ percent_radius=percent_radius,
1287
+ add_process_image=add_process_image,
1288
+ path_to_patches=paths["plates"],
1289
+ ),
1290
+ errors=errors,
1291
+ paths=paths,
1292
+ )
src/leaf_patch_gen_diff.ipynb ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Genotype Differenciation"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "## Imports"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "%load_ext autoreload\n",
24
+ "%autoreload 2"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "import warnings\n",
34
+ "\n",
35
+ "import numpy as np\n",
36
+ "import pandas as pd\n",
37
+ "\n",
38
+ "import scipy.stats as stats\n",
39
+ "import statsmodels.api as sm\n",
40
+ "from statsmodels.formula.api import ols\n",
41
+ "from statsmodels.regression.linear_model import RegressionResultsWrapper\n",
42
+ "from statsmodels.stats.multicomp import pairwise_tukeyhsd\n",
43
+ "\n",
44
+ "from matplotlib.figure import Figure\n",
45
+ "import seaborn as sns\n",
46
+ "import panel as pn\n",
47
+ "\n",
48
+ "import com_const as cc\n",
49
+ "import com_func as cf\n",
50
+ "import com_image as ci"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "markdown",
55
+ "metadata": {},
56
+ "source": [
57
+ "## Setup"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "warnings.simplefilter(action=\"ignore\", category=UserWarning)\n",
67
+ "warnings.simplefilter(action=\"ignore\", category=FutureWarning)"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": null,
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "pd.set_option(\"display.max_colwidth\", 500)\n",
77
+ "pd.set_option(\"display.max_columns\", 500)\n",
78
+ "pd.set_option(\"display.width\", 1000)\n",
79
+ "pd.set_option(\"display.max_rows\", 20)"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "sns.set_style(\"whitegrid\")"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "pn.extension(\"ipywidgets\", \"plotly\", design=\"material\")"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "markdown",
102
+ "metadata": {},
103
+ "source": [
104
+ "## Constants"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "stars = [-np.log(0.05), -np.log(0.01), -np.log(0.001), -np.log(0.0001)]"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "markdown",
118
+ "metadata": {},
119
+ "source": [
120
+ "## Functions"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "def plot_single_progression(\n",
130
+ " ax,\n",
131
+ " df,\n",
132
+ " target,\n",
133
+ " title: str,\n",
134
+ " hue=\"gen\",\n",
135
+ " style=\"gen\",\n",
136
+ " show_legend: bool = False,\n",
137
+ "):\n",
138
+ " lp = sns.lineplot(\n",
139
+ " df.sort_values(hue),\n",
140
+ " x=\"dpi\",\n",
141
+ " y=target,\n",
142
+ " hue=hue,\n",
143
+ " markers=True,\n",
144
+ " style=style,\n",
145
+ " dashes=False,\n",
146
+ " palette=\"tab10\",\n",
147
+ " markersize=12,\n",
148
+ " ax=ax,\n",
149
+ " )\n",
150
+ " lp.set_yticklabels([\"\", \"3\", \"\", \"5\", \"\", \"7\", \"\", \"9\"])\n",
151
+ " ax.set_title(title)\n",
152
+ " if show_legend is True:\n",
153
+ " sns.move_legend(ax, \"upper left\", bbox_to_anchor=(1, 1))\n",
154
+ " else:\n",
155
+ " ax.get_legend().set_visible(False)"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "metadata": {},
162
+ "outputs": [],
163
+ "source": [
164
+ "def get_model(\n",
165
+ " df: pd.DataFrame, target: str, formula: str, dpi: int = None\n",
166
+ ") -> RegressionResultsWrapper:\n",
167
+ " df_ = df[df.dpi == dpi] if dpi is not None else df\n",
168
+ " return ols(f\"{target} {formula}\", data=df_).fit()"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": [
177
+ "def anova_table(aov, add_columns: bool = True):\n",
178
+ " \"\"\"\n",
179
+ " The function below was created specifically for the one-way ANOVA table\n",
180
+ " results returned for Type II sum of squares\n",
181
+ " \"\"\"\n",
182
+ " if add_columns is True:\n",
183
+ " aov[\"mean_sq\"] = aov[:][\"sum_sq\"] / aov[:][\"df\"]\n",
184
+ "\n",
185
+ " aov[\"eta_sq\"] = aov[:-1][\"sum_sq\"] / sum(aov[\"sum_sq\"])\n",
186
+ "\n",
187
+ " aov[\"omega_sq\"] = (\n",
188
+ " aov[:-1][\"sum_sq\"] - (aov[:-1][\"df\"] * aov[\"mean_sq\"][-1])\n",
189
+ " ) / (sum(aov[\"sum_sq\"]) + aov[\"mean_sq\"][-1])\n",
190
+ "\n",
191
+ " cols = [\"sum_sq\", \"df\", \"mean_sq\", \"F\", \"PR(>F)\", \"eta_sq\", \"omega_sq\"]\n",
192
+ " aov = aov[cols]\n",
193
+ " return aov"
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "execution_count": null,
199
+ "metadata": {},
200
+ "outputs": [],
201
+ "source": [
202
+ "def plot_assumptions(models: list, titles: list, figsize=(12, 4)):\n",
203
+ " fig = Figure(figsize=figsize)\n",
204
+ " fig.suptitle(\"Probability plot of model residual's\", fontsize=\"x-large\")\n",
205
+ " axii = fig.subplots(1, len(models))\n",
206
+ " for ax, model, title in zip(axii, models, titles):\n",
207
+ " _ = stats.probplot(model.resid, plot=ax, rvalue=True)\n",
208
+ " ax.set_title(title)\n",
209
+ "\n",
210
+ " return fig"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": null,
216
+ "metadata": {},
217
+ "outputs": [],
218
+ "source": [
219
+ "def hghlight_rejection(s):\n",
220
+ " df = pd.DataFrame(columns=s.columns, index=s.index)\n",
221
+ " df.loc[s[\"reject_pred\"].ne(s[\"reject_obs\"]), [\"group1\", \"group2\"]] = (\n",
222
+ " \"background: red\"\n",
223
+ " )\n",
224
+ " df.loc[s[\"reject_pred\"].eq(s[\"reject_obs\"]), [\"group1\", \"group2\"]] = (\n",
225
+ " \"background: green\"\n",
226
+ " )\n",
227
+ " df.loc[s.reject_pred, [\"reject_pred\"]] = \"background: green\"\n",
228
+ " df.loc[~s.reject_pred, [\"reject_pred\"]] = \"background: red\"\n",
229
+ " df.loc[s.reject_obs, [\"reject_obs\"]] = \"background: green\"\n",
230
+ " df.loc[~s.reject_obs, [\"reject_obs\"]] = \"background: red\"\n",
231
+ " return df"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": null,
237
+ "metadata": {},
238
+ "outputs": [],
239
+ "source": [
240
+ "def get_tuckey_df(endog, groups, df_genotypes) -> pd.DataFrame:\n",
241
+ " tukey = pairwise_tukeyhsd(endog=endog, groups=groups)\n",
242
+ " df_tuc = pd.DataFrame(tukey._results_table)\n",
243
+ " df_tuc.columns = [str(c) for c in df_tuc.iloc[0]]\n",
244
+ " ret = (\n",
245
+ " df_tuc.drop(df_tuc.index[0])\n",
246
+ " .assign(group1=lambda s: s.group1.astype(str))\n",
247
+ " .assign(group2=lambda s: s.group2.astype(str))\n",
248
+ " .assign(reject=lambda s: s.reject.astype(str) == \"True\")\n",
249
+ " )\n",
250
+ " ret[\"p-adj\"] = tukey.pvalues\n",
251
+ " if df_genotypes is None:\n",
252
+ " return ret\n",
253
+ " else:\n",
254
+ " return (\n",
255
+ " ret.merge(right=df_genotypes, how=\"left\", left_on=\"group1\", right_on=\"gen\")\n",
256
+ " .drop([\"gen\"], axis=1)\n",
257
+ " .rename(columns={\"rpvloci\": \"group1_rpvloci\"})\n",
258
+ " .merge(right=df_genotypes, how=\"left\", left_on=\"group2\", right_on=\"gen\")\n",
259
+ " .drop([\"gen\"], axis=1)\n",
260
+ " .rename(columns={\"rpvloci\": \"group2_rpvloci\"})\n",
261
+ " )\n",
262
+ "\n",
263
+ "\n",
264
+ "def get_tuckey_compare(df, df_genotypes=None, groups: str = \"gen\"):\n",
265
+ " merge_on = (\n",
266
+ " [\"group1\", \"group2\"]\n",
267
+ " if df_genotypes is None\n",
268
+ " else [\"group1\", \"group2\", \"group1_rpvloci\", \"group2_rpvloci\"]\n",
269
+ " )\n",
270
+ " df_poiv = get_tuckey_df(df.p_oiv, df[groups], df_genotypes=df_genotypes)\n",
271
+ " df_oiv = get_tuckey_df(df.oiv, df[groups], df_genotypes=df_genotypes)\n",
272
+ " df = pd.merge(left=df_poiv, right=df_oiv, on=merge_on, suffixes=[\"_pred\", \"_obs\"])\n",
273
+ " return df"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": null,
279
+ "metadata": {},
280
+ "outputs": [],
281
+ "source": [
282
+ "def df_tukey_cmp_plot(df, groups):\n",
283
+ " df_tukey = (\n",
284
+ " get_tuckey_compare(df=df, groups=groups, df_genotypes=None)\n",
285
+ " .assign(pair_groups=lambda s: s.group1 + \"\\n\" + s.group2)\n",
286
+ " .sort_values(\"p-adj_obs\")\n",
287
+ " )\n",
288
+ "\n",
289
+ " df_tukey_reject = df_tukey[df_tukey.reject_obs & df_tukey.reject_pred]\n",
290
+ " df_tukey_accept = df_tukey[~df_tukey.reject_obs & ~df_tukey.reject_pred]\n",
291
+ " df_tukey_diverge = df_tukey[df_tukey.reject_obs != df_tukey.reject_pred]\n",
292
+ "\n",
293
+ " fig = Figure(figsize=(20, 6))\n",
294
+ " ax_reject, ax_diverge, ax_accept = fig.subplots(\n",
295
+ " 1,\n",
296
+ " 3,\n",
297
+ " gridspec_kw={\n",
298
+ " \"width_ratios\": [\n",
299
+ " len(df_tukey_reject),\n",
300
+ " len(df_tukey_diverge),\n",
301
+ " len(df_tukey_accept),\n",
302
+ " ]\n",
303
+ " },\n",
304
+ " sharey=True,\n",
305
+ " )\n",
306
+ "\n",
307
+ " for ax in [ax_reject, ax_accept, ax_diverge]:\n",
308
+ " ax.set_yticks(ticks=stars, labels=[\"*\", \"**\", \"***\", \"****\"])\n",
309
+ " ax.grid(False)\n",
310
+ "\n",
311
+ " ax_reject.set_title(\"Rejected\")\n",
312
+ " ax_diverge.set_title(\"Conflict\")\n",
313
+ " ax_accept.set_title(\"Accepted\")\n",
314
+ "\n",
315
+ " for ax, df in zip(\n",
316
+ " [ax_reject, ax_accept, ax_diverge],\n",
317
+ " [df_tukey_reject, df_tukey_accept, df_tukey_diverge],\n",
318
+ " ):\n",
319
+ " for star in stars:\n",
320
+ " ax.axhline(y=star, linestyle=\"-\", color=\"black\", alpha=0.5)\n",
321
+ " ax.bar(\n",
322
+ " x=df[\"pair_groups\"],\n",
323
+ " height=-np.log(df[\"p-adj_pred\"]),\n",
324
+ " width=-0.4,\n",
325
+ " align=\"edge\",\n",
326
+ " color=\"green\",\n",
327
+ " label=\"predictions\",\n",
328
+ " )\n",
329
+ " ax.bar(\n",
330
+ " x=df[\"pair_groups\"],\n",
331
+ " height=-np.log(df[\"p-adj_obs\"]),\n",
332
+ " width=0.4,\n",
333
+ " align=\"edge\",\n",
334
+ " color=\"blue\",\n",
335
+ " label=\"scorings\",\n",
336
+ " )\n",
337
+ " ax.margins(0.01)\n",
338
+ "\n",
339
+ " ax_accept.legend(loc=\"upper left\", bbox_to_anchor=[0, 1], ncols=1, fancybox=True)\n",
340
+ " ax_reject.set_ylabel(\"-log(p value)\")\n",
341
+ " ax_reject.tick_params(axis=\"y\", which=\"major\", labelsize=16)\n",
342
+ "\n",
343
+ " fig.subplots_adjust(wspace=0.05, hspace=0.05)\n",
344
+ "\n",
345
+ " return fig"
346
+ ]
347
+ },
348
+ {
349
+ "cell_type": "code",
350
+ "execution_count": null,
351
+ "metadata": {},
352
+ "outputs": [],
353
+ "source": [
354
+ "def plot_patches(df, diff_only: bool = True):\n",
355
+ " if diff_only is True:\n",
356
+ " df = df[(df.oiv != df.p_oiv)]\n",
357
+ " df = df.assign(diff=lambda s: s.oiv != s.p_oiv).sort_values(\n",
358
+ " [\"diff\", \"oiv\", \"p_oiv\"]\n",
359
+ " )\n",
360
+ " return pn.GridBox(\n",
361
+ " *[\n",
362
+ " pn.Column(\n",
363
+ " pn.pane.Markdown(f\"### {row.file_name}|{row.oiv}->p{row.p_oiv}\"),\n",
364
+ " pn.pane.Image(\n",
365
+ " object=ci.enhance_pil_image(\n",
366
+ " image=ci.load_image(\n",
367
+ " file_name=row.file_name,\n",
368
+ " path_to_images=cc.path_to_leaf_patches,\n",
369
+ " ),\n",
370
+ " brightness=1.5,\n",
371
+ " )\n",
372
+ " ),\n",
373
+ " )\n",
374
+ " for _, row in df.iterrows()\n",
375
+ " ],\n",
376
+ " ncols=len(df),\n",
377
+ " )"
378
+ ]
379
+ },
380
+ {
381
+ "cell_type": "markdown",
382
+ "metadata": {},
383
+ "source": [
384
+ "## Load Data"
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "code",
389
+ "execution_count": null,
390
+ "metadata": {},
391
+ "outputs": [],
392
+ "source": [
393
+ "df = cf.read_dataframe(\n",
394
+ " path=cc.path_to_data.joinpath(\"genotype_differenciation_dataset.csv\")\n",
395
+ ").assign(exp=lambda s: s.experiment + s.inoc.astype(str))\n",
396
+ "df"
397
+ ]
398
+ },
399
+ {
400
+ "cell_type": "code",
401
+ "execution_count": null,
402
+ "metadata": {},
403
+ "outputs": [],
404
+ "source": [
405
+ "df_dpi_6 = df[df.dpi == 6]\n",
406
+ "df_dpi_6"
407
+ ]
408
+ },
409
+ {
410
+ "cell_type": "markdown",
411
+ "metadata": {},
412
+ "source": [
413
+ "## Visualizations"
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "code",
418
+ "execution_count": null,
419
+ "metadata": {},
420
+ "outputs": [],
421
+ "source": [
422
+ "fig = Figure(figsize=(12, 4))\n",
423
+ "ax_oiv, ax_p_oiv = fig.subplots(nrows=1, ncols=2)\n",
424
+ "\n",
425
+ "full_oiv = \"OIV 452-1\"\n",
426
+ "df_oiv = df.copy()\n",
427
+ "df_oiv[full_oiv] = df_oiv.oiv\n",
428
+ "df_p_oiv = df.copy()\n",
429
+ "df_p_oiv[full_oiv] = df_p_oiv.p_oiv\n",
430
+ "\n",
431
+ "var = \"gen\"\n",
432
+ "\n",
433
+ "plot_single_progression(\n",
434
+ " ax=ax_oiv, df=df_oiv, target=full_oiv, title=\"Human scored OIV 452-1\"\n",
435
+ ")\n",
436
+ "\n",
437
+ "plot_single_progression(\n",
438
+ " ax=ax_p_oiv,\n",
439
+ " df=df_p_oiv,\n",
440
+ " target=full_oiv,\n",
441
+ " title=\"Model predicted OIV 452-1\",\n",
442
+ " show_legend=True,\n",
443
+ ")\n",
444
+ "\n",
445
+ "fig"
446
+ ]
447
+ },
448
+ {
449
+ "cell_type": "code",
450
+ "execution_count": null,
451
+ "metadata": {},
452
+ "outputs": [],
453
+ "source": [
454
+ "fig = Figure(figsize=(16, 6))\n",
455
+ "sns.histplot(\n",
456
+ " df_dpi_6.sort_values(\"gen\"),\n",
457
+ " x=\"gen\",\n",
458
+ " hue=\"gen\",\n",
459
+ " shrink=0.8,\n",
460
+ " ax=fig.subplots(1, 1),\n",
461
+ ")\n",
462
+ "\n",
463
+ "fig"
464
+ ]
465
+ },
466
+ {
467
+ "cell_type": "markdown",
468
+ "metadata": {},
469
+ "source": [
470
+ "## ANOVA"
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "code",
475
+ "execution_count": null,
476
+ "metadata": {},
477
+ "outputs": [],
478
+ "source": [
479
+ "rpv_formula = f\"~ C(gen) + C(exp) + C(exp):C(gen)\""
480
+ ]
481
+ },
482
+ {
483
+ "cell_type": "code",
484
+ "execution_count": null,
485
+ "metadata": {},
486
+ "outputs": [],
487
+ "source": [
488
+ "(\n",
489
+ " pd.concat(\n",
490
+ " [\n",
491
+ " sm.stats.anova_lm(\n",
492
+ " get_model(df=df, target=\"oiv\", dpi=i, formula=rpv_formula)\n",
493
+ " ).assign(dpi=i)\n",
494
+ " for i in sorted(list(df.dpi.unique()))\n",
495
+ " ]\n",
496
+ " )\n",
497
+ " .reset_index()\n",
498
+ " .set_index(\"dpi\")\n",
499
+ " .drop(\n",
500
+ " [\"df\", \"sum_sq\", \"mean_sq\"],\n",
501
+ " axis=1,\n",
502
+ " )\n",
503
+ " .query(\"index != 'Residual'\")\n",
504
+ " .query(\"index != 'C(exp)'\")\n",
505
+ " .rename(columns={\"index\": \"source of variation\"})\n",
506
+ " .replace(\"C(gen)\", \"genotype (between)\")\n",
507
+ " .replace(\"C(exp):C(gen)\", \"interaction genotype/experiment\")\n",
508
+ " .reset_index()\n",
509
+ ")"
510
+ ]
511
+ },
512
+ {
513
+ "cell_type": "code",
514
+ "execution_count": null,
515
+ "metadata": {},
516
+ "outputs": [],
517
+ "source": [
518
+ "df_dpi_6.groupby(\"gen\").agg(\n",
519
+ " {\"oiv\": [\"mean\", \"std\"], \"p_oiv\": [\"mean\", \"std\"]}\n",
520
+ ").reset_index()"
521
+ ]
522
+ },
523
+ {
524
+ "cell_type": "code",
525
+ "execution_count": null,
526
+ "metadata": {},
527
+ "outputs": [],
528
+ "source": [
529
+ "pn.GridBox(\n",
530
+ " pn.Column(\n",
531
+ " pn.pane.Markdown(\"### Annotated\"),\n",
532
+ " anova_table(\n",
533
+ " sm.stats.anova_lm(\n",
534
+ " get_model(df=df_dpi_6, target=\"oiv\", dpi=6, formula=rpv_formula),\n",
535
+ " typ=2,\n",
536
+ " )\n",
537
+ " ),\n",
538
+ " ),\n",
539
+ " pn.Column(\n",
540
+ " pn.pane.Markdown(\"### Predicted\"),\n",
541
+ " anova_table(\n",
542
+ " sm.stats.anova_lm(\n",
543
+ " get_model(df=df_dpi_6, target=\"p_oiv\", dpi=6, formula=rpv_formula),\n",
544
+ " typ=2,\n",
545
+ " )\n",
546
+ " ),\n",
547
+ " ),\n",
548
+ " ncols=2,\n",
549
+ ")"
550
+ ]
551
+ },
552
+ {
553
+ "cell_type": "code",
554
+ "execution_count": null,
555
+ "metadata": {},
556
+ "outputs": [],
557
+ "source": [
558
+ "plot_assumptions(\n",
559
+ " models=[\n",
560
+ " get_model(df=df_dpi_6, target=\"oiv\", dpi=6, formula=rpv_formula),\n",
561
+ " get_model(df=df_dpi_6, target=\"p_oiv\", dpi=6, formula=rpv_formula),\n",
562
+ " ],\n",
563
+ " titles=[\"Score OIV 452-1\", \"Predicted OIV 452-1\"],\n",
564
+ " figsize=(10, 5),\n",
565
+ ")"
566
+ ]
567
+ },
568
+ {
569
+ "cell_type": "markdown",
570
+ "metadata": {},
571
+ "source": [
572
+ "# Tukey HSD"
573
+ ]
574
+ },
575
+ {
576
+ "cell_type": "code",
577
+ "execution_count": null,
578
+ "metadata": {},
579
+ "outputs": [],
580
+ "source": [
581
+ "dft = get_tuckey_compare(df=df_dpi_6, groups=\"gen\", df_genotypes=None)\n",
582
+ "dft.style.apply(hghlight_rejection, axis=None)"
583
+ ]
584
+ },
585
+ {
586
+ "cell_type": "code",
587
+ "execution_count": null,
588
+ "metadata": {},
589
+ "outputs": [],
590
+ "source": [
591
+ "df_tukey_cmp_plot(df=df_dpi_6, groups=\"gen\")"
592
+ ]
593
+ },
594
+ {
595
+ "cell_type": "code",
596
+ "execution_count": null,
597
+ "metadata": {},
598
+ "outputs": [],
599
+ "source": [
600
+ "df_cmp_means = (\n",
601
+ " (df_dpi_6[df_dpi_6.gen.isin([\"1441s\", \"1466s\"])])\n",
602
+ " .groupby(\"gen\")\n",
603
+ " .agg({\"oiv\": [\"mean\", \"std\"], \"p_oiv\": [\"mean\", \"std\"]})\n",
604
+ " .reset_index()\n",
605
+ ")\n",
606
+ "df_cmp_means[\"difference\"] = df_cmp_means.oiv[\"mean\"] - df_cmp_means.p_oiv[\"mean\"]\n",
607
+ "df_cmp_means"
608
+ ]
609
+ },
610
+ {
611
+ "cell_type": "code",
612
+ "execution_count": null,
613
+ "metadata": {},
614
+ "outputs": [],
615
+ "source": [
616
+ "plot_patches(df_dpi_6[df_dpi_6.gen.isin([\"1441s\"])], diff_only=True)"
617
+ ]
618
+ },
619
+ {
620
+ "cell_type": "code",
621
+ "execution_count": null,
622
+ "metadata": {},
623
+ "outputs": [],
624
+ "source": [
625
+ "plot_patches(df_dpi_6[df_dpi_6.gen.isin([\"1466s\"])], diff_only=True)"
626
+ ]
627
+ }
628
+ ],
629
+ "metadata": {
630
+ "kernelspec": {
631
+ "display_name": "env",
632
+ "language": "python",
633
+ "name": "python3"
634
+ },
635
+ "language_info": {
636
+ "codemirror_mode": {
637
+ "name": "ipython",
638
+ "version": 3
639
+ },
640
+ "file_extension": ".py",
641
+ "mimetype": "text/x-python",
642
+ "name": "python",
643
+ "nbconvert_exporter": "python",
644
+ "pygments_lexer": "ipython3",
645
+ "version": "3.9.2"
646
+ }
647
+ },
648
+ "nbformat": 4,
649
+ "nbformat_minor": 2
650
+ }
src/leaf_patch_oiv_predictor.ipynb ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "### Step by Step OIV 452-1 predictor Training"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "## Imports"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "%load_ext autoreload\n",
24
+ "%autoreload 2"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "import warnings\n",
34
+ "from pathlib import Path\n",
35
+ "import shutil\n",
36
+ "\n",
37
+ "from tqdm import tqdm\n",
38
+ "\n",
39
+ "import pandas as pd\n",
40
+ "\n",
41
+ "from sklearn.metrics import (\n",
42
+ " confusion_matrix,\n",
43
+ " mean_squared_error,\n",
44
+ " ConfusionMatrixDisplay,\n",
45
+ " classification_report,\n",
46
+ ")\n",
47
+ "\n",
48
+ "import matplotlib.pyplot as plt\n",
49
+ "import altair as alt\n",
50
+ "\n",
51
+ "import panel as pn\n",
52
+ "\n",
53
+ "import com_const as cc\n",
54
+ "import com_func as cf\n",
55
+ "import com_augmentations as ca\n",
56
+ "import leaf_patch_oiv_predictor_model as lpopm"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "metadata": {},
62
+ "source": [
63
+ "## Setup"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": null,
69
+ "metadata": {},
70
+ "outputs": [],
71
+ "source": [
72
+ "# Remove warnings\n",
73
+ "warnings.simplefilter(action=\"ignore\", category=UserWarning)\n",
74
+ "warnings.simplefilter(action=\"ignore\", category=FutureWarning)"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": null,
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "pd.options.display.float_format = \"{:4,.4f}\".format\n",
84
+ "\n",
85
+ "pd.set_option(\"display.max_colwidth\", 500)\n",
86
+ "pd.set_option(\"display.max_columns\", 500)\n",
87
+ "pd.set_option(\"display.width\", 1000)\n",
88
+ "pd.set_option(\"display.max_rows\", 16)"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "alt.data_transformers.disable_max_rows()"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "pn.extension(\"plotly\", \"vega\", notifications=True, console_output=\"disable\")"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "markdown",
111
+ "metadata": {},
112
+ "source": [
113
+ "## Dataset"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "markdown",
118
+ "metadata": {},
119
+ "source": [
120
+ "### Load"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "train, val, test = [\n",
130
+ " cf.read_dataframe(cc.path_to_data.joinpath(f\"oiv_{d}.csv\"))\n",
131
+ " for d in [\"train\", \"val\", \"test\"]\n",
132
+ "]\n",
133
+ "alt.hconcat(\n",
134
+ " *[\n",
135
+ " alt.Chart(df.assign(oiv=lambda x: x.oiv.astype(str)))\n",
136
+ " .mark_bar()\n",
137
+ " .encode(x=\"oiv\", y=\"count()\", color=\"source\", tooltip=\"count()\")\n",
138
+ " .properties(width=200, height=300, title=title)\n",
139
+ " for (df, title) in [\n",
140
+ " (train, \"train\"),\n",
141
+ " (val, \"val\"),\n",
142
+ " (test, \"test\"),\n",
143
+ " ]\n",
144
+ " ]\n",
145
+ ")"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": null,
151
+ "metadata": {},
152
+ "outputs": [],
153
+ "source": [
154
+ "# src_patches = (\n",
155
+ "# Path(cc.path_to_root)\n",
156
+ "# .joinpath(\"..\")\n",
157
+ "# .joinpath(\"leafdisks_powderymildew\")\n",
158
+ "# .joinpath(\"data_in\")\n",
159
+ "# .joinpath(\"202311_dataset\")\n",
160
+ "# .joinpath(\"patches\")\n",
161
+ "# )\n",
162
+ "# src_patches.is_dir()\n",
163
+ "\n",
164
+ "# for d in [train, val, test]:\n",
165
+ "# for fn in tqdm(d.file_name):\n",
166
+ "# shutil.copy(src=src_patches.joinpath(fn), dst=cc.path_to_leaf_patches.joinpath(fn))"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "markdown",
171
+ "metadata": {},
172
+ "source": [
173
+ "### Augmentation"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": null,
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": [
182
+ "augmentations_kinds = [\"fix_brightness\", \"resize\", \"affine\", \"color\", \"to_tensor\"]\n",
183
+ "augmentations_params = dict(\n",
184
+ " gamma=(60, 120),\n",
185
+ " brightness_limit=0.15,\n",
186
+ " contrast_limit=0.25,\n",
187
+ " brightness_target=115,\n",
188
+ " brightness_thresholds=(115, 130),\n",
189
+ ")\n",
190
+ "\n",
191
+ "ca.test_augmentations(\n",
192
+ " df=train,\n",
193
+ " image_size=224,\n",
194
+ " path_to_images=cc.path_to_leaf_patches,\n",
195
+ " kinds=augmentations_kinds,\n",
196
+ " columns=[\"oiv\"],\n",
197
+ " **augmentations_params\n",
198
+ ")"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "markdown",
203
+ "metadata": {},
204
+ "source": [
205
+ "## Model"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "markdown",
210
+ "metadata": {},
211
+ "source": [
212
+ "### Find Batch Size"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": null,
218
+ "metadata": {},
219
+ "outputs": [],
220
+ "source": [
221
+ "batch_size = 615"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "markdown",
226
+ "metadata": {},
227
+ "source": [
228
+ "We trained the models on an NVIDIA A100 80GB PCIe that allowed us a batch size of 769 that we reduced to 615 t avoid monopolizing the GPU. Uncomment the the following block to calculate optimal batch size"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": null,
234
+ "metadata": {},
235
+ "outputs": [],
236
+ "source": [
237
+ "# batch_size = lpopm.get_bs(\n",
238
+ "# batch_size=300,\n",
239
+ "# train=train,\n",
240
+ "# val=val,\n",
241
+ "# test=test,\n",
242
+ "# augmentations_kinds=augmentations_kinds,\n",
243
+ "# augmentations_params=augmentations_params,\n",
244
+ "# shrink_factor=0.8,\n",
245
+ "# )"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": null,
251
+ "metadata": {},
252
+ "outputs": [],
253
+ "source": [
254
+ "\n",
255
+ "batch_size"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "markdown",
260
+ "metadata": {},
261
+ "source": [
262
+ "### Find Learning Rate"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": null,
268
+ "metadata": {},
269
+ "outputs": [],
270
+ "source": [
271
+ "learning_rate = 0.000363"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "markdown",
276
+ "metadata": {},
277
+ "source": [
278
+ "We found that we our selected batch size the best learning rate was 0.000363. The function hereafter will calculate on optimal learning rate for your setup."
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": null,
284
+ "metadata": {},
285
+ "outputs": [],
286
+ "source": [
287
+ "# learning_rate = lpopm.get_lr(\n",
288
+ "# train=train,\n",
289
+ "# val=val,\n",
290
+ "# test=test,\n",
291
+ "# augmentations_params=augmentations_params,\n",
292
+ "# augmentations_kinds=augmentations_kinds,\n",
293
+ "# batch_size=batch_size,\n",
294
+ "# lr_times=10,\n",
295
+ "# )\n"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": null,
301
+ "metadata": {},
302
+ "outputs": [],
303
+ "source": [
304
+ "learning_rate"
305
+ ]
306
+ },
307
+ {
308
+ "cell_type": "markdown",
309
+ "metadata": {},
310
+ "source": [
311
+ "### Train"
312
+ ]
313
+ },
314
+ {
315
+ "cell_type": "code",
316
+ "execution_count": null,
317
+ "metadata": {},
318
+ "outputs": [],
319
+ "source": [
320
+ "# lpopm.train_model(\n",
321
+ "# path_to_images=cc.path_to_leaf_patches,\n",
322
+ "# train=train,\n",
323
+ "# val=val,\n",
324
+ "# test=test,\n",
325
+ "# monitor_loss=\"mse\",\n",
326
+ "# augmentations_kinds=augmentations_kinds,\n",
327
+ "# augmentations_params=augmentations_params,\n",
328
+ "# batch_size=batch_size,\n",
329
+ "# learning_rate=learning_rate,\n",
330
+ "# )"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "markdown",
335
+ "metadata": {},
336
+ "source": [
337
+ "### Validate"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": null,
343
+ "metadata": {},
344
+ "outputs": [],
345
+ "source": [
346
+ "model = lpopm.OivDetPatchesNet.load_from_checkpoint(\n",
347
+ " cc.path_to_chk_oiv.joinpath(\"oiv_scorer.ckpt\")\n",
348
+ ")\n",
349
+ "model.path_to_images = cc.path_to_leaf_patches\n",
350
+ "model.hr_desc()"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": null,
356
+ "metadata": {},
357
+ "outputs": [],
358
+ "source": [
359
+ "test_data = model.test_data.assign(oiv=lambda x :x.fixed_oiv)\n",
360
+ "test_data[\"p_oiv\"] = model.predict(test_data)\n",
361
+ "\n",
362
+ "print(f\"MSE: {mean_squared_error(test_data.oiv.astype(int), test_data.p_oiv.astype(int)):.3f}\")\n",
363
+ "ConfusionMatrixDisplay.from_predictions(\n",
364
+ " test_data.oiv.astype(int), test_data.p_oiv.astype(int)\n",
365
+ ");"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "code",
370
+ "execution_count": null,
371
+ "metadata": {},
372
+ "outputs": [],
373
+ "source": []
374
+ }
375
+ ],
376
+ "metadata": {
377
+ "kernelspec": {
378
+ "display_name": "env",
379
+ "language": "python",
380
+ "name": "python3"
381
+ },
382
+ "language_info": {
383
+ "codemirror_mode": {
384
+ "name": "ipython",
385
+ "version": 3
386
+ },
387
+ "file_extension": ".py",
388
+ "mimetype": "text/x-python",
389
+ "name": "python",
390
+ "nbconvert_exporter": "python",
391
+ "pygments_lexer": "ipython3",
392
+ "version": "3.9.2"
393
+ }
394
+ },
395
+ "nbformat": 4,
396
+ "nbformat_minor": 2
397
+ }
src/leaf_patch_oiv_predictor_model.py ADDED
@@ -0,0 +1,1266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from functools import partial
3
+ from pathlib import Path
4
+ from datetime import datetime as dt
5
+ import json
6
+
7
+ from rich.console import Console
8
+ from rich.table import Table
9
+ from rich.pretty import Pretty
10
+
11
+ from tqdm import tqdm
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+ from siuba import _ as s
16
+ from siuba import filter as sfilter
17
+ from siuba import mutate
18
+
19
+ from sklearn.metrics import (
20
+ classification_report,
21
+ mean_absolute_error,
22
+ mean_squared_error,
23
+ )
24
+
25
+ import torch
26
+ from torch.utils.data import DataLoader, Dataset
27
+ from torch import nn
28
+
29
+ import torchmetrics
30
+
31
+ import albumentations as A
32
+ from albumentations.pytorch import ToTensorV2
33
+
34
+ import pytorch_lightning as pl
35
+
36
+ from pytorch_lightning.callbacks import (
37
+ RichProgressBar,
38
+ DeviceStatsMonitor,
39
+ ModelCheckpoint,
40
+ LearningRateMonitor,
41
+ )
42
+ from pytorch_lightning import Trainer
43
+ from pytorch_lightning.callbacks.early_stopping import EarlyStopping
44
+ from pytorch_lightning.loggers import TensorBoardLogger
45
+ from pytorch_lightning.tuner.tuning import Tuner
46
+
47
+ from coral_pytorch.losses import corn_loss
48
+ from coral_pytorch.dataset import proba_to_label, corn_label_from_logits
49
+
50
+ from transformers import logging
51
+ from transformers import (
52
+ ViTForImageClassification,
53
+ SegformerForImageClassification,
54
+ BeitForImageClassification,
55
+ SwinForImageClassification,
56
+ ConvNextForImageClassification,
57
+ DeiTForImageClassificationWithTeacher,
58
+ ResNetForImageClassification,
59
+ )
60
+
61
+ import com_const as cc
62
+ import com_image as ci
63
+ import com_augmentations as ca
64
+ import com_func as cf
65
+
66
+
67
+ logging.set_verbosity_error()
68
+
69
+ torch.set_float32_matmul_precision("high")
70
+
71
+ oiv_models_overview_path = cc.path_to_data.joinpath("oiv_models_overview.csv")
72
+
73
+ g_device = (
74
+ "mps"
75
+ if torch.backends.mps.is_built() is True
76
+ else "cuda" if torch.backends.cuda.is_built() else "cpu"
77
+ )
78
+
79
+ checkpoints_dict = {
80
+ "hf_vit_g16": {
81
+ "path": "google/vit-base-patch16-224-in21k",
82
+ "name": "Google ViT 16",
83
+ "link": "https://huggingface.co/google/vit-base-patch16-224-in21k",
84
+ "class": ViTForImageClassification,
85
+ },
86
+ "hf_bb_16": {
87
+ "path": "microsoft/beit-base-patch16-224-pt22k-ft22k",
88
+ "name": "BEiT (base-sized model, fine-tuned on ImageNet-22k)",
89
+ "link": "https://huggingface.co/microsoft/beit-base-patch16-224-pt22k-ft22k",
90
+ "class": BeitForImageClassification,
91
+ },
92
+ "hf_seg": {
93
+ "path": "nvidia/mit-b0",
94
+ "name": "Segformer",
95
+ "link": "https://huggingface.co/nvidia/mit-b0",
96
+ "class": SegformerForImageClassification,
97
+ },
98
+ "hf_bl_16": {
99
+ "path": "microsoft/beit-large-patch16-224-pt22k-ft22k",
100
+ "name": "BEiT (large-sized model, fine-tuned on ImageNet-22k)",
101
+ "link": "https://huggingface.co/microsoft/beit-large-patch16-224-pt22k-ft22k",
102
+ "class": BeitForImageClassification,
103
+ },
104
+ "hf_vit_g32": {
105
+ "path": "google/vit-large-patch32-384",
106
+ "name": "Vision Transformer (large-sized model)",
107
+ "link": "https://huggingface.co/google/vit-large-patch32516-384",
108
+ "class": ViTForImageClassification,
109
+ },
110
+ "hf_swt_t": {
111
+ "path": "microsoft/swin-tiny-patch4-window7-224",
112
+ "name": "Swin Transformer (tiny-sized model)",
113
+ "link": "https://huggingface.co/microsoft/swin-tiny-patch4-window7-224",
114
+ "class": SwinForImageClassification,
115
+ },
116
+ "hf_cnx_t": {
117
+ "path": "facebook/convnext-tiny-224",
118
+ "name": "ConvNeXT (tiny-sized model)",
119
+ "link": "https://huggingface.co/facebook/convnext-tiny-224",
120
+ "class": ConvNextForImageClassification,
121
+ },
122
+ "hf_det_b": {
123
+ "path": "facebook/deit-base-distilled-patch16-224",
124
+ "name": "Distilled Data-efficient Image Transformer (base-sized model)",
125
+ "link": "https://huggingface.co/facebook/deit-base-distilled-patch16-224",
126
+ "class": DeiTForImageClassificationWithTeacher,
127
+ },
128
+ "hf_swt_l": {
129
+ "path": "microsoft/swin-large-patch4-window12-384-in22k",
130
+ "name": "Swin Transformer (large-sized model)",
131
+ "link": "https://huggingface.co/microsoft/swin-large-patch4-window12-384-in22k",
132
+ "class": SwinForImageClassification,
133
+ },
134
+ "hf_deit_s": {
135
+ "path": "facebook/deit-small-patch16-224",
136
+ "name": "Data-efficient Image Transformer (small-sized model)",
137
+ "link": "https://huggingface.co/facebook/deit-small-patch16-224",
138
+ "class": ViTForImageClassification,
139
+ },
140
+ "hf_seg_b3": {
141
+ "path": "nvidia/mit-b3",
142
+ "name": "SegFormer (b3-sized) encoder pre-trained-only",
143
+ "link": "https://huggingface.co/nvidia/mit-b3",
144
+ "class": SegformerForImageClassification,
145
+ },
146
+ "hf_vit_gl": {
147
+ "path": "google/vit-large-patch16-224",
148
+ "name": "Vision Transformer (large-sized model)",
149
+ "link": "https://huggingface.co/google/vit-large-patch16-224",
150
+ "class": ViTForImageClassification,
151
+ },
152
+ "hf_resnet": {
153
+ "path": "microsoft/resnet-50",
154
+ "name": "ResNet-50 v1.5",
155
+ "link": "https://huggingface.co/microsoft/resnet-50",
156
+ "class": ResNetForImageClassification,
157
+ },
158
+ }
159
+
160
+
161
+ def prepare_dataframe(
162
+ df, excluded_sources, predicted_var, invert_scale: bool = False
163
+ ) -> pd.DataFrame:
164
+ df = df >> sfilter(~s[predicted_var].isna()) >> sfilter(s[predicted_var] > 0)
165
+ if isinstance(excluded_sources, list) and excluded_sources:
166
+ df = df.filter_data(excluded_sources)
167
+ elif isinstance(excluded_sources, dict):
168
+ for k, v in excluded_sources.items():
169
+ if k == "==":
170
+ df = df >> sfilter(s[v[0]] == s[v[1]])
171
+ elif k == "!=":
172
+ df = df >> sfilter(s[v[0]] != s[v[1]])
173
+ df = df.reset_index(drop=True)
174
+ if isinstance(predicted_var, str):
175
+ df[predicted_var] = (df[predicted_var] - 1) // 2
176
+ df[predicted_var] = df[predicted_var].astype(int)
177
+ if invert_scale is True:
178
+ df[predicted_var] = df[predicted_var].max() - df[predicted_var]
179
+ elif isinstance(predicted_var, list):
180
+ for pv in predicted_var:
181
+ df[pv] = (df[pv] - 1) // 2
182
+ df[pv] = df[pv].astype(int)
183
+ if invert_scale is True:
184
+ df[pv] = df[pv].max() - df[pv]
185
+
186
+ return df
187
+
188
+
189
+ class OivDetPatches(Dataset):
190
+ def __init__(
191
+ self,
192
+ dataframe,
193
+ transform,
194
+ predicted_var: str = None,
195
+ path_to_images=cc.path_to_leaf_patches,
196
+ ) -> None:
197
+ super().__init__()
198
+ if isinstance(dataframe, pd.DataFrame):
199
+ self.dataframe = dataframe.reset_index(drop=True)
200
+ self.predicted_var = predicted_var
201
+ self.dataframe = self.dataframe
202
+ elif isinstance(dataframe, list):
203
+ self.dataframe = pd.DataFrame(data={"file_name": dataframe})
204
+ self.transform = transform
205
+ self.path_to_images = path_to_images
206
+
207
+ def __len__(self):
208
+ return self.dataframe.shape[0]
209
+
210
+ def __getitem__(self, index):
211
+ img = self.transform(image=self.get_image(index=index))["image"]
212
+ if self.dataframe.shape[1] == 1 or self.predicted_var is None:
213
+ return {"image": img}
214
+ else:
215
+ return {
216
+ "image": img,
217
+ "label": torch.tensor(
218
+ self.dataframe.loc[index, self.predicted_var], dtype=torch.long
219
+ ),
220
+ }
221
+
222
+ def get_resizer(self, to_tensor: bool = False):
223
+ for a in self.transform:
224
+ if isinstance(a, A.Resize):
225
+ if to_tensor is True:
226
+ return A.Compose([a, ToTensorV2()])
227
+ return A.Compose([a])
228
+ else:
229
+ return None
230
+
231
+ def get_image(self, index):
232
+ return ci.load_image(
233
+ file_name=self.dataframe.file_name.to_list()[index],
234
+ path_to_images=self.path_to_images,
235
+ )
236
+
237
+ def get_resized_image(self, index, to_tensor: bool = False):
238
+ t = self.get_resizer(to_tensor=to_tensor)
239
+ if t is not None:
240
+ return t(
241
+ image=ci.load_image(
242
+ file_name=self.dataframe.file_name.to_list()[index],
243
+ path_to_images=self.path_to_images,
244
+ )
245
+ )["image"]
246
+ else:
247
+ return self.get_image(index=index)
248
+
249
+ def get_data(self, index):
250
+ return self.dataframe.iloc[index]
251
+
252
+
253
+ def get_encoder_data(enc_key) -> dict:
254
+ return checkpoints_dict["hf_swt_t" if enc_key == "pretrained" else enc_key]
255
+
256
+
257
+ class OivDetPatchesNet(pl.LightningModule):
258
+ def __init__(
259
+ self,
260
+ batch_size: int,
261
+ learning_rate: float,
262
+ max_epochs: int,
263
+ num_workers,
264
+ accumulate_grad_batches,
265
+ train: pd.DataFrame,
266
+ val: pd.DataFrame,
267
+ test: pd.DataFrame,
268
+ predicted_var: str = "oiv",
269
+ backbone: str = "hf_swt_t",
270
+ data_source: str = "improved_patches_v3",
271
+ augmentations_kinds: list = ["resize", "train", "to_tensor"],
272
+ augmentations_params: dict = {"gamma": (60, 180), "crop": None},
273
+ optimizer: str = "adam",
274
+ scheduler: str = None,
275
+ scheduler_params: dict = {},
276
+ conv_feature_sizes=None,
277
+ linear_features_sizes=[],
278
+ exclude_if_source: list = [],
279
+ weight_loss: bool = False,
280
+ ordinal_regression_model=None,
281
+ monitor_loss: str = "mse",
282
+ skip_linear: bool = False,
283
+ use_sigmoid: bool = False,
284
+ val_monitor_target: str = "val_monitor",
285
+ val_monitor_mode: str = "min",
286
+ salt_name: str = "oiv",
287
+ binary_data: dict = None,
288
+ path_to_images: str = cc.path_to_leaf_patches,
289
+ invert_scale: bool = False,
290
+ ) -> None:
291
+ super().__init__()
292
+
293
+ self.backbone = backbone
294
+ self.conv_feature_sizes = conv_feature_sizes
295
+ self.linear_features_sizes = linear_features_sizes
296
+
297
+ self.predicted_var = predicted_var
298
+ self.invert_scale = invert_scale
299
+ self.model_name = (
300
+ f"{salt_name}_{predicted_var}_{self.backbone_name}_{monitor_loss}"
301
+ )
302
+ if isinstance(ordinal_regression_model, str):
303
+ self.model_name = self.model_name + "_" + ordinal_regression_model
304
+ else:
305
+ self.model_name = self.model_name + "_" + "classic"
306
+ self.short_model_name = f"oiv_{predicted_var}"
307
+
308
+ # dataframes
309
+ self.exclude_if_source = exclude_if_source
310
+ self.train_data = prepare_dataframe(
311
+ df=train,
312
+ excluded_sources=self.exclude_if_source,
313
+ predicted_var=self.predicted_var,
314
+ invert_scale=invert_scale,
315
+ )
316
+ self.val_data = prepare_dataframe(
317
+ df=val,
318
+ excluded_sources=self.exclude_if_source,
319
+ predicted_var=self.predicted_var,
320
+ invert_scale=invert_scale,
321
+ )
322
+ self.test_data = prepare_dataframe(
323
+ df=test,
324
+ excluded_sources=self.exclude_if_source,
325
+ predicted_var=self.predicted_var,
326
+ invert_scale=invert_scale,
327
+ )
328
+ self.data_source = data_source
329
+ self.path_to_images = path_to_images
330
+ self.labels_cardinal = len(self.train_data[self.predicted_var].unique())
331
+ self.binary_data = binary_data
332
+
333
+ # Encoder
334
+ enc_data = get_encoder_data(self.backbone)
335
+ self.encoder = enc_data["class"].from_pretrained(
336
+ enc_data["path"],
337
+ num_labels=self.labels_cardinal,
338
+ problem_type="single_label_classification",
339
+ ignore_mismatched_sizes=True,
340
+ )
341
+
342
+ self.image_size = 224
343
+ self.ordinal_regression_model = ordinal_regression_model
344
+ self.flatten = nn.Flatten()
345
+ self.skip_linear = skip_linear
346
+ self.use_sigmoid = use_sigmoid
347
+ if self.ordinal_regression_model == "corn":
348
+ self.linear_out = nn.Linear(
349
+ in_features=self._get_conv_output_size(self.image_size),
350
+ out_features=self.labels_cardinal - 1,
351
+ )
352
+ else:
353
+ self.linear_out = nn.Linear(
354
+ in_features=self._get_conv_output_size(self.image_size),
355
+ out_features=self.labels_cardinal,
356
+ )
357
+
358
+ # Hyperparameters
359
+ self.batch_size = batch_size
360
+ self.selected_device = g_device
361
+ self.learning_rate = learning_rate
362
+ self.start_lr = self.learning_rate
363
+ self.num_workers = num_workers
364
+ self.max_epochs = max_epochs
365
+ self.accumulate_grad_batches = accumulate_grad_batches
366
+ self.weight_loss = weight_loss
367
+ if self.ordinal_regression_model is None:
368
+ if weight_loss is True:
369
+ vc = train[self.predicted_var].value_counts()
370
+ self.criterion = nn.CrossEntropyLoss(
371
+ weight=torch.FloatTensor(
372
+ [vc[i] / len(train) for i in [0, 1, 2, 3, 4]]
373
+ )
374
+ )
375
+ else:
376
+ self.criterion = nn.CrossEntropyLoss()
377
+ elif self.ordinal_regression_model == "mse":
378
+ self.criterion = nn.MSELoss()
379
+ elif self.ordinal_regression_model == "mae":
380
+ self.criterion = nn.L1Loss()
381
+ elif self.ordinal_regression_model == "corn":
382
+ self.criterion = corn_loss
383
+
384
+ # Set up attributes for computing the MAE
385
+ self.monitor_loss = monitor_loss
386
+ self.val_monitor_target = val_monitor_target
387
+ self.val_monitor_mode = val_monitor_mode
388
+ if self.monitor_loss == "mse":
389
+ self.train_monitor = torchmetrics.MeanSquaredError()
390
+ self.val_monitor = torchmetrics.MeanSquaredError()
391
+ self.test_monitor = torchmetrics.MeanSquaredError()
392
+ elif self.monitor_loss == "mae":
393
+ self.train_monitor = torchmetrics.MeanAbsoluteError()
394
+ self.val_monitor = torchmetrics.MeanAbsoluteError()
395
+ self.test_monitor = torchmetrics.MeanAbsoluteError()
396
+
397
+ # Optimizer
398
+ self.optimizer = optimizer
399
+ self.scheduler = scheduler
400
+ self.scheduler_params = scheduler_params
401
+
402
+ # albumentations
403
+ self.augmentations_kinds = augmentations_kinds
404
+ self.augmentations_params = augmentations_params
405
+ self.augmentations_params["mean"] = (0.5, 0.5, 0.5)
406
+ self.augmentations_params["std"] = (0.5, 0.5, 0.5)
407
+
408
+ self.train_augmentations = ca.get_augmentations(
409
+ image_size=self.image_size,
410
+ kinds=self.augmentations_kinds,
411
+ **self.augmentations_params,
412
+ )
413
+
414
+ self.val_augmentations = ca.get_augmentations(
415
+ image_size=self.image_size,
416
+ kinds=["resize", "to_tensor"],
417
+ **self.augmentations_params,
418
+ )
419
+
420
+ self._thresholds = None
421
+ self._thresholds_source = None
422
+
423
+ self.save_hyperparameters()
424
+
425
+ def forward(self, x, binary_data=None, *args, **kwargs):
426
+ x = self.encoder(x)
427
+ if hasattr(x, "logits"):
428
+ x = x.logits
429
+ x = self.flatten(x)
430
+ if binary_data is not None:
431
+ x = torch.cat(x, binary_data)
432
+ if self.linear_out is not None:
433
+ x = self.linear_out(x)
434
+ if self.use_sigmoid:
435
+ x = nn.functional.sigmoid(x)
436
+ return x
437
+
438
+ def hr_desc(self):
439
+ table = Table(title=f"{self.model_name} params & values")
440
+ table.add_column("Param", justify="right", style="bold", no_wrap=True)
441
+ table.add_column("Value")
442
+
443
+ def add_pairs(table_, attributes: list) -> None:
444
+ for a in attributes:
445
+ try:
446
+ table_.add_row(a, Pretty(getattr(self, a)))
447
+ except:
448
+ pass
449
+
450
+ add_pairs(
451
+ table,
452
+ [
453
+ "backbone",
454
+ "predicted_var",
455
+ "invert_scale",
456
+ "skip_linear",
457
+ "use_sigmoid",
458
+ "loss_function",
459
+ "monitor_loss",
460
+ "val_monitor_target",
461
+ "val_monitor_mode",
462
+ "ordinal_regression_model",
463
+ "checkpoint_mode",
464
+ ],
465
+ )
466
+ for k, v in get_encoder_data(self.backbone).items():
467
+ if isinstance(v, str):
468
+ table.add_row(k, Pretty(v))
469
+
470
+ add_pairs(
471
+ table,
472
+ ["batch_size", "image_size", "augmentations_kinds", "augmentations_params"],
473
+ )
474
+
475
+ try:
476
+ if self.backbone == "custom":
477
+ table.add_row(
478
+ "Conv Encoder",
479
+ "\n".join(
480
+ [layer_data.hr_desc() for layer_data in self.conv_feature_sizes]
481
+ ),
482
+ )
483
+ table.add_row(
484
+ "Conv output size", str(self._get_conv_output_size(self.image_size))
485
+ )
486
+ table.add_row("Linear Encoder", self.encoder.lin_encoder.hr_desc())
487
+ except:
488
+ pass
489
+
490
+ eis = str(self.exclude_if_source)
491
+ if ">" in eis:
492
+ eis = (
493
+ eis.split(">")[1]
494
+ .replace("(", "")
495
+ .replace(")", "")
496
+ .replace("_,", "")
497
+ .replace("_.", "")
498
+ )
499
+ table.add_row("exclude_if_source", Pretty(eis))
500
+ table.add_row(
501
+ "path_to_images",
502
+ str(self.path_to_images.relative_to(cc.path_to_root.absolute())),
503
+ )
504
+
505
+ table.add_row(
506
+ "include_if_source",
507
+ str(self.train_data.source.sort_values().unique()),
508
+ )
509
+
510
+ add_pairs(
511
+ table,
512
+ [
513
+ "weight_loss",
514
+ "learning_rate",
515
+ "start_lr",
516
+ "optimizer",
517
+ "scheduler",
518
+ "scheduler_params",
519
+ "val_split",
520
+ ],
521
+ )
522
+
523
+ for name, df in zip(
524
+ ["train", "val", "test"],
525
+ [self.train_data, self.val_data, self.test_data],
526
+ ):
527
+ table.add_row(name, str(df.shape))
528
+
529
+ add_pairs(table_=table, attributes=["data_source"])
530
+
531
+ Console().print(table)
532
+
533
+ def do_test_augmentations(self):
534
+ ca.test_augmentations(
535
+ self.train_data,
536
+ self.image_size,
537
+ kinds=self.augmentations_kinds,
538
+ **self.augmentations_params,
539
+ )
540
+
541
+ def predict_sample(self, sample, device=g_device):
542
+ self.to(device)
543
+ if self.ordinal_regression_model == "coral":
544
+ prediction = proba_to_label(
545
+ torch.sigmoid(self(sample["image"].unsqueeze(0).to(device)))
546
+ )
547
+ elif self.ordinal_regression_model == "corn":
548
+ prediction = corn_label_from_logits(
549
+ self(sample["image"].unsqueeze(0).to(device))
550
+ )
551
+ else:
552
+ prediction = torch.argmax(
553
+ self(sample["image"].unsqueeze(0).to(device)),
554
+ dim=1,
555
+ )
556
+ return prediction.detach().to("cpu").flatten()
557
+
558
+ def predict_image(self, file_path, device=g_device):
559
+ return self.predict_sample(
560
+ self.val_augmentations(
561
+ image=ci.load_image(
562
+ file_path
563
+ if isinstance(file_path, Path)
564
+ else self.path_to_images.joinpath(file_path)
565
+ )
566
+ ),
567
+ device=device,
568
+ )
569
+
570
+ def embed_sample(self, sample, device=g_device):
571
+ self.to(device)
572
+ if self.ordinal_regression_model == "coral":
573
+ raise NotImplementedError
574
+ elif self.ordinal_regression_model == "corn":
575
+ embeddings = torch.sigmoid(self(sample["image"].unsqueeze(0).to(device)))
576
+ else:
577
+ embeddings = torch.sigmoid(self(sample["image"].unsqueeze(0).to(device)))
578
+ return embeddings.detach().to("cpu").flatten().numpy()
579
+
580
+ def embed_image(self, file_path, device=g_device):
581
+ return self.embed_sample(
582
+ self.val_augmentations(
583
+ image=ci.load_image(
584
+ file_path
585
+ if isinstance(file_path, Path)
586
+ else self.path_to_images.joinpath(file_path)
587
+ )
588
+ ),
589
+ device=device,
590
+ )
591
+
592
+ def predict(self, dataset: str = "val", show_progress: bool = True):
593
+ predictions = []
594
+
595
+ self.eval()
596
+ self.to(g_device)
597
+ dataset = self.get_dataset(dataset=dataset)
598
+
599
+ if show_progress is True:
600
+ for sample in tqdm(dataset, desc="Predicting"):
601
+ predictions.append(self.predict_sample(sample=sample))
602
+ else:
603
+ for sample in dataset:
604
+ predictions.append(self.predict_sample(sample=sample))
605
+
606
+ return torch.stack(predictions).detach().cpu().numpy()
607
+
608
+ def embed_data(self, dataset, device=g_device, predicted_var=None):
609
+ self.eval()
610
+ self.to(device)
611
+ dataset = self.get_dataset(dataset=dataset, predicted_var=predicted_var)
612
+ ret = pd.DataFrame()
613
+
614
+ for i in tqdm(range(len(dataset))):
615
+ sample = dataset[i]
616
+ emmbedding = self.embed_sample(sample=sample, device=device)
617
+ ret = pd.concat(
618
+ [
619
+ ret,
620
+ pd.DataFrame(
621
+ data={
622
+ "file_name": dataset.dataframe.file_name.to_list()[i],
623
+ "oiv": sample["label"],
624
+ }
625
+ | {f"Dim {i}": [enc] for i, enc in enumerate(emmbedding)}
626
+ ),
627
+ ]
628
+ )
629
+
630
+ return ret
631
+
632
+ def _get_conv_output_size(self, image_shape):
633
+ batch_size = 3
634
+ tensor_ = self.encoder(
635
+ torch.autograd.Variable(torch.rand(batch_size, 3, image_shape, image_shape))
636
+ )
637
+ return tensor_.logits.size(1) if hasattr(tensor_, "logits") else tensor_.size(1)
638
+
639
+ def get_dataset(self, dataset: str = "val", predicted_var=None):
640
+ if isinstance(dataset, str):
641
+ dataset = (
642
+ self.val_dataloader().dataset
643
+ if dataset == "val"
644
+ else (
645
+ self.test_dataloader().dataset
646
+ if dataset == "test"
647
+ else self.train_dataloader().dataset
648
+ )
649
+ )
650
+ elif isinstance(dataset, pd.DataFrame):
651
+ return OivDetPatches(
652
+ dataframe=dataset,
653
+ transform=self.val_augmentations,
654
+ path_to_images=self.path_to_images,
655
+ predicted_var=predicted_var,
656
+ )
657
+ elif isinstance(dataset, OivDetPatches):
658
+ return dataset
659
+ elif isinstance(dataset, list):
660
+ return OivDetPatches(dataframe=dataset, transform=self.val_augmentations)
661
+ return dataset
662
+
663
+ def configure_optimizers(self):
664
+ # Optimizer
665
+ if self.optimizer == "adam":
666
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
667
+ elif self.optimizer == "sgd":
668
+ optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
669
+ else:
670
+ optimizer = None
671
+
672
+ # Scheduler
673
+ if self.scheduler == "cycliclr":
674
+ scheduler = torch.optim.lr_scheduler.CyclicLR(
675
+ optimizer,
676
+ base_lr=self.learning_rate,
677
+ max_lr=0.01,
678
+ step_size_up=100,
679
+ mode=self.scheduler_mode,
680
+ )
681
+ elif self.scheduler == "steplr":
682
+ self.scheduler_params["optimizer"] = optimizer
683
+ scheduler = torch.optim.lr_scheduler.StepLR(**self.scheduler_params)
684
+ self.scheduler_params.pop("optimizer")
685
+ elif self.scheduler == "plateau":
686
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
687
+ optimizer,
688
+ mode="min",
689
+ factor=0.2,
690
+ patience=10,
691
+ min_lr=1e-6,
692
+ )
693
+ scheduler = {"scheduler": scheduler, "monitor": "val_loss"}
694
+ else:
695
+ scheduler = None
696
+ if scheduler is None:
697
+ return optimizer
698
+ else:
699
+ return [optimizer], [scheduler]
700
+
701
+ def train_dataloader(self) -> DataLoader:
702
+ return DataLoader(
703
+ OivDetPatches(
704
+ dataframe=self.train_data,
705
+ transform=self.train_augmentations,
706
+ predicted_var=self.predicted_var,
707
+ path_to_images=self.path_to_images,
708
+ ),
709
+ batch_size=self.batch_size,
710
+ shuffle=True,
711
+ num_workers=self.num_workers,
712
+ pin_memory=True,
713
+ )
714
+
715
+ def val_dataloader(self):
716
+ return DataLoader(
717
+ OivDetPatches(
718
+ dataframe=self.val_data,
719
+ transform=self.val_augmentations,
720
+ predicted_var=self.predicted_var,
721
+ path_to_images=self.path_to_images,
722
+ ),
723
+ batch_size=self.batch_size,
724
+ num_workers=self.num_workers,
725
+ pin_memory=True,
726
+ )
727
+
728
+ def test_dataloader(self):
729
+ return DataLoader(
730
+ OivDetPatches(
731
+ dataframe=self.test_data,
732
+ transform=self.val_augmentations,
733
+ predicted_var=self.predicted_var,
734
+ path_to_images=self.path_to_images,
735
+ ),
736
+ batch_size=self.batch_size,
737
+ num_workers=self.num_workers,
738
+ pin_memory=True,
739
+ )
740
+
741
+ def compute_loss(self, preds, targets):
742
+ return self.criterion(preds, targets)
743
+
744
+ def _shared_step(self, batch):
745
+ x, target = batch["image"], batch["label"]
746
+ logits = self(x)
747
+ if self.ordinal_regression_model == "corn":
748
+ loss = self.criterion(logits, target, num_classes=self.labels_cardinal)
749
+ predicted_labels = corn_label_from_logits(logits)
750
+ else:
751
+ loss = self.compute_loss(logits, target)
752
+ predicted_labels = torch.argmax(logits, dim=1)
753
+ return loss, target, predicted_labels
754
+
755
+ def training_step(self, batch, batch_idx):
756
+ loss, true_labels, predicted_labels = self._shared_step(batch)
757
+ self.log("train_loss", loss)
758
+ self.train_monitor(predicted_labels, true_labels)
759
+ self.log("train_monitor", self.train_monitor, on_epoch=True, on_step=False)
760
+ return loss # this is passed to the optimzer for training
761
+
762
+ def validation_step(self, batch, batch_idx):
763
+ loss, true_labels, predicted_labels = self._shared_step(batch)
764
+ self.log("val_loss", loss)
765
+ self.val_monitor(predicted_labels, true_labels)
766
+ self.log(
767
+ "val_monitor",
768
+ self.val_monitor,
769
+ on_epoch=True,
770
+ on_step=False,
771
+ prog_bar=self.val_monitor_target == "val_monitor",
772
+ )
773
+ return loss
774
+
775
+ def test_step(self, batch, batch_idx):
776
+ loss, true_labels, predicted_labels = self._shared_step(batch)
777
+ self.test_monitor(predicted_labels, true_labels)
778
+ self.log("test_monitor", self.test_monitor, on_epoch=True, on_step=False)
779
+ return loss
780
+
781
+ def check_forward(self, index: int = 0):
782
+ self.eval()
783
+ return self.training_step(
784
+ next(
785
+ iter(
786
+ DataLoader(
787
+ OivDetPatches(
788
+ dataframe=self.val_data,
789
+ transform=self.val_augmentations,
790
+ predicted_var=self.predicted_var,
791
+ path_to_images=self.path_to_images,
792
+ ),
793
+ batch_size=2,
794
+ num_workers=index,
795
+ pin_memory=True,
796
+ )
797
+ )
798
+ ),
799
+ index,
800
+ )
801
+
802
+ def get_trainer(
803
+ self,
804
+ checkpoints_path: Path = cc.path_to_chk_oiv,
805
+ log_every_n_steps: int = 5,
806
+ patience: int = 10,
807
+ patience_min_delta: float = 0.0005,
808
+ ):
809
+ callbacks = [
810
+ RichProgressBar(),
811
+ EarlyStopping(
812
+ monitor=self.val_monitor_target,
813
+ mode=self.val_monitor_mode,
814
+ patience=patience,
815
+ min_delta=patience_min_delta,
816
+ ),
817
+ DeviceStatsMonitor(),
818
+ ModelCheckpoint(
819
+ save_top_k=1,
820
+ monitor=self.val_monitor_target,
821
+ mode=self.val_monitor_mode,
822
+ auto_insert_metric_name=True,
823
+ filename=self.short_model_name
824
+ + "-{val_monitor:.3f}-{epoch}-{val_loss:.3f}-{train_loss:.3f}-{step}",
825
+ ),
826
+ LearningRateMonitor(logging_interval="epoch"),
827
+ ]
828
+ return Trainer(
829
+ default_root_dir=str(checkpoints_path),
830
+ logger=TensorBoardLogger(
831
+ save_dir=str(checkpoints_path),
832
+ version=self.model_name + "_" + dt.now().strftime("%Y%m%d_%H%M%S"),
833
+ name="lightning_logs",
834
+ ),
835
+ accelerator="cpu" if self.selected_device == "cpu" else "gpu",
836
+ max_epochs=self.max_epochs,
837
+ log_every_n_steps=log_every_n_steps,
838
+ callbacks=callbacks,
839
+ accumulate_grad_batches=self.accumulate_grad_batches,
840
+ )
841
+
842
+ def tune_trainer(
843
+ self,
844
+ trainer: Trainer,
845
+ tune_options: list = ["find_lr", "find_bs"],
846
+ find_bs_mode: str = "binsearch",
847
+ ):
848
+ tuner = Tuner(trainer=trainer)
849
+ if "find_lr" in tune_options:
850
+ tuner.lr_find(self)
851
+ if "find_bs" in tune_options:
852
+ tuner.scale_batch_size(model=self, mode=find_bs_mode)
853
+
854
+ @staticmethod
855
+ def short_bin_label(label):
856
+ if label == "sporulation":
857
+ return "sp"
858
+ if label == "necrosis_dots":
859
+ return "nd"
860
+ if label == "necrosis_stains":
861
+ return "nf"
862
+ if label == "necrosis_senescence":
863
+ return "ns"
864
+ if label == "necrosis":
865
+ return "n"
866
+ if label == "stains":
867
+ return "s"
868
+
869
+ @staticmethod
870
+ def name_from_backbone(backbone):
871
+ if isinstance(backbone, str):
872
+ return backbone
873
+ elif isinstance(backbone, dict):
874
+ ret = "bin"
875
+ if "labels" in backbone:
876
+ labels = backbone["labels"]
877
+ if isinstance(labels, str):
878
+ labels = (
879
+ labels.replace("[", "")
880
+ .replace("]", "")
881
+ .replace("'", "")
882
+ .replace('"', "")
883
+ .replace(" ", "")
884
+ .split(",")
885
+ )
886
+ for label in labels:
887
+ ret += "." + OivDetPatchesNet.short_bin_label(label)
888
+ if "max_epochs" in backbone:
889
+ ret += f"_me{backbone['max_epochs']}"
890
+ if "exclude_if_source" in backbone:
891
+ ret += f"_xis{len(backbone['exclude_if_source'])}"
892
+ if "_" not in ret:
893
+ ret += "_max_f1wa"
894
+ return ret
895
+ else:
896
+ raise Exception(f"Unknown backbone type {type(backbone)}")
897
+
898
+ @property
899
+ def labels(self):
900
+ return [1, 3, 5, 7, 9] if self.labels_cardinal == 5 else [1, 5, 9]
901
+
902
+ @property
903
+ def grad_cam_layer(self):
904
+ if isinstance(self.backbone, str):
905
+ return self.encoder.swin.layernorm
906
+ elif isinstance(self.backbone, dict):
907
+ return self.encoder.encoder.swin.layernorm
908
+ else:
909
+ raise Exception(f"Unknown backbone type {type(self.backbone)}")
910
+
911
+ @property
912
+ def relative_path_to_images(self):
913
+ return self.path_to_images.relative_to(cc.path_to_root.absolute())
914
+
915
+ @property
916
+ def backbone_name(self):
917
+ return OivDetPatchesNet.name_from_backbone(self.backbone)
918
+
919
+
920
+ def get_model(path_to_model) -> OivDetPatchesNet:
921
+ return OivDetPatchesNet.load_from_checkpoint(str(path_to_model))
922
+
923
+
924
+ def expand_dict(d: dict) -> pd.DataFrame:
925
+ return pd.DataFrame(data={k: [v] for k, v in d.items()})
926
+
927
+
928
+ def get_model_data(chk, test_data="val"):
929
+ name_data = {"target_var": chk.stem.split("-")[0][4:]} | {
930
+ k: v
931
+ for k, v in {
932
+ kv.split("=")[0]: kv.split("=")[1] for kv in chk.stem.split("-")[1:]
933
+ }.items()
934
+ }
935
+
936
+ model_: OivDetPatchesNet = get_model(str(chk))
937
+
938
+ model_data = {}
939
+ for key in [
940
+ "batch_size",
941
+ "num_workers",
942
+ "max_epochs",
943
+ "accumulate_grad_batches",
944
+ "image_size",
945
+ "augmentations_kinds",
946
+ "augmentations_params",
947
+ "labels",
948
+ "invert_scale",
949
+ "exclude_if_source",
950
+ "learning_rate",
951
+ "start_lr",
952
+ "optimizer",
953
+ "scheduler",
954
+ "data_source",
955
+ "relative_path_to_images",
956
+ "loss_function",
957
+ "checkpoint_monitor",
958
+ "checkpoint_mode",
959
+ "ordinal_regression_model",
960
+ "monitor_loss",
961
+ "use_sigmoid",
962
+ "skip_linear",
963
+ ]:
964
+ try:
965
+ model_data[key] = str(getattr(model_, key))
966
+ except:
967
+ pass
968
+
969
+ eis = str(model_.exclude_if_source)
970
+ if ">" in eis:
971
+ eis = (
972
+ eis.split(">")[1]
973
+ .replace("(", "")
974
+ .replace(")", "")
975
+ .replace("_,", "")
976
+ .replace("_.", "")
977
+ )
978
+ model_data["exclude_if_source"] = eis
979
+
980
+ for k, v in model_.scheduler_params.items():
981
+ model_data[f"sched_{k}"] = v
982
+
983
+ model_data["backbone"] = model_.backbone_name
984
+
985
+ y_hat = model_.predict(dataset=test_data, show_progress=False)
986
+ y = model_.get_dataset(dataset=test_data).dataframe[model_.predicted_var]
987
+ cr = classification_report(
988
+ y_true=y, y_pred=y_hat, output_dict=True, target_names=[1, 3, 5, 7, 9]
989
+ )
990
+
991
+ return expand_dict(
992
+ d={
993
+ "file_name": chk.stem,
994
+ "date": chk.parent.parent.name.split("_")[-2],
995
+ "time": chk.parent.parent.name.split("_")[-1],
996
+ }
997
+ | name_data
998
+ | model_data
999
+ | {
1000
+ "accuracy": cr["accuracy"],
1001
+ "macro avg": cr["macro avg"]["f1-score"],
1002
+ "weighted avg": cr["weighted avg"]["f1-score"],
1003
+ "mse": mean_squared_error(y, y_hat),
1004
+ "rmse": mean_squared_error(y, y_hat, squared=False),
1005
+ "mae": mean_absolute_error(y, y_hat),
1006
+ }
1007
+ | {
1008
+ f"f1-{k}": v["f1-score"]
1009
+ for k, v in cr.items()
1010
+ if isinstance(k, int) is True
1011
+ }
1012
+ | {
1013
+ name: str(df.shape[0])
1014
+ for name, df in zip(
1015
+ ["train_count", "val_count", "test_count"],
1016
+ [model_.train_data, model_.val_data, model_.test_data],
1017
+ )
1018
+ }
1019
+ | {"file_path": str(chk)}
1020
+ )
1021
+
1022
+
1023
+ def update_models_overview(
1024
+ test_data: pd.DataFrame = "val",
1025
+ overview_path=oiv_models_overview_path,
1026
+ checkpoints_path=cc.path_to_chk_oiv,
1027
+ force_reset: bool = False,
1028
+ add_new: bool = True,
1029
+ ):
1030
+ if overview_path.is_file() is True and force_reset is False:
1031
+ models_overview = cf.read_dataframe(overview_path)
1032
+ else:
1033
+ models_overview = pd.DataFrame().assign(file_name=None)
1034
+
1035
+ if add_new is False:
1036
+ return models_overview
1037
+
1038
+ checkpoints = sorted(
1039
+ [
1040
+ chk
1041
+ for chk in checkpoints_path.rglob("*.ckpt")
1042
+ if chk.name.startswith(".") is False
1043
+ ]
1044
+ )
1045
+
1046
+ for chk in tqdm(checkpoints, desc="Building models summaries"):
1047
+ if chk.stem in models_overview.file_name.unique():
1048
+ continue
1049
+ try:
1050
+ new_data = get_model_data(chk, test_data)
1051
+ except Exception as e:
1052
+ print("______________________________")
1053
+ print(chk)
1054
+ print(str(e))
1055
+ # break
1056
+ else:
1057
+ models_overview = pd.concat([models_overview, new_data]).reset_index(
1058
+ drop=True
1059
+ )
1060
+ ret = models_overview.sort_values(["date", "time"]).replace(
1061
+ ["False", "True"], [False, True]
1062
+ ).sort_values(["date", "time"]) >> sfilter(
1063
+ s.file_path.isin([str(c) for c in checkpoints])
1064
+ )
1065
+ ret["timestamp"] = pd.to_datetime(
1066
+ pd.to_datetime(models_overview.date, format="%Y%m%d").astype(str)
1067
+ + " "
1068
+ + pd.to_datetime(models_overview.time, format="%H%M%S").astype(str)
1069
+ )
1070
+
1071
+ return cf.write_dataframe(df=ret, path=overview_path)
1072
+
1073
+
1074
+ def create_model(
1075
+ train,
1076
+ val,
1077
+ test,
1078
+ augmentations_kinds,
1079
+ augmentations_params,
1080
+ backbone: str = "hf_swt_t",
1081
+ orm: str = "corn",
1082
+ predicted_var="oiv",
1083
+ learning_rate: float = 0.00055,
1084
+ batch_size=400,
1085
+ monitor_loss="mse",
1086
+ scheduler="steplr",
1087
+ scheduler_params={"step_size": 6, "gamma": 0.85},
1088
+ exclude_if_source=[],
1089
+ path_to_images=cc.path_to_leaf_patches,
1090
+ data_source="raw_dataset",
1091
+ conv_feature_sizes=None,
1092
+ linear_features_sizes=[],
1093
+ invert_scale: bool = False,
1094
+ ):
1095
+ return OivDetPatchesNet(
1096
+ backbone=backbone,
1097
+ train=train,
1098
+ val=val,
1099
+ test=test,
1100
+ batch_size=batch_size,
1101
+ learning_rate=learning_rate,
1102
+ num_workers=10,
1103
+ max_epochs=200,
1104
+ predicted_var=predicted_var,
1105
+ accumulate_grad_batches=1,
1106
+ scheduler=scheduler,
1107
+ scheduler_params=scheduler_params,
1108
+ augmentations_kinds=augmentations_kinds,
1109
+ exclude_if_source=exclude_if_source,
1110
+ augmentations_params=augmentations_params,
1111
+ ordinal_regression_model=orm,
1112
+ monitor_loss=monitor_loss,
1113
+ path_to_images=path_to_images,
1114
+ data_source=data_source,
1115
+ conv_feature_sizes=conv_feature_sizes,
1116
+ linear_features_sizes=linear_features_sizes,
1117
+ invert_scale=invert_scale,
1118
+ )
1119
+
1120
+
1121
+ def train_model(
1122
+ train,
1123
+ val,
1124
+ test,
1125
+ augmentations_kinds,
1126
+ augmentations_params,
1127
+ backbone: str = "hf_swt_t",
1128
+ orm: str = "corn",
1129
+ predicted_var="oiv",
1130
+ learning_rate: float = 0.00055,
1131
+ batch_size=400,
1132
+ monitor_loss="mae",
1133
+ scheduler="steplr",
1134
+ scheduler_params={"step_size": 6, "gamma": 0.85},
1135
+ patience=15,
1136
+ exclude_if_source=[],
1137
+ path_to_images=cc.path_to_leaf_patches,
1138
+ data_source="raw_dataset",
1139
+ conv_feature_sizes=None,
1140
+ linear_features_sizes=[],
1141
+ invert_scale: bool = False,
1142
+ checkpoints_path: Path = cc.path_to_chk_oiv,
1143
+ ):
1144
+ model = create_model(
1145
+ backbone=backbone,
1146
+ train=train,
1147
+ val=val,
1148
+ test=test,
1149
+ augmentations_kinds=augmentations_kinds,
1150
+ augmentations_params=augmentations_params,
1151
+ orm=orm,
1152
+ learning_rate=learning_rate,
1153
+ batch_size=batch_size,
1154
+ monitor_loss=monitor_loss,
1155
+ scheduler=scheduler,
1156
+ scheduler_params=scheduler_params,
1157
+ exclude_if_source=exclude_if_source,
1158
+ predicted_var=predicted_var,
1159
+ path_to_images=path_to_images,
1160
+ data_source=data_source,
1161
+ conv_feature_sizes=conv_feature_sizes,
1162
+ linear_features_sizes=linear_features_sizes,
1163
+ invert_scale=invert_scale,
1164
+ )
1165
+ model.hr_desc()
1166
+ trainer = model.get_trainer(
1167
+ patience=patience, log_every_n_steps=1, checkpoints_path=checkpoints_path
1168
+ )
1169
+ trainer.fit(model)
1170
+
1171
+
1172
+ def get_bs(
1173
+ train,
1174
+ val,
1175
+ test,
1176
+ augmentations_kinds,
1177
+ augmentations_params,
1178
+ backbone: str = "hf_swt_t",
1179
+ predicted_var="oiv",
1180
+ orm: str = "corn",
1181
+ batch_size=400,
1182
+ find_bs_mode: str = "binsearch",
1183
+ shrink_factor: float = 1.0,
1184
+ conv_feature_sizes=None,
1185
+ linear_features_sizes=[],
1186
+ ):
1187
+ model_ = create_model(
1188
+ backbone=backbone,
1189
+ train=train,
1190
+ val=val,
1191
+ test=test,
1192
+ augmentations_kinds=augmentations_kinds,
1193
+ augmentations_params=augmentations_params,
1194
+ orm=orm,
1195
+ batch_size=batch_size,
1196
+ conv_feature_sizes=conv_feature_sizes,
1197
+ linear_features_sizes=linear_features_sizes,
1198
+ predicted_var=predicted_var,
1199
+ )
1200
+ trainer = model_.get_trainer(checkpoints_path=cc.path_to_chk_oiv)
1201
+ model_.tune_trainer(
1202
+ trainer=trainer, tune_options=["find_bs"], find_bs_mode=find_bs_mode
1203
+ )
1204
+ return int(model_.batch_size * shrink_factor)
1205
+
1206
+
1207
+ def _inner_get_lr(
1208
+ train,
1209
+ val,
1210
+ test,
1211
+ augmentations_kinds,
1212
+ augmentations_params,
1213
+ backbone: str = "hf_swt_t",
1214
+ predicted_var="oiv",
1215
+ orm: str = "corn",
1216
+ batch_size=400,
1217
+ conv_feature_sizes=None,
1218
+ linear_features_sizes=[],
1219
+ ):
1220
+ model_ = create_model(
1221
+ backbone=backbone,
1222
+ train=train,
1223
+ val=val,
1224
+ test=test,
1225
+ augmentations_kinds=augmentations_kinds,
1226
+ augmentations_params=augmentations_params,
1227
+ orm=orm,
1228
+ batch_size=batch_size,
1229
+ conv_feature_sizes=conv_feature_sizes,
1230
+ linear_features_sizes=linear_features_sizes,
1231
+ predicted_var=predicted_var,
1232
+ )
1233
+ trainer = model_.get_trainer(checkpoints_path=cc.path_to_chk_oiv)
1234
+ model_.tune_trainer(trainer=trainer, tune_options=["find_lr"])
1235
+ return model_.learning_rate
1236
+
1237
+
1238
+ def get_lr(
1239
+ batch_size: int,
1240
+ train,
1241
+ val,
1242
+ test,
1243
+ augmentations_kinds,
1244
+ augmentations_params,
1245
+ backbone: str = "hf_swt_t",
1246
+ lr_times: int = 5,
1247
+ conv_feature_sizes=None,
1248
+ linear_features_sizes=[],
1249
+ predicted_var="oiv",
1250
+ ):
1251
+ lrs = [
1252
+ _inner_get_lr(
1253
+ backbone=backbone,
1254
+ train=train,
1255
+ val=val,
1256
+ test=test,
1257
+ augmentations_kinds=augmentations_kinds,
1258
+ augmentations_params=augmentations_params,
1259
+ batch_size=batch_size,
1260
+ conv_feature_sizes=conv_feature_sizes,
1261
+ linear_features_sizes=linear_features_sizes,
1262
+ predicted_var=predicted_var,
1263
+ )
1264
+ for _ in range(lr_times)
1265
+ ]
1266
+ return sum(lrs) / len(lrs)
src/repo_manager.ipynb ADDED
File without changes