/*
 * session.C: Implementation of a snmpsession
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later 
 * version.
 * 
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 59 Temple Place - Suite 330,
 * Boston, MA 02111-1307, USA.
 *
 *
 * See the AUTHORS file for a list of people who have hacked on 
 * this code. 
 * See the ChangeLog file for a list of changes.
 *
 */

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <sys/stat.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <string.h>
#include <netdb.h>
#include <unistd.h>
#include <errno.h>
#include <ctype.h>

#include <queue>

#include "snmpkit"

#include "ber.h"
#include "oidseq.h"
#include "snmpsock.h"

#define SESSION_DEBUGSNMP_FLAG 0x1ul
#define SESSION_DISABLED_FLAG  0x2ul
#define MAXTHREADS 100

char SNMP_session::need_newline; 
SNMP_session *SNMP_session::lastprint;
pthread_mutex_t SNMP_session::lastprint_m;

pthread_mutex_t joiner_m=PTHREAD_MUTEX_INITIALIZER;
pthread_cond_t joiner_cv=PTHREAD_COND_INITIALIZER;
std::queue<pthread_t> tojoin;

class inuse_t{
  pthread_mutex_t inuse_m;
  pthread_cond_t inuse_cv;
  unsigned int inuse;
  unsigned int limit;
  int finished;
public:
  inuse_t(unsigned int lim):inuse(0),limit(lim),finished(0){
    pthread_mutex_init(&inuse_m,NULL);
    pthread_cond_init(&inuse_cv,NULL);
  }
  void inc(){
    pthread_mutex_lock(&inuse_m);
    if(inuse>=limit)
      pthread_cond_wait(&inuse_cv,&inuse_m);
    inuse++;
    finished = 0;
    pthread_mutex_unlock(&inuse_m);
  }
  void dec(){
    pthread_mutex_lock(&inuse_m);
    inuse--;
    pthread_mutex_unlock(&inuse_m);
    pthread_cond_signal(&inuse_cv);
  }
  int done(){
    pthread_mutex_lock(&inuse_m);
    int state=!inuse && finished;
    pthread_mutex_unlock(&inuse_m);
    return state;
  }
  void set_done(){
    pthread_mutex_lock(&inuse_m);
    finished=1;
    pthread_mutex_unlock(&inuse_m);
  }
};

inuse_t inuse(MAXTHREADS);
pthread_t joiner_th;

struct run_session_t{
  SNMP_session *session;
  void *(*fp)(SNMP_session*);
};

void *SNMP_run_session(void *rs){
  run_session_t *data=reinterpret_cast<run_session_t*>(rs);
  void *retval=(data->fp)(data->session);
  delete data;

  pthread_mutex_lock(&joiner_m);
  tojoin.push(pthread_self());
  pthread_mutex_unlock(&joiner_m);

  pthread_cond_signal(&joiner_cv);
  return retval;
}

pthread_mutex_t joiner_running_m = PTHREAD_MUTEX_INITIALIZER;
int joiner_running = 0;

int SNMP_sessions_done(){
  int *retval = 0;
  inuse.set_done();
  pthread_join(joiner_th,reinterpret_cast<void**>(&retval));
  pthread_mutex_lock(&joiner_running_m);
  joiner_running = 0;
  pthread_mutex_unlock(&joiner_running_m);
  int i=0;
  if(retval){
    i=*retval;
    delete retval;
  }
  return i;
}

#define MAXPACKSIZE 10240

SNMP_socket *sock=NULL;
int timeout=10;
int retries=5;
int port=0;


void *joiner(void *){
  int *retval=new int;
  *retval=0;

  for(;!inuse.done();){
    pthread_mutex_lock(&joiner_m);
    pthread_cond_wait(&joiner_cv,&joiner_m);
    // cerr << "Joiner awake " << tojoin.empty() << endl << flush;
    while(!tojoin.empty()){
      pthread_t dead=tojoin.front();
      tojoin.pop();
      inuse.dec();
      int *th_retval;
      pthread_join(dead,reinterpret_cast<void**>(&th_retval));
      if(th_retval!=NULL){
	*retval=*th_retval;
	delete th_retval;
      }
    }
    pthread_mutex_unlock(&joiner_m);
  }
  return retval;
}

