/*
   Copyright 2023 Laurent Ulrich (laurentu@gmail.com)

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. 
*/
package main

import "fmt"
import "net/http"
import "io"
import "io/ioutil"
import "log"
import "gopkg.in/yaml.v3"
import "time"
import "sync"
import "os/signal"
import "os"
import "errors"
import "net"
import "os/exec"
import "flag"

type Check struct {
	Type     string `yaml:"type"`
  Host     string `yaml:"host"`
	Port     int    `yaml:"port"`
	Uri      string `yaml:"uri"`
	Method   string `yaml:"method"`
	Timeout  int    `yaml:"timeout"`
	Interval int    `yaml:"interval"`
}

type Group struct {
	Name   string   `yaml:"group"`
	Tables []string `yaml:"tables"`
	Hosts  []string `yaml:"hosts"`
	Check  Check    `yaml:"check"`
}


func main() {
	log.Println("Starting")

  confFileName := flag.String("f", "/usr/local/etc/hostchecker.yaml", "YAML configuration file")
  flag.Parse()

	var waitGroup sync.WaitGroup

	var conf map[string]Group
	yamlFile, err := ioutil.ReadFile(*confFileName)
	if err != nil {
		log.Fatalf("Configuration open error #%v ", err)
	}
	err = yaml.Unmarshal(yamlFile, &conf)
	if err != nil {
		log.Fatalf("Configuration read error #%v", err)
	}

  err = validateConfiguration(conf)
  if err != nil {
    log.Fatal("Configuration error #", err)
  }
	stopChannel := make(chan bool)
	for name, group := range conf {
		log.Println("Checking group", name, group)
		waitGroup.Add(1)
		go checkGroup(name, group, &waitGroup, stopChannel)
	}

	exit := make(chan os.Signal, 10)

	signal.Notify(exit, os.Interrupt)

	s := <-exit
	log.Println("Received signal", s)
	log.Println("main closing stopChannel")
	close(stopChannel)
	waitGroup.Wait()
}

func checkGroup(name string, group Group, waitGroup *sync.WaitGroup, stopChannel chan bool) {
	channels := make(map[string]chan int)
	for _, host := range group.Hosts {
		channel := make(chan int, 1)
		channels[host] = channel
		waitGroup.Add(1)
		go checkHost(channels[host], name, host, group.Check, waitGroup, stopChannel)
	}

	for {
		select {
		case <-stopChannel:
			log.Println("checkGroup", name, "stopChannel")
			waitGroup.Done()
			return
			break
		default:
			for host, channel := range channels {
				select {
				case stop := <-stopChannel:
					log.Println("checkGroup", name, "stopChannel", stop)
					break
				case status := <-channel:
					log.Println("Status for ", host, "is", status, "in group", name)
            updateTables( host, status, group.Tables )
				default:
					time.Sleep(100 * time.Millisecond)
				}
			}
		}
	}
}

func updateTables( host string, status int, tables []string ) {
  for _, table := range tables {
    op := "add"
    if status != 0 {
      op = "del"
    }
    cmd := exec.Command("pfctl", "-t", table, "-T", op, host)
    err := cmd.Run()
    if err != nil {
      log.Println("Unable to run command #", cmd)
    }
  }
}

func checkHost(status chan<- int, group string, host string, check Check, waitGroup *sync.WaitGroup, stopChannel chan bool) {

  var lastCheck time.Time

	for {
		select {
		case <-stopChannel:
			log.Println("checkHost", host, "group", group, "stopChannel")
			waitGroup.Done()
			return
		default:
			var err error
      if time.Since(lastCheck).Seconds() > float64(check.Interval) {
        lastCheck = time.Now()
  			err = nil
  			switch check.Type {
  			case "http":
  				err = CheckHTTP(host, check)
  			case "tcp":
  				err = CheckTCP(host, check)
  			case "smtp":
  				err = CheckSMTP(host, check)
  			}
  			if err != nil {
  				status <- 1
  				log.Println("checkHost", host, "group", group, "error", err)
  			} else {
  				status <- 0
  			}
      }
			time.Sleep(300*time.Millisecond)
		}
	}
}

func CheckHTTP(host string, check Check) error {


	client := &http.Client{
		Timeout: time.Duration(check.Timeout) * time.Second,
	}
	req, err := http.NewRequest(check.Method, fmt.Sprintf("http://%s%s", host, check.Uri), nil)
  if err != nil {
    return err
  }
  if len(check.Host) > 0 {
    req.Header.Set("Host", check.Host)
  }
	resp, err := client.Do(req)
	if err != nil {
		return err
	}
  defer resp.Body.Close()
	_, _ = io.ReadAll(resp.Body)

	return nil
}

func CheckTCP(host string, check Check) error {
  cnx, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, check.Port), time.Duration(check.Timeout) * time.Second)
  if err != nil {
    return err
  }
  cnx.Close()
	return nil
}

func CheckSMTP(host string, check Check) error {
	return nil
}

func validateConfiguration(conf map[string]Group) error {
  for name, group := range conf {
    log.Println("Validating configuration", name)
    if len(group.Tables) == 0 {
      return errors.New(fmt.Sprintf("No tables in group %s", name))
    }
    log.Println("Hosts", group.Hosts)
    if len(group.Hosts) == 0 {
      return errors.New(fmt.Sprintf("No hosts in group %s", name))
    }
    for _, host := range group.Hosts {
      ip := net.ParseIP(host)
      if ip == nil {
        return errors.New(fmt.Sprintf("Host %v is not an IP in group %s", host, name))
      }
    }
    switch group.Check.Type {
    case "http":
      if len(group.Check.Method) == 0 {
        group.Check.Method = "HEAD"
      }
      switch group.Check.Method {
      case "HEAD":
      case "GET":
      default:
        return errors.New(fmt.Sprintf("Check method shoud be HEAD or GET in group %s", name))
      }
      if group.Check.Port == 0 {
        group.Check.Port = 80
      }
      if len(group.Check.Uri) == 0 {
        group.Check.Uri = "/"
      }
    case "tcp":
      if group.Check.Port == 0 {
        return errors.New(fmt.Sprintf("Check port is undefined or 0 in group %s", name))
      }
    case "smtp":
      if group.Check.Port == 0 {
        group.Check.Port = 25
      }

    default:
      return errors.New(fmt.Sprintf("Check type should be http, smtp or tcp in group %s",name))
    }
    if group.Check.Interval == 0 {
      group.Check.Interval = 5
    }
    if group.Check.Timeout == 0 {
      group.Check.Timeout = 2
    }
    conf[name] = group
  }
  return nil
}