# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classes that define arguments for populating ArgumentParser.

The argparse module's ArgumentParser.add_argument() takes several parameters and
is quite customizable. However this can lead to bugs where arguments do not
behave as expected.

For better ease-of-use and better testability, define a set of classes for the
types of flags used by LLM Magics.

Sample usage:

  str_flag = SingleValueFlagDef(name="title", required=True)
  enum_flag = EnumFlagDef(name="colors", required=True, enum_type=ColorsEnum)

  str_flag.add_argument_to_parser(my_parser)
  enum_flag.add_argument_to_parser(my_parser)
"""
from __future__ import annotations

import abc
import argparse
import dataclasses
import enum
from typing import Any, Callable, Sequence, Union, Tuple

from google.generativeai.notebook.lib import llmfn_inputs_source
from google.generativeai.notebook.lib import llmfn_outputs

# These are the intermediate types that argparse.ArgumentParser.parse_args()
# will pass command line arguments into.
_PARSETYPES = Union[str, int, float]
# These are the final result types that the intermediate parsed values will be
# converted into. It is a superset of _PARSETYPES because we support converting
# the parsed type into a more precise type, e.g. from str to Enum.
_DESTTYPES = Union[
    _PARSETYPES,
    enum.Enum,
    Tuple[str, Callable[[str, str], Any]],  # For --compare_fn
    Sequence[str],  # For --ground_truth
    llmfn_inputs_source.LLMFnInputsSource,  # For --inputs
    llmfn_outputs.LLMFnOutputsSink,  # For --outputs
]

# The signature of a function that converts a command line argument from the
# intermediate parsed type to the result type.
_PARSEFN = Callable[[_PARSETYPES], _DESTTYPES]


def _get_type_name(x: type[Any]) -> str:
  try:
    return x.__name__
  except AttributeError:
    return str(x)


def _validate_flag_name(name: str) -> str:
  """Validation for long and short names for flags."""
  if not name:
    raise ValueError("Cannot be empty")
  if name[0] == "-":
    raise ValueError("Cannot start with dash")
  return name


@dataclasses.dataclass(frozen=True)
class FlagDef(abc.ABC):
  """Abstract base class for flag definitions.

  Attributes:
    name: Long name, e.g. "colors" will define the flag "--colors".
    required: Whether the flag must be provided on the command line.
    short_name: Optional short name.
    parse_type: The type that ArgumentParser should parse the command line
      argument to.
    dest_type: The type that the parsed value is converted to. This is used when
      we want ArgumentParser to parse as one type, then convert to a different
      type. E.g. for enums we parse as "str" then convert to the desired enum
      type in order to provide cleaner help messages.
    parse_to_dest_type_fn: If provided, this function will be used to convert
      the value from `parse_type` to `dest_type`. This can be used for
      validation as well.
    choices: If provided, limit the set of acceptable values to these choices.
    help_msg: If provided, adds help message when -h is used in the command
      line.
  """

  name: str
  required: bool = False

  short_name: str | None = None

  parse_type: type[_PARSETYPES] = str
  dest_type: type[_DESTTYPES] | None = None
  parse_to_dest_type_fn: _PARSEFN | None = None

  choices: list[_PARSETYPES] | None = None
  help_msg: str | None = None

  @abc.abstractmethod
  def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None:
    """Adds this flag as an argument to `parser`.

    Child classes should implement this as a call to parser.add_argument()
    with the appropriate parameters.

    Args:
      parser: The parser to which this argument will be added.
    """

  @abc.abstractmethod
  def _do_additional_validation(self) -> None:
    """For child classes to do additional validation."""

  def _get_dest_type(self) -> type[_DESTTYPES]:
    """Returns the final converted type."""
    return self.parse_type if self.dest_type is None else self.dest_type

  def _get_parse_to_dest_type_fn(
      self,
  ) -> _PARSEFN:
    """Returns a function to convert from parse_type to dest_type."""
    if self.parse_to_dest_type_fn is not None:
      return self.parse_to_dest_type_fn

    dest_type = self._get_dest_type()
    if dest_type == self.parse_type:
      return lambda x: x
    else:
      return dest_type

  def __post_init__(self):
    _validate_flag_name(self.name)
    if self.short_name is not None:
      _validate_flag_name(self.short_name)

    self._do_additional_validation()


def _has_non_default_value(
    namespace: argparse.Namespace,
    dest: str,
    has_default: bool = False,
    default_value: Any = None,
) -> bool:
  """Returns true if `namespace.dest` is set to a non-default value.

  Args:
    namespace: The Namespace that is populated by ArgumentParser.
    dest: The attribute in the Namespacde to be populated.
    has_default: "None" is a valid default value so we use an additional
      `has_default` boolean to indicate that `default_value` is present.
    default_value: The default value to use when `has_default` is True.

  Returns:
    Whether namespace.dest is set to something other than the default value.
  """
  if not hasattr(namespace, dest):
    return False

  if not has_default:
    # No default value provided so `namespace.dest` cannot possibly be equal to
    # the default value.
    return True

  return getattr(namespace, dest) != default_value


class _SingleValueStoreAction(argparse.Action):
  """Custom Action for storing a value in an argparse.Namespace.

  This action checks that the flag is specified at-most once.
  """

  def __init__(
      self,
      option_strings,
      dest,
      dest_type: type[Any],
      parse_to_dest_type_fn: _PARSEFN,
      **kwargs,
  ):
    super().__init__(option_strings, dest, **kwargs)
    self._dest_type = dest_type
    self._parse_to_dest_type_fn = parse_to_dest_type_fn

  def __call__(
      self,
      parser: argparse.ArgumentParser,
      namespace: argparse.Namespace,
      values: str | Sequence[Any] | None,
      option_string: str | None = None,
  ):
    # Because `nargs` is set to 1, `values` must be a Sequence, rather
    # than a string.
    assert not isinstance(values, str) and not isinstance(values, bytes)

    if _has_non_default_value(
        namespace,
        self.dest,
        has_default=hasattr(self, "default"),
        default_value=getattr(self, "default"),
    ):
      raise argparse.ArgumentError(
          self, "Cannot set {} more than once".format(option_string)
      )

    try:
      converted_value = self._parse_to_dest_type_fn(values[0])
    except Exception as e:
      raise argparse.ArgumentError(
          self,
          'Error with value "{}", got {}: {}'.format(
              values[0], _get_type_name(type(e)), e
          ),
      )

    if not isinstance(converted_value, self._dest_type):
      raise RuntimeError(
          "Converted to wrong type, expected {} got {}".format(
              _get_type_name(self._dest_type),
              _get_type_name(type(converted_value)),
          )
      )
    setattr(namespace, self.dest, converted_value)


class _MultiValuesAppendAction(argparse.Action):
  """Custom Action for appending values in an argparse.Namespace.

  This action checks that the flag is specified at-most once.
  """

  def __init__(
      self,
      option_strings,
      dest,
      dest_type: type[Any],
      parse_to_dest_type_fn: _PARSEFN,
      **kwargs,
  ):
    super().__init__(option_strings, dest, **kwargs)
    self._dest_type = dest_type
    self._parse_to_dest_type_fn = parse_to_dest_type_fn

  def __call__(
      self,
      parser: argparse.ArgumentParser,
      namespace: argparse.Namespace,
      values: str | Sequence[Any] | None,
      option_string: str | None = None,
  ):
    # Because `nargs` is set to "+", `values` must be a Sequence, rather
    # than a string.
    assert not isinstance(values, str) and not isinstance(values, bytes)

    curr_value = getattr(namespace, self.dest)
    if curr_value:
      raise argparse.ArgumentError(
          self, "Cannot set {} more than once".format(option_string)
      )

    for value in values:
      try:
        converted_value = self._parse_to_dest_type_fn(value)
      except Exception as e:
        raise argparse.ArgumentError(
            self,
            'Error with value "{}", got {}: {}'.format(
                values[0], _get_type_name(type(e)), e
            ),
        )

      if not isinstance(converted_value, self._dest_type):
        raise RuntimeError(
            "Converted to wrong type, expected {} got {}".format(
                self._dest_type, type(converted_value)
            )
        )
      if converted_value in curr_value:
        raise argparse.ArgumentError(
            self, 'Duplicate values "{}"'.format(value)
        )

      curr_value.append(converted_value)


class _BooleanValueStoreAction(argparse.Action):
  """Custom Action for setting a boolean value in argparse.Namespace.

  The boolean flag expects the default to be False and will set the value to
  True.
  This action checks that the flag is specified at-most once.
  """

  def __init__(
      self,
      option_strings,
      dest,
      **kwargs,
  ):
    super().__init__(option_strings, dest, **kwargs)

  def __call__(
      self,
      parser: argparse.ArgumentParser,
      namespace: argparse.Namespace,
      values: str | Sequence[Any] | None,
      option_string: str | None = None,
  ):
    if _has_non_default_value(
        namespace,
        self.dest,
        has_default=True,
        default_value=False,
    ):
      raise argparse.ArgumentError(
          self, "Cannot set {} more than once".format(option_string)
      )

    setattr(namespace, self.dest, True)


@dataclasses.dataclass(frozen=True)
class SingleValueFlagDef(FlagDef):
  """Definition for a flag that takes a single value.

  Sample usage:
    # This defines a flag that can be specified on the command line as:
    #   --count=10
    flag = SingleValueFlagDef(name="count", parse_type=int, required=True)
    flag.add_argument_to_parser(argument_parser)

  Attributes:
    default_value: Default value for optional flags.
  """

  class _DefaultValue(enum.Enum):
    """Special value to represent "no value provided".

    "None" can be used as a default value, so in order to differentiate between
    "None" and "no value provided", create a special value for "no value
    provided".
    """

    NOT_SET = None

  default_value: _DESTTYPES | _DefaultValue | None = _DefaultValue.NOT_SET

  def _has_default_value(self) -> bool:
    """Returns whether `default_value` has been provided."""
    return self.default_value != SingleValueFlagDef._DefaultValue.NOT_SET

  def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None:
    args = ["--" + self.name]
    if self.short_name is not None:
      args += ["-" + self.short_name]

    kwargs = {}
    if self._has_default_value():
      kwargs["default"] = self.default_value
    if self.choices is not None:
      kwargs["choices"] = self.choices
    if self.help_msg is not None:
      kwargs["help"] = self.help_msg

    parser.add_argument(
        *args,
        action=_SingleValueStoreAction,
        type=self.parse_type,
        dest_type=self._get_dest_type(),
        parse_to_dest_type_fn=self._get_parse_to_dest_type_fn(),
        required=self.required,
        nargs=1,
        **kwargs,
    )

  def _do_additional_validation(self) -> None:
    if self.required:
      if self._has_default_value():
        raise ValueError("Required flags cannot have default value")
    else:
      if not self._has_default_value():
        raise ValueError("Optional flags must have a default value")

    if self._has_default_value() and self.default_value is not None:
      if not isinstance(self.default_value, self._get_dest_type()):
        raise ValueError(
            "Default value must be of the same type as the destination type"
        )


class EnumFlagDef(SingleValueFlagDef):
  """Definition for a flag that takes a value from an Enum.

  Sample usage:
    # This defines a flag that can be specified on the command line as:
    #   --color=red
    flag = SingleValueFlagDef(name="color", enum_type=ColorsEnum,
                              required=True)
    flag.add_argument_to_parser(argument_parser)
  """

  def __init__(self, *args, enum_type: type[enum.Enum], **kwargs):
    if not issubclass(enum_type, enum.Enum):
      raise TypeError('"enum_type" must be of type Enum')

    # These properties are set by "enum_type" so don"t let the caller set them.
    if "parse_type" in kwargs:
      raise ValueError(
          'Cannot set "parse_type" for EnumFlagDef; set "enum_type" instead'
      )
    kwargs["parse_type"] = str

    if "dest_type" in kwargs:
      raise ValueError(
          'Cannot set "dest_type" for EnumFlagDef; set "enum_type" instead'
      )
    kwargs["dest_type"] = enum_type

    if "choices" in kwargs:
      # Verify that entries in `choices` are valid enum values.
      for x in kwargs["choices"]:
        try:
          enum_type(x)
        except ValueError:
          raise ValueError(
              'Invalid value in "choices": "{}"'.format(x)
          ) from None
    else:
      kwargs["choices"] = [x.value for x in enum_type]

    super().__init__(*args, **kwargs)


class MultiValuesFlagDef(FlagDef):
  """Definition for a flag that takes multiple values.

  Sample usage:
    # This defines a flag that can be specified on the command line as:
    #   --colors=red green blue
    flag = MultiValuesFlagDef(name="colors", parse_type=str, required=True)
    flag.add_argument_to_parser(argument_parser)
  """

  def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None:
    args = ["--" + self.name]
    if self.short_name is not None:
      args += ["-" + self.short_name]

    kwargs = {}
    if self.choices is not None:
      kwargs["choices"] = self.choices
    if self.help_msg is not None:
      kwargs["help"] = self.help_msg

    parser.add_argument(
        *args,
        action=_MultiValuesAppendAction,
        type=self.parse_type,
        dest_type=self._get_dest_type(),
        parse_to_dest_type_fn=self._get_parse_to_dest_type_fn(),
        required=self.required,
        default=[],
        nargs="+",
        **kwargs,
    )

  def _do_additional_validation(self) -> None:
    # No additional validation needed.
    pass


@dataclasses.dataclass(frozen=True)
class BooleanFlagDef(FlagDef):
  """Definition for a Boolean flag.

  A boolean flag is always optional with a default value of False. The flag does
  not take any values. Specifying the flag on the commandline will set it to
  True.
  """

  def _do_additional_validation(self) -> None:
    if self.dest_type is not None:
      raise ValueError("dest_type cannot be set for BooleanFlagDef")
    if self.parse_to_dest_type_fn is not None:
      raise ValueError("parse_to_dest_type_fn cannot be set for BooleanFlagDef")
    if self.choices is not None:
      raise ValueError("choices cannot be set for BooleanFlagDef")

  def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None:
    args = ["--" + self.name]
    if self.short_name is not None:
      args += ["-" + self.short_name]

    kwargs = {}
    if self.help_msg is not None:
      kwargs["help"] = self.help_msg

    parser.add_argument(
        *args,
        action=_BooleanValueStoreAction,
        type=bool,
        required=False,
        default=False,
        nargs=0,
        **kwargs,
    )
