Skip to content

Commit 93d1c40

Browse files
authored
Add files via upload
0 parents  commit 93d1c40

File tree

5 files changed

+408
-0
lines changed

5 files changed

+408
-0
lines changed

linCC.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import numpy as np
2+
import colour
3+
import cv2
4+
import matplotlib.pyplot as plt
5+
import linConversions as conv
6+
import os
7+
from PIL import Image
8+
9+
def sRGB_to_lin(rgb):
10+
# Converts array or list of sRGB values to linear RGB values
11+
rgb = np.array(rgb)/255
12+
linear = np.where(rgb<=0.04045, rgb/12.92, ((rgb+0.055)/1.055)**(2.4))
13+
return((linear*255).astype(np.uint8))
14+
15+
def ColourCorrect(image_lin, corrected_img_lin, corrected_img_sRGB, source_lin,
16+
reference_RGB, terms):
17+
18+
# Load in reference data and convert to linear RGB
19+
reference_RGB = np.loadtxt(reference_RGB, delimiter=",",
20+
skiprows=1, usecols=(0, 1, 2))
21+
reference_lin = conv.sRGB_to_lin(reference_RGB)
22+
23+
# Calculating the colour correction matrix and colour correcting the image
24+
# using the reference and source colour values
25+
CCM, colour_corrected=colour.characterisation.colour_correction_Cheung2004(
26+
image_lin, source_lin, reference_lin, terms)
27+
colour_corrected = colour_corrected.astype(int).clip(0,255)
28+
29+
# Converting the linear images to sRGB
30+
image_sRGB = conv.lin_to_sRGB(image_lin)
31+
corrected_sRGB = conv.lin_to_sRGB(colour_corrected)
32+
33+
# Plotting all images
34+
plt.imshow(image_lin)
35+
plt.title("Source linear")
36+
plt.show()
37+
38+
plt.imshow(colour_corrected)
39+
plt.title("Corrected linear")
40+
plt.show()
41+
42+
plt.imshow(image_sRGB)
43+
plt.title("Source sRGB")
44+
plt.show()
45+
46+
plt.imshow(corrected_sRGB)
47+
plt.title("Corrected sRGB")
48+
plt.show()
49+
50+
# Saving corrected images (cv2 works in BGR, so R and B are swapped)
51+
corrected_img_BGR_lin = colour_corrected[:,:,::-1]
52+
corrected_img_BGR_sRGB = corrected_sRGB[:,:,::-1]
53+
cv2.imwrite(corrected_img_lin, corrected_img_BGR_lin)
54+
cv2.imwrite(corrected_img_sRGB, corrected_img_BGR_sRGB)
55+
return(colour_corrected, reference_lin, CCM)
56+
57+
def BatchCorrect(map_path, checker_source_lin, checker_reference_sRGB, terms,
58+
corrected_dir):
59+
os.makedirs(corrected_dir)
60+
for image in sorted(os.listdir(map_path)):
61+
if image.endswith(".jpg") or image.endswith(".png"):
62+
image_name = image.replace(".jpg","")
63+
image_lin = sRGB_to_lin(Image.open(map_path+"/"+image))
64+
ColourCorrect(image_lin, corrected_dir+"/"+image_name+"_cor_lin.jpg",
65+
corrected_dir+"/"+image_name+"_cor_sRGB.jpg",
66+
checker_source_lin, checker_reference_sRGB, terms)

