diff options
Diffstat (limited to 'server/server.go')
| -rw-r--r-- | server/server.go | 139 |
1 files changed, 24 insertions, 115 deletions
diff --git a/server/server.go b/server/server.go index 712f654..7573968 100644 --- a/server/server.go +++ b/server/server.go @@ -13,7 +13,9 @@ type Server struct { listenAddr string name string users map[string]User - mu sync.Mutex + servers map[string]RemoteServer + usersMu sync.RWMutex + serversMu sync.RWMutex } func NewServer(listenAddr string, name string) *Server { @@ -21,6 +23,7 @@ func NewServer(listenAddr string, name string) *Server { listenAddr: listenAddr, name: name, users: make(map[string]User), + servers: make(map[string]RemoteServer), } } @@ -45,158 +48,64 @@ func (s *Server) Start() error { func (s *Server) handleConn(conn net.Conn) { defer conn.Close() - if err := s.performHandshake(conn); err != nil { - log.Println("handshake error:", err) - return - } - - user, err := s.performAuth(conn) + cType, err := s.performHandshake(conn) if err != nil { - log.Println("auth error:", err) + log.Println("handshake error:", err) return } - defer func() { - s.mu.Lock() - log.Println("client disconnected: ", user.Name) - delete(s.users[user.Name].Conns, conn) - if len(s.users[user.Name].Conns) == 0 { - log.Println("all connections closed for user:", user.Name) - delete(s.users, user.Name) - } - s.mu.Unlock() - }() - - if err := s.readInput(conn, user); err != nil { - log.Println(err) + switch cType { + case core.ConnTypeUnknown: + log.Println("invalid connection type") + case core.ConnTypeUser: + s.handleUserConn(conn) + case core.ConnTypeServer: + s.handleServerConn(conn) } } -func (s *Server) performHandshake(conn net.Conn) error { - serverHs := core.Handshake{0, 1} +func (s *Server) performHandshake(conn net.Conn) (core.ConnType, error) { + serverHs := core.Handshake{0, 2, core.ConnTypeServer} if err := core.Send(conn, serverHs); err != nil { - return err + return core.ConnTypeUnknown, err } clientPayload, err := core.Decode(conn) if err != nil { - return err + return core.ConnTypeUnknown, err } clientHs, ok := clientPayload.(core.Handshake) if !ok { - return core.ErrUnexpectedPayloadType + return core.ConnTypeUnknown, core.ErrUnexpectedPayloadType } if serverHs.Major != clientHs.Major { - return errors.New("server and client are using different protocol version") - } - - return nil -} - -func (s *Server) performAuth(conn net.Conn) (User, error) { - clientPayload, err := core.Decode(conn) - if err != nil { - return User{}, err - } - - clientAuth, ok := clientPayload.(core.Auth) - if !ok { - return User{}, core.ErrUnexpectedPayloadType - } - - // For testing --- - if clientAuth.Pass != "valid" { - if err := core.Send(conn, core.Error{core.ErrorAuthFailed}); err != nil { - log.Println("cannot send auth error:", err) - } - - return User{}, core.ErrAuthInvalidPassword - } - // --- - - if err := core.Send(conn, core.Success{}); err != nil { - return User{}, err - } - - user, ok := s.users[clientAuth.Name] - if !ok { - log.Println("initial connection from user:", clientAuth.Name) - user = NewUser(conn, clientAuth) - s.users[clientAuth.Name] = user - return user, nil + return clientHs.ConnType, errors.New("server and client are using different protocol version") } - log.Println("next connection from user:", user.Name) - user.Conns[conn] = struct{}{} - return user, nil + return clientHs.ConnType, nil } -func (s *Server) readInput(conn net.Conn, user User) error { +func (s *Server) readInput(conn net.Conn) error { for { payload, err := core.Decode(conn) if err != nil { return err } - if err := s.handlePayload(conn, user, payload); err != nil { + if err := s.handlePayload(payload); err != nil { log.Print("handler error: ", err) } } } -func (s *Server) handlePayload(conn net.Conn, user User, payload any) error { +func (s *Server) handlePayload(payload any) error { switch v := payload.(type) { case core.Message: - return s.handleMessage(conn, user, v) + return s.handleMessage(v) default: return core.ErrUnexpectedPayloadType } } - -func (s *Server) handleMessage(conn net.Conn, user User, msg core.Message) error { - log.Println("message:", user.Name, "->", msg.Target, msg.Content) - - channel, host, err := core.ReadAddr(msg.Target) - if err != nil { - return err - } - - if host == s.name { - return s.handleLocalMessage(channel, msg) - } - - return s.handleOutboundMessage(channel, host, msg) -} - -func (s *Server) handleLocalMessage(channel string, msg core.Message) error { - user, ok := s.users[channel] - if !ok { - return errors.New("target not found") - } - - return user.Send(msg) -} - -func (s *Server) handleOutboundMessage(channel, host string, msg core.Message) error { - conn, err := net.Dial("tcp", host+":9999") - if err != nil { - return err - } - defer conn.Close() - - hs := core.Handshake{0, 1} - if err := core.Send(conn, hs); err != nil { - return err - } - - // TODO, servers should not use this type of auth - auth := core.Auth{"server", "valid"} - if err := core.Send(conn, auth); err != nil { - return err - } - - return core.Send(conn, msg) -} |
