#!/usr/bin/env ruby

# Copyright (c) 2015 Alexandra Figlovskaya <fglval@gmail.com>
# Copyright (c) 2015-2019 Aleksey Cheusov <vle@gmx.net>
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

require 'optparse'
require 'set'

@options = {}
@err = nil
@unspecified_class="__strelka_i_raketa__"

def print_pretty(class_name, p, p_comment, r, r_comment, f1, f1_comment)
  puts "%13s P, R, F1:  %-6.4g %-13s,  %-6.4g %-13s,  %-6.4g" \
    % [class_name, p, p_comment, r, r_comment, f1, f1_comment]
end

def print_accuracy_pretty(a, a_comment)
  puts "Accuracy              :  %-6.4g %-13s" % [a, a_comment]
end

def print_raw(class_name, p, p_comment, r, r_comment, f1, f1_comment)
  puts "#{class_name}\tP\t#{p}\t#{p_comment.strip}"
  puts "#{class_name}\tR\t#{r}\t#{r_comment.strip}"
  puts "#{class_name}\tF1\t#{f1}\t#{f1_comment.strip}"
end

def print_accuracy_raw(a, a_comment)
  puts "\tA\t#{a}\t#{a_comment.strip}"
end

def print_stat(class_name, p, p_comment, r, r_comment, f1, f1_comment)
  if @options[:raw]
    print_raw(class_name, p, p_comment, r, r_comment, f1, f1_comment)
  else
    print_pretty(class_name, p, p_comment, r, r_comment, f1, f1_comment)
  end
end

def print_accuracy(a, a_comment)
  if @options[:raw]
    print_accuracy_raw(a, a_comment)
  else
    print_accuracy_pretty(a, a_comment)
  end
end

def pretty_div(a, b)
  "%5s/%-5s" % [a, b]
end

def normalize_tag(tag)
  tag = tag.to_s.sub(/^[+]/, "") # +1 => 1
  if tag =~ /^-?[0-9]+[.][0-9]+$/
    tag = tag.sub(/[.]0+$/, "") # -1.0000 => -1
  end
  return tag
end

def split_into_3(line, fn)
  line = line.gsub(/\s+/, " ").strip()

  ret = ["", "", Float::MAX]

  tokens = line.split(/ /)
  case tokens.size
  when 2
    ret = [normalize_tag(tokens[0]), normalize_tag(tokens[1]), Float::MAX]
  when 3
    ret = [normalize_tag(tokens[0]), normalize_tag(tokens[1]), tokens[2].to_f]
  else
    ret = [normalize_tag(tokens[0]), normalize_tag(tokens[1]), Float::MAX]
    line.sub!(/^fake ?/, "")
    STDERR.puts("Bad line '#{line}' in file '#{fn}'")
    @err = 1
  end

  if ret [2] < @options[:treshold]
    ret [1] = @unspecified_class
  end

  return ret
end

def mse(outcomes, predictions)
  sum = 0.0
  outcomes.each_index do |idx|
    sum += (outcomes[idx] - predictions[idx]) ** 2
  end
  return sum / outcomes.length.to_f
end

def rmse(outcomes, predictions)
  return Math.sqrt(mse(outcomes, predictions))
end

def mae(outcomes, predictions)
  sum = 0.0
  outcomes.each_index do |idx|
    sum += (outcomes[idx] - predictions[idx]).abs
  end
  return sum / outcomes.length.to_f
end

@options[:excluded] = Set.new()

OptionParser.new do |opts|
  opts.banner = <<EOF
heri-stat calculates precision (P), recall (R), F1, accuracy (A),
  mean absolute error (MAE), mean squared error (MSE) or
  root mean squared error (RMSE) for given outcomes and predictions.
  Unless -g is specified, predictions of classification is expected on input.
Usage:
    heri-stat -h
    heri-stat    [-g mode] [-R] [OPTIONS] <outcomes> <predictions>
    heri-stat -1 [-g mode] [-R] [OPTIONS] [files...]
