diff options
| author | bt <bt@rctt.net> | 2026-03-14 23:11:15 +0100 |
|---|---|---|
| committer | bt <bt@rctt.net> | 2026-03-18 16:21:58 +0100 |
| commit | 54ddec67c477a6fd73b0f623258c0849ba695b88 (patch) | |
| tree | 3a34b67e62b8788f0611abc4f9f8cfe7954aae46 /server/server.go | |
| parent | 8932846aa4d29d59fd208f40bbfd44d1bb9cf1ff (diff) | |
| download | solec-54ddec67c477a6fd73b0f623258c0849ba695b88.tar.gz solec-54ddec67c477a6fd73b0f623258c0849ba695b88.zip | |
Basic server implementation
Diffstat (limited to 'server/server.go')
| -rw-r--r-- | server/server.go | 188 |
1 files changed, 188 insertions, 0 deletions
diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..51e71b1 --- /dev/null +++ b/server/server.go @@ -0,0 +1,188 @@ +package server + +import ( + "errors" + "fmt" + "log" + "net" + "sync" + + "git.sr.ht/~rctt/solec/core" +) + +type Server struct { + listenAddr string + users map[string]User + mu sync.Mutex +} + +type User struct { + Name string + Conns map[net.Conn]struct{} +} + +func NewServer(listenAddr string) *Server { + return &Server{ + listenAddr: listenAddr, + users: make(map[string]User), + } +} + +func (s *Server) Start() error { + ln, err := net.Listen("tcp", s.listenAddr) + if err != nil { + return err + } + + for { + conn, err := ln.Accept() + if err != nil { + log.Print("cannot accept connection: ", err) + } + + log.Print("received connection from: ", conn.RemoteAddr()) + + go s.handleConn(conn) + } +} + +func (s *Server) handleConn(conn net.Conn) { + if err := s.performHandshake(conn); err != nil { + log.Println("handshake error", err) + return + } + + user, err := s.performAuth(conn) + if err != nil { + log.Println("auth error", err) + return + } + + defer func() { + s.mu.Lock() + log.Println("client disconnected: ", user.Name) + delete(s.users[user.Name].Conns, conn) + if len(s.users[user.Name].Conns) == 0 { + log.Println("all connections closed for user:", user.Name) + delete(s.users, user.Name) + } + s.mu.Unlock() + }() + + if err := s.readInput(conn, user); err != nil { + log.Println("cannot read incomming data", err) + } +} + +func (s *Server) performHandshake(conn net.Conn) error { + serverHs := core.Handshake{0, 1} + + if err := core.Send(conn, serverHs); err != nil { + return err + } + + clientPayload, err := core.Decode(conn) + if err != nil { + return err + } + + clientHs, ok := clientPayload.(core.Handshake) + if !ok { + return errors.New("received payload of invalid type") + } + + if serverHs.Major != clientHs.Major { + return errors.New("server and client are using different protocol version") + } + + return nil +} + +func (s *Server) performAuth(conn net.Conn) (User, error) { + clientPayload, err := core.Decode(conn) + if err != nil { + return User{}, err + } + + clientAuth, ok := clientPayload.(core.Auth) + if !ok { + return User{}, errors.New("received payload of invalid type") + } + + user, ok := s.users[clientAuth.Name] + if !ok { + log.Println("initial connection from user:", clientAuth.Name) + user = newUser(conn, clientAuth) + s.users[clientAuth.Name] = user + return user, nil + } + + log.Println("next connection from user:", user.Name) + user.Conns[conn] = struct{}{} + return user, nil +} + +func newUser(conn net.Conn, auth core.Auth) User { + u := User{ + Name: auth.Name, + Conns: make(map[net.Conn]struct{}), + } + + u.Conns[conn] = struct{}{} + return u +} + +func (s *Server) readInput(conn net.Conn, user User) error { + for { + payload, err := core.Decode(conn) + if err != nil { + return fmt.Errorf("decoder error: %w", err) + } + if err := s.handlePayload(conn, user, payload); err != nil { + return fmt.Errorf("payload handler error: %w", err) + } + } +} + +func (s *Server) handlePayload(conn net.Conn, user User, payload any) error { + switch v := payload.(type) { + case core.Message: + return s.handleMessage(conn, user, v) + + default: + return errors.New("invalid payload") + } +} + +func (s *Server) handleMessage(conn net.Conn, user User, msg core.Message) error { + log.Println("message:", user.Name, "->", msg.Target, msg.Content) + + if msg.Target == "*" { + return s.broadcastMessage(conn, user, msg) + } + + return nil +} + +func (s *Server) broadcastMessage(conn net.Conn, user User, msg core.Message) error { + for _, u := range s.users { + if u.Name == user.Name { + continue + } + + for c := range u.Conns { + if err := core.Send(c, msg); err != nil { + log.Println("cannot send", err) + } + } + } + + // Forward message for other connections from the same user + for c := range s.users[user.Name].Conns { + if err := core.Send(c, msg); err != nil { + log.Println("cannot send", err) + } + } + + return nil +} |
