Files
2023-11-22 20:41:41 -07:00

166 lines
4.0 KiB
Go

package main
import (
"fmt"
"os/exec"
"regexp"
"git.ohea.xyz/golang/config"
)
type Config struct {
IgnoreDevices bool
Devices []string
IgnoreConnections bool
Connections []string
VPNConnection string
}
func isMatch(ignoreStrings bool, strings []string, match string) bool {
if ignoreStrings {
for _, s := range strings {
if match == s {
return false
}
}
return true
} else {
for _, s := range strings {
if match == s {
return true
}
}
return false
}
}
func PrintArray(message string, items []string) {
fmt.Printf("%s: ", message)
for i, s := range items {
fmt.Print(s)
if i < len(items)-1 {
fmt.Print(", ")
}
}
fmt.Printf("\n")
}
func enableVPN(vpnCon string) error {
cmd := exec.Command("nmcli", "con", "up", vpnCon)
if err := cmd.Run(); err != nil {
return fmt.Errorf("could not launch VPN: %w\n", err)
}
return nil
}
func disableVPN(vpnCon string) error {
cmd := exec.Command("nmcli", "con", "down", vpnCon)
if err := cmd.Run(); err != nil {
return fmt.Errorf("could not stop VPN: %w\n", err)
}
return nil
}
func main() {
configData := config.Config[Config]{
Name: "nmcli-vpn-activator",
Filename: "config",
Config: Config{
IgnoreDevices: false,
Devices: []string{},
IgnoreConnections: true,
Connections: []string{},
VPNConnection: "",
},
}
new, err := configData.Get()
if err != nil {
fmt.Printf("Could not get config: %w\n", err)
return
}
if new {
fmt.Println("New config created, please update and restart.")
return
}
if configData.Config.IgnoreDevices {
PrintArray("Ignoring the following devices", configData.Config.Devices)
} else {
PrintArray("Enabling the VPN on the following devices", configData.Config.Devices)
}
if configData.Config.IgnoreConnections {
PrintArray("Ignoring the following connections", configData.Config.Connections)
} else {
PrintArray("Enabling the VPN on the following connections", configData.Config.Connections)
}
nmcli_monitor := exec.Command("nmcli", "monitor")
stdout, err := nmcli_monitor.StdoutPipe()
if err != nil {
fmt.Printf("could not get stdout: %w\n", err)
return
}
err = nmcli_monitor.Start()
if err != nil {
fmt.Printf("Could not start process: %w\n", err)
return
}
buf := make([]byte, 1024)
connectingEnable := false
connectingDisable := false
for {
n, err := stdout.Read(buf)
if err != nil {
fmt.Printf("could not read stdout: %w\n, err")
return
}
logLine := string(buf[:n])
usingConnection := regexp.MustCompile(`(?P<device>\S+?): using connection \'(?P<network>.+?)\'`)
results := usingConnection.FindStringSubmatch(logLine)
if len(results) > 0 {
if !isMatch(configData.Config.IgnoreDevices, configData.Config.Devices, results[1]) {
continue
}
if !isMatch(configData.Config.IgnoreConnections, configData.Config.Connections, results[2]) {
fmt.Printf("Monitored device %s has started connecting to VPN disabled network %s.\n", results[1], results[2])
connectingDisable = true
connectingEnable = false
} else {
fmt.Printf("Monitored device %s has started connecting to VPN enabled network %s.\n", results[1], results[2])
connectingEnable = true
connectingDisable = false
}
continue
}
deviceConnected := regexp.MustCompile(`(?P<device>\S+?): connected`)
results = deviceConnected.FindStringSubmatch(logLine)
if len(results) > 0 {
if connectingEnable {
fmt.Printf("Monitored device %s has connected to VPN enabled network, enabling VPN.\n", results[1])
err = enableVPN(configData.Config.VPNConnection)
if err != nil {
fmt.Printf("Could not enable VPN: %w\n", err)
}
connectingEnable = false
}
if connectingDisable {
fmt.Printf("Monitored device %s has connected to VPN disabled network, disabling VPN.\n", results[1])
err = disableVPN(configData.Config.VPNConnection)
if err != nil {
fmt.Printf("Could not disable VPN: %w\n", err)
}
connectingDisable = false
}
}
}
}