Skip to content

Commit

Permalink
Starting to implement curve fitting (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebranlard committed Jun 18, 2020
1 parent 767931b commit 891c8c8
Show file tree
Hide file tree
Showing 6 changed files with 1,122 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ else
endif

testfile=weio/_tests/FASTIn_arf_coords.txt
testfile= ws_01.outb
testfile= TestFit.csv
all:
ifeq ($(detected_OS),Darwin) # Mac OS X
./pythonmac pyDatView.py $(testfile)
Expand Down
4 changes: 3 additions & 1 deletion pydatview/GUIPlotPanel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from .common import *
from .GUICommon import *
from .GUIToolBox import MyMultiCursor, MyNavigationToolbar2Wx
from .GUITools import LogDecToolPanel, MaskToolPanel, RadialToolPanel
from .GUITools import LogDecToolPanel, MaskToolPanel, RadialToolPanel, CurveFitToolPanel
# from spectral import fft_wrap

font = {'size' : 8}
Expand Down Expand Up @@ -478,6 +478,8 @@ def showTool(self,toolName=''):
self.toolPanel=MaskToolPanel(self)
elif toolName=='FASTRadialAverage':
self.toolPanel=RadialToolPanel(self)
elif toolName=='CurveFitting':
self.toolPanel=CurveFitToolPanel(self)
else:
raise Exception('Unknown tool {}'.format(toolName))
self.toolSizer.Add(self.toolPanel, 0, wx.EXPAND|wx.ALL, 5)
Expand Down
289 changes: 288 additions & 1 deletion pydatview/GUITools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import wx
import numpy as np
import pandas as pd

# For log dec tool
from .damping import logDecFromDecay
from .common import CHAR, Error
from .common import CHAR, Error, pretty_num_short
from collections import OrderedDict
from .curve_fitting import model_fit, extract_key_tuples, extract_key_num


# --------------------------------------------------------------------------------}
Expand Down Expand Up @@ -260,3 +263,287 @@ def updateTabList(self,event=None):
self.cbTabs.Clear()
[self.cbTabs.Append(tn) for tn in tabListNames]
self.cbTabs.SetSelection(iSel)


# --------------------------------------------------------------------------------}
# --- Curve Fitting
# --------------------------------------------------------------------------------{
MODELS =[
{'label':'User defined model',
'name':'eval:',
'formula':'{a}*x**2 + {b}',
'coeffs':None,
'consts':None,
'bounds':None },
{'label':'Power law (u,alpha)',
'name':'predef: powerlaw_u_alpha',
'formula':'{u_ref} * (z / {z_ref}) ** {alpha}',
'coeffs': 'u_ref=10, alpha=0.1',
'consts': 'z_ref=100',
'bounds': 'u_ref=(0,inf), alpha=(-1,1)'},
{'label':'Gaussian',
'name':'eval:',
'formula':'1/({sigma}*sqrt(2*pi)) * exp(-1/2 * ((x-{mu})/{sigma})**2)',
'coeffs' :'sigma=1, mu=0',
'consts' :None,
'bounds' :None},
{'label':'Gaussian with y-offset',
'name':'eval:',
'formula':'1/({sigma}*sqrt(2*pi)) * exp(-1/2 * ((x-{mu})/{sigma})**2) + {y0}',
'coeffs' :'sigma=1, mu=0, y0=0',
'consts' :None,
'bounds' :'sigma=(-inf,inf), mu=(-inf,inf), y0=(-inf,inf)'},
{'label':'Exponential decay',
'name':'eval:',
'formula':'{A}*exp(-{k}*x)+{B}',
'coeffs' :'k=1, A=1, B=0',
'consts' :None,
'bounds' :None},
]

class CurveFitToolPanel(GUIToolPanel):
def __init__(self, parent):
super(CurveFitToolPanel,self).__init__(parent)

# Data
self.x = None
self.y_fit = None
#
# GUI Objecst
btClose = self.getBtBitmap(self, 'Close','close', self.destroy)
btClear = self.getBtBitmap(self, 'Clear','sun', self.onClear) # DELETE
btAdd = self.getBtBitmap(self, 'Add','add' , self.onAdd)
btCompFit = self.getBtBitmap(self, 'Fit','check', self.onCurveFit)
btHelp = self.getBtBitmap(self, 'Help','help', self.onHelp)

boldFont = self.GetFont().Bold()
lbOutputs = wx.StaticText(self, -1, 'Model outputs')
lbInputs = wx.StaticText(self, -1, 'Model inputs ')
lbOutputs.SetFont(boldFont)
lbInputs.SetFont(boldFont)

self.textFormula = wx.TextCtrl(self, wx.ID_ANY, '')
self.textGuess = wx.TextCtrl(self, wx.ID_ANY, '')
self.textBounds = wx.TextCtrl(self, wx.ID_ANY, '')
self.textConstants = wx.TextCtrl(self, wx.ID_ANY, '')

self.textFormulaNum = wx.TextCtrl(self, wx.ID_ANY, '', style=wx.TE_READONLY)
self.textCoeffs = wx.TextCtrl(self, wx.ID_ANY, '', style=wx.TE_READONLY)
self.textInfo = wx.TextCtrl(self, wx.ID_ANY, '', style=wx.TE_READONLY)

Models=[d['label'] for d in MODELS]
self.cbModels = wx.ComboBox(self, choices=Models, style=wx.CB_READONLY)
self.cbModels.SetSelection(0)

btSizer = wx.FlexGridSizer(rows=3, cols=2, hgap=2, vgap=0)
btSizer.Add(btClose ,0,flag = wx.ALL|wx.EXPAND, border = 1)
btSizer.Add(btClear ,0,flag = wx.ALL|wx.EXPAND, border = 1)
btSizer.Add(btAdd ,0,flag = wx.ALL|wx.EXPAND, border = 1)
btSizer.Add(btCompFit ,0,flag = wx.ALL|wx.EXPAND, border = 1)
btSizer.Add(btHelp ,0,flag = wx.ALL|wx.EXPAND, border = 1)

# self.lb = wx.StaticText( self, -1, """Select tables, averaging method and average parameter (`Period` methods uses the `azimuth` signal) """)
# self.cbTabs = wx.ComboBox(self, choices=tabListNames, style=wx.CB_READONLY)
# self.cbMethod = wx.ComboBox(self, choices=sAVG_METHODS, style=wx.CB_READONLY)
# self.cbMethod.SetSelection(0)
# self.textAverageParam = wx.TextCtrl(self, wx.ID_ANY, '2',size = (36,-1), style=wx.TE_PROCESS_ENTER)

# vertSizerCB = wx.BoxSizer(wx.VERTICAL)
# vertSizerCB.Add(wx.StaticText(self, -1, 'Model:') ,0, flag = wx.LEFT|wx.EXPAND,border = 1)
# vertSizerCB.Add(self.cbModels ,0, flag = wx.LEFT|wx.EXPAND,border = 1)

inputSizer = wx.FlexGridSizer(rows=5, cols=2, hgap=0, vgap=0)
inputSizer.Add(lbInputs ,0, flag=wx.ALIGN_LEFT|wx.ALIGN_CENTER_VERTICAL|wx.TOP|wx.BOTTOM,border = 1)
inputSizer.Add(self.cbModels ,1, flag=wx.ALIGN_LEFT|wx.ALIGN_CENTER_VERTICAL|wx.TOP|wx.BOTTOM|wx.EXPAND,border = 1)
inputSizer.Add(wx.StaticText(self, -1, 'Formula:') ,0, flag=wx.ALIGN_LEFT|wx.ALIGN_CENTER_VERTICAL|wx.TOP|wx.BOTTOM,border = 1)
inputSizer.Add(self.textFormula ,1, flag=wx.ALIGN_LEFT|wx.ALIGN_CENTER_VERTICAL|wx.TOP|wx.BOTTOM|wx.EXPAND,border = 1)
inputSizer.Add(wx.StaticText(self, -1, 'Guess:') ,0, flag=wx.ALIGN_LEFT|wx.ALIGN_CENTER_VERTICAL|wx.TOP|wx.BOTTOM,border = 1)
inputSizer.Add(self.textGuess ,1, flag=wx.ALIGN_LEFT|wx.ALIGN_CENTER_VERTICAL|wx.TOP|wx.BOTTOM|wx.EXPAND,border = 1)
inputSizer.Add(wx.StaticText(self, -1, 'Bounds:') ,0, flag=wx.ALIGN_LEFT|wx.ALIGN_CENTER_VERTICAL|wx.TOP|wx.BOTTOM,border = 1)
inputSizer.Add(self.textBounds ,1, flag=wx.ALIGN_LEFT|wx.ALIGN_CENTER_VERTICAL|wx.TOP|wx.BOTTOM|wx.EXPAND,border = 1)
inputSizer.Add(wx.StaticText(self, -1, 'Constants:'),0, flag=wx.ALIGN_LEFT|wx.ALIGN_CENTER_VERTICAL|wx.TOP|wx.BOTTOM,border = 1)
inputSizer.Add(self.textConstants ,1, flag=wx.ALIGN_LEFT|wx.ALIGN_CENTER_VERTICAL|wx.TOP|wx.BOTTOM|wx.EXPAND,border = 1)
inputSizer.AddGrowableCol(1,1)

outputSizer = wx.FlexGridSizer(rows=5, cols=2, hgap=0, vgap=0)
outputSizer.Add(lbOutputs ,0 , flag=wx.ALIGN_LEFT|wx.ALIGN_CENTER_VERTICAL|wx.TOP|wx.BOTTOM,border = 1)
outputSizer.Add(wx.StaticText(self, -1, '') ,0 , flag=wx.ALIGN_LEFT|wx.ALIGN_CENTER_VERTICAL|wx.TOP|wx.BOTTOM,border = 1)
outputSizer.Add(wx.StaticText(self, -1, 'Formula:'),0 , flag=wx.ALIGN_LEFT|wx.ALIGN_CENTER_VERTICAL|wx.TOP|wx.BOTTOM,border = 1)
outputSizer.Add(self.textFormulaNum ,1 , flag=wx.ALIGN_LEFT|wx.ALIGN_CENTER_VERTICAL|wx.TOP|wx.BOTTOM|wx.EXPAND,border = 1)
outputSizer.Add(wx.StaticText(self, -1, 'Parameters:') ,0 , flag=wx.ALIGN_LEFT|wx.ALIGN_CENTER_VERTICAL|wx.TOP|wx.BOTTOM,border = 1)
outputSizer.Add(self.textCoeffs ,1 , flag=wx.ALIGN_LEFT|wx.ALIGN_CENTER_VERTICAL|wx.TOP|wx.BOTTOM|wx.EXPAND,border = 1)
outputSizer.Add(wx.StaticText(self, -1, 'Accuracy:') ,0 , flag=wx.ALIGN_LEFT|wx.ALIGN_CENTER_VERTICAL|wx.TOP|wx.BOTTOM,border = 1)
outputSizer.Add(self.textInfo ,1 , flag=wx.ALIGN_LEFT|wx.ALIGN_CENTER_VERTICAL|wx.TOP|wx.BOTTOM|wx.EXPAND,border = 1)
outputSizer.AddGrowableCol(1,0.5)


#print(font.GetFamily(),font.GetStyle(),font.GetPointSize())
#font.SetFamily(wx.FONTFAMILY_DEFAULT)
#font.SetFamily(wx.FONTFAMILY_MODERN)
#font.SetFamily(wx.FONTFAMILY_SWISS)
#font.SetPointSize(8)
#print(font.GetFamily(),font.GetStyle(),font.GetPointSize())
#self.SetFont(font)

horzSizer = wx.BoxSizer(wx.HORIZONTAL)
horzSizer.Add(inputSizer ,1.0, flag = wx.LEFT|wx.EXPAND,border = 2)
horzSizer.Add(outputSizer ,1.0, flag = wx.LEFT|wx.EXPAND,border = 9)

vertSizer = wx.BoxSizer(wx.VERTICAL)
# vertSizer.Add(self.lbHelp ,0, flag = wx.LEFT ,border = 1)
vertSizer.Add(horzSizer ,1, flag = wx.LEFT|wx.EXPAND,border = 1)

self.sizer = wx.BoxSizer(wx.HORIZONTAL)
self.sizer.Add(btSizer ,0, flag = wx.LEFT ,border = 1)
# self.sizer.Add(vertSizerCB ,0, flag = wx.LEFT ,border = 1)
self.sizer.Add(vertSizer ,1, flag = wx.EXPAND|wx.LEFT ,border = 1)
self.SetSizer(self.sizer)

self.Bind(wx.EVT_COMBOBOX, self.onModelChange, self.cbModels)

self.onModelChange()

def onModelChange(self,event=None):
iModel = self.cbModels.GetSelection()
d = MODELS[iModel]

# Formula
if d['name'].find('eval:')==0 :
self.textFormula.Enable(True)
else:
self.textFormula.Enable(False)
self.textFormula.SetValue(d['formula'])

# Guess
if d['coeffs'] is None:
self.textGuess.SetValue('')
else:
self.textGuess.SetValue(d['coeffs'])

# Constants
if d['consts'] is None or len(d['consts'].strip())==0:
self.textConstants.Enable(False)
self.textConstants.SetValue('')
else:
self.textConstants.Enable(True)
self.textConstants.SetValue(d['consts'])

# Bounds
self.textBounds.Enable(True)
if d['bounds'] is None or len(d['bounds'].strip())==0:
self.textBounds.SetValue('all=(-np.inf, np.inf)')
else:
self.textBounds.SetValue(d['bounds'])

# Outputs
self.textFormulaNum.SetValue('(Click on Fit)')
self.textCoeffs.SetValue('')
self.textInfo.SetValue('')

def onCurveFit(self,event=None):
self.x = None
self.y_fit = None
if len(self.parent.plotData)!=1:
Error(self,'Curve fitting tool only works with a single curve. Plot less data.')
return
PD =self.parent.plotData[0]

iModel = self.cbModels.GetSelection()
d = MODELS[iModel]

# Formula
sFunc=d['name']
if sFunc=='eval:':
sFunc+=self.textFormula.GetLineText(0)


# Bounds
bounds=self.textBounds.GetLineText(0).replace('np.inf','inf')

# dBounds=extract_key_tuples(self.textBounds.GetLineText(0).replace('np.inf','inf'))
# if len(dBounds)>0:
# if 'all' in dBounds.keys():
# bounds=dBounds['all']
# else:
# bounds=dBounds
# else:
# bounds=None
#
# Guess
p0=self.textGuess.GetLineText(0).replace('np.inf','inf')
# dGuess=extract_key_num(self.textGuess.GetLineText(0).replace('np.inf','inf'))
# if len(dGuess)>0:
# p0=dGuess
# else:
# p0=None
# Constants
fun_kwargs=extract_key_num(self.textConstants.GetLineText(0).replace('np.inf','inf'))
print('>>> Model fit sFunc :',sFunc )
print('>>> Model fit p0 :',p0 )
print('>>> Model fit bounds:',bounds )
print('>>> Model fit kwargs:',fun_kwargs)
y_fit, pfit, fitter = model_fit(sFunc, PD.x, PD.y, p0=p0, bounds=bounds,**fun_kwargs)

formatter = lambda x: pretty_num_short(x, digits=3)
formula_num = fitter.formula_num(fmt=formatter)
# Update info
self.textFormulaNum.SetValue(formula_num)
self.textCoeffs.SetValue(', '.join(['{}={:s}'.format(k,formatter(v)) for k,v in fitter.model['coeffs'].items()]))
self.textInfo.SetValue('R2 = {:.3f} '.format(fitter.model['R2']))

# Saving
d['formula'] = self.textFormula.GetLineText(0)
d['bounds'] = self.textBounds.GetLineText(0)
d['coeffs'] = self.textGuess.GetLineText(0)
d['consts'] = self.textConstants.GetLineText(0)


# Plot
ax=self.parent.fig.axes[0]
ax.plot(PD.x,y_fit,'o')
self.parent.canvas.draw()

self.x=PD.x
self.y_fit=y_fit
self.sx=PD.sx
self.sy=PD.sy

def onClear(self,event=None):
self.parent.redraw() # DATA HAS CHANGED
self.onModelChange()

def onAdd(self,event=None):
name='model_fit'
if self.x is not None and self.y_fit is not None:
df=pd.DataFrame({self.sx:self.x, self.sy:self.y_fit})
print('Adding>>',df)
self.parent.mainframe.load_df(df,name,bAdd=True)

def onHelp(self,event=None):
print('>>> Help')
# try:
# avgParam = float(self.textAverageParam.GetLineText(0))
# except:
# raise Exception('Error: the averaging parameter needs to be an integer or a float')
# iSel = self.cbTabs.GetSelection()
# avgMethod = AVG_METHODS[self.cbMethod.GetSelection()]
# tabList = self.parent.selPanel.tabList
# mainframe = self.parent.mainframe
# if iSel==0:
# dfs, names, errors = tabList.radialAvg(avgMethod,avgParam)
# mainframe.load_dfs(dfs,names,bAdd=True)
# if len(errors)>0:
# raise Exception('Error: The mask failed on some tables:\n\n'+'\n'.join(errors))
# else:
# dfs, names = tabList.get(iSel-1).radialAvg(avgMethod,avgParam)
# mainframe.load_dfs(dfs,names,bAdd=True)
#
# self.updateTabList()
#
# def updateTabList(self,event=None):
# tabList = self.parent.selPanel.tabList
# tabListNames = ['All opened tables']+tabList.getDisplayTabNames()
# iSel=np.min([self.cbTabs.GetSelection(),len(tabListNames)])
# self.cbTabs.Clear()
# [self.cbTabs.Append(tn) for tn in tabListNames]
# self.cbTabs.SetSelection(iSel)
39 changes: 38 additions & 1 deletion pydatview/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import platform
import datetime
import re

CHAR={
'menu' : u'\u2630',
Expand All @@ -15,7 +16,9 @@
'clear' : u'-',
'sun' : u'\u2600',
'suncloud' : u'\u26C5',
'cloud' : u'\u2601'
'cloud' : u'\u2601',
'check' : u'\u2714',
'help' : u'\u2753'
}
# --------------------------------------------------------------------------------}
# --- ellude
Expand Down Expand Up @@ -115,6 +118,24 @@ def ellude_common(strings,minLength=2):
return strings


# --------------------------------------------------------------------------------}
# --- Key value
# --------------------------------------------------------------------------------{
def extract_key_tuples(text):
"""
all=(0.1,-2),b=(inf,0), c=(-inf,0.3e+10)
"""
regex = re.compile(r'(?P<key>[\w\-]+)=\((?P<value1>[0-9+epinf.-]*?),(?P<value2>[0-9+epinf.-]*?)\)($|,)')
return {match.group("key"): (np.float(match.group("value1")),np.float(match.group("value2"))) for match in regex.finditer(text.replace(' ',''))}


def extract_key_num(text):
"""
all=0.1, b=inf, c=-0.3e+10
"""
regex = re.compile(r'(?P<key>[\w\-]+)=(?P<value>[0-9+epinf.-]*?)($|,)')
return {match.group("key"): np.float(match.group("value")) for match in regex.finditer(text.replace(' ',''))}

# --------------------------------------------------------------------------------}
# ---
# --------------------------------------------------------------------------------{
Expand Down Expand Up @@ -313,6 +334,22 @@ def pretty_num(x):
else:
return '{:.3e}'.format(x)

def pretty_num_short(x,digits=3):
if digits==4:
if abs(x)<1000 and abs(x)>1e-1:
return "{:.4f}".format(x)
else:
return "{:.4e}".format(x)
elif digits==3:
if abs(x)<1000 and abs(x)>1e-1:
return "{:.3f}".format(x)
else:
return "{:.3e}".format(x)
elif digits==2:
if abs(x)<1000 and abs(x)>1e-1:
return "{:.2f}".format(x)
else:
return "{:.2e}".format(x)

# --------------------------------------------------------------------------------}
# --- Chinese characters
Expand Down
Loading

0 comments on commit 891c8c8

Please sign in to comment.