package main import "fmt" import "net/http" import "io" import "io/ioutil" import "log" import "gopkg.in/yaml.v3" import "time" import "sync" import "os/signal" import "os" import "errors" import "net" import "os/exec" import "flag" type Check struct { Type string `yaml:"type"` Host string `yaml:"host"` Port int `yaml:"port"` Uri string `yaml:"uri"` Method string `yaml:"method"` Timeout int `yaml:"timeout"` Interval int `yaml:"interval"` } type Group struct { Name string `yaml:"group"` Tables []string `yaml:"tables"` Hosts []string `yaml:"hosts"` Check Check `yaml:"check"` } func main() { log.Println("Starting") confFileName := flag.String("f", "/usr/local/etc/hostchecker.yaml", "YAML configuration file") flag.Parse() var waitGroup sync.WaitGroup var conf map[string]Group yamlFile, err := ioutil.ReadFile(*confFileName) if err != nil { log.Fatalf("Configuration open error #%v ", err) } err = yaml.Unmarshal(yamlFile, &conf) if err != nil { log.Fatalf("Configuration read error #%v", err) } err = validateConfiguration(conf) if err != nil { log.Fatal("Configuration error #", err) } stopChannel := make(chan bool) for name, group := range conf { log.Println("Checking group", name, group) waitGroup.Add(1) go checkGroup(name, group, &waitGroup, stopChannel) } exit := make(chan os.Signal, 10) signal.Notify(exit, os.Interrupt) s := <-exit log.Println("Received signal", s) log.Println("main closing stopChannel") close(stopChannel) waitGroup.Wait() } func checkGroup(name string, group Group, waitGroup *sync.WaitGroup, stopChannel chan bool) { channels := make(map[string]chan int) for _, host := range group.Hosts { channel := make(chan int, 1) channels[host] = channel waitGroup.Add(1) go checkHost(channels[host], name, host, group.Check, waitGroup, stopChannel) } for { select { case <-stopChannel: log.Println("checkGroup", name, "stopChannel") waitGroup.Done() return break default: for host, channel := range channels { select { case stop := <-stopChannel: log.Println("checkGroup", name, "stopChannel", stop) break case status := <-channel: log.Println("Status for ", host, "is", status, "in group", name) updateTables( host, status, group.Tables ) default: time.Sleep(100 * time.Millisecond) } } } } } func updateTables( host string, status int, tables []string ) { for _, table := range tables { op := "add" if status != 0 { op = "del" } cmd := exec.Command("pfctl", "-t", table, "-T", op, host) err := cmd.Run() if err != nil { log.Println("Unable to run command #", cmd) } } } func checkHost(status chan<- int, group string, host string, check Check, waitGroup *sync.WaitGroup, stopChannel chan bool) { var lastCheck time.Time for { select { case <-stopChannel: log.Println("checkHost", host, "group", group, "stopChannel") waitGroup.Done() return default: var err error if time.Since(lastCheck).Seconds() > float64(check.Interval) { lastCheck = time.Now() err = nil switch check.Type { case "http": err = CheckHTTP(host, check) case "tcp": err = CheckTCP(host, check) case "smtp": err = CheckSMTP(host, check) } if err != nil { status <- 1 log.Println("checkHost", host, "group", group, "error", err) } else { status <- 0 } } time.Sleep(300*time.Millisecond) } } } func CheckHTTP(host string, check Check) error { client := &http.Client{ Timeout: time.Duration(check.Timeout) * time.Second, } req, err := http.NewRequest(check.Method, fmt.Sprintf("http://%s%s", host, check.Uri), nil) if err != nil { return err } if len(check.Host) > 0 { req.Header.Set("Host", check.Host) } resp, err := client.Do(req) if err != nil { return err } defer resp.Body.Close() _, _ = io.ReadAll(resp.Body) return nil } func CheckTCP(host string, check Check) error { cnx, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, check.Port), time.Duration(check.Timeout) * time.Second) if err != nil { return err } cnx.Close() return nil } func CheckSMTP(host string, check Check) error { return nil } func validateConfiguration(conf map[string]Group) error { for name, group := range conf { log.Println("Validating configuration", name) if len(group.Tables) == 0 { return errors.New(fmt.Sprintf("No tables in group %s", name)) } log.Println("Hosts", group.Hosts) if len(group.Hosts) == 0 { return errors.New(fmt.Sprintf("No hosts in group %s", name)) } for _, host := range group.Hosts { ip := net.ParseIP(host) if ip == nil { return errors.New(fmt.Sprintf("Host %v is not an IP in group %s", host, name)) } } switch group.Check.Type { case "http": if len(group.Check.Method) == 0 { group.Check.Method = "HEAD" } switch group.Check.Method { case "HEAD": case "GET": default: return errors.New(fmt.Sprintf("Check method shoud be HEAD or GET in group %s", name)) } if group.Check.Port == 0 { group.Check.Port = 80 } if len(group.Check.Uri) == 0 { group.Check.Uri = "/" } case "tcp": if group.Check.Port == 0 { return errors.New(fmt.Sprintf("Check port is undefined or 0 in group %s", name)) } case "smtp": if group.Check.Port == 0 { group.Check.Port = 25 } default: return errors.New(fmt.Sprintf("Check type should be http, smtp or tcp in group %s",name)) } if group.Check.Interval == 0 { group.Check.Interval = 5 } if group.Check.Timeout == 0 { group.Check.Timeout = 2 } conf[name] = group } return nil }