void set_snmpsock_props(int t,int r,int p){
  timeout=t;
  retries=r;
  port=p;
}

OidSeq *SNMP_session::do_req(Tags tag, OidSeq *oids)
  throw(SNMPPacketNotSequenceException,SNMPRespNotSequenceException,
	SNMPNotResponseTagException,SNMPSeqnoNotIntException,
	SNMPStateNotIntException,SNMPFaultOidNotIntException,
	OidSeqBadLayoutException,SNMPBadOidException,std::bad_alloc,
	SocketSendShortExecption,BerSequenceTagException,BerLengthException,
	BerIntTagException,BerIntLengthExecption,BerCounterTagException,
	BerCounterLengthExecption,BerStringTagException,BerNullTagException,
	BerNullLengthExecption,BerOidTagException,BerTimeTickTagException,
	BerTimeTickLengthExecption,BerIPAddrLengthExecption){
  if(flags&SESSION_DISABLED_FLAG)
    return NULL;
  /* ------- Construct the packet ------- */
  BerSequence *opacket,*request;
  /* does not need to be a very random number */
  long seqno=random(); /* ITS4: ignore */
  request=new BerSequence(tag);
  request->append(new BerInt(seqno));
  request->append(new BerInt(0L));
  request->append(new BerInt(0L));
  request->append(oids->Seq());
  opacket=new BerSequence(SEQUENCE_TAG);
  opacket->append(new BerInt(0L));
  opacket->append(new BerString(community));
  opacket->append(request);

  ustring opackdat;
  opacket->encode(opackdat);

  if(flags&SESSION_DEBUGSNMP_FLAG){
    __write_debug("Sent", opacket);
    __write_debug_bin(opackdat);
  }

  /* ------- Send the packet ------- */
  unsigned char *inbuf;

  BerSequence *top,*cur;
  long seqno2;
  do{
    int buflen=opackdat.length();
    while((inbuf=sock->call(he->h_addr_list[ipidx],he->h_length,
			    he->h_addrtype,opackdat.data(),buflen))==NULL){
      ipidx++;
      if(he->h_addr_list[ipidx]==NULL){
	flags|=SESSION_DISABLED_FLAG;
	return NULL;
      }
    }

    top=new BerSequence(inbuf);
    if(flags&SESSION_DEBUGSNMP_FLAG){
      __write_debug("Received", top);
      ustring foo(inbuf,buflen);
      __write_debug_bin(foo);
    }
    
    //throw away wrapper and check that this is a response
    if((cur=dynamic_cast<BerSequence*>(top))==NULL)
      throw SNMPPacketNotSequenceException();
    if((cur=dynamic_cast<BerSequence*>(cur->extract(cur->begin()+2)))==NULL)
      throw SNMPRespNotSequenceException();
    if(cur->type()!=GET_RESP_TAG) throw SNMPNotResponseTagException();

    /* make sure that this is a response to this request not some
       previous request. What was happening with slow was that I 
       would send a packet then it would time out and then I 
       would assume that the packet was lost and then I would 
       send another copy. Right after that, the reply to the 
       first packet would arrive and so I would move on and send 
       another packet. Then the second reply to the first packet 
       would arrive and the program would get all confused. This 
       could also fix the problem where table lines are sometimes 
       repeated over slow links. */
    BerInt *i=dynamic_cast<BerInt*>(*cur->begin());
    if(i==NULL) throw SNMPSeqnoNotIntException();
    seqno2=i->value();
  }while(seqno2!=seqno);
  //  printf("Outside the loop.\n");

  // make sure that no errors came back
  /* if this fails then the device is not sending back a properly
     constructed packet*/
  BerInt *bi=dynamic_cast<BerInt*>(*(cur->begin()+1));
  if(bi==NULL) throw SNMPStateNotIntException();
  if(bi->value()!=0){ //error
    /* if this fails something more sophisticated will have to be 
       written -- hopefully this will never fail. */
    if((bi=dynamic_cast<BerInt*>(*(cur->begin()+2)))==NULL) 
      throw SNMPFaultOidNotIntException();
    delete top;
    top=cur;
    /*cur should now point to the problem oid */
    if((top=dynamic_cast<BerSequence*>(*(top->begin()+3)))==NULL)
      throw SNMPRespNotSequenceException();
    if((top=dynamic_cast<BerSequence*>(*(top->begin()+bi->value()-1)))==NULL)
      throw OidSeqBadLayoutException();
    BerOid *oid=dynamic_cast<BerOid*>(*(top->begin()));
    if(oid==NULL)
      throw OidSeqBadLayoutException();
    std::string buf;
    oid->ascii_print(buf);
    throw SNMPBadOidException(buf);
  } 
  delete top;
  top=cur;
  if((cur=dynamic_cast<BerSequence*>(top->extract(top->begin()+3)))==NULL)
    throw SNMPRespNotSequenceException();
  delete top;
  return new OidSeq(cur);
}

