package server import ( "fmt" "log" "net" "go.rctt.net/solec/core" ) type User struct { Name string Conns map[net.Conn]struct{} Channels map[string]*Channel } func NewUser(conn net.Conn, name string) User { u := User{ Name: name, Conns: make(map[net.Conn]struct{}), Channels: make(map[string]*Channel), } u.Conns[conn] = struct{}{} return u } func (u *User) Send(senderConn net.Conn, payload core.Wrapper) error { for c := range u.Conns { if c == senderConn { continue } if err := core.Send(c, payload); err != nil { return err } } return nil } func (s *Server) handleUserConn(conn net.Conn) { name, err := s.performUserAuth(conn) if err != nil { log.Println("user auth error:", err) return } s.usersMu.Lock() user, ok := s.users[name] if ok { log.Println("next connection from user:", user.Name) user.Conns[conn] = struct{}{} } else { log.Println("initial connection from user:", name) user = NewUser(conn, name) s.users[name] = user } s.usersMu.Unlock() defer func() { s.usersMu.Lock() log.Println("client disconnected: ", user.Name) delete(s.users[user.Name].Conns, conn) if len(s.users[user.Name].Conns) == 0 { log.Println("all connections closed for user:", user.Name) delete(s.users, user.Name) } s.usersMu.Unlock() }() if err := s.readUserInput(&user, conn); err != nil { log.Println(err) } } func (s *Server) performUserAuth(conn net.Conn) (string, error) { clientPayload, err := core.Decode(conn) if err != nil { return "", err } clientAuth, ok := clientPayload.(core.UserAuth) if !ok { return "", core.ErrUnexpectedPayloadType } hash, err := s.Storage.GetUserPass(clientAuth.Name) if err != nil { s.authFail(conn) return "", core.ErrAuthInvalidUser } if !core.CheckPass(clientAuth.Pass, hash) { s.authFail(conn) return "", core.ErrAuthInvalidPassword } if err := core.Send(conn, core.Success{}); err != nil { return "", err } return clientAuth.Name, nil } func (s *Server) authFail(conn net.Conn) { if err := core.Send(conn, core.Error{core.ErrorAuthFailed}); err != nil { log.Println("cannot send auth error:", err) } } func (s *Server) readUserInput(user *User, conn net.Conn) error { for { payload, err := core.Decode(conn) if err != nil { return err } if err := s.handleUserPayload(user, conn, payload); err != nil { log.Print("handler error: ", err) } } } func (s *Server) handleUserPayload(user *User, sender net.Conn, payload any) error { switch v := payload.(type) { case core.Message: return s.handleMessage(sender, core.ConnTypeUser, v) case core.Usermode: return s.handleUsermode(user, sender, v) case core.History: return s.handleHistory(user, sender, v) default: return core.ErrUnexpectedPayloadType } } func (s *Server) handleUsermode(user *User, conn net.Conn, mode core.Usermode) error { userAddr, err := core.ReadAddr(mode.UserAddr) if err != nil { return err } chanAddr, err := core.ReadAddr(mode.ChannelName) if err != nil { return err } if user.Name != userAddr.Channel { log.Println("unauthorized") return user.Send(conn, core.Error{core.ErrorUnauthorized}) } s.channelsMu.RLock() channel, ok := s.channels[chanAddr.Channel] if !ok { log.Println("not found", userAddr.Channel) return user.Send(conn, core.Error{core.ErrorNotFound}) } s.channelsMu.RUnlock() switch mode.Mode { case core.UsermodeNone: channel.Remove(user) case core.UsermodeInChannel: channel.Add(user) } return nil } // TODO: Replace user.Send(error) with conn.Send() // TODO: Better errors func (s *Server) handleHistory(user *User, conn net.Conn, hist core.History) error { addr, err := core.ReadAddr(hist.Channel) if err != nil { fmt.Println("cannot parse address:", err) return user.Send(conn, core.Error{core.ErrorNotFound}) } if _, ok := user.Channels[addr.Channel]; !ok { fmt.Println("cannot get message history: not authorized") return user.Send(conn, core.Error{core.ErrorNotFound}) } messages, err := s.Storage.GetHistory(hist.Channel, hist.Since, int(hist.Count), int(hist.Offset)) if err != nil { fmt.Println("cannot get message history:", err) return user.Send(conn, core.Error{core.ErrorNotFound}) } for _, m := range messages { data, err := core.Encode(m) if err != nil { return err } if _, err := conn.Write(data); err != nil { return err } } return nil }