linConversions.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import numpy as np
2+
3+
def sRGB_to_lin(rgb):
4+
# Converts array or list of sRGB values to linear RGB values
5+
rgb = np.array(rgb)/255
6+
linear = np.where(rgb<=0.04045, rgb/12.92, ((rgb+0.055)/1.055)**(2.4))
7+
return((linear*255).astype(int))
8+
9+
def lin_to_XYZ(lin):
10+
# Converts array of linear RGB values to XYZ values
11+
lin = np.array(lin)/255
12+
M = np.array([[0.4124, 0.3576, 0.1805],
13+
[0.2126, 0.7152, 0.0722],
14+
[0.0193, 0.1192, 0.9505]])
15+
return(np.transpose(np.matmul(M, np.transpose(lin))))
16+
17+
def lin_to_sRGB(lin):
18+
# Converts array or list of linear RGB values to sRGB values
19+
lin = np.array(lin)/255
20+
rgb = np.where(lin <= 0.0031308, 12.92*lin, (1.055*lin)**(1/2.4)-0.055)
21+
return((rgb*255).astype(int))
22+
23+
def sRGB_to_XYZ(rgb):
24+
# Converts array of sRGB values to XYZ values
25+
rgb = rgb/255
26+
linear = np.where(rgb<=0.04045, rgb/12.92, ((rgb+0.055)/1.055)**2.4)
27+
M = np.array([[0.4124, 0.3576, 0.1805],
28+
[0.2126, 0.7152, 0.0722],
29+
[0.0193, 0.1192, 0.9505]])
30+
return(np.transpose(np.matmul(M, np.transpose(linear))))