hostent *dup_hostent(hostent *src){ 
  if(src==NULL)
    return NULL;
  hostent *dest=new hostent;
  dest->h_name=strdup(src->h_name); 
  dest->h_addrtype=src->h_addrtype; 
  dest->h_length=src->h_length;
 
  int i; 
  for(i=0;src->h_aliases[i]!=NULL;i++); // count aliases
  dest->h_aliases=new char*[i+1];
  dest->h_aliases[i]=NULL; 
  for(;src->h_aliases[i] && i>=0;i--) // copy aliases over
    dest->h_aliases[i]=strdup(src->h_aliases[i]); 

  for(i=0;src->h_addr_list[i]!=NULL;i++); // count addrs
  dest->h_addr_list=new char*[i+1]; 
  dest->h_addr_list[i]=NULL; 
  for(--i;i>=0;i--){ // copy addrs over
    dest->h_addr_list[i]=new char[src->h_length]; 
    memcpy(dest->h_addr_list[i],src->h_addr_list[i],dest->h_length); 
  } 
  return dest;
} 

void del_hostent(hostent *dead){ 
  int i; for(i=0;dead->h_aliases[i]!=NULL;i++) 
    delete dead->h_aliases[i]; 
  for(i=0;dead->h_addr_list[i]!=NULL;i++) 
    delete dead->h_addr_list[i]; 
}

static void
start_joiner(void)
{
  int ret;

  pthread_mutex_lock(&joiner_running_m);
  if (!joiner_running) {
    if((ret=pthread_create(&joiner_th,NULL,joiner,NULL)))
      throw JoinerCreateException(ret);
    else
      joiner_running = 1;
  }
  pthread_mutex_unlock(&joiner_running_m);
}

SNMP_session::SNMP_session(const std::string &host,
			   void *(*start_routine)(SNMP_session *),
			   const std::string &comm) 
  throw(std::bad_alloc,SocketNoUDPException,SocketCreateFailException,
	ReceiverCreateException,SessionHostNotFoundException,
	JoinerCreateException,SessionWorkerCreateException):
  community(comm),ipidx(0),hostname(host),flags(0){
  if(sock==NULL){
    sock=new SNMP_socket(timeout,retries,port);
  }


  if((he=dup_hostent(gethostbyname(host.c_str())))==NULL)
    throw SessionHostNotFoundException(h_errno);
  fflush(stderr);

  inuse.inc();
  start_joiner();
  int ret;
  run_session_t *rs=new run_session_t;
  rs->fp=start_routine;
  rs->session=this;
  pthread_t newthread;
  if((ret=pthread_create(&newthread,NULL,SNMP_run_session,rs)))
    throw SessionWorkerCreateException(ret);
}

SNMP_session::~SNMP_session(){
  del_hostent(he);
  delete he;
}

