// Maria vector expression class -*- c++ -*-

#include "snprintf.h"

#ifdef __GNUC__
# pragma implementation
#endif // __GNUC__
#include "VectorExpression.h"
#include "VectorValue.h"
#include "Printer.h"
#include <string.h>

/** @file VectorExpression.C
 * Vector constructor
 */

/* Copyright  1999-2002 Marko Mkel (msmakela@tcs.hut.fi).

   This file is part of MARIA, a reachability analyzer and model checker
   for high-level Petri nets.

   MARIA is free software; you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2, or (at your option)
   any later version.

   MARIA is distributed in the hope that it will be useful, but
   WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   General Public License for more details.

   The GNU General Public License is often shipped with GNU software, and
   is generally kept in a file called COPYING or LICENSE.  If you do not
   have a copy of the license, write to the Free Software Foundation,
   59 Temple Place, Suite 330, Boston, MA 02111 USA. */

VectorExpression::VectorExpression (const class VectorType& type) :
  Expression (),
  myComponents (NULL)
{
  assert (type.getSize () > 0);
  myComponents = new class Expression*[type.getSize ()];
  memset (myComponents, 0, type.getSize () * sizeof *myComponents);
  setType (type);
}

VectorExpression::~VectorExpression ()
{
  for (card_t i = static_cast<const class VectorType*>(getType ())->getSize ();
       i--; )
    myComponents[i]->destroy ();
  delete[] myComponents;
}

void
VectorExpression::setType (const class Type& type)
{
  assert (type.getKind () == Type::tVector);
  const class Type& itemType =
    static_cast<const class VectorType&>(type).getItemType ();

  for (card_t i = static_cast<const class VectorType&>(type).getSize ();
       i--; )
    if (myComponents[i])
      myComponents[i]->setType (itemType);

  Expression::setType (type);
}

class Value*
VectorExpression::do_eval (const class Valuation& valuation) const
{
  class VectorValue* vector = new class VectorValue (*getType ());

  for (card_t i = static_cast<const class VectorType*>(getType ())->getSize ();
       i--; ) {
    assert (!!myComponents[i]);
    if (class Value* v = myComponents[i]->eval (valuation)) {
      assert (&v->getType () ==
	      &static_cast<const class VectorType*>
	      (getType ())->getItemType ());
      (*vector)[i] = v;
    }
    else {
      delete vector;
      return NULL;
    }
  }

  return constrain (valuation, vector);
}

class Expression*
VectorExpression::ground (const class Valuation& valuation,
			  class Transition* transition,
			  bool declare)
{
  bool same = true;
  const class VectorType* const type =
    static_cast<const class VectorType*>(getType ());
  class VectorExpression* expr = new class VectorExpression (*type);

  for (card_t i = static_cast<const class VectorType*>(getType ())->getSize ();
       i--; ) {
    assert (!!myComponents[i]);
    class Expression* e =
      myComponents[i]->ground (valuation, transition, declare);
    if (!e) {
      expr->destroy ();
      return NULL;
    }

    assert (valuation.isOK ());

    if (e != myComponents[i])
      same = false;
    expr->myComponents[i] = e;
  }

  if (same) {
    expr->destroy ();
    return copy ();
  }

  return static_cast<class Expression*>(expr)->ground (valuation);
}

class Expression*
VectorExpression::substitute (class Substitution& substitution)
{
  bool same = true;
  const class VectorType* const type =
    static_cast<const class VectorType*>(getType ());
  class VectorExpression* expr = new class VectorExpression (*type);

  for (card_t i = static_cast<const class VectorType*>(getType ())->getSize ();
       i--; ) {
    assert (!!myComponents[i]);
    class Expression* e = myComponents[i]->substitute (substitution);
    if (e != myComponents[i])
      same = false;
    expr->myComponents[i] = e;
  }
  if (same) {
    expr->destroy ();
    return copy ();
  }
  else
    return expr->cse ();
}

bool
VectorExpression::depends (const class VariableSet& vars,
			   bool complement) const
{
  for (card_t i = static_cast<const class VectorType*>(getType ())->getSize ();
       i--; ) {
    assert (!!myComponents[i]);
    if (myComponents[i]->depends (vars, complement))
      return true;
  }
  return false;
}

