package server import ( "fmt" "log" "net" "go.rctt.net/solec/core" ) type User struct { Addr string Conns map[net.Conn]struct{} } func NewUser(conn net.Conn, addr string) User { u := User{ Addr: addr, Conns: make(map[net.Conn]struct{}), } u.Conns[conn] = struct{}{} return u } func (u *User) Send(senderConn net.Conn, payload core.Wrapper) error { for c := range u.Conns { if c == senderConn { continue } if err := core.Send(c, payload); err != nil { return err } } return nil } func (s *Server) handleUserConn(conn net.Conn) { addr, err := s.performUserAuth(conn) if err != nil { log.Println("user auth error:", err) return } s.usersMu.Lock() user, ok := s.users[addr] if ok { log.Println("next connection from user:", user.Addr) user.Conns[conn] = struct{}{} } else { log.Println("initial connection from user:", addr) user = NewUser(conn, addr) s.users[addr] = user } s.usersMu.Unlock() defer func() { s.usersMu.Lock() log.Println("client disconnected: ", user.Addr) delete(s.users[user.Addr].Conns, conn) if len(s.users[user.Addr].Conns) == 0 { log.Println("all connections closed for user:", user.Addr) delete(s.users, user.Addr) } s.usersMu.Unlock() }() if err := s.readUserInput(&user, conn); err != nil { log.Println(err) } } func (s *Server) performUserAuth(conn net.Conn) (string, error) { clientPayload, err := core.Decode(conn) if err != nil { return "", err } clientAuth, ok := clientPayload.(core.UserAuth) if !ok { return "", core.ErrUnexpectedPayloadType } hash, err := s.Storage.GetUserPass(clientAuth.Addr) if err != nil { s.authFail(conn) return "", core.ErrAuthInvalidUser } if !core.CheckPass(clientAuth.Pass, hash) { s.authFail(conn) return "", core.ErrAuthInvalidPassword } if err := core.Send(conn, core.Success{}); err != nil { return "", err } return clientAuth.Addr, nil } func (s *Server) authFail(conn net.Conn) { if err := core.Send(conn, core.Error{core.ErrorAuthFailed}); err != nil { log.Println("cannot send auth error:", err) } } func (s *Server) readUserInput(user *User, conn net.Conn) error { for { payload, err := core.Decode(conn) if err != nil { return err } if err := s.handleUserPayload(user, conn, payload); err != nil { log.Print("handler error: ", err) } } } func (s *Server) handleUserPayload(user *User, sender net.Conn, payload any) error { switch v := payload.(type) { case core.Message: return s.handleMessage(sender, core.ConnTypeUser, user, v) case core.Usermode: return s.handleUsermode(user, sender, v) case core.History: return s.handleHistory(user, sender, v) default: return core.ErrUnexpectedPayloadType } } func (s *Server) handleUsermode(user *User, conn net.Conn, mode core.Usermode) error { _, err := core.ReadAddr(mode.UserAddr) if err != nil { log.Println("invalid user address") return err } _, err = core.ReadAddr(mode.ChannelAddr) if err != nil { log.Println("invalid channel address") return err } perm := core.PermissionData{ User: mode.UserAddr, Channel: mode.ChannelAddr, } if mode.Mode == core.UsermodeInChannel { perm.Read = true perm.Write = true } return s.Storage.SetPermission(perm) } // TODO: Replace user.Send(error) with conn.Send() // TODO: Better errors func (s *Server) handleHistory(user *User, conn net.Conn, hist core.History) error { addr, err := core.ReadAddr(hist.Channel) if err != nil { fmt.Println("cannot parse address:", err) return user.Send(conn, core.Error{core.ErrorNotFound}) } perm, err := s.Storage.GetPermission(user.Addr, hist.Channel) if err != nil { fmt.Println("cannot get message history:", err) return user.Send(conn, core.Error{core.ErrorNotFound}) } if perm.Read == false { fmt.Println("cannot get message history: not authorized") return user.Send(conn, core.Error{core.ErrorNotFound}) } var messages []core.Message count := int(hist.Count) offset := int(hist.Offset) if addr.Type == core.AddrUser { messages, err = s.Storage.GetHistoryUser(user.Addr, hist.Channel, hist.Since, count, offset) } else { messages, err = s.Storage.GetHistory(hist.Channel, hist.Since, count, offset) } if err != nil { fmt.Println("cannot get message history:", err) return user.Send(conn, core.Error{core.ErrorNotFound}) } for _, m := range messages { data, err := core.Encode(m) if err != nil { return err } if _, err := conn.Write(data); err != nil { return err } } return nil }