Skip to content

Commit

Permalink
fix math tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ratnania committed Apr 14, 2019
1 parent 150eb70 commit dc7614c
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 22 deletions.
22 changes: 19 additions & 3 deletions lampy/lambdify.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .utilities import get_decorators
from .utilities import get_pyccel_imports_code
from .utilities import get_dependencies_code
from .utilities import math_atoms_as_str
from .printing import pycode
from .interface import LambdaInterface

Expand Down Expand Up @@ -136,8 +137,25 @@ def _lambdify(func, namespace={}, **kwargs):
return pycode(func)
# ...

# ...
imports = []
# ...

# ... get math functions and constants
math_elements = math_atoms_as_str(func)
math_imports = []
for e in math_elements:
math_imports += [Import(e, 'numpy')]

imports += math_imports

# convert to a string
imports = '\n'.join([pycode(i) for i in imports])
# ...

# ... print python code
code = get_pyccel_imports_code()
code += '\n' + imports + '\n'
code += get_dependencies_code(list(user_functions.values()))
code += '\n\n'
code += pycode(func)
Expand Down Expand Up @@ -168,8 +186,6 @@ def _lambdify(func, namespace={}, **kwargs):
# ...

# we return a module, that will processed by epyccel
if not typed_functions:
raise NotImplementedError('TODO')

# ... module case
from pyccel.epyccel import epyccel
Expand All @@ -184,7 +200,7 @@ def _lambdify(func, namespace={}, **kwargs):
# ####### DEBUG
# return f2py_func

if not typed_functions or not func.is_procedure:
if not func.is_procedure:
return f2py_func
# ...

Expand Down
8 changes: 3 additions & 5 deletions lampy/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .ast import BasicMap
from .ast import PartialFunction
from .ast import LampyLambda
from .ast import FunctionSymbol


#=========================================================================
Expand Down Expand Up @@ -136,13 +137,10 @@ def __init__(self, expr, **kwargs):
# ...

# ... get all functions
calls = list(expr.atoms(AppliedUndef))
map_funcs = [i.args[0] for i in calls if i.__class__.__name__ in _internal_map_functors]
callables = [i.func for i in calls if not i.__class__.__name__ in _internal_functors]
functions = list(set(map_funcs + callables))
functions = list(expr.atoms(FunctionSymbol))

for f in functions:
if str(f) in _elemental_math_functions:
if f.name in _elemental_math_functions:
type_domain = self.default_type
type_codomain = self.default_type

Expand Down
14 changes: 10 additions & 4 deletions lampy/tests/inprogress_math.py → lampy/tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@

#=========================================================
def test_annotate_map_list(**settings):
L = lambda xs: map(sin, xs)
sin = np.sin
l = lambda xs: map(sin, xs)

L = _lambdify( L, **settings )
print(L)
L = _lambdify( l, **settings )

xs = np.linspace(0., np.pi, 100)
out = L(xs)
expected = list(l(xs))
assert(np.allclose( out, expected ))

print('DONE.')

Expand All @@ -34,7 +39,8 @@ def teardown_function():

##########################################
#if __name__ == '__main__':
# settings = {}
## settings = {'ast_only' : True}
# settings = {'printing_only' : True}
## settings = {'printing_only' : True}
#
# test_annotate_map_list(**settings)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from lampy import add, mul

#=========================================================
def test_map_list(**settings):
def test_map_sin_list(**settings):
settings['semantic_only'] = True

L = lambda xs: map(sin, xs)
Expand Down Expand Up @@ -42,8 +42,8 @@ def teardown_function():
from sympy import cache
cache.clear_cache()

##########################################
###########################################
#if __name__ == '__main__':
# settings = {'semantic_only' : True}
#
# test_map_list(**settings)
# test_map_sin_list(**settings)
34 changes: 27 additions & 7 deletions lampy/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
from sympy.core.function import UndefinedFunction
from sympy import sympify
from sympy import Dummy
from sympy import Function
from sympy import preorder_traversal
from sympy import NumberSymbol
from sympy.printing.pycode import _known_functions_math
from sympy.printing.pycode import _known_constants_math

from pyccel.codegen.utilities import get_source_function
from pyccel.ast.datatypes import dtype_and_precsision_registry as dtype_registry
Expand All @@ -24,7 +29,22 @@
from pyccel.ast.datatypes import get_default_value
from pyccel.parser import Parser

#==============================================================================
from .ast import FunctionSymbol

#==========================================================================
def math_atoms_as_str(expr):
math_functions = [str(type(i)) for i in preorder_traversal(expr) if isinstance(i, Function)]
math_functions += [i.name for i in expr.atoms(FunctionSymbol)]
math_functions = [i for i in math_functions if i in _known_functions_math.values()]
math_functions = list(set(math_functions)) # remove redundancies

math_constants = [str(i) for i in preorder_traversal(expr) if isinstance(i, NumberSymbol)]
math_constants = [i for i in math_constants if i in _known_constants_math.values()]
math_constants = list(set(math_constants)) # remove redundancies

return math_functions + math_constants

#==========================================================================
def get_decorators(cls):
target = cls
decorators = {}
Expand All @@ -45,7 +65,7 @@ def visit_FunctionDef(node):
node_iter.visit(ast.parse(inspect.getsource(target)))
return decorators

#==============================================================================
#==========================================================================
def get_pyccel_imports_code():
code = ''
code += '\nfrom pyccel.decorators import types'
Expand All @@ -60,15 +80,15 @@ def get_pyccel_imports_code():

return code

#==============================================================================
#==========================================================================
def get_numpy_imports_code():
code = ''
code += '\nfrom numpy import zeros'
code += '\nfrom numpy import float64'

return code

#==============================================================================
#==========================================================================
def get_dependencies_code(user_functions):
code = ''
for f in user_functions:
Expand All @@ -78,7 +98,7 @@ def get_dependencies_code(user_functions):
return code


#==============================================================================
#==========================================================================
def parse_where_stmt(where_stmt):
"""syntactic parsing of the where statement."""

Expand Down Expand Up @@ -115,7 +135,7 @@ def parse_where_stmt(where_stmt):
return where_stmt


#==============================================================================
#==========================================================================
# TODO move as method of FunctionDef
def get_results_shape(func):
"""returns a dictionary that contains for each result, its shape. When using
Expand Down Expand Up @@ -175,7 +195,7 @@ def get_results_shape(func):
return d_shapes


#==============================================================================
#==========================================================================
def _get_default_value(var, op=None):
"""Returns the default value of a variable depending on its datatype and the
used operation."""
Expand Down

0 comments on commit dc7614c

Please sign in to comment.