diff --git a/conf.yaml b/conf.yaml index 688dd25..bf8d0ee 100644 --- a/conf.yaml +++ b/conf.yaml @@ -6,6 +6,7 @@ IDE: - 172.17.4.31 check: type: http + host: check port: 80 uri: / method: HEAD @@ -36,3 +37,11 @@ SMTP: interval: 5 timeout: 2 +TEST: + tables: + - table1 + hosts: + - 172.17.4.40 + check: + type: tcp + port: 8080 diff --git a/hostchecker.go b/hostchecker.go new file mode 100644 index 0000000..978b347 --- /dev/null +++ b/hostchecker.go @@ -0,0 +1,237 @@ +package main + +import "fmt" +import "net/http" +import "io" +import "io/ioutil" +import "log" +import "gopkg.in/yaml.v3" +import "time" +import "sync" +import "os/signal" +import "os" +import "errors" +import "net" + +type Check struct { + Type string `yaml:"type"` + Host string `yaml:"host"` + Port int `yaml:"port"` + Uri string `yaml:"uri"` + Method string `yaml:"method"` + Timeout int `yaml:"timeout"` + Interval int `yaml:"interval"` +} + +type Group struct { + Name string `yaml:"group"` + Tables []string `yaml:"tables"` + Hosts []string `yaml:"hosts"` + Check Check `yaml:"check"` +} + + +func main() { + log.Println("Starting") + + var waitGroup sync.WaitGroup + + var conf map[string]Group + yamlFile, err := ioutil.ReadFile("conf.yaml") + if err != nil { + log.Fatalf("Configuration open error #%v ", err) + } + err = yaml.Unmarshal(yamlFile, &conf) + if err != nil { + log.Fatalf("Configuration read error #%v", err) + } + + err = validateConfiguration(conf) + if err != nil { + log.Fatal("Configuration error #", err) + } + stopChannel := make(chan bool) + for name, group := range conf { + log.Println("Checking group", name, group) + waitGroup.Add(1) + go checkGroup(name, group, &waitGroup, stopChannel) + } + + exit := make(chan os.Signal, 10) + + signal.Notify(exit, os.Interrupt) + + s := <-exit + log.Println("Received signal", s) + log.Println("main closing stopChannel") + close(stopChannel) + waitGroup.Wait() +} + +func checkGroup(name string, group Group, waitGroup *sync.WaitGroup, stopChannel chan bool) { + channels := make(map[string]chan int) + for _, host := range group.Hosts { + channel := make(chan int, 1) + channels[host] = channel + waitGroup.Add(1) + go checkHost(channels[host], name, host, group.Check, waitGroup, stopChannel) + } + + for { + select { + case <-stopChannel: + log.Println("checkGroup", name, "stopChannel") + waitGroup.Done() + return + break + default: + for host, channel := range channels { + select { + case stop := <-stopChannel: + log.Println("checkGroup", name, "stopChannel", stop) + break + case status := <-channel: + log.Println("Status for ", host, "is", status, "in group", name) + if status == 0 { + addHostToTables( host, group.Tables ) + } else { + removeHostFromTables( host, group.Tables) + } + default: + time.Sleep(100 * time.Millisecond) + } + } + } + } +} + +func addHostToTables( host string, tables []string ) { +} +func removeHostFromTables( host string, tables []string ) { +} +func checkHost(status chan<- int, group string, host string, check Check, waitGroup *sync.WaitGroup, stopChannel chan bool) { + + var lastCheck time.Time + + for { + select { + case <-stopChannel: + log.Println("checkHost", host, "group", group, "stopChannel") + waitGroup.Done() + return + default: + var err error + if time.Since(lastCheck).Seconds() > float64(check.Interval) { + lastCheck = time.Now() + err = nil + switch check.Type { + case "http": + err = CheckHTTP(host, check) + case "tcp": + err = CheckTCP(host, check) + case "smtp": + err = CheckSMTP(host, check) + } + if err != nil { + status <- 1 + log.Println("checkHost", host, "group", group, "error", err) + } else { + status <- 0 + } + } + time.Sleep(300*time.Millisecond) + } + } +} + +func CheckHTTP(host string, check Check) error { + + + client := &http.Client{ + Timeout: time.Duration(check.Timeout) * time.Second, + } + req, err := http.NewRequest(check.Method, fmt.Sprintf("http://%s%s", host, check.Uri), nil) + if err != nil { + return err + } + if len(check.Host) > 0 { + req.Header.Set("Host", check.Host) + } + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + return nil +} + +func CheckTCP(host string, check Check) error { + cnx, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, check.Port), time.Duration(check.Timeout) * time.Second) + if err != nil { + return err + } + cnx.Close() + return nil +} + +func CheckSMTP(host string, check Check) error { + return nil +} + +func validateConfiguration(conf map[string]Group) error { + for name, group := range conf { + log.Println("Validating configuration", name) + if len(group.Tables) == 0 { + return errors.New(fmt.Sprintf("No tables in group %s", name)) + } + log.Println("Hosts", group.Hosts) + if len(group.Hosts) == 0 { + return errors.New(fmt.Sprintf("No hosts in group %s", name)) + } + for _, host := range group.Hosts { + ip := net.ParseIP(host) + if ip == nil { + return errors.New(fmt.Sprintf("Host %v is not an IP in group %s", host, name)) + } + } + switch group.Check.Type { + case "http": + if len(group.Check.Method) == 0 { + group.Check.Method = "HEAD" + } + switch group.Check.Method { + case "HEAD": + case "GET": + default: + return errors.New(fmt.Sprintf("Check method shoud be HEAD or GET in group %s", name)) + } + if group.Check.Port == 0 { + group.Check.Port = 80 + } + if len(group.Check.Uri) == 0 { + group.Check.Uri = "/" + } + case "tcp": + if group.Check.Port == 0 { + return errors.New(fmt.Sprintf("Check port is undefined or 0 in group %s", name)) + } + case "smtp": + if group.Check.Port == 0 { + group.Check.Port = 25 + } + + default: + return errors.New(fmt.Sprintf("Check type should be http, smtp or tcp in group %s",name)) + } + if group.Check.Interval == 0 { + group.Check.Interval = 5 + } + if group.Check.Timeout == 0 { + group.Check.Timeout = 2 + } + conf[name] = group + } + return nil +} diff --git a/main.go b/main.go deleted file mode 100644 index ad9a8a7..0000000 --- a/main.go +++ /dev/null @@ -1,164 +0,0 @@ -package main - -import "fmt" -import "net/http" -import "io" -import "io/ioutil" -import "log" -import "gopkg.in/yaml.v3" -import "time" -import "sync" -import "os/signal" -import "os" - -type Check struct { - Type string `yaml:"type"` - Port int `yaml:"port"` - Uri string `yaml:"uri"` - Method string `yaml:"method"` - Timeout int `yaml:"timeout"` - Interval int `yaml:"interval"` -} - -type Group struct { - Name string `yaml:"group"` - Tables []string `yaml:"tables"` - Hosts []string `yaml:"hosts"` - Check Check `yaml:"check"` -} - -func main() { - log.Println("Starting") - - var waitGroup sync.WaitGroup - - var conf map[string]Group - yamlFile, err := ioutil.ReadFile("conf.yaml") - if err != nil { - log.Fatalf("Configuration open error #%v ", err) - } - err = yaml.Unmarshal(yamlFile, &conf) - if err != nil { - log.Fatalf("Configuration read error #%v", err) - } - - stopChannel := make(chan bool) - for name, group := range conf { - log.Println("Checking group", name, group) - waitGroup.Add(1) - go checkGroup(name, group, &waitGroup, stopChannel) - } - - exit := make(chan os.Signal, 1) - - signal.Notify(exit, os.Interrupt) - - s := <-exit - log.Println("Received signal", s) - log.Println("main closing stopChannel") - close(stopChannel) - waitGroup.Wait() -} - -func checkGroup(name string, group Group, waitGroup *sync.WaitGroup, stopChannel chan bool) { - channels := make(map[string]chan int) - for _, host := range group.Hosts { - channel := make(chan int, 1) - channels[host] = channel - waitGroup.Add(1) - go checkHost(channels[host], name, host, group.Check, waitGroup, stopChannel) - } - - for { - select { - case <-stopChannel: - log.Println("checkGroup", name, "stopChannel") - waitGroup.Done() - return - break - default: - for host, channel := range channels { - select { - case stop := <-stopChannel: - log.Println("checkGroup", name, "stopChannel", stop) - break - case status := <-channel: - log.Println("Status for ", host, "is", status) - default: - time.Sleep(100 * time.Millisecond) - } - } - } - } -} - -func checkHost(status chan<- int, group string, host string, check Check, waitGroup *sync.WaitGroup, stopChannel chan bool) { - - if check.Timeout == 0 { - check.Timeout = 5 - } - if check.Interval == 0 { - check.Interval = 60 - } - if len(check.Method) == 0 { - check.Method = "HEAD" - } - if len(check.Uri) == 0 { - check.Uri = "/" - } - if check.Port == 0 { - check.Port = 80 - } - for { - select { - case <-stopChannel: - log.Println("checkHost", host, "group", group, "stopChannel") - waitGroup.Done() - return - default: - var err error - - err = nil - switch check.Type { - case "http": - err = CheckHTTP(host, check) - case "tcp": - err = CheckTCP(host, check) - case "smtp": - err = CheckSMTP(host, check) - default: - err = CheckHTTP(host, check) - } - if err != nil { - status <- 1 - log.Println("checkHost", host, "group", group, "error", err) - } else { - status <- 0 - } - time.Sleep(time.Duration(check.Interval) * time.Second) - } - } -} - -func CheckHTTP(host string, check Check) error { - - client := &http.Client{ - Timeout: time.Second * time.Duration(check.Timeout), - } - resp, err := client.Head(fmt.Sprintf("http://%s", host)) - if err != nil { - return err - } - _, _ = io.ReadAll(resp.Body) - resp.Body.Close() - - return nil -} - -func CheckTCP(host string, check Check) error { - return nil -} - -func CheckSMTP(host string, check Check) error { - return nil -}