#!/usr/bin/env vpython
# Copyright 2018 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import argparse
import json
import multiprocessing
import sys

from core import benchmark_utils
from core import bot_platforms
from core import retrieve_story_timing
from core import sharding_map_generator


def GetParser():
  parser = argparse.ArgumentParser(
      description='Generate perf test sharding map.')
  subparsers = parser.add_subparsers()

  parser_update = subparsers.add_parser('update')
  parser_update.add_argument(
      '--regenerate-timing-data', '-r', action='store_true',
      help=('Whether to regenerate timing data for all builders in '
            'chromium.perf'), default=False)
  parser_update.add_argument(
      '--builders', '-b', action='store', nargs='*',
      help=('The builder names to reshard. If not specified, use all '
            'perf builders'),
      choices=bot_platforms.ALL_PLATFORM_NAMES,
      default=bot_platforms.ALL_PLATFORM_NAMES)
  parser.add_argument(
      '--debug', action='store_true',
      help=('Whether to include detailed debug info of the sharding map in the'
            'shard maps.'), default=False)

  parser_update.set_defaults(func=_UpdateShardsForBuilders)

  parser_create = subparsers.add_parser('create')
  parser_create.add_argument(
      '--benchmark', help='The benchmark that you want to create shard for',
      required=True)
  parser_create.add_argument(
      '--timing-data-source', '-t', choices=bot_platforms.ALL_PLATFORM_NAMES,
      help='The timing data that you want to use. If not set, it will assume '
           'all stories use the same amount of time to run')
  parser_create.add_argument(
      # pinpoint typically has 16 machines for each hardware types, so we set
      # the default to use half of them to avoid starving the pool.
      '--shards-num', type=int, default=8,
      help="The number of shards you'd like to use, default is %(default)s")
  parser_create.add_argument(
      '--output-path', default='new_shard_map.json',
      help='Output file path for the shard map, default is `%(default)s`')
  parser_create.set_defaults(func=_CreateShardMapForBenchmark)
  return parser


def _GenerateBenchmarksToShardsList(benchmarks):
  """Return |benchmarks_to_shard| from given list of |benchmarks|.

    benchmarks_to_shard is a list all benchmarks to be sharded. Its
    structure is as follows:
    [{
       "name": "benchmark_1",
       "stories": [ "storyA", "storyB",...],
       "repeat": <number of pageset_repeat>
      },
      {
       "name": "benchmark_2",
       "stories": [ "storyA", "storyB",...],
       "repeat": <number of pageset_repeat>
      },
       ...
    ]

    The "stories" field contains a list of ordered story names. Notes that
    this should match the actual order of how the benchmark stories are
    executed for the sharding algorithm to be effective.
  """
  benchmarks_to_shard = []
  for b in benchmarks:
    benchmarks_to_shard.append({
        'name': b.Name(),
        'repeat': b().options.get('pageset_repeat', 1),
        'stories': benchmark_utils.GetBenchmarkStoryNames(b())
    })
  return benchmarks_to_shard


def _LoadTimingData(args):
  builder_name, timing_file_path = args
  data = retrieve_story_timing.FetchAverageStortyTimingData(
      configurations=[builder_name], num_last_days=5)
  with open(timing_file_path, 'w') as output_file:
    json.dump(data, output_file, indent=4, separators=(',', ': '))
  print 'Finish retrieve story timing data for %s' % repr(builder_name)


def _GenerateShardMap(
    builder, num_of_shards, output_path, debug, benchmark):
  timing_data = []
  if builder:
    with open(builder.timing_file_path) as f:
      timing_data = json.load(f)
  benchmarks_to_shard = _GenerateBenchmarksToShardsList(
      [b for b in builder.benchmarks_to_run if not benchmark or (
          b.Name() == benchmark)])
  sharding_map = sharding_map_generator.generate_sharding_map(
      benchmarks_to_shard, timing_data, num_shards=num_of_shards,
      debug=debug)
  with open(output_path, 'w') as output_file:
    json.dump(sharding_map, output_file, indent=4, separators=(',', ': '))


def _UpdateShardsForBuilders(args):
  builders = {b for b in bot_platforms.ALL_PLATFORMS if b.name in args.builders}
  if args.regenerate_timing_data:
    print 'Update shards timing data. May take a while...'
    load_timing_args = []
    for b in builders:
      load_timing_args.append((b.name, b.timing_file_path))
    p = multiprocessing.Pool(len(load_timing_args))
    p.map(_LoadTimingData, load_timing_args)

  for b in builders:
    _GenerateShardMap(
        b, b.num_shards, b.shards_map_file_path, args.debug, benchmark=None)
    print 'Updated sharding map for %s' % repr(b.name)


def _CreateShardMapForBenchmark(args):
  """Create the shard map for the given benchmark.

  Args:
    args(Namespace object): the namespace object for the subparser `create`. It
      will contain the attributes:
        `benchmark`: the name of the benchmark that we want the shard for
        `num_shards`: the total number of shards that we want to use
        `output_path`: the output file path for the shard map
        `builder`: the builder name, unlike the above, this is a string instead
          of a list of string like above
  """
  builder = None
  if args.timing_data_source:
    [builder] = [b for b in bot_platforms.ALL_PLATFORMS
                 if b.name == args.timing_data_source]
  _GenerateShardMap(
      builder, args.shards_num, args.output_path, args.debug, args.benchmark)


def main():
  parser = GetParser()
  options = parser.parse_args()
  options.func(options)

if __name__ == '__main__':
  sys.exit(main())
