//
// Copyright 2004 Andras Varga
//
// This library is free software, you can redistribute it and/or modify
// it under  the terms of the GNU Lesser General Public License
// as published by the Free Software Foundation;
// either version 2 of the License, or any later version.
// The library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU Lesser General Public License for more details.
//

//#include <stdlib.h>

#include <math.h>
#include <fstream>

#include "TrafGen.h"
#include "TCPSocket.h"
#include "TCPCommand_m.h"
#include "UDPSocket.h"
#include "IPControlInfo.h"
#include "AnsaUDPControlInfo_m.h"
#include "TrafGenPacket_m.h"
#include "IPv4InterfaceData.h"
#include "InterfaceTableAccess.h"
#include "IInterfaceTable.h"



Define_Module(TrafGen);

void TrafGen::initialize(int stage)
{
  if (stage == 4)
  {      
    numSent = 0;
    numReceived = 0;
    counter = 0;
    
    
    
    const char *fileName = par("flowDefFile");
          
    if (fileName == NULL || (!strcmp(fileName, "")) || !LoadFlowsFromXML(fileName))
        error("Error reading TrafGen flows from file %s", fileName);
        
    detectRoles();
    initTimers();
    bindSockets();
    initStats();
    
    
    for(TG::Flows::iterator it =  flows.begin(); it != flows.end(); ++it)
    {
      std::cout << "getId " << it->getId() << std::endl;
      std::cout << "getStartTime " << it->getStartTime() << std::endl;
      std::cout << "getDuration " << it->getDuration() << std::endl;
      std::cout << "getSrcIP " << it->getSrcIP() << std::endl;
      std::cout << "getDstIP " << it->getDstIP() << std::endl;
      std::cout << "getTos " << (int) it->getTos() << std::endl;
      std::cout << "getTtl " << it->getTtl() << std::endl;
      std::cout << "getProtocol " << it->getProtocol() << std::endl;
      std::cout << "getSrcPort " << it->getSrcPort() << std::endl;
      std::cout << "getDstPort " << it->getDstPort() << std::endl;
      if(it->isAnalyzing())
        std::cout << "anal " << std::endl;
      if(it->isGenerating())
        std::cout << "gen " <<  std::endl;
    }
    
    WATCH(numSent);
    WATCH(numReceived);
    WATCH_VECTOR(sentStatistics);
    WATCH_VECTOR(receivedStatistics);
    
  }
  else if( stage == 5)
    installRcvStatIts(); 
}

bool TrafGen::LoadFlowsFromXML(const char * filename)
{
  cXMLElement* trafgenConfig = ev.getXMLDocument(filename);
  if (trafgenConfig == NULL)
    error("Cannot read TrafGen flows configuration from file: %s", filename);

  std::string nodeName = trafgenConfig->getTagName();

  if(nodeName == "trafgen")
  {  
    cXMLElementList flowsConfigList= trafgenConfig->getChildren();
    for (cXMLElementList::iterator flowIt = flowsConfigList.begin(); flowIt != flowsConfigList.end(); flowIt++)
    {
      std::string elName = (*flowIt)->getTagName();
      if (elName == "flow" && (*flowIt)->getAttribute("id"))
        parseFlow(*(*flowIt));
      else
        error("Wrong flow definition");
    }

  }
  else
    return false;
    
  if(flows.size() == 0)
    return false;
        
  return true;
}

