// Maria union expression class -*- c++ -*-

#include "snprintf.h"

#ifdef __GNUC__
# pragma implementation
#endif // __GNUC__
#include "UnionExpression.h"
#include "UnionType.h"
#include "UnionValue.h"
#include "Printer.h"

/** @file UnionExpression.C
 * Union value constructor
 */

/* Copyright  1999-2002 Marko Mkel (msmakela@tcs.hut.fi).

   This file is part of MARIA, a reachability analyzer and model checker
   for high-level Petri nets.

   MARIA is free software; you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2, or (at your option)
   any later version.

   MARIA is distributed in the hope that it will be useful, but
   WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   General Public License for more details.

   The GNU General Public License is often shipped with GNU software, and
   is generally kept in a file called COPYING or LICENSE.  If you do not
   have a copy of the license, write to the Free Software Foundation,
   59 Temple Place, Suite 330, Boston, MA 02111 USA. */

UnionExpression::UnionExpression (const class Type& type,
				  class Expression& expr,
				  card_t i) :
  Expression (),
  myUnionType (&type), myExpr (&expr), myIndex (i)
{
  assert (type.getKind () == Type::tUnion);
  assert (myExpr->getType ()->isAssignable
	  (static_cast<const class UnionType&>(type)[i]));
  assert (myExpr->isBasic ());
  Expression::setType (type);
}

UnionExpression::~UnionExpression ()
{
  myExpr->destroy ();
}

class Value*
UnionExpression::do_eval (const class Valuation& valuation) const
{
  class Value* v = myExpr->eval (valuation);
  if (!v)
    return NULL;

  return new class UnionValue (*myUnionType, myIndex, *v);
}

class Expression*
UnionExpression::ground (const class Valuation& valuation,
			 class Transition* transition,
			 bool declare)
{
  class Expression* e = myExpr->ground (valuation, transition, declare);
  if (!e)
    return NULL;
  assert (valuation.isOK ());
  if (e == myExpr) {
    e->destroy ();
    return copy ();
  }
  else
    return static_cast<class Expression*>
      (new class UnionExpression
       (*myUnionType, *e, myIndex))->ground (valuation);
}

class Expression*
UnionExpression::substitute (class Substitution& substitution)
{
  class Expression* e = myExpr->substitute (substitution);
  if (e == myExpr) {
    e->destroy ();
    return copy ();
  }
  else
    return (new class UnionExpression (*myUnionType, *e, myIndex))->cse ();
}

bool
UnionExpression::depends (const class VariableSet& vars,
			  bool complement) const
{
  return myExpr->depends (vars, complement);
}

bool
UnionExpression::forVariables (bool (*operation)
			       (const class Expression&,void*),
			       void* data) const
{
  return myExpr->forVariables (operation, data);
}

bool
UnionExpression::isCompatible (const class Value& value,
			       const class Valuation& valuation) const
{
  assert (value.getKind () == Value::vUnion &&
	  value.getType ().getKind () == Type::tUnion);

  const class UnionValue& uv = static_cast<const class UnionValue&>(value);

  return uv.getIndex () == myIndex &&
    myExpr->isCompatible (uv.getValue (), valuation);
}

void
UnionExpression::getLvalues (const class Value& value,
			     class Valuation& valuation,
			     const class VariableSet& vars) const
{
  assert (value.getKind () == Value::vUnion);
  assert (&value.getType () == getType ());
  const class UnionValue& uv = static_cast<const class UnionValue&>(value);
  if (uv.getIndex () == myIndex)
    myExpr->getLvalues (uv.getValue (), valuation, vars);
}

void
UnionExpression::getLvalues (const class VariableSet& rvalues,
			     class VariableSet*& lvalues) const
{
  myExpr->getLvalues (rvalues, lvalues);
}

#ifdef EXPR_COMPILE
# include "CExpression.h"
# include <stdio.h>

void
UnionExpression::compileLvalue (class CExpression& cexpr,
				unsigned indent,
				const class VariableSet& vars,
				const char* lvalue) const
{
  assert (!!lvalue);
  const size_t len = strlen (lvalue);
  char* const newlvalue = new char[len + 25];
  char* const suffix = newlvalue + len;
  memcpy (newlvalue, lvalue, len);

  snprintf (suffix, 25, ".u.u%u", myIndex);
  myExpr->compileLvalue (cexpr, indent, vars, newlvalue);
  delete[] newlvalue;
}

void
UnionExpression::compileCompatible (class CExpression& cexpr,
				    unsigned indent,
				    const class VariableSet& vars,
				    const char* value) const
{
  const size_t len = strlen (value);
  char* const val = new char[len + 25];
  char* const suffix = val + len;
  memcpy (val, value, len);

  class StringBuffer& out = cexpr.getOut ();
  snprintf (suffix, 25, ".u.u%u", myIndex);
  out.indent (indent), out.append ("if (");
  out.append (value);
  out.append (".t!="), out.append (suffix + 4);
  out.append (")\n");
  cexpr.compileError (indent + 2, errComp);
  myExpr->compileCompatible (cexpr, indent, vars, val);
  delete[] val;
}

void
UnionExpression::compile (class CExpression& cexpr,
			  unsigned indent,
			  const char* lvalue,
			  const class VariableSet* vars) const
{
  size_t len = strlen (lvalue);
  char* const newlvalue = new char[len + 25];
  char* const suffix = newlvalue + len;
  memcpy (newlvalue, lvalue, len);

  class StringBuffer& out = cexpr.getOut ();
  snprintf (suffix, 25, ".u.u%u", myIndex);
  out.indent (indent);
  out.append (lvalue);
  out.append (".t=");
  out.append (suffix + 4);
  out.append (";\n");
  myExpr->compile (cexpr, indent, newlvalue, vars);
  delete[] newlvalue;
  compileConstraint (cexpr, indent, lvalue);
}

#endif // EXPR_COMPILE

void
UnionExpression::display (const class Printer& printer) const
{
  printer.print (static_cast<const class UnionType*>(myUnionType)
		 ->getComponentName (myIndex));
  printer.delimiter ('=');
  myExpr->display (printer);
}