void SNMP_session::setDebug()
  throw(DebugFileOpenException){
  flags|=SESSION_DEBUGSNMP_FLAG;      
  int filenum=0;
  char namebuf[20];
  int retval;

  snprintf(namebuf,19,"snmplog.%d",filenum);
  /* ITS4: ignore open */
  while((debugfile=open(namebuf,O_WRONLY|O_CREAT|O_EXCL,0644))==-1 && 
	errno==EEXIST){
    filenum++;
    snprintf(namebuf,19,"snmplog.%d",filenum);
  }
  //stat returned a different error than we were expecting.
  if(retval==-1 && errno!=EEXIST) throw DebugFileOpenException(errno);
  char buf[256];
  int blen=snprintf(buf,256,"Contacting %u.%u.%u.%u\n",
		    (static_cast<unsigned>(he->h_addr_list[ipidx][0])&0xff),
		    (static_cast<unsigned>(he->h_addr_list[ipidx][1])&0xff),
		    (static_cast<unsigned>(he->h_addr_list[ipidx][2])&0xff),
		    (static_cast<unsigned>(he->h_addr_list[ipidx][3])&0xff));
  write(debugfile,buf,blen);
}

void SNMP_session::__write_debug(const std::string &dirstr,
			       BerSequence *packet){
  std::string printbuf=dirstr;
  printbuf+='\n';
  packet->ascii_print(printbuf);
  printbuf+='\n';
  
  write(debugfile,printbuf.c_str(),printbuf.size());
}

// write the binary version of the packet
void SNMP_session::__write_debug_bin(const ustring &str){
  std::string prntbuf;
  char littlebuf[10];
  
  char i=0;
  for(ustring::const_iterator cur=str.begin();cur!=str.end();cur++,i++){
    snprintf(littlebuf,10,"%02x ",*cur);
    prntbuf+=littlebuf;
    if(i>0 && (i+1)%16==0)
      prntbuf+="\n";
  }
  prntbuf+="\n";
  write(debugfile,prntbuf.data(),prntbuf.length());  
}

void SNMP_session::printstr(char neednl,char *str, char ck_name_flag){
  static const char *basestr[]={"%s%s","\n%s%s","hostname=\"%s\";%s",
				  "\nhostname=\"%s\";%s"};
  std::string hn;
  char idx=0;
  pthread_mutex_lock(&lastprint_m);
  if(lastprint!=this){
    if(need_newline)
      idx=1;
    if(ck_name_flag){
      idx|=2;
      hn=hostname;
    }
  }else{
    if(!neednl && ck_name_flag){
      idx|=2;
      hn=hostname;
    }
    if(need_newline && !neednl){
      idx=1;
    } 
  }
  need_newline=neednl;
  lastprint=this;
  pthread_mutex_unlock(&lastprint_m);

  printf(basestr[idx],hn.c_str(),str); /* ITS4: ignore */
}

void SNMP_session::end(){
  pthread_mutex_lock(&lastprint_m);
  if(need_newline)
    putchar('\n');
  pthread_mutex_unlock(&lastprint_m);
}

int contigbits(unsigned int mask){
  char i=0;
  // count zero's
  while(mask&1==0)
    mask=mask>>1,i++; 
  // count 1's
  while(mask&1==1)
    mask=mask>>1,i++;
  return i==32;
}

