From c9c7a085c744d5023f932b2a0a6dba08153d2ba7 Mon Sep 17 00:00:00 2001 From: bt Date: Sun, 24 May 2026 16:29:03 +0200 Subject: [daemon] Add TLS support --- .gitignore | 3 ++- cmd/daemon/main.go | 32 +++++++++++++++++++++++-- server/message.go | 2 +- server/remote.go | 2 +- server/server.go | 70 +++++++++++++++++++++++++++++++++++++++++------------- tools/key-gen.sh | 5 ++++ tools/run-tls.sh | 8 +++++++ 7 files changed, 100 insertions(+), 22 deletions(-) create mode 100755 tools/key-gen.sh create mode 100755 tools/run-tls.sh diff --git a/.gitignore b/.gitignore index 884e3de..2c206f1 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -test.db \ No newline at end of file +test.db +*.pem \ No newline at end of file diff --git a/cmd/daemon/main.go b/cmd/daemon/main.go index 10fb61c..aebc085 100644 --- a/cmd/daemon/main.go +++ b/cmd/daemon/main.go @@ -27,11 +27,15 @@ var ( ) func main() { - fmt.Println("SOLEC MOCK SERVER") + fmt.Println("SOLEC SERVER") fmt.Println("Commands:", slices.Sorted(maps.Keys(cmds))) addr := flag.String("a", "localhost:9999", "listening address:port") name := flag.String("n", "localhost", "server name") + enableTls := flag.Bool("tls", false, "Enable TLS") + certPath := flag.String("tls-cert", "", "TLS certificate PEM file path") + keyPath := flag.String("tls-key", "", "TLS key PEM file path") + flag.Parse() db, err := storage.InitDb("test.db") @@ -39,7 +43,17 @@ func main() { panic(err) } - serv = server.NewServer(*addr, *name, db) + cfg := server.Config{ + ListenAddr: *addr, + Name: *name, + } + + if *enableTls { + cfg.Tls = true + cfg.CertPem, cfg.KeyPem = loadKeys(*certPath, *keyPath) + } + + serv = server.NewServer(cfg, db) serv.AddChannel("test") go func() { @@ -53,6 +67,20 @@ func main() { readCmds() } +func loadKeys(certPath, keyPath string) ([]byte, []byte) { + cert, err := os.ReadFile(certPath) + if err != nil { + panic(err) + } + + key, err := os.ReadFile(keyPath) + if err != nil { + panic(err) + } + + return cert, key +} + func readCmds() { sc := bufio.NewScanner(os.Stdin) for sc.Scan() { diff --git a/server/message.go b/server/message.go index b8b0631..e814d9d 100644 --- a/server/message.go +++ b/server/message.go @@ -46,7 +46,7 @@ func (s *Server) handleMessage(sender net.Conn, connType core.ConnType, msg core log.Println("cannot write to database", err) } - if addr.Host == s.name { + if addr.Host == s.cfg.Name { return s.handleLocalMessage(sender, addr, msg) } diff --git a/server/remote.go b/server/remote.go index 6f9bd66..e1829b1 100644 --- a/server/remote.go +++ b/server/remote.go @@ -102,7 +102,7 @@ func (s *Server) initRemoteConn(name string) (net.Conn, error) { return conn, err } - auth := core.ServerAuth{Name: s.name} + auth := core.ServerAuth{Name: s.cfg.Name} if err := core.Send(conn, auth); err != nil { conn.Close() return conn, err diff --git a/server/server.go b/server/server.go index 58714c5..b5840df 100644 --- a/server/server.go +++ b/server/server.go @@ -1,6 +1,7 @@ package server import ( + "crypto/tls" "errors" "log" "net" @@ -10,8 +11,7 @@ import ( ) type Server struct { - listenAddr string - name string + cfg Config users map[string]User // TODO: Use full address instead of just name servers map[string]RemoteServer channels map[string]*Channel @@ -21,23 +21,66 @@ type Server struct { Storage Storage } -func NewServer(listenAddr string, name string, storage Storage) *Server { +type Config struct { + ListenAddr string + Name string + Tls bool + CertPem []byte + KeyPem []byte +} + +func NewServer(cfg Config, storage Storage) *Server { return &Server{ - listenAddr: listenAddr, - name: name, - users: make(map[string]User), - servers: make(map[string]RemoteServer), - channels: make(map[string]*Channel), - Storage: storage, + cfg: cfg, + Storage: storage, + users: make(map[string]User), + servers: make(map[string]RemoteServer), + channels: make(map[string]*Channel), } } func (s *Server) Start() error { - ln, err := net.Listen("tcp", s.listenAddr) + if s.cfg.Tls { + return s.listenTls() + } + + return s.listenPlain() +} + +func (s *Server) AddChannel(name string) { + s.channelsMu.Lock() + defer s.channelsMu.Unlock() + s.channels[name] = NewChannel(name) + log.Println("created channel", name) +} + +func (s *Server) listenPlain() error { + ln, err := net.Listen("tcp", s.cfg.ListenAddr) + if err != nil { + return err + } + + s.listen(ln) + return nil +} + +func (s *Server) listenTls() error { + cert, err := tls.X509KeyPair(s.cfg.CertPem, s.cfg.KeyPem) + if err != nil { + return err + } + + cfg := &tls.Config{Certificates: []tls.Certificate{cert}} + ln, err := tls.Listen("tcp", s.cfg.ListenAddr, cfg) if err != nil { return err } + s.listen(ln) + return nil +} + +func (s *Server) listen(ln net.Listener) { for { conn, err := ln.Accept() if err != nil { @@ -50,13 +93,6 @@ func (s *Server) Start() error { } } -func (s *Server) AddChannel(name string) { - s.channelsMu.Lock() - defer s.channelsMu.Unlock() - s.channels[name] = NewChannel(name) - log.Println("created channel", name) -} - func (s *Server) handleConn(conn net.Conn) { defer conn.Close() diff --git a/tools/key-gen.sh b/tools/key-gen.sh new file mode 100755 index 0000000..14957d4 --- /dev/null +++ b/tools/key-gen.sh @@ -0,0 +1,5 @@ +#!/bin/sh +openssl req -x509 -newkey rsa:4096 \ + -keyout key.pem -out cert.pem \ + -sha256 -days 3650 -nodes \ + -subj "/C=XX/ST=TEST/L=TEST/O=TEST/OU=TEST/CN=TEST" diff --git a/tools/run-tls.sh b/tools/run-tls.sh new file mode 100755 index 0000000..be668f5 --- /dev/null +++ b/tools/run-tls.sh @@ -0,0 +1,8 @@ +#!/bin/sh + +tmux \ + new-session "go run cmd/daemon/main.go -tls -tls-cert cert.pem -tls-key key.pem; read" \; \ + split-window "sleep 0.5; go run cmd/client/main.go -tls -tls-debug -u user1; read" \; \ + split-window "sleep 0.5; go run cmd/client/main.go -tls -tls-debug -u user2; read" \; \ + split-window "sleep 0.5; go run cmd/client/main.go -tls -tls-debug -u user3; read" \; \ + select-layout tiled; \ No newline at end of file -- cgit v1.2.3