-
Notifications
You must be signed in to change notification settings - Fork 0
/
scatterHist.py
67 lines (52 loc) · 2.24 KB
/
scatterHist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
'''
Created on Sep 15, 2016
@author: urishaham
'''
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import NullFormatter
def scatterHist(x1,x2, y1,y2, axis1='', axis2='', title='', name1='', name2='',
plots_dir=''):
nullfmt = NullFormatter() # no labels
# definitions for the axes
left, width = 0.1, 0.65
bottom, height = 0.1, 0.65
bottom_h = left_h = left + width + 0.02
rect_scatter = [left, bottom, width, height]
rect_histx = [left, bottom_h, width, 0.2]
rect_histy = [left_h, bottom, 0.2, height]
# start with a rectangular Figure
fig = plt.figure(figsize=(8, 8))
axScatter = plt.axes(rect_scatter)
axHistx = plt.axes(rect_histx)
axHisty = plt.axes(rect_histy)
# no labels
axHistx.xaxis.set_major_formatter(nullfmt)
axHisty.yaxis.set_major_formatter(nullfmt)
# the scatter plot:
axScatter.scatter(x1, x2, color = 'blue', s=3)
axScatter.scatter(y1, y2, color = 'red', s=3)
# now determine nice limits by hand:
binwidth = 0.5
xymax = np.max([np.max(np.fabs(x1)), np.max(np.fabs(x2))])
lim = (int(xymax/binwidth) + 1) * binwidth
axScatter.set_xlim((-lim, lim))
axScatter.set_ylim((-lim, lim))
bins = np.arange(-lim, lim + binwidth, binwidth)
axHistx.hist(x1, bins=bins, color = 'blue', density=True, stacked = True, histtype='step' )
axHisty.hist(x2, bins=bins, orientation='horizontal', color = 'blue', normed=True, stacked = True, histtype='step')
axHistx.hist(y1, bins=bins, color = 'red', normed=True, stacked = True, histtype='step')
axHisty.hist(y2, bins=bins, orientation='horizontal', color = 'red', normed=True, stacked = True, histtype='step')
axHistx.set_xlim(axScatter.get_xlim())
axHisty.set_ylim(axScatter.get_ylim())
axHistx.set_xticklabels([])
axHistx.set_yticklabels([])
axHisty.set_xticklabels([])
axHisty.set_yticklabels([])
axScatter.set_xlabel(axis1, fontsize=18)
axScatter.set_ylabel(axis2, fontsize=18)
axHistx.set_title(title, fontsize=18)
axScatter.legend([name1, name2], fontsize=18)
plt.show(block=False)
if not plots_dir=='':
fig.savefig(plots_dir+'/'+title+'.eps' ,format='eps')