diff options
| -rw-r--r-- | cmd/client/main.go | 5 | ||||
| -rw-r--r-- | core/data.go | 2 | ||||
| -rw-r--r-- | core/errors.go | 12 | ||||
| -rw-r--r-- | server/message.go | 29 | ||||
| -rw-r--r-- | server/server.go | 90 | ||||
| -rw-r--r-- | server/user.go | 42 |
6 files changed, 117 insertions, 63 deletions
diff --git a/cmd/client/main.go b/cmd/client/main.go index 061efdc..c362247 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -2,6 +2,7 @@ package main import ( "flag" + "fmt" "log" "net" "strings" @@ -42,7 +43,7 @@ func main() { panic(err) } - auth := core.Auth{user, "pass"} + auth := core.Auth{user, "valid"} if err := core.Send(conn, auth); err != nil { panic(err) } @@ -67,6 +68,8 @@ func handlePayload(conn net.Conn, payload any) { switch v := payload.(type) { case core.Message: handleMessage(conn, v) + default: + fmt.Println(payload) } } diff --git a/core/data.go b/core/data.go index 9783c47..a96b56f 100644 --- a/core/data.go +++ b/core/data.go @@ -182,7 +182,7 @@ func decodeString(buf io.Reader, ptr *string) error { func ReadAddr(addr string) (string, string, error) { channel, host, ok := strings.Cut(addr, "@") if !ok { - return "", "", fmt.Errorf("invalid address") + return "", "", ErrInvalidAddress } return channel, host, nil diff --git a/core/errors.go b/core/errors.go new file mode 100644 index 0000000..a50fbf7 --- /dev/null +++ b/core/errors.go @@ -0,0 +1,12 @@ +package core + +import "errors" + +var ( + ErrUnexpectedPayloadType = errors.New("unexpected payload type") + ErrAuthInvalidUser = errors.New("invalid user") + ErrAuthInvalidPassword = errors.New("invalid password") + ErrInvalidAddress = errors.New("invalid address") + ErrNotSupported = errors.New("not supported") + ErrDisconnected = errors.New("disconnected") +) diff --git a/server/message.go b/server/message.go new file mode 100644 index 0000000..8d89797 --- /dev/null +++ b/server/message.go @@ -0,0 +1,29 @@ +package server + +import ( + "log" + "time" + + "git.rctt.net/solec/core" +) + +func (s *Server) SendBroadcast(msg string) { + payload := core.Message{ + Source: "op@example.org", + Target: "*@example.org", + Timestamp: time.Now(), + Content: msg, + } + data, err := core.Encode(payload) + if err != nil { + panic(err) + } + + for _, u := range s.users { + for c := range u.Conns { + if _, err := c.Write(data); err != nil { + log.Println("cannot send", err) + } + } + } +} diff --git a/server/server.go b/server/server.go index abc90ed..7524bba 100644 --- a/server/server.go +++ b/server/server.go @@ -2,11 +2,9 @@ package server import ( "errors" - "fmt" "log" "net" "sync" - "time" "git.rctt.net/solec/core" ) @@ -18,21 +16,6 @@ type Server struct { mu sync.Mutex } -type User struct { - Name string - Conns map[net.Conn]struct{} -} - -func (u *User) Send(payload core.Wrapper) error { - for c := range u.Conns { - if err := core.Send(c, payload); err != nil { - return err - } - } - - return nil -} - func NewServer(listenAddr string, name string) *Server { return &Server{ listenAddr: listenAddr, @@ -50,45 +33,26 @@ func (s *Server) Start() error { for { conn, err := ln.Accept() if err != nil { - log.Print("cannot accept connection: ", err) + log.Println("cannot listen:", err) + continue } - log.Print("received connection from: ", conn.RemoteAddr()) - + log.Println("client connected:", conn.RemoteAddr()) go s.handleConn(conn) } } -func (s *Server) SendBroadcast(msg string) { - payload := core.Message{ - Source: "op@example.org", - Target: "*@example.org", - Timestamp: time.Now(), - Content: msg, - } - data, err := core.Encode(payload) - if err != nil { - panic(err) - } - - for _, u := range s.users { - for c := range u.Conns { - if _, err := c.Write(data); err != nil { - log.Println("cannot send", err) - } - } - } -} - func (s *Server) handleConn(conn net.Conn) { + defer conn.Close() + if err := s.performHandshake(conn); err != nil { - log.Println("handshake error", err) + log.Println("handshake error:", err) return } user, err := s.performAuth(conn) if err != nil { - log.Println("auth error", err) + log.Println("auth error:", err) return } @@ -104,7 +68,7 @@ func (s *Server) handleConn(conn net.Conn) { }() if err := s.readInput(conn, user); err != nil { - log.Println("cannot read incomming data", err) + log.Println(err) } } @@ -122,7 +86,7 @@ func (s *Server) performHandshake(conn net.Conn) error { clientHs, ok := clientPayload.(core.Handshake) if !ok { - return errors.New("received payload of invalid type") + return core.ErrUnexpectedPayloadType } if serverHs.Major != clientHs.Major { @@ -140,13 +104,27 @@ func (s *Server) performAuth(conn net.Conn) (User, error) { clientAuth, ok := clientPayload.(core.Auth) if !ok { - return User{}, errors.New("received payload of invalid type") + return User{}, core.ErrUnexpectedPayloadType + } + + // For testing --- + if clientAuth.Pass != "valid" { + if err := core.Send(conn, core.Error{core.ErrorAuthFailed}); err != nil { + log.Println("cannot send auth error:", err) + } + + return User{}, core.ErrAuthInvalidPassword + } + // --- + + if err := core.Send(conn, core.Success{}); err != nil { + return User{}, err } user, ok := s.users[clientAuth.Name] if !ok { log.Println("initial connection from user:", clientAuth.Name) - user = newUser(conn, clientAuth) + user = NewUser(conn, clientAuth) s.users[clientAuth.Name] = user return user, nil } @@ -156,24 +134,14 @@ func (s *Server) performAuth(conn net.Conn) (User, error) { return user, nil } -func newUser(conn net.Conn, auth core.Auth) User { - u := User{ - Name: auth.Name, - Conns: make(map[net.Conn]struct{}), - } - - u.Conns[conn] = struct{}{} - return u -} - func (s *Server) readInput(conn net.Conn, user User) error { for { payload, err := core.Decode(conn) if err != nil { - return fmt.Errorf("decoder error: %w", err) + return err } if err := s.handlePayload(conn, user, payload); err != nil { - return fmt.Errorf("payload handler error: %w", err) + log.Print("handler error: ", err) } } } @@ -184,7 +152,7 @@ func (s *Server) handlePayload(conn net.Conn, user User, payload any) error { return s.handleMessage(conn, user, v) default: - return errors.New("invalid payload") + return core.ErrUnexpectedPayloadType } } @@ -200,7 +168,7 @@ func (s *Server) handleMessage(conn net.Conn, user User, msg core.Message) error return s.handleLocalMessage(channel, msg) } - return fmt.Errorf("not supported") + return core.ErrNotSupported } func (s *Server) handleLocalMessage(channel string, msg core.Message) error { diff --git a/server/user.go b/server/user.go new file mode 100644 index 0000000..204fbbe --- /dev/null +++ b/server/user.go @@ -0,0 +1,42 @@ +package server + +import ( + "net" + + "git.rctt.net/solec/core" +) + +type User struct { + Name string + Conns map[net.Conn]struct{} +} + +func NewUser(conn net.Conn, auth core.Auth) User { + u := User{ + Name: auth.Name, + Conns: make(map[net.Conn]struct{}), + } + + u.Conns[conn] = struct{}{} + return u +} + +func (u *User) Send(payload core.Wrapper) error { + for c := range u.Conns { + if err := core.Send(c, payload); err != nil { + return err + } + } + + return nil +} + +func (u *User) Auth(pass string) error { + // TODO: Implement auth + + if pass != "valid" { + return core.ErrAuthInvalidPassword + } + + return nil +} |