void TrafGen::parseFlow(const cXMLElement& flowConfig)
{
  TG::FlowRecord flw;
  cXMLElement* element;
  
  flw.setId(flowConfig.getAttribute("id"));
  
  if(flw.getId() == "" || getFlowById(flw.getId()) != flows.end())
  {
    error("Missing or duplicate flow ID");
    return;
  }
  
  element = flowConfig.getFirstChildWithTag("duration");
  if(element != NULL)
    flw.setDuration(atof(element->getNodeValue()));
    
  element = flowConfig.getFirstChildWithTag("start_time");
  if(element != NULL)
    flw.setStartTime(atof(element->getNodeValue()));
    
  cXMLElement* ipHeader = flowConfig.getFirstChildWithTag("ip_header");
  if(ipHeader != NULL)
  {
    element = ipHeader->getFirstChildWithTag("source_ip");
    if(element != NULL)
      flw.setSrcIP(IPAddress(element->getNodeValue()));
    
    element = ipHeader->getFirstChildWithTag("destination_ip");
    if(element != NULL)
      flw.setDstIP(IPAddress(element->getNodeValue()));
    
    element = ipHeader->getFirstChildWithTag("tos");
    if(element != NULL)
      flw.setTos(readDscp(element->getNodeValue()));
      
    element = ipHeader->getFirstChildWithTag("ttl");
    if(element != NULL)
      flw.setTtl(atoi(element->getNodeValue()));
      
    element = ipHeader->getFirstChildWithTag("protocol");
    if(element != NULL)
    {
      std::string prtStr = element->getNodeValue();
      if (prtStr == "tcp" || prtStr == "Tcp" || prtStr == "TCP")
        flw.setProtocol(IP_PROT_TCP);
      else
        flw.setProtocol(IP_PROT_UDP);
    }
  }
  
  cXMLElement* transHeader = flowConfig.getFirstChildWithTag("transport_header");
  if(ipHeader != NULL)
  {
    element = transHeader->getFirstChildWithTag("source_port");
    if(element != NULL)
      flw.setSrcPort(atoi(element->getNodeValue()));
    
    element = transHeader->getFirstChildWithTag("destination_port");
    if(element != NULL)
      flw.setDstPort(atoi(element->getNodeValue()));
  }
  
  cXMLElement* appEl = flowConfig.getFirstChildWithTag("application");
  if(appEl != NULL && appEl->getAttribute("type"))
  {
    if(!flw.setApplication(*appEl))
      error("Wrong application configuration for flow with ID: %s", flw.getId().c_str());
  }
  
  if(!isValidFlow(flw))
  {
    error("Invalid definition of flow with ID: %s", flw.getId().c_str());
    return;
  }
    

  flows.push_back(flw);
}

bool TrafGen::isValidFlow(TG::FlowRecord &flw)
{
  if(flw.getSrcIP().get4() == IPAddress::UNSPECIFIED_ADDRESS)
    return false;
  if(flw.getDstIP().get4() == IPAddress::UNSPECIFIED_ADDRESS)
    return false;
  if(flw.getTtl() < 1)
    return false;
  if(flw.getSrcPort() < 1)
    return false;
  if(flw.getDstPort() < 1)
    return false;
    
  return true;    
}

unsigned char TrafGen::readDscp(std::string dscpString)
{

  if(dscpString.size() == 3)
  {
    std::string prf = dscpString.substr(0,2);
    unsigned char cls = dscpString[2] - '0';
    std::cout << "lama " << (int) cls << std::endl; 
    if((prf == "CS" || prf == "Cs" || prf == "cs") && cls > 0 && cls < 8)
      return cls*32; 
  }
  else if (dscpString.size() == 4)
  {
    std::string prf = dscpString.substr(0,2);
    unsigned char cls =  dscpString[2] - '0';
    unsigned char dp = dscpString[3] - '0';
    if((prf == "AF" || prf == "Af" || prf == "af") && cls > 0 && cls < 5 && dp > 0 && dp < 4)
      return cls*32 + dp*8;
  }
  else if(dscpString == "EF" || dscpString == "Ef" || dscpString == "ef")
    return 184;

  return atoi(dscpString.c_str());
}

void TrafGen::detectRoles()
{
  IInterfaceTable *ift = InterfaceTableAccess().get();
   
  for(TG::Flows::iterator it =  flows.begin(); it != flows.end(); ++it)
  {
    for(int i = 0; i < ift->getNumInterfaces(); ++i)
    {
      if(ift->getInterface(i)->ipv4Data()->getIPAddress() == it->getSrcIP().get4())
        it->setGenerating(true);
      if(ift->getInterface(i)->ipv4Data()->getIPAddress() == it->getDstIP().get4())
        it->setAnalyzing(true);
    }
  }

}

void TrafGen::initTimers()
{
  for(TG::Flows::iterator it =  flows.begin(); it != flows.end(); ++it)
  {
    if(it->isGenerating() && it->getDuration() > 0.0)
    {
      cMessage *timer = new cMessage(it->getId().c_str());
      scheduleAt(it->getStartTime(), timer);
    }
  }
}

