SakuraD commited on
Commit
3ad9c6b
1 Parent(s): 27b5582
Files changed (3) hide show
  1. app.py +8 -2
  2. test_dark.ipynb +33 -0
  3. test_exposure.ipynb +0 -0
app.py CHANGED
@@ -23,14 +23,17 @@ def dark_inference(img):
23
  state_dict = torch.load(checkpoint_file_path, map_location='cpu')
24
  model.load_state_dict(state_dict)
25
  model.eval()
 
26
 
27
  transform = Compose([
28
  ToTensor(),
29
  Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
30
  ConvertImageDtype(torch.float)
31
  ])
 
 
32
 
33
- enhanced_img = model(transform(img).unsqueeze(0))
34
  return enhanced_img[0].permute(1, 2, 0).detach().numpy()
35
 
36
 
@@ -40,13 +43,16 @@ def exposure_inference(img):
40
  state_dict = torch.load(checkpoint_file_path, map_location='cpu')
41
  model.load_state_dict(state_dict)
42
  model.eval()
 
43
 
44
  transform = Compose([
45
  ToTensor(),
46
  ConvertImageDtype(torch.float)
47
  ])
 
 
48
 
49
- enhanced_img = model(transform(img).unsqueeze(0))
50
  return enhanced_img[0].permute(1, 2, 0).detach().numpy()
51
 
52
 
 
23
  state_dict = torch.load(checkpoint_file_path, map_location='cpu')
24
  model.load_state_dict(state_dict)
25
  model.eval()
26
+ print(f'Load model from {checkpoint_file_path}')
27
 
28
  transform = Compose([
29
  ToTensor(),
30
  Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
31
  ConvertImageDtype(torch.float)
32
  ])
33
+ input_img = transform(img)
34
+ print(f'Image shape: {input_img.shape}')
35
 
36
+ enhanced_img = model(input_img.unsqueeze(0))
37
  return enhanced_img[0].permute(1, 2, 0).detach().numpy()
38
 
39
 
 
43
  state_dict = torch.load(checkpoint_file_path, map_location='cpu')
44
  model.load_state_dict(state_dict)
45
  model.eval()
46
+ print(f'Load model from {checkpoint_file_path}')
47
 
48
  transform = Compose([
49
  ToTensor(),
50
  ConvertImageDtype(torch.float)
51
  ])
52
+ input_img = transform(img)
53
+ print(f'Image shape: {input_img.shape}')
54
 
55
+ enhanced_img = model(input_img.unsqueeze(0))
56
  return enhanced_img[0].permute(1, 2, 0).detach().numpy()
57
 
58
 
test_dark.ipynb CHANGED
@@ -246,6 +246,39 @@
246
  "plt.imshow(enhanced_img[0].permute(1, 2, 0).detach().numpy())"
247
  ]
248
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  {
250
  "cell_type": "code",
251
  "execution_count": null,
 
246
  "plt.imshow(enhanced_img[0].permute(1, 2, 0).detach().numpy())"
247
  ]
248
  },
249
+ {
250
+ "cell_type": "code",
251
+ "execution_count": 9,
252
+ "metadata": {},
253
+ "outputs": [],
254
+ "source": [
255
+ "def dark_inference(img_path):\n",
256
+ " model = IAT()\n",
257
+ " checkpoint_file_path = './checkpoint/best_Epoch_lol.pth'\n",
258
+ " state_dict = torch.load(checkpoint_file_path, map_location='cpu')\n",
259
+ " model.load_state_dict(state_dict)\n",
260
+ " model.eval()\n",
261
+ "\n",
262
+ " img = np.array(Image.open(img_path))\n",
263
+ " transform = Compose([\n",
264
+ " ToTensor(), \n",
265
+ " Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), \n",
266
+ " ConvertImageDtype(torch.float) \n",
267
+ " ])\n",
268
+ "\n",
269
+ " enhanced_img = model(transform(img).unsqueeze(0))\n",
270
+ " return enhanced_img[0].permute(1, 2, 0).detach().numpy()"
271
+ ]
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "execution_count": 10,
276
+ "metadata": {},
277
+ "outputs": [],
278
+ "source": [
279
+ "out = dark_inference('./dark_imgs/1.jpg')"
280
+ ]
281
+ },
282
  {
283
  "cell_type": "code",
284
  "execution_count": null,
test_exposure.ipynb CHANGED
The diff for this file is too large to render. See raw diff