/* Copyright 2023 Laurent Ulrich (laurentu@gmail.com) Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package main import "fmt" import "net/http" import "io" import "io/ioutil" import "log" import "gopkg.in/yaml.v3" import "time" import "sync" import "os/signal" import "os" import "errors" import "net" import "os/exec" import "flag" import "log/syslog" type Check struct { Type string `yaml:"type"` Host string `yaml:"host"` Port int `yaml:"port"` Uri string `yaml:"uri"` Method string `yaml:"method"` Timeout int `yaml:"timeout"` Interval int `yaml:"interval"` } type Group struct { Name string `yaml:"group"` Tables []string `yaml:"tables"` Hosts []string `yaml:"hosts"` Check Check `yaml:"check"` } var verbose int func main() { if verbose > 0 { log.Println("Starting") } confFileName := flag.String("f", "/usr/local/etc/hostchecker.yaml", "YAML configuration file") flag.IntVar(&verbose, "verbose", 0, "Set logs verbosity") useSyslog := flag.Bool("syslog", false, "Send logs to syslog") flag.Parse() if *useSyslog == true { syslogWriter, err := syslog.New(syslog.LOG_ERR, "hostchecker") if err != nil { log.Fatal("Error opening syslog #", err) } log.SetOutput(syslogWriter) } var waitGroup sync.WaitGroup var conf map[string]Group yamlFile, err := ioutil.ReadFile(*confFileName) if err != nil { log.Fatalf("Configuration open error #%v ", err) } err = yaml.Unmarshal(yamlFile, &conf) if err != nil { log.Fatalf("Configuration read error #%v", err) } err = validateConfiguration(conf) if err != nil { log.Fatal("Configuration error #", err) } stopChannel := make(chan bool) for name, group := range conf { if verbose > 0 { log.Println("Checking group", name, group) } waitGroup.Add(1) go checkGroup(name, group, &waitGroup, stopChannel) } exit := make(chan os.Signal, 10) signal.Notify(exit, os.Interrupt) s := <-exit if verbose > 1 { log.Println("Received signal", s) } if verbose > 0 { log.Println("main closing stopChannel") } close(stopChannel) waitGroup.Wait() } func checkGroup(name string, group Group, waitGroup *sync.WaitGroup, stopChannel chan bool) { channels := make(map[string]chan int) statusPerHost := make(map[string]int) for _, host := range group.Hosts { channel := make(chan int, 1) channels[host] = channel statusPerHost[host] = -1 waitGroup.Add(1) go checkHost(channels[host], name, host, group.Check, waitGroup, stopChannel) } for { select { case <-stopChannel: if verbose > 0 { log.Println("checkGroup", name, "stopChannel") } waitGroup.Done() return default: for host, channel := range channels { select { case stop := <-stopChannel: if verbose > 0 { log.Println("checkGroup", name, "stopChannel", stop) } break case status := <-channel: if verbose > 1 { log.Println("Status for ", host, "was", statusPerHost[host], "in group", name) log.Println("Status for ", host, "is", status, "in group", name) } if statusPerHost[host] == -1 || statusPerHost[host] != status { statusPerHost[host] = status updateTables(host, status, group.Tables) } default: time.Sleep(100 * time.Millisecond) } } } } } func updateTables(host string, status int, tables []string) { for _, table := range tables { op := "add" if status != 0 { op = "del" } cmd := exec.Command("pfctl", "-t", table, "-T", op, host) err := cmd.Run() if err != nil { log.Println("Unable to run command #", cmd) } } } func checkHost(status chan<- int, group string, host string, check Check, waitGroup *sync.WaitGroup, stopChannel chan bool) { var lastCheck time.Time for { select { case <-stopChannel: if verbose > 0 { log.Println("checkHost", host, "group", group, "stopChannel") } waitGroup.Done() return default: var err error if time.Since(lastCheck).Seconds() > float64(check.Interval) { lastCheck = time.Now() err = nil switch check.Type { case "http": err = CheckHTTP(host, check) case "tcp": err = CheckTCP(host, check) case "smtp": err = CheckSMTP(host, check) } if err != nil { status <- 1 if verbose > 1 { log.Println("checkHost", host, "group", group, "error", err) } } else { status <- 0 } } time.Sleep(300 * time.Millisecond) } } } func CheckHTTP(host string, check Check) error { client := &http.Client{ Timeout: time.Duration(check.Timeout) * time.Second, } req, err := http.NewRequest(check.Method, fmt.Sprintf("http://%s%s", host, check.Uri), nil) if err != nil { return err } if len(check.Host) > 0 { req.Header.Set("Host", check.Host) } resp, err := client.Do(req) if err != nil { return err } defer resp.Body.Close() _, _ = io.ReadAll(resp.Body) return nil } func CheckTCP(host string, check Check) error { cnx, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, check.Port), time.Duration(check.Timeout)*time.Second) if err != nil { return err } cnx.Close() return nil } func CheckSMTP(host string, check Check) error { return nil } func validateConfiguration(conf map[string]Group) error { for name, group := range conf { if verbose > 1 { log.Println("Validating configuration", name) } if len(group.Tables) == 0 { return errors.New(fmt.Sprintf("No tables in group %s", name)) } if verbose > 1 { log.Println("Hosts", group.Hosts) } if len(group.Hosts) == 0 { return errors.New(fmt.Sprintf("No hosts in group %s", name)) } for _, host := range group.Hosts { ip := net.ParseIP(host) if ip == nil { return errors.New(fmt.Sprintf("Host %v is not an IP in group %s", host, name)) } } switch group.Check.Type { case "http": if len(group.Check.Method) == 0 { group.Check.Method = "HEAD" } switch group.Check.Method { case "HEAD": case "GET": default: return errors.New(fmt.Sprintf("Check method shoud be HEAD or GET in group %s", name)) } if group.Check.Port == 0 { group.Check.Port = 80 } if len(group.Check.Uri) == 0 { group.Check.Uri = "/" } case "tcp": if group.Check.Port == 0 { return errors.New(fmt.Sprintf("Check port is undefined or 0 in group %s", name)) } case "smtp": if group.Check.Port == 0 { group.Check.Port = 25 } default: return errors.New(fmt.Sprintf("Check type should be http, smtp or tcp in group %s", name)) } if group.Check.Interval == 0 { group.Check.Interval = 5 } if group.Check.Timeout == 0 { group.Check.Timeout = 2 } conf[name] = group } return nil }