void TrafGen::initStats()
{
  for(TG::Flows::iterator it =  flows.begin(); it != flows.end(); ++it)
  {
    if(it->isAnalyzing())
    {
      TG::RcvFlowRecord rr;
      rr.setId(it->getId());
      rr.setStartTime(it->getStartTime());
      receivedStatistics.push_back(rr);
    }
    if(it->isGenerating())
    {
      TG::SntFlowRecord sr;
      sr.setId(it->getId());
      sr.setStartTime(it->getStartTime());
      sentStatistics.push_back(sr);
    }
  }
}

void TrafGen::bindSockets()
{
  for(TG::Flows::iterator it =  flows.begin(); it != flows.end(); ++it)
  {
    if(it->isAnalyzing())
    {
      if(it->getProtocol() == IP_PROT_UDP)
      {
        cMessage *msg = new cMessage("UDP_C_BIND", UDP_C_BIND);
        AnsaUDPControlInfo *ctrl = new AnsaUDPControlInfo();
        ctrl->setSockId(UDPSocket::generateSocketId());
        ctrl->setSrcPort(it->getDstPort());
        msg->setControlInfo(ctrl);
        send(msg, "udpOut");
      }
    }
  }
}

void TrafGen::installRcvStatIts()
{
  cTopology topology;
  
  topology.extractByNedTypeName(cStringTokenizer("inet.ansa.TrafficGenerator.TrafGen").asVector());
  
  for(TG::Flows::iterator it =  flows.begin(); it != flows.end(); ++it)
  {
    if(it->isGenerating())
    {
      bool found = false;
      for(int i = 0; i < topology.getNumNodes(); ++i)
      {
        cTopology::Node *node = topology.getNode(i);
        TrafGen *tgpt = check_and_cast<TrafGen*> (node->getModule());
        if(tgpt->getFlowById(it->getId())->isAnalyzing())
        {
          it->setRcvModStatsIt(tgpt->getRcvFlowStatById(it->getId()));
          found = true;
        }
      }
      if(!found)
        error("Missing receiver for flow with ID: %s", it->getId().c_str());
    }
  }
} 


cPacket *TrafGen::createPacket(TG::Flows::iterator flowIt)
{
    char msgName[32];
    sprintf(msgName,"TrafGenPacket - Flow %s", flowIt->getId().c_str());
    
    int totalDataLen = flowIt->getPApplication()->getPacketSize() + flowIt->getPApplication()->anotherEncapsulationOverhead();

    TrafGenPacket *payload = new TrafGenPacket(msgName);
    payload->setFlowId(flowIt->getId().c_str());
    payload->setApplication((flowIt->getAppName()).c_str());
    payload->setSentTime(simTime().dbl());
    payload->setByteLength(totalDataLen);
    return payload;
}

void TrafGen::sendPacket(TG::Flows::iterator flowIt)
{
    cPacket *payload = createPacket(flowIt);
    
    // send message to UDP, with the appropriate control info attached
    payload->setKind(UDP_C_DATA);

    AnsaUDPControlInfo *ctrl = new AnsaUDPControlInfo();
    ctrl->setSrcAddr(flowIt->getSrcIP());
    ctrl->setSrcPort(flowIt->getSrcPort());
    ctrl->setDestAddr(flowIt->getDstIP());
    ctrl->setDestPort(flowIt->getDstPort());
    ctrl->setDiffServCodePoint(flowIt->getTos());
    ctrl->setTimeToLive(flowIt->getTtl());
    payload->setControlInfo(ctrl);


    int appDataLen = payload->getByteLength() - flowIt->getPApplication()->anotherEncapsulationOverhead();

    getSntFlowStatById(flowIt->getId())->addTotalSentPkts();
    getSntFlowStatById(flowIt->getId())->addTotalBytes(appDataLen);
    flowIt->getRcvModStatsIt()->addTotalSentPkts();
    numSent++;
    
    send(payload, "udpOut");    
}