linDiffPlots.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import numpy as np
2+
import csv
3+
import matplotlib.pyplot as plt
4+
import colour.models as cm
5+
import skimage.color as sc
6+
import linConversions as conv
7+
8+
SMALL_SIZE = 10*2
9+
MEDIUM_SIZE = 12*2
10+
BIGGER_SIZE = 14*2
11+
12+
plt.rcParams.update({"text.usetex": True,"font.family": "serif",
13+
"font.serif": ["Palatino"]}) # controls default text sizes
14+
plt.rc('axes', titlesize=MEDIUM_SIZE) # fontsize of the axes title
15+
plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
16+
plt.rc('xtick', labelsize=SMALL_SIZE, direction='in') # fontsize of the tick labels
17+
plt.rc('ytick', labelsize=SMALL_SIZE, direction='in') # fontsize of the tick labels
18+
plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize
19+
plt.rc('figure', figsize='15, 6') # size of the figure, used to be '4, 3' in inches
20+
21+
def diffPlot(source_lin, reference_lin, corrected_lin):
22+
source_XYZ = conv.lin_to_XYZ(source_lin)
23+
reference_XYZ = conv.lin_to_XYZ(reference_lin)
24+
corrected_XYZ = conv.lin_to_XYZ(corrected_lin)
25+
26+
source_Lab = sc.xyz2lab(source_XYZ)
27+
reference_Lab = sc.xyz2lab(reference_XYZ)
28+
corrected_Lab = sc.xyz2lab(corrected_XYZ)
29+
30+
dE00_source = sc.deltaE_ciede2000(source_Lab, reference_Lab)
31+
dE00_corrected = sc.deltaE_ciede2000(corrected_Lab, reference_Lab)
32+
33+
diff_source_XYZ = np.sqrt(np.sum((source_XYZ - reference_XYZ)**2, axis=1))*100
34+
diff_corrected_XYZ = np.sqrt(np.sum((corrected_XYZ - reference_XYZ)**2, axis=1))*100
35+
36+
diff_source_lin = np.sqrt(np.sum((np.array(source_lin)/255 - np.array(reference_lin)/255)**2, axis=1))*100
37+
diff_corrected_lin = np.sqrt(np.sum((np.array(corrected_lin)/255 - np.array(reference_lin)/255)**2, axis=1))*100
38+
39+
x = np.arange(len(diff_source_XYZ))+1
40+
width = 0.4
41+
42+
plt.figure()
43+
source_bar_XYZ = plt.bar(x - width/2, diff_source_XYZ, width, label='Source')
44+
corrected_bar_XYZ = plt.bar(x + width/2, diff_corrected_XYZ, width, label='Corrected')
45+
plt.title("Colour distance in XYZ-space")
46+
plt.xlabel("Patch")
47+
plt.ylabel("Distance")
48+
plt.ylim([0,np.max(np.concatenate([diff_source_XYZ,diff_corrected_XYZ]))+5])
49+
plt.vlines(np.arange(5.5, len(diff_source_XYZ),6),0,140, color="red")
50+
plt.legend()
51+
plt.show()
52+
53+
plt.figure()
54+
source_bar_RGB = plt.bar(x - width/2, diff_source_lin, width, label='Source')
55+
corrected_bar_RGB = plt.bar(x + width/2, diff_corrected_lin, width, label='Corrected')
56+
plt.title("Colour distance in RGB-space")
57+
plt.xlabel("Patch")
58+
plt.ylabel("Distance")
59+
plt.ylim([0,np.max(np.concatenate([diff_source_lin,diff_corrected_lin]))+5])
60+
print()
61+
plt.vlines(np.arange(5.5, len(diff_source_lin),6),0,140, color="red")
62+
plt.legend()
63+
plt.show()
64+
65+
plt.figure()
66+
source_bar_dE = plt.bar(x - width/2, dE00_source, width, label='Source')
67+
corrected_bar_dE = plt.bar(x + width/2, dE00_corrected, width, label='Corrected')
68+
plt.title("Colour difference in dE00")
69+
plt.xlabel("Patch")
70+
plt.ylabel("dE00")
71+
plt.xticks(np.arange(1,30.5,1))
72+
plt.vlines(np.arange(0.5, len(dE00_corrected)+0.6,6),0,140, color="red")
73+
plt.ylim([0,np.max(np.concatenate([dE00_source,dE00_corrected]))+5])
74+
plt.xlim(0.5,30.5)
75+
plt.legend()
76+
plt.savefig(r"figures\dE00.pdf")
77+
plt.show()
78+
79+
source_err_XYZ = np.mean(diff_source_XYZ)
80+
corrected_err_XYZ =np.mean(diff_corrected_XYZ)
81+
82+
source_err_lin = np.mean(diff_source_lin)
83+
corrected_err_lin =np.mean(diff_corrected_lin)
84+
85+
dE00_source_avg = np.mean(dE00_source)
86+
dE00_corrected_avg = np.mean(dE00_corrected)
87+
return(source_err_XYZ,corrected_err_XYZ, source_err_lin,corrected_err_lin,dE00_source_avg,dE00_corrected_avg)
88+
89+
def diffChart(reference_lin, corrected_lin, source_lin):
90+
ref_sRGB = conv.lin_to_sRGB(reference_lin)
91+
cor_sRGB = conv.lin_to_sRGB(corrected_lin)
92+
sou_sRGB = conv.lin_to_sRGB(source_lin)
93+
94+
colours2 = np.array([ref_sRGB,cor_sRGB])
95+
colours1 = np.array([ref_sRGB,sou_sRGB])
96+
97+
fig, axs = plt.subplots(nrows=2, sharex=True, figsize=(15,4))
98+
axs[0].set_title('Colour comparisons')
99+
axs[0].imshow(colours1)
100+
axs[0].set_yticks([0,1])
101+
axs[0].set_yticklabels(["reference","source"])
102+
103+
#axs[1].set_title('blue should be down')
104+
axs[1].imshow(colours2)
105+
axs[1].set_yticks([0,1])
106+
axs[1].set_yticklabels(["reference","corrected"])
107+
108+
plt.xticks(np.arange(0,29.5,1),np.arange(1,30.5,1,dtype=int))
109+
plt.xlabel("Patch number")
110+
plt.savefig(r"figures\difference_visualised.pdf")
111+
plt.show()
112+