OPTIONS:
EOF

  opts.on('-h', '--help','display this message and exit') do
    puts opts
    exit 0
  end

  @options[:raw] = false
  opts.on('-R', '--raw','raw tab-separated output') do
    @options[:raw] = true
  end

  @options[:micro_avg] = false
  opts.on('-m', '--micro-avg','disable micro averaged P/R/F1 output') do
    @options[:micro_avg] = true
  end

  @options[:macro_avg] = false
  opts.on('-r', '--macro-avg','disable macro averaged P/R/F1 output') do
    @options[:macro_avg] = true
  end

  @options[:statistics] = false
  opts.on('-c', '--per-class','disable output of per-class statistics') do
    @options[:statistics] = true
  end

  @options[:accuracy] = false
  opts.on('-a', '--accuracy','disable output of accuracy') do
    @options[:accuracy] = true
  end

  @options[:single] = false
  opts.on('-1', '--single','obtain both outcomes and
predicted classes from single source. If this option is specified,
the first token on input represents the outcome class
and second one -- predicted class') do
    @options[:single] = true
  end

  @options[:unclassified] = false
  opts.on("-u", "--unclassified=UNCLASSIFIED", 'set the label for "unclassified" object') do |u|
    @options[:unclassified] = u.to_s
  end

  @options[:treshold] = -Float::MAX
  opts.on("-t", "--treshold=TRESHOLD", 'Minimal treshold for score') do |u|
    @options[:treshold] = u.to_f
    @options[:unclassified] = @unspecified_class
  end

  @options[:excluded_labels] = false
  opts.on("-x", "--exclude=LABELS", 'exclude the specified labels from statistics') do |x|
    @options[:excluded] = Set.new(x.split(","))
  end

  opts.on("-g", "--regression=mode", 'Calculate MAE, MSE or RMSE') do |g|
    @options[:regression] = true

    @options[:mse]  = g.include?("s")
    @options[:rmse] = g.include?("r")
    @options[:mae]  = g.include?("a")

    if /[^sra]/ =~ g
      STDERR.puts "Invalid mode '#{g}' for option -g"
      exit 1
    end
  end

  opts.separator " "
end.parse!

if @options[:unclassified]
  @options[:accuracy]=true
else
  @options[:micro_avg]=true
end

def get_up_to_two_tokens(s)
  s = s.strip
  s = s.sub(/^([^\s]+\s([^\s]+))\s.*$/, "\\1")
  return s
end

def get_up_to_three_tokens(s)
  s = s.strip
  s = s.sub(/^([^\s]+\s([^\s]+)\s([^\s]+))\s.*$/, "\\1")
  return s
end

if @options[:single]
  outcome_tags = []
  prediction_tags = []
  while line = gets do
    gt, rt, fake = split_into_3(get_up_to_three_tokens(line), "")
    outcome_tags << gt
    prediction_tags << rt
  end
else
  outcome_tags = IO.read(ARGV[0]).split("\n").map! do |x| get_up_to_two_tokens(x) end
  prediction_tags = IO.read(ARGV[1]).split("\n").map! do |x| get_up_to_two_tokens(x) end
  if outcome_tags.length != prediction_tags.length
    STDERR.puts("Dataset and prediction files should contain the same amount of lines");
    exit 1
  end

  outcome_tags.each_index do |i|
    fake1, outcome_tags[i], fake = split_into_3("fake " + outcome_tags[i], ARGV[0])
    fake1, prediction_tags[i], fake = split_into_3("fake " + prediction_tags[i], ARGV[1])
  end
end

exit 1 if @err

if @options[:regression]
  outcome_tags.map! {|x| x.to_f}
  prediction_tags.map! {|x| x.to_f}

  mae_value = mae(outcome_tags, prediction_tags)
  mse_value = mse(outcome_tags, prediction_tags)
  rmse_value = rmse(outcome_tags, prediction_tags)

  if @options[:mse]
    if @options[:raw]
      puts "\tMSE\t#{mse_value}"
    else
      puts "MSE: #{mse_value}"
    end
  end
  if @options[:rmse]
    if @options[:raw]
      puts "\tRMSE\t#{rmse_value}"
    else
      puts "RMSE: #{rmse_value}"
    end
  end
  if @options[:mae]
    if @options[:raw]
      puts "\tMAE\t#{mae_value}"
    else
      puts "MAE: #{mae_value}"
    end
  end

  exit 0
end

tag2outcome_cnt = Hash.new(0)
tag2prediction_cnt = Hash.new(0)
tag2TP_cnt = Hash.new(0)
all_precision = 0
all_recall = 0

outcome_tags.each_index do |i|
  gt = outcome_tags[i]
  rt = prediction_tags[i]

  # make sure hash cell exists
  tag2TP_cnt[gt] += 0
  tag2prediction_cnt[gt] += 0
  tag2outcome_cnt[gt] += 1 if gt != @options[:unclassified]

  if @options[:excluded].include?(rt)
    next
  end

  if rt != @options[:unclassified]
    tag2prediction_cnt[rt] += 1
    tag2TP_cnt[rt] += (gt == rt ? 1 : 0)
  end
end

@options[:excluded].each do |excluded|
  tag2TP_cnt.delete(excluded)
  tag2prediction_cnt.delete(excluded)
end

all_tp = 0
all_f1 = 0
res_tag2TP_cnt = tag2TP_cnt.sort_by { |key, value| key }
res_tag2TP_cnt.each do |t, tp|
  if @options[:excluded].include?(t)
    next
  end

  p = (tag2prediction_cnt[t] > 0.0  ?  tp.to_f / tag2prediction_cnt[t]  :  0.0)
  r = (tag2outcome_cnt[t] > 0.0  ?  tp.to_f / tag2outcome_cnt[t]  :  0.0)
  f1 = (p+r > 0.0  ?  2*p*r / (p+r)  :  0.0)
  if !@options[:statistics]
    print_stat("Class  %-6s" % [t],
                 p, pretty_div(tp, tag2prediction_cnt[t]),
                 r, pretty_div(tp, tag2outcome_cnt[t]),
                 f1, "")
  end
  all_precision += p
  all_recall += r
  all_tp += tp
  all_f1 += f1
end

all_rt = 0
tag2prediction_cnt.each do |tag, rt|
  all_rt += rt
end

all_gt = 0
tag2outcome_cnt.each do |tag, gt|
  all_gt += gt
end

included_labels_count = res_tag2TP_cnt.size()
if included_labels_count == 1
  @options[:accuracy] = true
  @options[:micro_avg] = true
  @options[:macro_avg] = true
end

if !@options[:accuracy]
  accuracy = all_tp.to_f / all_rt.to_f
  print_accuracy(accuracy, pretty_div(all_tp, all_rt))
end

if !@options[:micro_avg]
  micro_avg_precision = all_tp.to_f / all_rt.to_f
  micro_avg_recall = all_tp.to_f / all_gt.to_f
  micro_avg_f1 = 2*micro_avg_precision*micro_avg_recall / (micro_avg_precision+micro_avg_recall)
  print_stat("Micro average",
               micro_avg_precision, pretty_div(all_tp, all_rt),
               micro_avg_recall, pretty_div(all_tp, all_gt),
               micro_avg_f1, "")
end

if !@options[:macro_avg] && tag2TP_cnt.size > 0
  macro_avg_precision = all_precision / tag2TP_cnt.size
  macro_avg_recall = all_recall / tag2TP_cnt.size
  macro_avg_f1 = all_f1 / tag2TP_cnt.size
  print_stat("Macro average",
               macro_avg_precision, "",
               macro_avg_recall, "",
               macro_avg_f1, "")
end