void TrafGen::processPacket(cPacket *msg)
{
  if (msg->getKind() == UDP_I_ERROR)
	{
		delete msg;
		return;
	}

  TrafGenPacket *packet = check_and_cast<TrafGenPacket *>(msg);
  
  TG::Flows::iterator flowIt = getFlowById(packet->getFlowId());
  if(flowIt != flows.end())
  {
    TG::RcvStats::iterator flowStatIt = getRcvFlowStatById(flowIt->getId());
    if(flowStatIt != receivedStatistics.end())
    { 
      double actTime = simTime().dbl();
      double delay = actTime - packet->getSentTime();
      double jitter;

      flowStatIt->setEndTime(actTime);
      
      if(flowStatIt->getTotalRcvPkts() == 0)
      {
        flowStatIt->setStartTime(actTime);
        flowStatIt->setMinDelay(delay);
        flowStatIt->setMaxDelay(delay);
        jitter = 0.0;
      }
      else
        jitter = fabs(delay - (flowStatIt->getTotalDelay()/flowStatIt->getTotalRcvPkts()));
        
      if(flowStatIt->getMinDelay() > delay)
        flowStatIt->setMinDelay(delay);
      if(flowStatIt->getMaxDelay() < delay)
        flowStatIt->setMaxDelay(delay);
        
      flowStatIt->addTotalDelay(delay);
      flowStatIt->addTotalJitter(jitter);      
      
      flowStatIt->addTotalBytes(packet->getByteLength() - flowIt->getPApplication()->anotherEncapsulationOverhead());
      flowStatIt->addTotalRcvPkts();
      numReceived++;
    }
  }
  
  delete msg;
}

void TrafGen::handleMessage(cMessage *msg)
{
   if (msg->isSelfMessage())
    {
        TG::Flows::iterator flowIt = getFlowById(msg->getName());
        // send, then reschedule next sending
        sendPacket(flowIt);
        if(simTime().dbl() < flowIt->getStartTime() + flowIt->getDuration())
          scheduleAt(simTime()+ flowIt->getPApplication()->getNextPacketTime(), msg);
        else
          delete msg;
    }
    else
    {
        // process incoming packet
        processPacket(PK(msg));
    }
}

void TrafGen::finish()
{
  if(receivedStatistics.size() > 0)
  {
    std::ofstream outFile;
    std::string moduleName = this->getParentModule()->getName();
    std::string filename = "results/TrafGen-" + moduleName + ".txt";
    outFile.open (filename.c_str());
    
    int numFlows = 0;
    double tStartTime = 0.0;
    double tEndTime = 0.0;
    double tMinDelay = 0.0;
    double tMaxDelay = 0.0;
    double tTotDelay = 0.0;
    double tTotJitter = 0.0;
    long tReceived = 0;
    long tBytes = 0;
    long tDropped = 0;
    
    for(TG::RcvStats::iterator it =  receivedStatistics.begin(); it != receivedStatistics.end(); ++it)
    {
      TG::Flows::iterator fit = getFlowById(it->getId()); 
      outFile << "----------------------------------------------------------" << std::endl;
      outFile << "Flow ID: " << it->getId() << std::endl;
      outFile << "From:\t" << fit->getSrcIP().str() << ":" << fit->getSrcPort() << std::endl;
      outFile << "To:\t" << fit->getDstIP().str() << ":" << fit->getDstPort() << std::endl;
      outFile << "----------------------------------------------------------" << std::endl;
      if(it->getTotalRcvPkts() > 0)
      { 
        if(numFlows == 0)
        {
          tStartTime = it->getStartTime();
          tEndTime = it->getEndTime();
          tMinDelay = it->getMinDelay();
          tMaxDelay = it->getMaxDelay();
        }
        else
        {
          if(tStartTime > it->getStartTime())
            tStartTime = it->getStartTime();
          if(tEndTime < it->getEndTime())
            tEndTime = it->getEndTime();
          if(tMinDelay > it->getMinDelay())
            tMinDelay = it->getMinDelay();
          if(tMaxDelay < it->getMaxDelay())
            tMaxDelay = it->getMaxDelay();
        }
        tTotDelay += it->getTotalDelay();
        tTotJitter += it->getTotalJitter();
        tReceived += it->getTotalRcvPkts();
        tBytes += it->getTotalBytes();
        tDropped += it->getTotalSentPkts() - it->getTotalRcvPkts();
        ++numFlows;
        
        double duration = it->getEndTime() - it->getStartTime();
        outFile << "Total time               = " << duration << " s" << std::endl; 
        outFile << "Total packets            = " << it->getTotalRcvPkts() << std::endl;
        outFile << "Minimum delay            = " << it->getMinDelay() << " s" << std::endl;
        outFile << "Maximum delay            = " << it->getMaxDelay() << " s" << std::endl;
        outFile << "Average delay            = " << it->getTotalDelay() / it->getTotalRcvPkts() << " s" << std::endl;
        outFile << "Average jitter           = " << it->getTotalJitter() / it->getTotalRcvPkts() << " s" << std::endl;
        outFile << "Bytes received           = " << it->getTotalBytes() << std::endl;
        outFile << "Average bitrate          = " << (it->getTotalBytes() * 8)/(duration * 1000) << " Kbit/s" << std::endl;
        outFile << "Average packet rate      = " << it->getTotalRcvPkts()/duration << " pkt/s" << std::endl;
        outFile << "Packets dropped          = " << it->getTotalSentPkts() - it->getTotalRcvPkts() << std::endl;
      }
      else
      {
        outFile << "No received data yet" << std::endl;
      }
      
      outFile << "----------------------------------------------------------" << std::endl;
      
    }
    outFile << std::endl;   
    outFile << "__________________________________________________________" << std::endl;
    outFile << "****************  TOTAL RESULTS   ******************" << std::endl;
    outFile << "__________________________________________________________" << std::endl;

    double duration = tEndTime - tStartTime;
    outFile << "Number of flows          = " << numFlows << std::endl;
    outFile << "Total time               = " << duration << " s" << std::endl; 
    outFile << "Total packets            = " << tReceived << std::endl;
    outFile << "Minimum delay            = " << tMinDelay << " s" << std::endl;
    outFile << "Maximum delay            = " << tMaxDelay << " s" << std::endl;
    outFile << "Average delay            = " << tTotDelay / tReceived << " s" << std::endl;
    outFile << "Average jitter           = " << tTotJitter / tReceived << " s" << std::endl;
    outFile << "Bytes received           = " << tBytes << std::endl;
    outFile << "Average bitrate          = " << (tBytes * 8)/(duration * 1000) << " Kbit/s" << std::endl;
    outFile << "Average packet rate      = " << tReceived/duration << " pkt/s" << std::endl;
    outFile << "Packets dropped          = " << tDropped  << std::endl;
    
    outFile << "----------------------------------------------------------" << std::endl;
    
    outFile.close();
  }
}

