-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplotting.py
More file actions
101 lines (79 loc) · 3.07 KB
/
plotting.py
File metadata and controls
101 lines (79 loc) · 3.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import matplotlib.pyplot as plt
from utils import get_stylized_rgb
def plot_and_save(stylized_images, content_images, style_images):
fig, axs = plt.subplots(len(stylized_images), 3, figsize=(11, 10))
# Naming columns
if len(stylized_images) > 1:
axs[0, 0].set_title("Original Image")
axs[0, 1].set_title("Style Image")
axs[0, 2].set_title("Transferred Image")
else:
axs[0].set_title("Original Image")
axs[1].set_title("Style Image")
axs[2].set_title("Transferred Image")
for i in range(len(stylized_images)):
if len(stylized_images) > 1:
axs[i, 0].imshow(content_images[i])
axs[i, 1].imshow(style_images[i])
axs[i, 2].imshow(get_stylized_rgb(stylized_images[i]))
# Turning off axis coords
axs[i, 0].axis("off")
axs[i, 1].axis("off")
axs[i, 2].axis("off")
else:
axs[0].imshow(content_images[i])
axs[1].imshow(style_images[i])
axs[2].imshow(get_stylized_rgb(stylized_images[i]))
# Turning off axis coords
axs[0].axis("off")
axs[1].axis("off")
axs[2].axis("off")
plt.axis('off')
plt.tight_layout()
plt.savefig("output.png")
plt.show()
def save_stylized_outputs(stylized_images, content_images, style_images, output_dir="images/output"):
os.makedirs(output_dir, exist_ok=True)
for i in range(len(stylized_images)):
fig, axs = plt.subplots(1, 3, figsize=(11, 4))
axs[0].imshow(content_images[i])
axs[1].imshow(style_images[i])
axs[2].imshow(get_stylized_rgb(stylized_images[i]))
axs[0].set_title("Original Image")
axs[1].set_title("Style Image")
axs[2].set_title("Transferred Image")
for ax in axs:
ax.axis("off")
plt.tight_layout()
save_path = os.path.join(output_dir, f"{i+1}.png")
plt.savefig(save_path)
plt.close(fig)
def plot_for_streamlit(stylized_images, content_images, style_images):
fig, axs = plt.subplots(len(stylized_images), 3, figsize=(11, 15))
# Naming columns
if len(stylized_images) > 1:
axs[0, 0].set_title("Original Image")
axs[0, 1].set_title("Style Image")
axs[0, 2].set_title("Transferred Image")
else:
axs[0].set_title("Original Image")
axs[1].set_title("Style Image")
axs[2].set_title("Transferred Image")
for i in range(len(stylized_images)):
if len(stylized_images) > 1:
axs[i, 0].imshow(content_images[i])
axs[i, 1].imshow(style_images[i])
axs[i, 2].imshow(get_stylized_rgb(stylized_images[i]))
axs[i, 0].axis("off")
axs[i, 1].axis("off")
axs[i, 2].axis("off")
else:
axs[0].imshow(content_images[i])
axs[1].imshow(style_images[i])
axs[2].imshow(get_stylized_rgb(stylized_images[i]))
axs[0].axis("off")
axs[1].axis("off")
axs[2].axis("off")
plt.tight_layout()
return fig