###########################################################
# check min/max extrems
###########################################################
#
# For cumulated parameters (from forecast start) the recomputed hourly extreems are checked
#   - the problem is that the limits for e.g. 0-3 period should be very different to ones
#     for e.g. 0-120 even the value is divided by the period lenght in hours!

import traceback
import sys
import operator
from gribapi import *
import argparse

from grib_check_setup import params, def_undef

VERBOSE = 1  # verbose error reporting


def process_inline_args():
  # process inline arguments

  global args
  parser = argparse.ArgumentParser()
  parser.add_argument('inp_file', type=argparse.FileType('r'), nargs='*', help="enter input file name(s)")
  parser.add_argument("-v", "--verbosity", type=int, help="increase output verbosity [0-2]")
  parser.add_argument("-w", "--warning", action="store_true", help="warnings are treated as errors..")
  parser.add_argument("-his", "--history", action="store_true", help="add value ranges history, if exists, to listing..")
  parser.add_argument("-l", "--listing", type=str, nargs="?", const="all", help="list defined parameters (search if argument provided!)")

  args = parser.parse_args()

  if len(sys.argv[1:]) == 0:
    parser.print_help()
    parser.exit(1)

  # default
  if args.verbosity is None:
    args.verbosity = 0
  if args.warning is None:
    args.warning = False


def parse(keyValDef):
  # parsing above could be activated for future specific requirements
  #   e.g. to allow only selected step ranges to be matched etc.

  # parse ANY with the actual value from the grib
  keyValDefParsed = []
  for x in keyValDef:
    if "ANY" in x:
      key = x.split(":")[0]
      val = str(gribKeysValuesDict.get(key))
      keyValDefParsed.append(key+":"+val)
    else:
      keyValDefParsed.append(x)

  if args.verbosity >= 2:
    print("keyValDef = ", keyValDef)
    if keyValDefParsed != keyValDef:
      print("keyValDefParsed = ", keyValDefParsed)

  return keyValDefParsed


def check_limits():

  paramId_gribKeys = {}
  warnings = 0

  # open grib files
  for finp in args.inp_file:

    # go through all grib messages
    while 1:

      gid = grib_new_from_file(finp)
      if gid is None:
        break

      paramId = grib_get(gid, 'paramId', str)

      assert paramId in params.keys(), "Definitions for paramId %s not found!" % paramId

      paramName = grib_get(gid, 'name', str)
      fieldCount = grib_get(gid, 'count', str)
      step = grib_get(gid, "step", int)

      minVal = grib_get(gid, "min")
      maxVal = grib_get(gid, "max")

      if args.verbosity >= 1:
        print("\nProcessing file %s, field %s (%s, paramId=%s), minimum: %g, maximum: %g" % \
          (finp.name, fieldCount, paramName, paramId, minVal, maxVal))

      typeOfStatisticalProcessing = 0
      if grib_is_defined(gid, "typeOfStatisticalProcessing"):
        # Only "Accumulation" in hours starting from the beginning are handled!
        #  - if more general options are needed it can be fine tuned in grib_check_setup.py!
        #    (i.e. the kyes assert below could be moved to the setup file)

        typeOfStatisticalProcessing = grib_get(gid, "typeOfStatisticalProcessing", int)

        # Accumulation
        if typeOfStatisticalProcessing == 1:
          # typeOfTimeIncrement = 2 [Successive times processed have same start time of forecast, forecast time is incremented]
          assert grib_get(gid, "typeOfTimeIncrement", int) == 2
          # indicatorOfUnitForTimeRange = 1 [Hour]
          assert grib_get(gid, "indicatorOfUnitForTimeRange", int) == 1

          if step != 0:
            minVal = minVal / step
            maxVal = maxVal / step

          if args.verbosity >= 1:
            print("Accumulation (extremes recomputed)!, step %dh, minimum: %g, maximum: %g" % (step, minVal, maxVal))

      if paramId in paramId_gribKeys.keys():

        uniqKeysList = paramId_gribKeys[paramId]

      else:

        # find out which keys will be checked for given paramId
        #  - do it only once for given paramId
        keysValues1 = [x.split(",") for x in params[paramId]["def"].keys()]
        keysValues2 = [item.replace(" ", "") for sublist in keysValues1 for item in sublist]
        uniqKeysList = list(set([x.split(":")[0] for x in keysValues2]))
        paramId_gribKeys.update({paramId: uniqKeysList})

#       if args.verbosity >= 1 && len(uniqKeysList) > 1:
        if args.verbosity >= 1:
          print("...grib keys to be checked for paramId %s: %s" % (paramId, uniqKeysList))

      # default set up is selected if nothing better matches
      matchedKey = "default"

      if len(uniqKeysList) == 1:

        assert uniqKeysList[0] == "default",\
          "The default ranges set up must exist for paramId %s! Exiting.." % paramId

      else:

        #######################
        # Find the best match

        matched = []
        matchedLength = []

        # get actual grib keys values
        gribKeysValuesDict = {}
        for k in uniqKeysList:
          if grib_is_defined(gid, k):
            gribKeysValuesDict.update({k: grib_get(gid, k, str)})

        if args.verbosity >= 2:
          print("gribKeysValuesDict = ", gribKeysValuesDict)

        for k in params[paramId]["def"].keys():

          keyValDef = [x.replace(" ", "") for x in k.split(",")]  # ignore blanks
          gribValues = []

          for kk in [x.split(":")[0] for x in keyValDef]:
            gribValues.append(kk + ":" + str(gribKeysValuesDict.get(kk)))

