package server import ( "errors" "fmt" "log" "net" "sync" "git.sr.ht/~rctt/solec/core" ) type Server struct { listenAddr string users map[string]User mu sync.Mutex } type User struct { Name string Conns map[net.Conn]struct{} } func NewServer(listenAddr string) *Server { return &Server{ listenAddr: listenAddr, users: make(map[string]User), } } func (s *Server) Start() error { ln, err := net.Listen("tcp", s.listenAddr) if err != nil { return err } for { conn, err := ln.Accept() if err != nil { log.Print("cannot accept connection: ", err) } log.Print("received connection from: ", conn.RemoteAddr()) go s.handleConn(conn) } } func (s *Server) handleConn(conn net.Conn) { if err := s.performHandshake(conn); err != nil { log.Println("handshake error", err) return } user, err := s.performAuth(conn) if err != nil { log.Println("auth error", err) return } defer func() { s.mu.Lock() log.Println("client disconnected: ", user.Name) delete(s.users[user.Name].Conns, conn) if len(s.users[user.Name].Conns) == 0 { log.Println("all connections closed for user:", user.Name) delete(s.users, user.Name) } s.mu.Unlock() }() if err := s.readInput(conn, user); err != nil { log.Println("cannot read incomming data", err) } } func (s *Server) performHandshake(conn net.Conn) error { serverHs := core.Handshake{0, 1} if err := core.Send(conn, serverHs); err != nil { return err } clientPayload, err := core.Decode(conn) if err != nil { return err } clientHs, ok := clientPayload.(core.Handshake) if !ok { return errors.New("received payload of invalid type") } if serverHs.Major != clientHs.Major { return errors.New("server and client are using different protocol version") } return nil } func (s *Server) performAuth(conn net.Conn) (User, error) { clientPayload, err := core.Decode(conn) if err != nil { return User{}, err } clientAuth, ok := clientPayload.(core.Auth) if !ok { return User{}, errors.New("received payload of invalid type") } user, ok := s.users[clientAuth.Name] if !ok { log.Println("initial connection from user:", clientAuth.Name) user = newUser(conn, clientAuth) s.users[clientAuth.Name] = user return user, nil } log.Println("next connection from user:", user.Name) user.Conns[conn] = struct{}{} return user, nil } func newUser(conn net.Conn, auth core.Auth) User { u := User{ Name: auth.Name, Conns: make(map[net.Conn]struct{}), } u.Conns[conn] = struct{}{} return u } func (s *Server) readInput(conn net.Conn, user User) error { for { payload, err := core.Decode(conn) if err != nil { return fmt.Errorf("decoder error: %w", err) } if err := s.handlePayload(conn, user, payload); err != nil { return fmt.Errorf("payload handler error: %w", err) } } } func (s *Server) handlePayload(conn net.Conn, user User, payload any) error { switch v := payload.(type) { case core.Message: return s.handleMessage(conn, user, v) default: return errors.New("invalid payload") } } func (s *Server) handleMessage(conn net.Conn, user User, msg core.Message) error { log.Println("message:", user.Name, "->", msg.Target, msg.Content) if msg.Target == "*" { return s.broadcastMessage(conn, user, msg) } return nil } func (s *Server) broadcastMessage(conn net.Conn, user User, msg core.Message) error { for _, u := range s.users { if u.Name == user.Name { continue } for c := range u.Conns { if err := core.Send(c, msg); err != nil { log.Println("cannot send", err) } } } // Forward message for other connections from the same user for c := range s.users[user.Name].Conns { if err := core.Send(c, msg); err != nil { log.Println("cannot send", err) } } return nil }