package server import ( "errors" "log" "net" "sync" "git.rctt.net/solec/core" ) type Server struct { listenAddr string name string users map[string]User mu sync.Mutex } func NewServer(listenAddr string, name string) *Server { return &Server{ listenAddr: listenAddr, name: name, users: make(map[string]User), } } func (s *Server) Start() error { ln, err := net.Listen("tcp", s.listenAddr) if err != nil { return err } for { conn, err := ln.Accept() if err != nil { log.Println("cannot listen:", err) continue } log.Println("client connected:", conn.RemoteAddr()) go s.handleConn(conn) } } func (s *Server) handleConn(conn net.Conn) { defer conn.Close() if err := s.performHandshake(conn); err != nil { log.Println("handshake error:", err) return } user, err := s.performAuth(conn) if err != nil { log.Println("auth error:", err) return } defer func() { s.mu.Lock() log.Println("client disconnected: ", user.Name) delete(s.users[user.Name].Conns, conn) if len(s.users[user.Name].Conns) == 0 { log.Println("all connections closed for user:", user.Name) delete(s.users, user.Name) } s.mu.Unlock() }() if err := s.readInput(conn, user); err != nil { log.Println(err) } } func (s *Server) performHandshake(conn net.Conn) error { serverHs := core.Handshake{0, 1} if err := core.Send(conn, serverHs); err != nil { return err } clientPayload, err := core.Decode(conn) if err != nil { return err } clientHs, ok := clientPayload.(core.Handshake) if !ok { return core.ErrUnexpectedPayloadType } if serverHs.Major != clientHs.Major { return errors.New("server and client are using different protocol version") } return nil } func (s *Server) performAuth(conn net.Conn) (User, error) { clientPayload, err := core.Decode(conn) if err != nil { return User{}, err } clientAuth, ok := clientPayload.(core.Auth) if !ok { return User{}, core.ErrUnexpectedPayloadType } // For testing --- if clientAuth.Pass != "valid" { if err := core.Send(conn, core.Error{core.ErrorAuthFailed}); err != nil { log.Println("cannot send auth error:", err) } return User{}, core.ErrAuthInvalidPassword } // --- if err := core.Send(conn, core.Success{}); err != nil { return User{}, err } user, ok := s.users[clientAuth.Name] if !ok { log.Println("initial connection from user:", clientAuth.Name) user = NewUser(conn, clientAuth) s.users[clientAuth.Name] = user return user, nil } log.Println("next connection from user:", user.Name) user.Conns[conn] = struct{}{} return user, nil } func (s *Server) readInput(conn net.Conn, user User) error { for { payload, err := core.Decode(conn) if err != nil { return err } if err := s.handlePayload(conn, user, payload); err != nil { log.Print("handler error: ", err) } } } func (s *Server) handlePayload(conn net.Conn, user User, payload any) error { switch v := payload.(type) { case core.Message: return s.handleMessage(conn, user, v) default: return core.ErrUnexpectedPayloadType } } func (s *Server) handleMessage(conn net.Conn, user User, msg core.Message) error { log.Println("message:", user.Name, "->", msg.Target, msg.Content) channel, host, err := core.ReadAddr(msg.Target) if err != nil { return err } if host == s.name { return s.handleLocalMessage(channel, msg) } return core.ErrNotSupported } func (s *Server) handleLocalMessage(channel string, msg core.Message) error { user, ok := s.users[channel] if !ok { return errors.New("target not found") } return user.Send(msg) }