package main

import "fmt"
import "net/http"
import "io"
import "io/ioutil"
import "log"
import "gopkg.in/yaml.v3"
import "time"
import "sync"
import "os/signal"
import "os"
import "errors"
import "net"
import "os/exec"
import "flag"

type Check struct {
	Type     string `yaml:"type"`
  Host     string `yaml:"host"`
	Port     int    `yaml:"port"`
	Uri      string `yaml:"uri"`
	Method   string `yaml:"method"`
	Timeout  int    `yaml:"timeout"`
	Interval int    `yaml:"interval"`
}

type Group struct {
	Name   string   `yaml:"group"`
	Tables []string `yaml:"tables"`
	Hosts  []string `yaml:"hosts"`
	Check  Check    `yaml:"check"`
}


func main() {
	log.Println("Starting")

  confFileName := flag.String("f", "/usr/local/etc/hostchecker.yaml", "YAML configuration file")
  flag.Parse()

	var waitGroup sync.WaitGroup

	var conf map[string]Group
	yamlFile, err := ioutil.ReadFile(*confFileName)
	if err != nil {
		log.Fatalf("Configuration open error #%v ", err)
	}
	err = yaml.Unmarshal(yamlFile, &conf)
	if err != nil {
		log.Fatalf("Configuration read error #%v", err)
	}

  err = validateConfiguration(conf)
  if err != nil {
    log.Fatal("Configuration error #", err)
  }
	stopChannel := make(chan bool)
	for name, group := range conf {
		log.Println("Checking group", name, group)
		waitGroup.Add(1)
		go checkGroup(name, group, &waitGroup, stopChannel)
	}

	exit := make(chan os.Signal, 10)

	signal.Notify(exit, os.Interrupt)

	s := <-exit
	log.Println("Received signal", s)
	log.Println("main closing stopChannel")
	close(stopChannel)
	waitGroup.Wait()
}

func checkGroup(name string, group Group, waitGroup *sync.WaitGroup, stopChannel chan bool) {
	channels := make(map[string]chan int)
	for _, host := range group.Hosts {
		channel := make(chan int, 1)
		channels[host] = channel
		waitGroup.Add(1)
		go checkHost(channels[host], name, host, group.Check, waitGroup, stopChannel)
	}

	for {
		select {
		case <-stopChannel:
			log.Println("checkGroup", name, "stopChannel")
			waitGroup.Done()
			return
			break
		default:
			for host, channel := range channels {
				select {
				case stop := <-stopChannel:
					log.Println("checkGroup", name, "stopChannel", stop)
					break
				case status := <-channel:
					log.Println("Status for ", host, "is", status, "in group", name)
            updateTables( host, status, group.Tables )
				default:
					time.Sleep(100 * time.Millisecond)
				}
			}
		}
	}
}

func updateTables( host string, status int, tables []string ) {
  for _, table := range tables {
    op := "add"
    if status != 0 {
      op = "del"
    }
    cmd := exec.Command("pfctl", "-t", table, "-T", op, host)
    err := cmd.Run()
    if err != nil {
      log.Println("Unable to run command #", cmd)
    }
  }
}

func checkHost(status chan<- int, group string, host string, check Check, waitGroup *sync.WaitGroup, stopChannel chan bool) {

  var lastCheck time.Time

	for {
		select {
		case <-stopChannel:
			log.Println("checkHost", host, "group", group, "stopChannel")
			waitGroup.Done()
			return
		default:
			var err error
      if time.Since(lastCheck).Seconds() > float64(check.Interval) {
        lastCheck = time.Now()
  			err = nil
  			switch check.Type {
  			case "http":
  				err = CheckHTTP(host, check)
  			case "tcp":
  				err = CheckTCP(host, check)
  			case "smtp":
  				err = CheckSMTP(host, check)
  			}
  			if err != nil {
  				status <- 1
  				log.Println("checkHost", host, "group", group, "error", err)
  			} else {
  				status <- 0
  			}
      }
			time.Sleep(300*time.Millisecond)
		}
	}
}

func CheckHTTP(host string, check Check) error {


	client := &http.Client{
		Timeout: time.Duration(check.Timeout) * time.Second,
	}
	req, err := http.NewRequest(check.Method, fmt.Sprintf("http://%s%s", host, check.Uri), nil)
  if err != nil {
    return err
  }
  if len(check.Host) > 0 {
    req.Header.Set("Host", check.Host)
  }
	resp, err := client.Do(req)
	if err != nil {
		return err
	}
  defer resp.Body.Close()
	_, _ = io.ReadAll(resp.Body)

	return nil
}

func CheckTCP(host string, check Check) error {
  cnx, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, check.Port), time.Duration(check.Timeout) * time.Second)
  if err != nil {
    return err
  }
  cnx.Close()
	return nil
}

func CheckSMTP(host string, check Check) error {
	return nil
}

func validateConfiguration(conf map[string]Group) error {
  for name, group := range conf {
    log.Println("Validating configuration", name)
    if len(group.Tables) == 0 {
      return errors.New(fmt.Sprintf("No tables in group %s", name))
    }
    log.Println("Hosts", group.Hosts)
    if len(group.Hosts) == 0 {
      return errors.New(fmt.Sprintf("No hosts in group %s", name))
    }
    for _, host := range group.Hosts {
      ip := net.ParseIP(host)
      if ip == nil {
        return errors.New(fmt.Sprintf("Host %v is not an IP in group %s", host, name))
      }
    }
    switch group.Check.Type {
    case "http":
      if len(group.Check.Method) == 0 {
        group.Check.Method = "HEAD"
      }
      switch group.Check.Method {
      case "HEAD":
      case "GET":
      default:
        return errors.New(fmt.Sprintf("Check method shoud be HEAD or GET in group %s", name))
      }
      if group.Check.Port == 0 {
        group.Check.Port = 80
      }
      if len(group.Check.Uri) == 0 {
        group.Check.Uri = "/"
      }
    case "tcp":
      if group.Check.Port == 0 {
        return errors.New(fmt.Sprintf("Check port is undefined or 0 in group %s", name))
      }
    case "smtp":
      if group.Check.Port == 0 {
        group.Check.Port = 25
      }

    default:
      return errors.New(fmt.Sprintf("Check type should be http, smtp or tcp in group %s",name))
    }
    if group.Check.Interval == 0 {
      group.Check.Interval = 5
    }
    if group.Check.Timeout == 0 {
      group.Check.Timeout = 2
    }
    conf[name] = group
  }
  return nil
}