linMain.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import linMask as m
2+
import linCC as CC
3+
import linDiffPlots as dP
4+
5+
6+
if __name__ == "__main__":
7+
# Put in file names:
8+
# Image with colour checker
9+
img_path = r"img\src1.jpg"
10+
# csv files in which data from image for squares of checker will be stored
11+
source_XYZ_csv = r"data\src1XYZ.csv"
12+
source_lin_csv = r"data\src1RGB.csv"
13+
# csv files in which data from corrected image for squares of checker will
14+
# be stored
15+
corrected_XYZ_csv = r"data\src1CorXYZ.csv"
16+
corrected_lin_csv = r"data\src1CorRGB.csv"
17+
# csv with reference sRGB data of colours on checker
18+
reference_csv_RGB = r"data\reference_RGB.csv"
19+
# Image files in which the corrected images are saved
20+
corrected_img_lin = r"img\src1Corlin.jpg"
21+
corrected_img_sRGB = r"img\src1CorsRGB.jpg"
22+
# map of images to be corrected
23+
map_path = r"img\tobecorrected"
24+
corrected_directory = r"img\tobecorrected\corrected"
25+
26+
27+
# Fill in the shape of the colour checker
28+
n_vertical_swatches = 5
29+
n_horizontal_swatches = 6
30+
total = n_vertical_swatches * n_horizontal_swatches
31+
32+
# Use the horizontal line in the figure to line up the checker horizontally
33+
rotation = -1
34+
35+
# It is optional to crop the image if positioning the mask is too difficult,
36+
# because the colour checker is small in the image
37+
left = 900
38+
top = 1600
39+
right = 800
40+
bottom = 1200
41+
42+
# Choose the width of the mask squares to fit inside the squares of the checker
43+
swatch_width = 130
44+
45+
# Choose the position of the top left mask, using the pixel count on the axes
46+
vertical_start = 140
47+
horizontal_start = 20
48+
49+
# choose the vertical and horizontal steps so that each black mask is inside one
50+
# of the squares on the colour checker
51+
vertical_steps = 210
52+
horizontal_steps = 210
53+
54+
write_csv = 0
55+
correct = 1
56+
compare = 1
57+
batch = 0
58+
59+
# Choose number of correction terms out of:
60+
# [3, 5, 7, 8, 10, 11, 14, 16, 17, 19, 20, 22]
61+
correction_terms = 20
62+
63+
# ============================================================================#
64+
lin = 1
65+
source_lin, source_XYZ, image_lin = m.createMasks(n_vertical_swatches,
66+
n_horizontal_swatches, rotation, left, top, right, bottom,
67+
swatch_width, vertical_start, horizontal_start, vertical_steps,
68+
horizontal_steps, img_path, lin)
69+
70+
if write_csv == 1:
71+
m.writeCSVs(source_lin, source_XYZ, source_XYZ_csv, source_lin_csv)
72+
73+
if correct == 1:
74+
corrected, reference_lin, CCM = CC.ColourCorrect(image_lin, corrected_img_lin,
75+
corrected_img_sRGB, source_lin, reference_csv_RGB,
76+
correction_terms)
77+
78+
if compare == 1:
79+
lin = 0
80+
corrected_lin, corrected_XYZ, cor_image_lin = m.createMasks(
81+
n_vertical_swatches, n_horizontal_swatches, rotation, left,
82+
top, right, bottom, swatch_width, vertical_start,
83+
horizontal_start, vertical_steps,horizontal_steps,
84+
corrected_img_lin, lin)
85+
if write_csv == 1:
86+
m.writeCSVs(corrected_lin, corrected_XYZ, corrected_XYZ_csv,
87+
corrected_lin_csv)
88+
source_err_XYZ,corrected_err_XYZ, source_err_lin,\
89+
corrected_err_lin, dE00_source, dE00_corrected = dP.diffPlot(source_lin,
90+
reference_lin, corrected_lin)
91+
92+
dP.diffChart(reference_lin, corrected_lin, source_lin)
93+
94+
print(f"Average error in XYZ values for source image is {source_err_XYZ:.3f}%")
95+
print(f"Average error in XYZ values for corrected image is {corrected_err_XYZ:.3f}%")
96+
print(f"Average error in RGB values for source image is {source_err_lin:.3f}%")
97+
print(f"Average error in RGB values for corrected image is {corrected_err_lin:.3f}%")
98+
print(f"Average dE00 for source image is {dE00_source:.3f}")
99+
print(f"Average dE00 for corrected image is {dE00_corrected:.3f}")
100+
101+
if batch == 1:
102+
CC.BatchCorrect(map_path, source_lin, reference_csv_RGB, correction_terms,
103+
corrected_directory)
104+
105+
106+
107+
108+

0 commit comments

Comments
 (0)