bool
VectorExpression::forExpressions (bool (*operation)
				  (const class Expression&,void*),
				  void* data) const
{
  if (!(*operation) (*this, data))
    return false;
  for (card_t i = static_cast<const class VectorType*>(getType ())->getSize ();
       i--; ) {
    assert (!!myComponents[i]);
    if (!myComponents[i]->forExpressions (operation, data))
      return false;
  }
  return true;
}

bool
VectorExpression::isCompatible (const class Value& value,
				const class Valuation& valuation) const
{
  assert (value.getKind () == Value::vVector &&
	  value.getType ().getKind () == Type::tVector);

  const class VectorValue& vv = static_cast<const class VectorValue&>(value);
  const class VectorType* vt =
    static_cast<const class VectorType*>(getType ());

  assert (vv.getSize () ==
	  static_cast<const class VectorType&>(vv.getType ()).getSize ());

  if (!vv.getType ().isAssignable (*getType ()))
    return false;

  for (card_t i = vt->getSize (); i--; )
    if (!myComponents[i]->isCompatible (vv[i], valuation))
      return false;

  return true;
}

void
VectorExpression::getLvalues (const class Value& value,
			      class Valuation& valuation,
			      const class VariableSet& vars) const
{
  assert (value.getKind () == Value::vVector);
  assert (&value.getType () == getType ());
  const class VectorValue& vv = static_cast<const class VectorValue&>(value);

  for (card_t i = static_cast<const class VectorType*>(getType ())->getSize ();
       i--; )
    myComponents[i]->getLvalues (vv[i], valuation, vars);
}

void
VectorExpression::getLvalues (const class VariableSet& rvalues,
			      class VariableSet*& lvalues) const
{
  for (card_t i = static_cast<const class VectorType*>(getType ())->getSize ();
       i--; )
    myComponents[i]->getLvalues (rvalues, lvalues);
}

#ifdef EXPR_COMPILE
# include "CExpression.h"
# include <stdio.h>

void
VectorExpression::compileLvalue (class CExpression& cexpr,
				 unsigned indent,
				 const class VariableSet& vars,
				 const char* lvalue) const
{
  assert (!!lvalue);
  const size_t len = strlen (lvalue);
  char* const newlvalue = new char[len + 25];
  char* const suffix = newlvalue + len;
  memcpy (newlvalue, lvalue, len);

  const card_t size =
    static_cast<const class VectorType*>(getType ())->getSize ();
  for (card_t i = 0; i < size; i++) {
    snprintf (suffix, 25, ".a[%u]", i);
    myComponents[i]->compileLvalue (cexpr, indent, vars, newlvalue);
  }
  delete[] newlvalue;
}

void
VectorExpression::compileCompatible (class CExpression& cexpr,
				     unsigned indent,
				     const class VariableSet& vars,
				     const char* value) const
{
  const size_t len = strlen (value);
  char* const val = new char[len + 25];
  char* const suffix = val + len;
  memcpy (val, value, len);

  const card_t size =
    static_cast<const class VectorType*>(getType ())->getSize ();
  for (card_t i = 0; i < size; i++) {
    snprintf (suffix, 25, ".a[%u]", i);
    myComponents[i]->compileCompatible (cexpr, indent, vars, val);
  }
  delete[] val;
}

void
VectorExpression::compile (class CExpression& cexpr,
			   unsigned indent,
			   const char* lvalue,
			   const class VariableSet* vars) const
{
  size_t len = strlen (lvalue);
  char* const newlvalue = new char[len + 25];
  char* const suffix = newlvalue + len;
  memcpy (newlvalue, lvalue, len);

  for (card_t i = 0;
       i < static_cast<const class VectorType*>(getType ())->getSize ();
       i++) {
    snprintf (suffix, 25, ".a[%u]", i);
    myComponents[i]->compile (cexpr, indent, newlvalue, vars);
  }
  delete[] newlvalue;
  compileConstraint (cexpr, indent, lvalue);
}

#endif // EXPR_COMPILE

void
VectorExpression::display (const class Printer& printer) const
{
  const class VectorType* type =
    static_cast<const class VectorType*>(getType ());
  printer.delimiter ('{')++;
  for (card_t i = 0;;) {
    assert (!!myComponents[i]);
    myComponents[i]->display (printer);
    if (++i == type->getSize ())
      break;
    printer.delimiter (',');
  }
  --printer.delimiter ('}');
}
