summaryrefslogtreecommitdiffstats
path: root/server/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'server/server.go')
-rw-r--r--server/server.go90
1 files changed, 29 insertions, 61 deletions
diff --git a/server/server.go b/server/server.go
index abc90ed..7524bba 100644
--- a/server/server.go
+++ b/server/server.go
@@ -2,11 +2,9 @@ package server
import (
"errors"
- "fmt"
"log"
"net"
"sync"
- "time"
"git.rctt.net/solec/core"
)
@@ -18,21 +16,6 @@ type Server struct {
mu sync.Mutex
}
-type User struct {
- Name string
- Conns map[net.Conn]struct{}
-}
-
-func (u *User) Send(payload core.Wrapper) error {
- for c := range u.Conns {
- if err := core.Send(c, payload); err != nil {
- return err
- }
- }
-
- return nil
-}
-
func NewServer(listenAddr string, name string) *Server {
return &Server{
listenAddr: listenAddr,
@@ -50,45 +33,26 @@ func (s *Server) Start() error {
for {
conn, err := ln.Accept()
if err != nil {
- log.Print("cannot accept connection: ", err)
+ log.Println("cannot listen:", err)
+ continue
}
- log.Print("received connection from: ", conn.RemoteAddr())
-
+ log.Println("client connected:", conn.RemoteAddr())
go s.handleConn(conn)
}
}
-func (s *Server) SendBroadcast(msg string) {
- payload := core.Message{
- Source: "op@example.org",
- Target: "*@example.org",
- Timestamp: time.Now(),
- Content: msg,
- }
- data, err := core.Encode(payload)
- if err != nil {
- panic(err)
- }
-
- for _, u := range s.users {
- for c := range u.Conns {
- if _, err := c.Write(data); err != nil {
- log.Println("cannot send", err)
- }
- }
- }
-}
-
func (s *Server) handleConn(conn net.Conn) {
+ defer conn.Close()
+
if err := s.performHandshake(conn); err != nil {
- log.Println("handshake error", err)
+ log.Println("handshake error:", err)
return
}
user, err := s.performAuth(conn)
if err != nil {
- log.Println("auth error", err)
+ log.Println("auth error:", err)
return
}
@@ -104,7 +68,7 @@ func (s *Server) handleConn(conn net.Conn) {
}()
if err := s.readInput(conn, user); err != nil {
- log.Println("cannot read incomming data", err)
+ log.Println(err)
}
}
@@ -122,7 +86,7 @@ func (s *Server) performHandshake(conn net.Conn) error {
clientHs, ok := clientPayload.(core.Handshake)
if !ok {
- return errors.New("received payload of invalid type")
+ return core.ErrUnexpectedPayloadType
}
if serverHs.Major != clientHs.Major {
@@ -140,13 +104,27 @@ func (s *Server) performAuth(conn net.Conn) (User, error) {
clientAuth, ok := clientPayload.(core.Auth)
if !ok {
- return User{}, errors.New("received payload of invalid type")
+ return User{}, core.ErrUnexpectedPayloadType
+ }
+
+ // For testing ---
+ if clientAuth.Pass != "valid" {
+ if err := core.Send(conn, core.Error{core.ErrorAuthFailed}); err != nil {
+ log.Println("cannot send auth error:", err)
+ }
+
+ return User{}, core.ErrAuthInvalidPassword
+ }
+ // ---
+
+ if err := core.Send(conn, core.Success{}); err != nil {
+ return User{}, err
}
user, ok := s.users[clientAuth.Name]
if !ok {
log.Println("initial connection from user:", clientAuth.Name)
- user = newUser(conn, clientAuth)
+ user = NewUser(conn, clientAuth)
s.users[clientAuth.Name] = user
return user, nil
}
@@ -156,24 +134,14 @@ func (s *Server) performAuth(conn net.Conn) (User, error) {
return user, nil
}
-func newUser(conn net.Conn, auth core.Auth) User {
- u := User{
- Name: auth.Name,
- Conns: make(map[net.Conn]struct{}),
- }
-
- u.Conns[conn] = struct{}{}
- return u
-}
-
func (s *Server) readInput(conn net.Conn, user User) error {
for {
payload, err := core.Decode(conn)
if err != nil {
- return fmt.Errorf("decoder error: %w", err)
+ return err
}
if err := s.handlePayload(conn, user, payload); err != nil {
- return fmt.Errorf("payload handler error: %w", err)
+ log.Print("handler error: ", err)
}
}
}
@@ -184,7 +152,7 @@ func (s *Server) handlePayload(conn net.Conn, user User, payload any) error {
return s.handleMessage(conn, user, v)
default:
- return errors.New("invalid payload")
+ return core.ErrUnexpectedPayloadType
}
}
@@ -200,7 +168,7 @@ func (s *Server) handleMessage(conn net.Conn, user User, msg core.Message) error
return s.handleLocalMessage(channel, msg)
}
- return fmt.Errorf("not supported")
+ return core.ErrNotSupported
}
func (s *Server) handleLocalMessage(channel string, msg core.Message) error {