summaryrefslogtreecommitdiffstats
path: root/server
diff options
context:
space:
mode:
authorbt <bt@rctt.net>2026-04-10 19:31:31 +0200
committerbt <bt@rctt.net>2026-04-18 22:33:20 +0200
commite9aebac1a2a4732763c2f7e4428a23983d4eb6a3 (patch)
treeb66ea36939ea75360ed6c554cf352348d19786bb /server
parentf66e28aa88a5f4176934001fa9e4967ddccde4a9 (diff)
downloadsolec-e9aebac1a2a4732763c2f7e4428a23983d4eb6a3.tar.gz
solec-e9aebac1a2a4732763c2f7e4428a23983d4eb6a3.zip
[common] Exchange messages between servers
Diffstat (limited to 'server')
-rw-r--r--server/message.go37
-rw-r--r--server/remote.go112
-rw-r--r--server/server.go139
-rw-r--r--server/user.go68
4 files changed, 234 insertions, 122 deletions
diff --git a/server/message.go b/server/message.go
index c1384f7..2487abd 100644
--- a/server/message.go
+++ b/server/message.go
@@ -1,6 +1,8 @@
package server
import (
+ "errors"
+ "fmt"
"log"
"time"
@@ -27,3 +29,38 @@ func (s *Server) SendBroadcast(msg string) {
}
}
}
+
+func (s *Server) handleMessage(msg core.Message) error {
+ log.Println("message:", msg.Source, "->", msg.Target, msg.Content)
+
+ channel, host, err := core.ReadAddr(msg.Target)
+ if err != nil {
+ return err
+ }
+
+ if host == s.name {
+ return s.handleLocalMessage(channel, msg)
+ }
+
+ return s.handleOutboundMessage(channel, host, msg)
+}
+
+func (s *Server) handleLocalMessage(channel string, msg core.Message) error {
+ s.usersMu.RLock()
+ user, ok := s.users[channel]
+ if !ok {
+ return errors.New("target not found")
+ }
+ s.usersMu.RUnlock()
+
+ return user.Send(msg)
+}
+
+func (s *Server) handleOutboundMessage(channel, host string, msg core.Message) error {
+ remote, err := s.getRemote(host)
+ if err != nil {
+ return fmt.Errorf("cannot access remote server: %w", err)
+ }
+
+ return core.Send(remote.Conn, msg)
+}
diff --git a/server/remote.go b/server/remote.go
new file mode 100644
index 0000000..2449511
--- /dev/null
+++ b/server/remote.go
@@ -0,0 +1,112 @@
+package server
+
+import (
+ "log"
+ "net"
+
+ "go.rctt.net/solec/core"
+)
+
+type RemoteServer struct {
+ Name string
+ Conn net.Conn
+}
+
+func NewRemoteServer(name string, conn net.Conn) RemoteServer {
+ return RemoteServer{name, conn}
+}
+
+func (s *Server) handleServerConn(conn net.Conn) {
+ defer conn.Close()
+
+ name, err := s.performServerAuth(conn)
+ if err != nil {
+ log.Println("server auth error:", err)
+ return
+ }
+
+ s.serversMu.RLock()
+ if _, ok := s.servers[name]; ok {
+ log.Println("server already connected")
+ return
+ }
+ s.serversMu.RUnlock()
+
+ rs := NewRemoteServer(name, conn)
+ s.serversMu.Lock()
+ s.servers[name] = rs
+ s.serversMu.Unlock()
+ log.Println("connection from server:", name)
+
+ defer func() {
+ s.serversMu.Lock()
+ log.Println("server disconnected: ", rs.Name)
+ delete(s.servers, rs.Name)
+ s.serversMu.Unlock()
+ }()
+
+ if err := s.readInput(conn); err != nil {
+ log.Println(err)
+ }
+}
+
+func (s *Server) performServerAuth(conn net.Conn) (string, error) {
+ payload, err := core.Decode(conn)
+ if err != nil {
+ return "", err
+ }
+ auth, ok := payload.(core.ServerAuth)
+ if !ok {
+ return "", core.ErrUnexpectedPayloadType
+ }
+
+ if err := core.Send(conn, core.Success{}); err != nil {
+ return "", err
+ }
+
+ return auth.Name, nil
+}
+
+func (s *Server) getRemote(name string) (RemoteServer, error) {
+ s.serversMu.RLock()
+ remote, ok := s.servers[name]
+ s.serversMu.RUnlock()
+
+ if ok {
+ return remote, nil
+ }
+
+ conn, err := s.initRemoteConn(name)
+ if err != nil {
+ return RemoteServer{}, err
+ }
+
+ rs := NewRemoteServer(name, conn)
+ s.serversMu.Lock()
+ s.servers[name] = rs
+ s.serversMu.Unlock()
+ log.Println("connected to server:", name)
+
+ return rs, nil
+}
+
+func (s *Server) initRemoteConn(name string) (net.Conn, error) {
+ conn, err := net.Dial("tcp", name+":9999")
+ if err != nil {
+ return conn, err
+ }
+
+ hs := core.Handshake{0, 1, core.ConnTypeServer}
+ if err := core.Send(conn, hs); err != nil {
+ conn.Close()
+ return conn, err
+ }
+
+ auth := core.ServerAuth{Name: s.name}
+ if err := core.Send(conn, auth); err != nil {
+ conn.Close()
+ return conn, err
+ }
+
+ return conn, nil
+}
diff --git a/server/server.go b/server/server.go
index 712f654..7573968 100644
--- a/server/server.go
+++ b/server/server.go
@@ -13,7 +13,9 @@ type Server struct {
listenAddr string
name string
users map[string]User
- mu sync.Mutex
+ servers map[string]RemoteServer
+ usersMu sync.RWMutex
+ serversMu sync.RWMutex
}
func NewServer(listenAddr string, name string) *Server {
@@ -21,6 +23,7 @@ func NewServer(listenAddr string, name string) *Server {
listenAddr: listenAddr,
name: name,
users: make(map[string]User),
+ servers: make(map[string]RemoteServer),
}
}
@@ -45,158 +48,64 @@ func (s *Server) Start() error {
func (s *Server) handleConn(conn net.Conn) {
defer conn.Close()
- if err := s.performHandshake(conn); err != nil {
- log.Println("handshake error:", err)
- return
- }
-
- user, err := s.performAuth(conn)
+ cType, err := s.performHandshake(conn)
if err != nil {
- log.Println("auth error:", err)
+ log.Println("handshake error:", err)
return
}
- defer func() {
- s.mu.Lock()
- log.Println("client disconnected: ", user.Name)
- delete(s.users[user.Name].Conns, conn)
- if len(s.users[user.Name].Conns) == 0 {
- log.Println("all connections closed for user:", user.Name)
- delete(s.users, user.Name)
- }
- s.mu.Unlock()
- }()
-
- if err := s.readInput(conn, user); err != nil {
- log.Println(err)
+ switch cType {
+ case core.ConnTypeUnknown:
+ log.Println("invalid connection type")
+ case core.ConnTypeUser:
+ s.handleUserConn(conn)
+ case core.ConnTypeServer:
+ s.handleServerConn(conn)
}
}
-func (s *Server) performHandshake(conn net.Conn) error {
- serverHs := core.Handshake{0, 1}
+func (s *Server) performHandshake(conn net.Conn) (core.ConnType, error) {
+ serverHs := core.Handshake{0, 2, core.ConnTypeServer}
if err := core.Send(conn, serverHs); err != nil {
- return err
+ return core.ConnTypeUnknown, err
}
clientPayload, err := core.Decode(conn)
if err != nil {
- return err
+ return core.ConnTypeUnknown, err
}
clientHs, ok := clientPayload.(core.Handshake)
if !ok {
- return core.ErrUnexpectedPayloadType
+ return core.ConnTypeUnknown, core.ErrUnexpectedPayloadType
}
if serverHs.Major != clientHs.Major {
- return errors.New("server and client are using different protocol version")
- }
-
- return nil
-}
-
-func (s *Server) performAuth(conn net.Conn) (User, error) {
- clientPayload, err := core.Decode(conn)
- if err != nil {
- return User{}, err
- }
-
- clientAuth, ok := clientPayload.(core.Auth)
- if !ok {
- return User{}, core.ErrUnexpectedPayloadType
- }
-
- // For testing ---
- if clientAuth.Pass != "valid" {
- if err := core.Send(conn, core.Error{core.ErrorAuthFailed}); err != nil {
- log.Println("cannot send auth error:", err)
- }
-
- return User{}, core.ErrAuthInvalidPassword
- }
- // ---
-
- if err := core.Send(conn, core.Success{}); err != nil {
- return User{}, err
- }
-
- user, ok := s.users[clientAuth.Name]
- if !ok {
- log.Println("initial connection from user:", clientAuth.Name)
- user = NewUser(conn, clientAuth)
- s.users[clientAuth.Name] = user
- return user, nil
+ return clientHs.ConnType, errors.New("server and client are using different protocol version")
}
- log.Println("next connection from user:", user.Name)
- user.Conns[conn] = struct{}{}
- return user, nil
+ return clientHs.ConnType, nil
}
-func (s *Server) readInput(conn net.Conn, user User) error {
+func (s *Server) readInput(conn net.Conn) error {
for {
payload, err := core.Decode(conn)
if err != nil {
return err
}
- if err := s.handlePayload(conn, user, payload); err != nil {
+ if err := s.handlePayload(payload); err != nil {
log.Print("handler error: ", err)
}
}
}
-func (s *Server) handlePayload(conn net.Conn, user User, payload any) error {
+func (s *Server) handlePayload(payload any) error {
switch v := payload.(type) {
case core.Message:
- return s.handleMessage(conn, user, v)
+ return s.handleMessage(v)
default:
return core.ErrUnexpectedPayloadType
}
}
-
-func (s *Server) handleMessage(conn net.Conn, user User, msg core.Message) error {
- log.Println("message:", user.Name, "->", msg.Target, msg.Content)
-
- channel, host, err := core.ReadAddr(msg.Target)
- if err != nil {
- return err
- }
-
- if host == s.name {
- return s.handleLocalMessage(channel, msg)
- }
-
- return s.handleOutboundMessage(channel, host, msg)
-}
-
-func (s *Server) handleLocalMessage(channel string, msg core.Message) error {
- user, ok := s.users[channel]
- if !ok {
- return errors.New("target not found")
- }
-
- return user.Send(msg)
-}
-
-func (s *Server) handleOutboundMessage(channel, host string, msg core.Message) error {
- conn, err := net.Dial("tcp", host+":9999")
- if err != nil {
- return err
- }
- defer conn.Close()
-
- hs := core.Handshake{0, 1}
- if err := core.Send(conn, hs); err != nil {
- return err
- }
-
- // TODO, servers should not use this type of auth
- auth := core.Auth{"server", "valid"}
- if err := core.Send(conn, auth); err != nil {
- return err
- }
-
- return core.Send(conn, msg)
-}
diff --git a/server/user.go b/server/user.go
index f69d126..5d2731c 100644
--- a/server/user.go
+++ b/server/user.go
@@ -1,6 +1,7 @@
package server
import (
+ "log"
"net"
"go.rctt.net/solec/core"
@@ -11,9 +12,9 @@ type User struct {
Conns map[net.Conn]struct{}
}
-func NewUser(conn net.Conn, auth core.Auth) User {
+func NewUser(conn net.Conn, name string) User {
u := User{
- Name: auth.Name,
+ Name: name,
Conns: make(map[net.Conn]struct{}),
}
@@ -31,12 +32,65 @@ func (u *User) Send(payload core.Wrapper) error {
return nil
}
-func (u *User) Auth(pass string) error {
- // TODO: Implement auth
+func (s *Server) handleUserConn(conn net.Conn) {
+ name, err := s.performUserAuth(conn)
+ if err != nil {
+ log.Println("user auth error:", err)
+ return
+ }
- if pass != "valid" {
- return core.ErrAuthInvalidPassword
+ s.usersMu.Lock()
+ user, ok := s.users[name]
+ if ok {
+ log.Println("next connection from user:", user.Name)
+ user.Conns[conn] = struct{}{}
+ } else {
+ log.Println("initial connection from user:", name)
+ user = NewUser(conn, name)
+ s.users[name] = user
}
+ s.usersMu.Unlock()
- return nil
+ defer func() {
+ s.usersMu.Lock()
+ log.Println("client disconnected: ", user.Name)
+ delete(s.users[user.Name].Conns, conn)
+ if len(s.users[user.Name].Conns) == 0 {
+ log.Println("all connections closed for user:", user.Name)
+ delete(s.users, user.Name)
+ }
+ s.usersMu.Unlock()
+ }()
+
+ if err := s.readInput(conn); err != nil {
+ log.Println(err)
+ }
+}
+
+func (s *Server) performUserAuth(conn net.Conn) (string, error) {
+ clientPayload, err := core.Decode(conn)
+ if err != nil {
+ return "", err
+ }
+
+ clientAuth, ok := clientPayload.(core.UserAuth)
+ if !ok {
+ return "", core.ErrUnexpectedPayloadType
+ }
+
+ // For testing ---
+ if clientAuth.Pass != "valid" {
+ if err := core.Send(conn, core.Error{core.ErrorAuthFailed}); err != nil {
+ log.Println("cannot send auth error:", err)
+ }
+
+ return "", core.ErrAuthInvalidPassword
+ }
+ // ---
+
+ if err := core.Send(conn, core.Success{}); err != nil {
+ return "", err
+ }
+
+ return clientAuth.Name, nil
}