From e9aebac1a2a4732763c2f7e4428a23983d4eb6a3 Mon Sep 17 00:00:00 2001 From: bt Date: Fri, 10 Apr 2026 19:31:31 +0200 Subject: [common] Exchange messages between servers --- server/message.go | 37 +++++++++++++++ server/remote.go | 112 +++++++++++++++++++++++++++++++++++++++++++ server/server.go | 139 ++++++++++-------------------------------------------- server/user.go | 68 +++++++++++++++++++++++--- 4 files changed, 234 insertions(+), 122 deletions(-) create mode 100644 server/remote.go (limited to 'server') diff --git a/server/message.go b/server/message.go index c1384f7..2487abd 100644 --- a/server/message.go +++ b/server/message.go @@ -1,6 +1,8 @@ package server import ( + "errors" + "fmt" "log" "time" @@ -27,3 +29,38 @@ func (s *Server) SendBroadcast(msg string) { } } } + +func (s *Server) handleMessage(msg core.Message) error { + log.Println("message:", msg.Source, "->", msg.Target, msg.Content) + + channel, host, err := core.ReadAddr(msg.Target) + if err != nil { + return err + } + + if host == s.name { + return s.handleLocalMessage(channel, msg) + } + + return s.handleOutboundMessage(channel, host, msg) +} + +func (s *Server) handleLocalMessage(channel string, msg core.Message) error { + s.usersMu.RLock() + user, ok := s.users[channel] + if !ok { + return errors.New("target not found") + } + s.usersMu.RUnlock() + + return user.Send(msg) +} + +func (s *Server) handleOutboundMessage(channel, host string, msg core.Message) error { + remote, err := s.getRemote(host) + if err != nil { + return fmt.Errorf("cannot access remote server: %w", err) + } + + return core.Send(remote.Conn, msg) +} diff --git a/server/remote.go b/server/remote.go new file mode 100644 index 0000000..2449511 --- /dev/null +++ b/server/remote.go @@ -0,0 +1,112 @@ +package server + +import ( + "log" + "net" + + "go.rctt.net/solec/core" +) + +type RemoteServer struct { + Name string + Conn net.Conn +} + +func NewRemoteServer(name string, conn net.Conn) RemoteServer { + return RemoteServer{name, conn} +} + +func (s *Server) handleServerConn(conn net.Conn) { + defer conn.Close() + + name, err := s.performServerAuth(conn) + if err != nil { + log.Println("server auth error:", err) + return + } + + s.serversMu.RLock() + if _, ok := s.servers[name]; ok { + log.Println("server already connected") + return + } + s.serversMu.RUnlock() + + rs := NewRemoteServer(name, conn) + s.serversMu.Lock() + s.servers[name] = rs + s.serversMu.Unlock() + log.Println("connection from server:", name) + + defer func() { + s.serversMu.Lock() + log.Println("server disconnected: ", rs.Name) + delete(s.servers, rs.Name) + s.serversMu.Unlock() + }() + + if err := s.readInput(conn); err != nil { + log.Println(err) + } +} + +func (s *Server) performServerAuth(conn net.Conn) (string, error) { + payload, err := core.Decode(conn) + if err != nil { + return "", err + } + auth, ok := payload.(core.ServerAuth) + if !ok { + return "", core.ErrUnexpectedPayloadType + } + + if err := core.Send(conn, core.Success{}); err != nil { + return "", err + } + + return auth.Name, nil +} + +func (s *Server) getRemote(name string) (RemoteServer, error) { + s.serversMu.RLock() + remote, ok := s.servers[name] + s.serversMu.RUnlock() + + if ok { + return remote, nil + } + + conn, err := s.initRemoteConn(name) + if err != nil { + return RemoteServer{}, err + } + + rs := NewRemoteServer(name, conn) + s.serversMu.Lock() + s.servers[name] = rs + s.serversMu.Unlock() + log.Println("connected to server:", name) + + return rs, nil +} + +func (s *Server) initRemoteConn(name string) (net.Conn, error) { + conn, err := net.Dial("tcp", name+":9999") + if err != nil { + return conn, err + } + + hs := core.Handshake{0, 1, core.ConnTypeServer} + if err := core.Send(conn, hs); err != nil { + conn.Close() + return conn, err + } + + auth := core.ServerAuth{Name: s.name} + if err := core.Send(conn, auth); err != nil { + conn.Close() + return conn, err + } + + return conn, nil +} diff --git a/server/server.go b/server/server.go index 712f654..7573968 100644 --- a/server/server.go +++ b/server/server.go @@ -13,7 +13,9 @@ type Server struct { listenAddr string name string users map[string]User - mu sync.Mutex + servers map[string]RemoteServer + usersMu sync.RWMutex + serversMu sync.RWMutex } func NewServer(listenAddr string, name string) *Server { @@ -21,6 +23,7 @@ func NewServer(listenAddr string, name string) *Server { listenAddr: listenAddr, name: name, users: make(map[string]User), + servers: make(map[string]RemoteServer), } } @@ -45,158 +48,64 @@ func (s *Server) Start() error { func (s *Server) handleConn(conn net.Conn) { defer conn.Close() - if err := s.performHandshake(conn); err != nil { - log.Println("handshake error:", err) - return - } - - user, err := s.performAuth(conn) + cType, err := s.performHandshake(conn) if err != nil { - log.Println("auth error:", err) + log.Println("handshake error:", err) return } - defer func() { - s.mu.Lock() - log.Println("client disconnected: ", user.Name) - delete(s.users[user.Name].Conns, conn) - if len(s.users[user.Name].Conns) == 0 { - log.Println("all connections closed for user:", user.Name) - delete(s.users, user.Name) - } - s.mu.Unlock() - }() - - if err := s.readInput(conn, user); err != nil { - log.Println(err) + switch cType { + case core.ConnTypeUnknown: + log.Println("invalid connection type") + case core.ConnTypeUser: + s.handleUserConn(conn) + case core.ConnTypeServer: + s.handleServerConn(conn) } } -func (s *Server) performHandshake(conn net.Conn) error { - serverHs := core.Handshake{0, 1} +func (s *Server) performHandshake(conn net.Conn) (core.ConnType, error) { + serverHs := core.Handshake{0, 2, core.ConnTypeServer} if err := core.Send(conn, serverHs); err != nil { - return err + return core.ConnTypeUnknown, err } clientPayload, err := core.Decode(conn) if err != nil { - return err + return core.ConnTypeUnknown, err } clientHs, ok := clientPayload.(core.Handshake) if !ok { - return core.ErrUnexpectedPayloadType + return core.ConnTypeUnknown, core.ErrUnexpectedPayloadType } if serverHs.Major != clientHs.Major { - return errors.New("server and client are using different protocol version") - } - - return nil -} - -func (s *Server) performAuth(conn net.Conn) (User, error) { - clientPayload, err := core.Decode(conn) - if err != nil { - return User{}, err - } - - clientAuth, ok := clientPayload.(core.Auth) - if !ok { - return User{}, core.ErrUnexpectedPayloadType - } - - // For testing --- - if clientAuth.Pass != "valid" { - if err := core.Send(conn, core.Error{core.ErrorAuthFailed}); err != nil { - log.Println("cannot send auth error:", err) - } - - return User{}, core.ErrAuthInvalidPassword - } - // --- - - if err := core.Send(conn, core.Success{}); err != nil { - return User{}, err - } - - user, ok := s.users[clientAuth.Name] - if !ok { - log.Println("initial connection from user:", clientAuth.Name) - user = NewUser(conn, clientAuth) - s.users[clientAuth.Name] = user - return user, nil + return clientHs.ConnType, errors.New("server and client are using different protocol version") } - log.Println("next connection from user:", user.Name) - user.Conns[conn] = struct{}{} - return user, nil + return clientHs.ConnType, nil } -func (s *Server) readInput(conn net.Conn, user User) error { +func (s *Server) readInput(conn net.Conn) error { for { payload, err := core.Decode(conn) if err != nil { return err } - if err := s.handlePayload(conn, user, payload); err != nil { + if err := s.handlePayload(payload); err != nil { log.Print("handler error: ", err) } } } -func (s *Server) handlePayload(conn net.Conn, user User, payload any) error { +func (s *Server) handlePayload(payload any) error { switch v := payload.(type) { case core.Message: - return s.handleMessage(conn, user, v) + return s.handleMessage(v) default: return core.ErrUnexpectedPayloadType } } - -func (s *Server) handleMessage(conn net.Conn, user User, msg core.Message) error { - log.Println("message:", user.Name, "->", msg.Target, msg.Content) - - channel, host, err := core.ReadAddr(msg.Target) - if err != nil { - return err - } - - if host == s.name { - return s.handleLocalMessage(channel, msg) - } - - return s.handleOutboundMessage(channel, host, msg) -} - -func (s *Server) handleLocalMessage(channel string, msg core.Message) error { - user, ok := s.users[channel] - if !ok { - return errors.New("target not found") - } - - return user.Send(msg) -} - -func (s *Server) handleOutboundMessage(channel, host string, msg core.Message) error { - conn, err := net.Dial("tcp", host+":9999") - if err != nil { - return err - } - defer conn.Close() - - hs := core.Handshake{0, 1} - if err := core.Send(conn, hs); err != nil { - return err - } - - // TODO, servers should not use this type of auth - auth := core.Auth{"server", "valid"} - if err := core.Send(conn, auth); err != nil { - return err - } - - return core.Send(conn, msg) -} diff --git a/server/user.go b/server/user.go index f69d126..5d2731c 100644 --- a/server/user.go +++ b/server/user.go @@ -1,6 +1,7 @@ package server import ( + "log" "net" "go.rctt.net/solec/core" @@ -11,9 +12,9 @@ type User struct { Conns map[net.Conn]struct{} } -func NewUser(conn net.Conn, auth core.Auth) User { +func NewUser(conn net.Conn, name string) User { u := User{ - Name: auth.Name, + Name: name, Conns: make(map[net.Conn]struct{}), } @@ -31,12 +32,65 @@ func (u *User) Send(payload core.Wrapper) error { return nil } -func (u *User) Auth(pass string) error { - // TODO: Implement auth +func (s *Server) handleUserConn(conn net.Conn) { + name, err := s.performUserAuth(conn) + if err != nil { + log.Println("user auth error:", err) + return + } - if pass != "valid" { - return core.ErrAuthInvalidPassword + s.usersMu.Lock() + user, ok := s.users[name] + if ok { + log.Println("next connection from user:", user.Name) + user.Conns[conn] = struct{}{} + } else { + log.Println("initial connection from user:", name) + user = NewUser(conn, name) + s.users[name] = user } + s.usersMu.Unlock() - return nil + defer func() { + s.usersMu.Lock() + log.Println("client disconnected: ", user.Name) + delete(s.users[user.Name].Conns, conn) + if len(s.users[user.Name].Conns) == 0 { + log.Println("all connections closed for user:", user.Name) + delete(s.users, user.Name) + } + s.usersMu.Unlock() + }() + + if err := s.readInput(conn); err != nil { + log.Println(err) + } +} + +func (s *Server) performUserAuth(conn net.Conn) (string, error) { + clientPayload, err := core.Decode(conn) + if err != nil { + return "", err + } + + clientAuth, ok := clientPayload.(core.UserAuth) + if !ok { + return "", core.ErrUnexpectedPayloadType + } + + // For testing --- + if clientAuth.Pass != "valid" { + if err := core.Send(conn, core.Error{core.ErrorAuthFailed}); err != nil { + log.Println("cannot send auth error:", err) + } + + return "", core.ErrAuthInvalidPassword + } + // --- + + if err := core.Send(conn, core.Success{}); err != nil { + return "", err + } + + return clientAuth.Name, nil } -- cgit v1.2.3