This commit is contained in:
laurentu 2023-08-17 16:14:13 +02:00
parent ceabd1d76b
commit b10c025860
1 changed files with 114 additions and 116 deletions

View File

@ -1,17 +1,17 @@
/* /*
Copyright 2023 Laurent Ulrich (laurentu@gmail.com) Copyright 2023 Laurent Ulrich (laurentu@gmail.com)
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package main package main
@ -32,7 +32,7 @@ import "flag"
type Check struct { type Check struct {
Type string `yaml:"type"` Type string `yaml:"type"`
Host string `yaml:"host"` Host string `yaml:"host"`
Port int `yaml:"port"` Port int `yaml:"port"`
Uri string `yaml:"uri"` Uri string `yaml:"uri"`
Method string `yaml:"method"` Method string `yaml:"method"`
@ -47,12 +47,11 @@ type Group struct {
Check Check `yaml:"check"` Check Check `yaml:"check"`
} }
func main() { func main() {
log.Println("Starting") log.Println("Starting")
confFileName := flag.String("f", "/usr/local/etc/hostchecker.yaml", "YAML configuration file") confFileName := flag.String("f", "/usr/local/etc/hostchecker.yaml", "YAML configuration file")
flag.Parse() flag.Parse()
var waitGroup sync.WaitGroup var waitGroup sync.WaitGroup
@ -66,10 +65,10 @@ func main() {
log.Fatalf("Configuration read error #%v", err) log.Fatalf("Configuration read error #%v", err)
} }
err = validateConfiguration(conf) err = validateConfiguration(conf)
if err != nil { if err != nil {
log.Fatal("Configuration error #", err) log.Fatal("Configuration error #", err)
} }
stopChannel := make(chan bool) stopChannel := make(chan bool)
for name, group := range conf { for name, group := range conf {
log.Println("Checking group", name, group) log.Println("Checking group", name, group)
@ -112,7 +111,7 @@ func checkGroup(name string, group Group, waitGroup *sync.WaitGroup, stopChannel
break break
case status := <-channel: case status := <-channel:
log.Println("Status for ", host, "is", status, "in group", name) log.Println("Status for ", host, "is", status, "in group", name)
updateTables( host, status, group.Tables ) updateTables(host, status, group.Tables)
default: default:
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
@ -121,23 +120,23 @@ func checkGroup(name string, group Group, waitGroup *sync.WaitGroup, stopChannel
} }
} }
func updateTables( host string, status int, tables []string ) { func updateTables(host string, status int, tables []string) {
for _, table := range tables { for _, table := range tables {
op := "add" op := "add"
if status != 0 { if status != 0 {
op = "del" op = "del"
} }
cmd := exec.Command("pfctl", "-t", table, "-T", op, host) cmd := exec.Command("pfctl", "-t", table, "-T", op, host)
err := cmd.Run() err := cmd.Run()
if err != nil { if err != nil {
log.Println("Unable to run command #", cmd) log.Println("Unable to run command #", cmd)
} }
} }
} }
func checkHost(status chan<- int, group string, host string, check Check, waitGroup *sync.WaitGroup, stopChannel chan bool) { func checkHost(status chan<- int, group string, host string, check Check, waitGroup *sync.WaitGroup, stopChannel chan bool) {
var lastCheck time.Time var lastCheck time.Time
for { for {
select { select {
@ -147,58 +146,57 @@ func checkHost(status chan<- int, group string, host string, check Check, waitGr
return return
default: default:
var err error var err error
if time.Since(lastCheck).Seconds() > float64(check.Interval) { if time.Since(lastCheck).Seconds() > float64(check.Interval) {
lastCheck = time.Now() lastCheck = time.Now()
err = nil err = nil
switch check.Type { switch check.Type {
case "http": case "http":
err = CheckHTTP(host, check) err = CheckHTTP(host, check)
case "tcp": case "tcp":
err = CheckTCP(host, check) err = CheckTCP(host, check)
case "smtp": case "smtp":
err = CheckSMTP(host, check) err = CheckSMTP(host, check)
} }
if err != nil { if err != nil {
status <- 1 status <- 1
log.Println("checkHost", host, "group", group, "error", err) log.Println("checkHost", host, "group", group, "error", err)
} else { } else {
status <- 0 status <- 0
} }
} }
time.Sleep(300*time.Millisecond) time.Sleep(300 * time.Millisecond)
} }
} }
} }
func CheckHTTP(host string, check Check) error { func CheckHTTP(host string, check Check) error {
client := &http.Client{ client := &http.Client{
Timeout: time.Duration(check.Timeout) * time.Second, Timeout: time.Duration(check.Timeout) * time.Second,
} }
req, err := http.NewRequest(check.Method, fmt.Sprintf("http://%s%s", host, check.Uri), nil) req, err := http.NewRequest(check.Method, fmt.Sprintf("http://%s%s", host, check.Uri), nil)
if err != nil { if err != nil {
return err return err
} }
if len(check.Host) > 0 { if len(check.Host) > 0 {
req.Header.Set("Host", check.Host) req.Header.Set("Host", check.Host)
} }
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return err return err
} }
defer resp.Body.Close() defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body) _, _ = io.ReadAll(resp.Body)
return nil return nil
} }
func CheckTCP(host string, check Check) error { func CheckTCP(host string, check Check) error {
cnx, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, check.Port), time.Duration(check.Timeout) * time.Second) cnx, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, check.Port), time.Duration(check.Timeout)*time.Second)
if err != nil { if err != nil {
return err return err
} }
cnx.Close() cnx.Close()
return nil return nil
} }
@ -207,57 +205,57 @@ func CheckSMTP(host string, check Check) error {
} }
func validateConfiguration(conf map[string]Group) error { func validateConfiguration(conf map[string]Group) error {
for name, group := range conf { for name, group := range conf {
log.Println("Validating configuration", name) log.Println("Validating configuration", name)
if len(group.Tables) == 0 { if len(group.Tables) == 0 {
return errors.New(fmt.Sprintf("No tables in group %s", name)) return errors.New(fmt.Sprintf("No tables in group %s", name))
} }
log.Println("Hosts", group.Hosts) log.Println("Hosts", group.Hosts)
if len(group.Hosts) == 0 { if len(group.Hosts) == 0 {
return errors.New(fmt.Sprintf("No hosts in group %s", name)) return errors.New(fmt.Sprintf("No hosts in group %s", name))
} }
for _, host := range group.Hosts { for _, host := range group.Hosts {
ip := net.ParseIP(host) ip := net.ParseIP(host)
if ip == nil { if ip == nil {
return errors.New(fmt.Sprintf("Host %v is not an IP in group %s", host, name)) return errors.New(fmt.Sprintf("Host %v is not an IP in group %s", host, name))
} }
} }
switch group.Check.Type { switch group.Check.Type {
case "http": case "http":
if len(group.Check.Method) == 0 { if len(group.Check.Method) == 0 {
group.Check.Method = "HEAD" group.Check.Method = "HEAD"
} }
switch group.Check.Method { switch group.Check.Method {
case "HEAD": case "HEAD":
case "GET": case "GET":
default: default:
return errors.New(fmt.Sprintf("Check method shoud be HEAD or GET in group %s", name)) return errors.New(fmt.Sprintf("Check method shoud be HEAD or GET in group %s", name))
} }
if group.Check.Port == 0 { if group.Check.Port == 0 {
group.Check.Port = 80 group.Check.Port = 80
} }
if len(group.Check.Uri) == 0 { if len(group.Check.Uri) == 0 {
group.Check.Uri = "/" group.Check.Uri = "/"
} }
case "tcp": case "tcp":
if group.Check.Port == 0 { if group.Check.Port == 0 {
return errors.New(fmt.Sprintf("Check port is undefined or 0 in group %s", name)) return errors.New(fmt.Sprintf("Check port is undefined or 0 in group %s", name))
} }
case "smtp": case "smtp":
if group.Check.Port == 0 { if group.Check.Port == 0 {
group.Check.Port = 25 group.Check.Port = 25
} }
default: default:
return errors.New(fmt.Sprintf("Check type should be http, smtp or tcp in group %s",name)) return errors.New(fmt.Sprintf("Check type should be http, smtp or tcp in group %s", name))
} }
if group.Check.Interval == 0 { if group.Check.Interval == 0 {
group.Check.Interval = 5 group.Check.Interval = 5
} }
if group.Check.Timeout == 0 { if group.Check.Timeout == 0 {
group.Check.Timeout = 2 group.Check.Timeout = 2
} }
conf[name] = group conf[name] = group
} }
return nil return nil
} }