# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LLMFunction."""
from __future__ import annotations

import abc
import dataclasses
from typing import AbstractSet, Any, Callable, Iterable, Mapping, Optional, Sequence

from google.generativeai.notebook.lib import llmfn_input_utils
from google.generativeai.notebook.lib import llmfn_output_row
from google.generativeai.notebook.lib import llmfn_outputs
from google.generativeai.notebook.lib import llmfn_post_process
from google.generativeai.notebook.lib import llmfn_post_process_cmds
from google.generativeai.notebook.lib import model as model_lib
from google.generativeai.notebook.lib import prompt_utils


# In the same spirit as post-processing functions (see: llmfn_post_process.py),
# we keep the LLM functions more flexible by providing the entire left- and
# right-hand side rows to the user-defined comparison function.
#
# Possible use-cases include adding a scoring function as a post-process
# command, then comparing the scores.
CompareFn = Callable[
    [llmfn_output_row.LLMFnOutputRowView, llmfn_output_row.LLMFnOutputRowView],
    Any,
]


def _is_equal_fn(
    lhs: llmfn_output_row.LLMFnOutputRowView,
    rhs: llmfn_output_row.LLMFnOutputRowView,
) -> bool:
  """Default function used when comparing outputs."""
  return lhs.result_value() == rhs.result_value()


def _convert_compare_fn_to_batch_add_fn(
    fn: Callable[
        [
            llmfn_output_row.LLMFnOutputRowView,
            llmfn_output_row.LLMFnOutputRowView,
        ],
        Any,
    ]
) -> llmfn_post_process.LLMCompareFnPostProcessBatchAddFn:
  """Vectorize a single-row-based comparison function."""

  def _fn(
      lhs_and_rhs_rows: Sequence[
          tuple[
              llmfn_output_row.LLMFnOutputRowView,
              llmfn_output_row.LLMFnOutputRowView,
          ]
      ]
  ) -> Sequence[Any]:
    return [fn(lhs, rhs) for lhs, rhs in lhs_and_rhs_rows]

  return _fn


@dataclasses.dataclass
class _PromptInfo:
  prompt_num: int
  prompt: str
  input_num: int
  prompt_vars: Mapping[str, str]
  model_input: str


def _generate_prompts(
    prompts: Sequence[str], inputs: llmfn_input_utils.LLMFunctionInputs | None
) -> Iterable[_PromptInfo]:
  """Generate a tuple of fields needed for processing prompts.

  Args:
    prompts: A list of prompts, with optional keyword placeholders.
    inputs: A list of key/value pairs to substitute into placeholders in
      `prompts`.

  Yields:
    A _PromptInfo instance.
  """
  normalized_inputs: Sequence[Mapping[str, str]] = []
  if inputs is not None:
    normalized_inputs = llmfn_input_utils.to_normalized_inputs(inputs)

  # Must have at least one entry so that we execute the prompt at least once.
  if not normalized_inputs:
    normalized_inputs = [{}]

  for prompt_num, prompt in enumerate(prompts):
    for input_num, prompt_vars in enumerate(normalized_inputs):
      # Perform keyword substitution on the prompt based on `prompt_vars`.
      model_input = prompt.format(**prompt_vars)
      yield _PromptInfo(
          prompt_num=prompt_num,
          prompt=prompt,
          input_num=input_num,
          prompt_vars=prompt_vars,
          model_input=model_input,
      )


