# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module that holds a global gspread.client.Client."""
from __future__ import annotations

import abc
import datetime
from typing import Any, Callable, Mapping, Sequence
from google.auth import credentials
from google.generativeai.notebook import html_utils
from google.generativeai.notebook import ipython_env
from google.generativeai.notebook import sheets_id


# The code may be running in an environment where the gspread library has not
# been installed.
_gspread_import_error: Exception | None = None
try:
  # pylint: disable-next=g-import-not-at-top
  from gspread import gspread
except (ImportError, ModuleNotFoundError):
  try:
    # pylint: disable-next=g-import-not-at-top
    import gspread
  except ImportError as e:
    _gspread_import_error = e
    gspread = None

# Base class of exceptions that  gspread.open(), open_by_url() and open_by_key()
# may throw.
GSpreadException = Exception if gspread is None else gspread.exceptions.GSpreadException  # type: ignore


class SpreadsheetNotFoundError(RuntimeError):
  pass


def _get_import_error() -> Exception:
  return RuntimeError(
      '"gspread" module not imported, got: {}'.format(_gspread_import_error)
  )


class GSpreadClient(abc.ABC):
  """Wrapper around gspread.client.Client.

  This adds a layer of indirection for us to inject mocks for testing.
  """

  @abc.abstractmethod
  def validate(self, sid: sheets_id.SheetsIdentifier) -> None:
    """Validates that `name` is the name of a Google Sheets document.

    Raises an exception if false.

    Args:
      sid: The identifier for the document.
    """

  @abc.abstractmethod
  def get_all_records(
      self,
      sid: sheets_id.SheetsIdentifier,
      worksheet_id: int,
  ) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]:
    """Returns all records for a Google Sheets worksheet."""

  @abc.abstractmethod
  def write_records(
      self,
      sid: sheets_id.SheetsIdentifier,
      rows: Sequence[Sequence[Any]],
  ) -> None:
    """Writes results to a new worksheet to the Google Sheets document."""


class GSpreadClientImpl(GSpreadClient):
  """Concrete implementation of GSpreadClient."""

  def __init__(self, client: Any, env: ipython_env.IPythonEnv | None):
    """Constructor.

    Args:
      client: Instance of gspread.client.Client.
      env: Optional instance of IPythonEnv. This is used to display messages
        such as the URL of the output Worksheet.
    """
    self._client = client
    self._ipython_env = env

  def _open(self, sid: sheets_id.SheetsIdentifier):
    """Opens a Sheets document from `sid`.

    Args:
      sid: The identifier for the Sheets document.

    Raises:
      SpreadsheetNotFoundError: If the Sheets document cannot be found or
        cannot be opened.

    Returns:
      A gspread.Worksheet instance representing the worksheet referred to by
      `sid`.
    """
    try:
      if sid.name():
        return self._client.open(sid.name())
      if sid.key():
        return self._client.open_by_key(str(sid.key()))
      if sid.url():
        return self._client.open_by_url(str(sid.url()))
    except GSpreadException as exc:
      raise SpreadsheetNotFoundError(
          "Unable to find Sheets with {}".format(sid)
      ) from exc
    raise SpreadsheetNotFoundError("Invalid sheets_id.SheetsIdentifier")

  def validate(self, sid: sheets_id.SheetsIdentifier) -> None:
    self._open(sid)

  def get_all_records(
      self,
      sid: sheets_id.SheetsIdentifier,
      worksheet_id: int,
  ) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]:
    sheet = self._open(sid)
    worksheet = sheet.get_worksheet(worksheet_id)

    if self._ipython_env is not None:
      env = self._ipython_env

      def _display_fn():
        env.display_html(
            "Reading inputs from worksheet {}".format(
                html_utils.get_anchor_tag(
                    url=sheets_id.SheetsURL(worksheet.url),
                    text="{} in {}".format(worksheet.title, sheet.title),
                )
            )
        )

    else:

      def _display_fn():
        print(
            "Reading inputs from worksheet {} in {}".format(
                worksheet.title, sheet.title
            )
        )

    return worksheet.get_all_records(), _display_fn

  def write_records(
      self,
      sid: sheets_id.SheetsIdentifier,
      rows: Sequence[Sequence[Any]],
  ) -> None:
    sheet = self._open(sid)

    # Create a new Worksheet.
    # `title` has to be carefully constructed: some characters like colon ":"
    # will not work with gspread in Worksheet.append_rows().
    current_datetime = datetime.datetime.now()
    title = f"Results {current_datetime:%Y_%m_%d} ({current_datetime:%s})"

    # append_rows() will resize the worksheet as needed, so `rows` and `cols`
    # can be set to 1 to create a worksheet with only a single cell.
    worksheet = sheet.add_worksheet(title=title, rows=1, cols=1)
    worksheet.append_rows(values=rows)

    if self._ipython_env is not None:
      self._ipython_env.display_html(
          "Results written to new worksheet {}".format(
              html_utils.get_anchor_tag(
                  url=sheets_id.SheetsURL(worksheet.url),
                  text="{} in {}".format(worksheet.title, sheet.title),
              )
          )
      )
    else:
      print(
          "Results written to new worksheet {} in {}".format(
              worksheet.title, sheet.title
          )
      )


class NullGSpreadClient(GSpreadClient):
  """Null-object implementation of GSpreadClient.

  This class raises an error if any of its methods are called. It is used when
  the gspread library is not available.
  """

  def validate(self, sid: sheets_id.SheetsIdentifier) -> None:
    raise _get_import_error()

  def get_all_records(
      self,
      sid: sheets_id.SheetsIdentifier,
      worksheet_id: int,
  ) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]:
    raise _get_import_error()

  def write_records(
      self,
      sid: sheets_id.SheetsIdentifier,
      rows: Sequence[Sequence[Any]],
  ) -> None:
    raise _get_import_error()


# Global instance of gspread client.
_gspread_client: GSpreadClient | None = None


def authorize(
    creds: credentials.Credentials, env: ipython_env.IPythonEnv | None
) -> None:
  """Sets up credential for gspreads."""
  global _gspread_client
  if gspread is not None:
    client = gspread.authorize(creds)  # type: ignore
    _gspread_client = GSpreadClientImpl(client=client, env=env)
  else:
    _gspread_client = NullGSpreadClient()


def get_client() -> GSpreadClient:
  if not _gspread_client:
    raise RuntimeError("Must call authorize() first")
  return _gspread_client


def testonly_set_client(client: GSpreadClient) -> None:
  """Overrides the global client for testing."""
  global _gspread_client
  _gspread_client = client