#         keyValDefParsed = parse(keyValDef) # parsing above could be activated for future specific requirements
          keyValDefParsed = keyValDef

          # check if grib keys match the definition
          if set(gribValues).issubset(set(keyValDefParsed)):
            matched.append(k)
            matchedLength.append(len(keyValDefParsed))

          if len(matched) > 0:
            # check if there is any ambiguity what should be matched for paramId
            #   - the number of matched grib key values must be uniq
            #   - all defined grib kyes must be matched

            if len(matchedLength) > len(set(matchedLength)):
              print("There is an ambiguity what should be matched for paramId:", paramId)
              print("...matched definitions (only 1 must be matched!):")
              for x in range(0, len(matchedLength)):
                if matchedLength[x] == max(matchedLength):
                  print("%4d => %s" % (x, matched[x]))
              print("matched: ", matched)
              print("Check and fix the definition file! Exiting..")
              exit(1)

            # select the best match!
            matchedKey = matched[matchedLength.index(max(matchedLength))]

      matchedRanges = params[paramId]["def"][matchedKey][0][0]

      #################
      # check min/max!

      minL = matchedRanges[0]
      minH = matchedRanges[1]
      maxL = matchedRanges[2]
      maxH = matchedRanges[3]

      if typeOfStatisticalProcessing != 0 and step == 0:

        if not(minVal == 0 and maxVal == 0):
          warnings = + 1
          print("Min/max of accumulated parameter at step 0 must be 0! \
(%s, field %s (%s, paramId=%s), matched definition (%s)) minVal=%g, maxVal=%g)" % \
           (finp.name, fieldCount, paramName, paramId, matchedKey, minVal, maxVal))

#        assert minVal == 0 and maxVal == 0, \
#          "Min/max of accumulated parameter at step 0 must be 0! \
#(%s, field %s (%s, paramId=%s), matched definition (%s)) minVal=%g, maxVal=%g)" % \
#          (finp.name, fieldCount, paramName, paramId, matchedKey, minVal, maxVal)

      else:

        if not (minL <= minVal <= minH):
          warnings = + 1
          print("warning: %s, field %s (%s, paramId=%s), matched definition (%s) =>  minimum value %f is not in [%f,%f]" % \
            (finp.name, fieldCount, paramName, paramId, matchedKey, minVal, minL, minH))

        if not (maxL <= maxVal <= maxH):
          warnings = + 1
          print("warning: %s, field %s (%s, paramId=%s), matched definition (%s) =>  maximum value %f is not in [%f,%f]" % \
            (finp.name, fieldCount, paramName, paramId, matchedKey, maxVal, maxL, maxH))

        if args.verbosity >= 1:
          print("...matched definition (%s) =>  minimum ranges [%g,%g], maximum ranges [%g,%g]" % \
            (matchedKey, minL, minH, maxL, maxH))

      grib_release(gid)
     
    finp.close()

  if warnings > 0:
    if args.verbosity >= 1:
      print("...number of limit exceeding warnings: %d\n" % warnings)
    exit(1)


def listing():

  sortedByParamID = {}
  sortedByName = {}

  if args.listing == "all":

    for paramId in params.keys():
      sortedByName.update({paramId: params[paramId]["name"]})

  else:

    for paramId in params.keys():
      if args.listing in paramId:
        sortedByParamID.update({paramId: params[paramId]["name"]})
      if args.listing.lower() in params[paramId]["name"].lower():
        sortedByName.update({paramId: params[paramId]["name"]})

  sortedKeysByParamID = sorted(sortedByParamID.items(), key=operator.itemgetter(0))
  printOut(sortedKeysByParamID, params)

  if bool(sortedKeysByParamID):
    print("")

  sortedKeysByName = sorted(sortedByName.items(), key=operator.itemgetter(1))
  printOut(sortedKeysByName, params)


def printOut(sortedKeys, params):

  format1 = "%s   [%s]   [min <%g,%g> max <%g,%g>] (%s) %s"
  format2 = "%s   [min <%g,%g> max <%g,%g>] (%s) %s"
  undef_txt = "!!RANGES ARE UNDEFINED!!"

  for x in sortedKeys:
    for validFor in  params[x[0]]["def"].keys():
      s = params[x[0]]["def"][validFor][0][0]
      undef = undef_txt if s == def_undef else ""
      if validFor == "default":
        print(format1 % (x[0], x[1], float(s[0]), float(s[1]), float(s[2]), float(s[3]), validFor, undef))
        empty_string = ' ' * len("%s   [%s]" % (x[0], x[1]))
      else:
        if args.history:
          print(format1 % (x[0], x[1], float(s[0]), float(s[1]), float(s[2]), float(s[3]), validFor, undef))
        else:
          print(format2 % (empty_string, float(s[0]), float(s[1]), float(s[2]), float(s[3]), validFor, undef))
      if args.history:
        history(x[0], validFor)

def history(paramId, validFor):

  if paramId in params.keys():

    defs = params[paramId]["def"][validFor]

    print("")
    print("# History of changes: ")
    print("#")

    for s in defs:
      print("#  %3d.version (valid from: %s)" % (len(defs)-defs.index(s), s[1]))
      print("#     min <%g,%g>" % (float(s[0][0]), float(s[0][1])))
      print("#     max <%g,%g>" % (float(s[0][2]), float(s[0][3])))
      print("#        info:%s" % s[2])

  else:
    print("Definitions for paramId %s not found" % paramId)

  print("")


def main():

  try:
      process_inline_args()
      if args.listing is not None:
        listing()
      else:
        check_limits()

  except GribInternalError as err:
      if VERBOSE:
          traceback.print_exc(file=sys.stderr)
      else:
          sys.stderr.write(err.msg)

      return 1

if __name__ == "__main__":
  sys.exit(main())