class LLMFunction(
    Callable[
        [Optional[llmfn_input_utils.LLMFunctionInputs]],
        llmfn_outputs.LLMFnOutputs,
    ],
    metaclass=abc.ABCMeta,
):
  """Base class for LLMFunctionImpl and LLMCompareFunction."""

  def __init__(
      self,
      outputs_ipython_display_fn: Callable[[llmfn_outputs.LLMFnOutputs], None]
      | None = None,
  ):
    """Constructor.

    Args:
      outputs_ipython_display_fn: Optional function that will be used to
        override how the outputs of this LLMFunction will be displayed in a
        notebook (See further documentation in LLMFnOutputs.__init__().)
    """
    self._post_process_cmds: list[
        llmfn_post_process_cmds.LLMFnPostProcessCommand
    ] = []
    self._outputs_ipython_display_fn = outputs_ipython_display_fn

  @abc.abstractmethod
  def get_placeholders(self) -> AbstractSet[str]:
    """Returns the placeholders that should be present in inputs for this function."""

  @abc.abstractmethod
  def _call_impl(
      self, inputs: llmfn_input_utils.LLMFunctionInputs | None
  ) -> Sequence[llmfn_outputs.LLMFnOutputEntry]:
    """Concrete implementation of __call__()."""

  def __call__(
      self, inputs: llmfn_input_utils.LLMFunctionInputs | None = None
  ) -> llmfn_outputs.LLMFnOutputs:
    """Runs and returns results based on `inputs`."""
    outputs = self._call_impl(inputs)

    return llmfn_outputs.LLMFnOutputs(
        outputs=outputs, ipython_display_fn=self._outputs_ipython_display_fn
    )

  def add_post_process_reorder_fn(
      self, name: str, fn: llmfn_post_process.LLMFnPostProcessBatchReorderFn
  ) -> LLMFunction:
    self._post_process_cmds.append(
        llmfn_post_process_cmds.LLMFnPostProcessReorderCommand(name=name, fn=fn)
    )
    return self

  def add_post_process_add_fn(
      self,
      name: str,
      fn: llmfn_post_process.LLMFnPostProcessBatchAddFn,
  ) -> LLMFunction:
    self._post_process_cmds.append(
        llmfn_post_process_cmds.LLMFnPostProcessAddCommand(name=name, fn=fn)
    )
    return self

  def add_post_process_replace_fn(
      self,
      name: str,
      fn: llmfn_post_process.LLMFnPostProcessBatchReplaceFn,
  ) -> LLMFunction:
    self._post_process_cmds.append(
        llmfn_post_process_cmds.LLMFnPostProcessReplaceCommand(name=name, fn=fn)
    )
    return self


class LLMFunctionImpl(LLMFunction):
  """Callable class that executes the contents of a Magics cell.

  An LLMFunction is constructed from the Magics command line and cell contents
  specified by the user. It is defined by:
  - A model instance,
  - Model arguments
  - A prompt template (e.g. "the opposite of hot is {word}") with an optional
    keyword placeholder.

  The LLMFunction takes as its input a sequence of dictionaries containing
  values for keyword replacement, e.g. [{"word": "hot"}, {"word": "tall"}].

  This will cause the model to be executed with the following prompts:
    "The opposite of hot is"
    "The opposite of tall is"

  The results will be returned in a LLMFnOutputs instance.
  """

  def __init__(
      self,
      model: model_lib.AbstractModel,
      prompts: Sequence[str],
      model_args: model_lib.ModelArguments | None = None,
      outputs_ipython_display_fn: Callable[[llmfn_outputs.LLMFnOutputs], None]
      | None = None,
  ):
    """Constructor.

    Args:
      model: The model that the prompts will execute on.
      prompts: A sequence of prompt templates with optional placeholders. The
        placeholders will be replaced by the inputs passed into this function.
      model_args: Optional set of model arguments to configure how the model
        executes the prompts.
      outputs_ipython_display_fn: See documentation in LLMFunction.__init__().
    """
    super().__init__(outputs_ipython_display_fn=outputs_ipython_display_fn)
    self._model = model
    self._prompts = prompts
    self._model_args = (
        model_lib.ModelArguments() if model_args is None else model_args
    )

    # Compute placeholders.
    self._placeholders = frozenset({})
    for prompt in self._prompts:
      self._placeholders = self._placeholders.union(
          prompt_utils.get_placeholders(prompt)
      )

  def _run_post_processing_cmds(
      self, results: Sequence[llmfn_output_row.LLMFnOutputRow]
  ) -> Sequence[llmfn_output_row.LLMFnOutputRow]:
    """Runs post-processing commands over `results`."""
    for cmd in self._post_process_cmds:
      try:
        if isinstance(cmd, llmfn_post_process_cmds.LLMFnImplPostProcessCommand):
          results = cmd.run(results)
        else:
          raise llmfn_post_process.PostProcessExecutionError(
              "Unsupported post-process command type: {}".format(type(cmd))
          )
      except llmfn_post_process.PostProcessExecutionError:
        raise
      except RuntimeError as e:
        raise llmfn_post_process.PostProcessExecutionError(
            'Error executing "{}", got {}: {}'.format(
                cmd.name(), type(e).__name__, e
            )
        )
    return results

  def get_placeholders(self) -> AbstractSet[str]:
    return self._placeholders

  def _call_impl(
      self, inputs: llmfn_input_utils.LLMFunctionInputs | None
  ) -> Sequence[llmfn_outputs.LLMFnOutputEntry]:
    results: list[llmfn_outputs.LLMFnOutputEntry] = []
    for info in _generate_prompts(prompts=self._prompts, inputs=inputs):
      model_results = self._model.call_model(
          model_input=info.model_input, model_args=self._model_args
      )
      output_rows: list[llmfn_output_row.LLMFnOutputRow] = []
      for result_num, text_result in enumerate(model_results.text_results):
        output_rows.append(
            llmfn_output_row.LLMFnOutputRow(
                data={
                    llmfn_outputs.ColumnNames.RESULT_NUM: result_num,
                    llmfn_outputs.ColumnNames.TEXT_RESULT: text_result,
                },
                result_type=str,
            )
        )
      results.append(
          llmfn_outputs.LLMFnOutputEntry(
              prompt_num=info.prompt_num,
              input_num=info.input_num,
              prompt=info.prompt,
              prompt_vars=info.prompt_vars,
              model_input=info.model_input,
              model_results=model_results,
              output_rows=self._run_post_processing_cmds(output_rows),
          )
      )
    return results


