Skip to content

Commit

Permalink
#2725 Make all versions consider NEMO_FUNCTIONS
Browse files Browse the repository at this point in the history
  • Loading branch information
sergisiso committed Oct 1, 2024
1 parent e957c87 commit f663f16
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 37 deletions.
24 changes: 0 additions & 24 deletions examples/nemo/scripts/acc_kernels_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,6 @@
"bdytide_init", "bdy_init", "bdy_segs", "sbc_cpl_init",
"asm_inc_init", "dia_obs_init"] # Str handling, init routine

# Currently fparser has no way of distinguishing array accesses from
# function calls if the symbol is imported from some other module.
# We therefore work-around this by keeping a list of known NEMO
# functions that must be excluded from within KERNELS regions.
NEMO_FUNCTIONS = ["alpha_charn", "cd_neutral_10m", "cpl_freq", "cp_air",
"eos_pt_from_ct", "gamma_moist", "l_vap",
"sbc_dcy", "solfrac", "psi_h", "psi_m", "psi_m_coare",
"psi_h_coare", "psi_m_ecmwf", "psi_h_ecmwf", "q_sat",
"rho_air", "visc_air", "sbc_dcy", "glob_sum",
"glob_sum_full", "ptr_sj", "ptr_sjk", "interp1", "interp2",
"interp3", "integ_spline"]


class ExcludeSettings():
'''
Class to hold settings on what to exclude from OpenACC KERNELS regions.
Expand Down Expand Up @@ -260,17 +247,6 @@ def valid_acc_kernel(node):
"other loops", enode)
return False

# Finally, check that we haven't got any 'array accesses' that are in
# fact function calls.
refs = node.walk(ArrayReference)
for ref in refs:
# Check if this reference has the name of a known function and if that
# reference appears outside said known function.
if ref.name.lower() in NEMO_FUNCTIONS and \
ref.name.lower() != routine_name.lower():
log_msg(routine_name,
f"Loop contains function call: {ref.name}", ref)
return False
return True


Expand Down
3 changes: 1 addition & 2 deletions examples/nemo/scripts/acc_loops_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ def trans(psyir):

# This are functions that are called from inside parallel regions,
# annotate them with 'acc routine'
if subroutine.name.lower().startswith("sign_") or subroutine.name in (
"solfrac", ):
if subroutine.name.lower().startswith("sign_"):
ACCRoutineTrans().apply(subroutine)
print(f"Marked {subroutine.name} as GPU-enabled")
continue
Expand Down
3 changes: 1 addition & 2 deletions examples/nemo/scripts/omp_gpu_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ def trans(psyir):

# This are functions that are called from inside parallel regions,
# annotate them with 'omp declare target'
if subroutine.name.lower().startswith("sign_") or subroutine.name in (
"solfrac", ):
if subroutine.name.lower().startswith("sign_"):
OMPDeclareTargetTrans().apply(subroutine)
print(f"Marked {subroutine.name} as GPU-enabled")
# We continue parallelising inside the routine, but this could
Expand Down
33 changes: 24 additions & 9 deletions examples/nemo/scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@
"interp1", "interp2", "interp3", "integ_spline", "sbc_dcy",
"sum", "sign_", "ddpdd"]

# Currently fparser has no way of distinguishing array accesses from
# function calls if the symbol is imported from some other module.
# We therefore work-around this by keeping a list of known NEMO functions.
NEMO_FUNCTIONS = ["alpha_charn", "cd_neutral_10m", "cpl_freq", "cp_air",
"eos_pt_from_ct", "gamma_moist", "l_vap",
"sbc_dcy", "solfrac", "psi_h", "psi_m", "psi_m_coare",
"psi_h_coare", "psi_m_ecmwf", "psi_h_ecmwf", "q_sat",
"rho_air", "visc_air", "sbc_dcy", "glob_sum",
"glob_sum_full", "ptr_sj", "ptr_sjk", "interp1", "interp2",
"interp3", "integ_spline"]


VERBOSE = False


Expand Down Expand Up @@ -131,15 +143,18 @@ def enhance_tree_information(schedule):
ArrayType.Extent.ATTRIBUTE,
ArrayType.Extent.ATTRIBUTE,
ArrayType.Extent.ATTRIBUTE]))
elif reference.symbol.name == "sbc_dcy":
# The parser gets this wrong, it is a Call not an Array access
if not isinstance(reference.symbol, RoutineSymbol):
# We haven't already specialised this Symbol.
reference.symbol.specialise(RoutineSymbol)
call = Call.create(reference.symbol)
for child in reference.children:
call.addchild(child.detach())
reference.replace_with(call)
elif reference.symbol.name in NEMO_FUNCTIONS:
if reference.symbol.is_import or reference.symbol.is_unresolved:
# The parser gets these wrong, they are Calls not ArrayRefs
if not isinstance(reference.symbol, RoutineSymbol):
# We haven't already specialised this Symbol.
reference.symbol.specialise(RoutineSymbol)
if not (isinstance(reference.parent, Call) and
reference.parent.routine is reference):
call = Call.create(reference.symbol)
for child in reference.children[:]:
call.addchild(child.detach())
reference.replace_with(call)


def inline_calls(schedule):
Expand Down

0 comments on commit f663f16

Please sign in to comment.