std::list<SNMP_session*> &SNMP_sessions(std::list<SNMP_session*> &dest,
				   std::string &hostspec,
				   void *(*start_routine)(SNMP_session *),
				   const std::string &community)
  throw(std::bad_alloc,SocketNoUDPException,SocketCreateFailException,
	ReceiverCreateException,SessionHostNotFoundException,
	JoinerCreateException,SessionWorkerCreateException,
	SessionCommunityException,SessionOctetOverflowException,
	SessionBadSubnetException,SessionNetbitsOverflowException,
	SessionBadNetmaskException){
  // strip off the community name
  std::string cmty;
  std::string::size_type tmploc=hostspec.find('(');
  if(tmploc==std::string::npos)
    cmty=community;
  else{
    std::string::size_type end=hostspec.find(')',tmploc);
    if(end==std::string::npos)
      throw SessionCommunityException();
    cmty=hostspec.substr(tmploc+1,end-1);
    hostspec=hostspec.substr(0,tmploc-1);
  }

  if(isdigit(hostspec[0])){
    /* assume that this is 
       a) an ipaddress like 10.1.1.102
       b) a range of ipaddresses line 10.1.1.102-150
       c) a network and a subnet mask 10.1.1.0/24
    */
    if(hostspec.find('-')!=std::string::npos){ // the range case
      unsigned int val[5];
      int num=sscanf(hostspec.c_str(),"%u.%u.%u.%u-%u",val,val+1,val+2,val+3,
		    val+4);
      if(num!=5)
	throw SessionBadRangeException();
      if(val[0]>=256 || val[1]>=256 || val[2]>=256 || val[3]>=256 ||
	 val[4]>=256)
	throw SessionOctetOverflowException();
      for(;val[3]<=val[4];val[3]++){
	char buf[20];
	snprintf(buf,20,"%u.%u.%u.%u",val[0],val[1],val[2],val[3]);
 	dest.push_back(new SNMP_session(buf,start_routine,cmty));
      }
    } else if(hostspec.find('/')!=std::string::npos){ // the network case
      unsigned int val[8];
      unsigned int baseaddr, topaddr, mask;
      int numread=sscanf(hostspec.c_str(),"%u.%u.%u.%u/%u.%u.%u.%u",val,val+1,
		      val+2,val+3,val+4,val+5,val+6,val+7);
      if(val[0]>=256 || val[1]>=256 || val[2]>=256 || val[3]>=256)
	throw SessionOctetOverflowException();
      switch(numread){
      case 5:
	if(val[4]>=32)
	  throw SessionNetbitsOverflowException();
	mask=0xffffffffu << (32-val[4]);
	break;
      case 8:
	if(val[4]>=256 || val[5]>=256 || val[6]>=256 || val[7]>=256)
	  throw SessionOctetOverflowException();
	mask=(val[4]<<24)|(val[5]<<16)|(val[6]<<8)|val[7];
	if(!contigbits(mask))
	   throw SessionBadNetmaskException();
      default:
	throw SessionBadSubnetException();
      }
      baseaddr=((val[0]<<24)|(val[1]<<16)|(val[2]<<8)|val[3]) & mask;
      topaddr=baseaddr|~mask;
	
      // loop through all the addrs skipping the network and 
      // broadcast
      for(baseaddr++;baseaddr<topaddr;baseaddr++){
	char buf[20];
	snprintf(buf,20,"%u.%u.%u.%u",(baseaddr&0xff000000u)>>24,
		 (baseaddr&0xff0000)>>16,(baseaddr&0xff00)>>8,baseaddr&0xff);
	dest.push_back(new SNMP_session(buf,start_routine,cmty));
      }	
    } else  // the single ip case
      dest.push_back(new SNMP_session(hostspec,start_routine,cmty));
  }else     // a hostname
    dest.push_back(new SNMP_session(hostspec,start_routine,cmty));
  return dest;
}

std::list<SNMP_session*> &SNMP_sessions(std::list<SNMP_session*> &dest,
					std::list<std::string> &hostspecs,
				   void *(*start_routine)(SNMP_session*),
				   const std::string &community)
  throw(std::bad_alloc,SocketNoUDPException,SocketCreateFailException,
	ReceiverCreateException,SessionHostNotFoundException,
	JoinerCreateException,SessionWorkerCreateException,
	SessionCommunityException,SessionOctetOverflowException,
	SessionBadSubnetException,SessionNetbitsOverflowException,
	SessionBadNetmaskException){
  for(std::list<std::string>::iterator i=hostspecs.begin();i!=hostspecs.end();i++)
    SNMP_sessions(dest,*i,start_routine,community);
  return dest;
}