class LLMCompareFunction(LLMFunction):
  """LLMFunction for comparisons.

  LLMCompareFunction runs an input over a pair of LLMFunctions and compares the
  result.
  """

  def __init__(
      self,
      lhs_name_and_fn: tuple[str, LLMFunction],
      rhs_name_and_fn: tuple[str, LLMFunction],
      compare_name_and_fns: Sequence[tuple[str, CompareFn]] | None = None,
      outputs_ipython_display_fn: Callable[[llmfn_outputs.LLMFnOutputs], None]
      | None = None,
  ):
    """Constructor.

    Args:
      lhs_name_and_fn: Name and function for the left-hand side of the
        comparison.
      rhs_name_and_fn: Name and function for the right-hand side of the
        comparison.
      compare_name_and_fns: Optional names and functions for comparing the
        results of the left- and right-hand sides.
      outputs_ipython_display_fn: See documentation in LLMFunction.__init__().
    """
    super().__init__(outputs_ipython_display_fn=outputs_ipython_display_fn)
    self._lhs_name: str = lhs_name_and_fn[0]
    self._lhs_fn: LLMFunction = lhs_name_and_fn[1]
    self._rhs_name: str = rhs_name_and_fn[0]
    self._rhs_fn: LLMFunction = rhs_name_and_fn[1]
    self._placeholders = frozenset(self._lhs_fn.get_placeholders()).union(
        self._rhs_fn.get_placeholders()
    )

    if not compare_name_and_fns:
      self._result_name = "is_equal"
      self._result_compare_fn = _is_equal_fn
    else:
      # Assume the last entry in `compare_name_and_fns` is the one that
      # produces value for the result cell.
      name, fn = compare_name_and_fns[-1]
      self._result_name = name
      self._result_compare_fn = fn

      # Treat the other compare_fns as post-processing operators.
      for name, cmp_fn in compare_name_and_fns[:-1]:
        self.add_compare_post_process_add_fn(
            name=name, fn=_convert_compare_fn_to_batch_add_fn(cmp_fn)
        )

  def _run_post_processing_cmds(
      self,
      lhs_output_rows: Sequence[llmfn_output_row.LLMFnOutputRow],
      rhs_output_rows: Sequence[llmfn_output_row.LLMFnOutputRow],
      results: Sequence[llmfn_output_row.LLMFnOutputRow],
  ) -> Sequence[llmfn_output_row.LLMFnOutputRow]:
    """Runs post-processing commands over `results`."""
    for cmd in self._post_process_cmds:
      try:
        if isinstance(cmd, llmfn_post_process_cmds.LLMFnImplPostProcessCommand):
          results = cmd.run(results)
        elif isinstance(
            cmd, llmfn_post_process_cmds.LLMCompareFnPostProcessCommand
        ):
          results = cmd.run(
              list(zip(lhs_output_rows, rhs_output_rows, results))
          )
        else:
          raise RuntimeError(
              "Unsupported post-process command type: {}".format(type(cmd))
          )
      except llmfn_post_process.PostProcessExecutionError:
        raise
      except RuntimeError as e:
        raise llmfn_post_process.PostProcessExecutionError(
            'Error executing "{}", got {}: {}'.format(
                cmd.name(), type(e).__name__, e
            )
        )
    return results

  def get_placeholders(self) -> AbstractSet[str]:
    return self._placeholders

  def _call_impl(
      self, inputs: llmfn_input_utils.LLMFunctionInputs | None
  ) -> Sequence[llmfn_outputs.LLMFnOutputEntry]:
    lhs_results = self._lhs_fn(inputs)
    rhs_results = self._rhs_fn(inputs)

    # Combine the results.
    outputs: list[llmfn_outputs.LLMFnOutputEntry] = []
    for lhs_entry, rhs_entry in zip(lhs_results, rhs_results):
      if lhs_entry.prompt_num != rhs_entry.prompt_num:
        raise RuntimeError(
            "Prompt num mismatch: {} vs {}".format(
                lhs_entry.prompt_num, rhs_entry.prompt_num
            )
        )
      if lhs_entry.input_num != rhs_entry.input_num:
        raise RuntimeError(
            "Input num mismatch: {} vs {}".format(
                lhs_entry.input_num, rhs_entry.input_num
            )
        )
      if lhs_entry.prompt_vars != rhs_entry.prompt_vars:
        raise RuntimeError(
            "Prompt vars mismatch: {} vs {}".format(
                lhs_entry.prompt_vars, rhs_entry.prompt_vars
            )
        )

      # The two functions may have different numbers of results due to
      # options like candidate_count, so we can only compare up to the
      # minimum of the two.
      num_output_rows = min(
          len(lhs_entry.output_rows), len(rhs_entry.output_rows)
      )
      lhs_output_rows = lhs_entry.output_rows[:num_output_rows]
      rhs_output_rows = rhs_entry.output_rows[:num_output_rows]
      output_rows: list[llmfn_output_row.LLMFnOutputRow] = []
      for result_num, lhs_and_rhs_output_row in enumerate(
          zip(lhs_output_rows, rhs_output_rows)
      ):
        lhs_output_row, rhs_output_row = lhs_and_rhs_output_row

        # Combine cells from lhs_output_row and rhs_output_row into a
        # single row.
        # Although it is possible for RESULT_NUM (the index of each
        # text_result if a prompt produces multiple text_results) to be
        # different between the left and right sides, we ignore their
        # RESULT_NUM entries and write our own.
        row_data: dict[str, Any] = {
            llmfn_outputs.ColumnNames.RESULT_NUM: result_num,
            self._result_name: self._result_compare_fn(
                lhs_output_row, rhs_output_row
            ),
        }
        output_row = llmfn_output_row.LLMFnOutputRow(
            data=row_data, result_type=Any
        )

        # Add the prompt vars.
        output_row.add(
            llmfn_outputs.ColumnNames.PROMPT_VARS, lhs_entry.prompt_vars
        )

        # Add the results from the left-hand side and right-hand side.
        for name, row in [
            (self._lhs_name, lhs_output_row),
            (self._rhs_name, rhs_output_row),
        ]:
          for k, v in row.items():
            if k != llmfn_outputs.ColumnNames.RESULT_NUM:
              # We use LLMFnOutputRow.add() because it handles column
              # name collisions.
              output_row.add("{}_{}".format(name, k), v)

        output_rows.append(output_row)

      outputs.append(
          llmfn_outputs.LLMFnOutputEntry(
              prompt_num=lhs_entry.prompt_num,
              input_num=lhs_entry.input_num,
              prompt_vars=lhs_entry.prompt_vars,
              output_rows=self._run_post_processing_cmds(
                  lhs_output_rows=lhs_output_rows,
                  rhs_output_rows=rhs_output_rows,
                  results=output_rows,
              ),
          )
      )
    return outputs

  def add_compare_post_process_add_fn(
      self,
      name: str,
      fn: llmfn_post_process.LLMCompareFnPostProcessBatchAddFn,
  ) -> LLMFunction:
    self._post_process_cmds.append(
        llmfn_post_process_cmds.LLMCompareFnPostProcessAddCommand(
            name=name, fn=fn
        )
    )
    return self