TG::Flows::iterator TrafGen::getFlowById(std::string s_id)
{
  for(TG::Flows::iterator it =  flows.begin(); it != flows.end(); ++it)
  {
    if(it->getId() == s_id)
      return it;
  }
  return  flows.end();
}

TG::SntStats::iterator TrafGen::getSntFlowStatById(std::string s_id)
{
  for(TG::SntStats::iterator it =  sentStatistics.begin(); it != sentStatistics.end(); ++it)
  {
    if(it->getId() == s_id)
      return it;
  }
  return  sentStatistics.end();
}

TG::RcvStats::iterator TrafGen::getRcvFlowStatById(std::string r_id)
{
  for(TG::RcvStats::iterator it =  receivedStatistics.begin(); it != receivedStatistics.end(); ++it)
  {
    if(it->getId() == r_id)
      return it;
  }
  return  receivedStatistics.end();
}


TG::FlowRecord::FlowRecord()
{
  id.erase();
  startTime = 0.0;
  duration = 0.0;
  srcIP = IPAddress::UNSPECIFIED_ADDRESS;
  dstIP = IPAddress::UNSPECIFIED_ADDRESS;
  tos = 0;
  ttl = 0;
  protocol = 0;
  srcPort = 0;
  dstPort = 0;
  
  pApplication = NULL;
  
  generating = false;
  analyzing = false;
}


bool TG::FlowRecord::setApplication(const cXMLElement& appConfig)
{
  appName = appConfig.getAttribute("type");
  
  pApplication = check_and_cast<ITrafGenApplication *>(createOne(appName.c_str()));
  
  if(pApplication != NULL)
    return pApplication->loadConfig(appConfig);
  
  return false;    
}

TG::RcvFlowRecord::RcvFlowRecord()
{
  startTime = 0.0;
  totalSentPkts = 0;
  totalRcvPkts = 0;
  minDelay = 0.0;
  maxDelay = 0.0;
  totalDelay = 0.0;
  totalJitter = 0.0;
  totalBytes = 0.0;
}



