From eec10d41af62fb9a93cd5fd79dcf94616701cc2a Mon Sep 17 00:00:00 2001 From: bt Date: Sun, 19 Apr 2026 21:32:53 +0200 Subject: [common] Basic group channels support --- server/channel.go | 48 ++++++++++++++++++++++++++++++++++++++ server/message.go | 40 ++++++++++++++++++++------------ server/remote.go | 23 ++++++++++++++++++- server/server.go | 34 +++++++++------------------ server/user.go | 69 ++++++++++++++++++++++++++++++++++++++++++++++++++----- 5 files changed, 169 insertions(+), 45 deletions(-) create mode 100644 server/channel.go (limited to 'server') diff --git a/server/channel.go b/server/channel.go new file mode 100644 index 0000000..11969f2 --- /dev/null +++ b/server/channel.go @@ -0,0 +1,48 @@ +package server + +import ( + "log" + "net" + "sync" + + "go.rctt.net/solec/core" +) + +type Channel struct { + Name string + Users map[string]*User + UsersMu sync.RWMutex +} + +func NewChannel(name string) *Channel { + return &Channel{ + Name: name, + Users: make(map[string]*User), + } +} + +func (c *Channel) Add(u *User) { + c.UsersMu.Lock() + c.Users[u.Name] = u + u.Channels[c.Name] = c + c.UsersMu.Unlock() + + log.Println("user joined a channel") +} + +func (c *Channel) Remove(u *User) { + c.UsersMu.Lock() + delete(c.Users, u.Name) + delete(u.Channels, c.Name) + c.UsersMu.Unlock() + + log.Println("user left a channel") +} + +func (c *Channel) Send(senderConn net.Conn, msg core.Message) { + for _, u := range c.Users { + if err := u.Send(senderConn, msg); err != nil { + log.Print("cannot send a message to user on channel", err) + } + } +} diff --git a/server/message.go b/server/message.go index 2487abd..58f2908 100644 --- a/server/message.go +++ b/server/message.go @@ -1,9 +1,9 @@ package server import ( - "errors" "fmt" "log" + "net" "time" "go.rctt.net/solec/core" @@ -30,34 +30,44 @@ func (s *Server) SendBroadcast(msg string) { } } -func (s *Server) handleMessage(msg core.Message) error { +func (s *Server) handleMessage(sender net.Conn, msg core.Message) error { log.Println("message:", msg.Source, "->", msg.Target, msg.Content) - channel, host, err := core.ReadAddr(msg.Target) + addr, err := core.ReadAddr(msg.Target) if err != nil { return err } - if host == s.name { - return s.handleLocalMessage(channel, msg) + if addr.Host == s.name { + return s.handleLocalMessage(sender, addr, msg) } - return s.handleOutboundMessage(channel, host, msg) + return s.handleOutboundMessage(sender, addr, msg) } -func (s *Server) handleLocalMessage(channel string, msg core.Message) error { - s.usersMu.RLock() - user, ok := s.users[channel] - if !ok { - return errors.New("target not found") +func (s *Server) handleLocalMessage(sender net.Conn, addr core.Addr, msg core.Message) error { + if addr.Type == core.AddrUser { + s.usersMu.RLock() + user, ok := s.users[addr.Channel] + if !ok { + return core.Send(sender, core.Error{core.ErrorNotFound}) + } + s.usersMu.RUnlock() + return user.Send(sender, msg) } - s.usersMu.RUnlock() - return user.Send(msg) + s.channelsMu.RLock() + channel, ok := s.channels[addr.Channel] + if !ok { + return core.Send(sender, core.Error{core.ErrorNotFound}) + } + s.channelsMu.RUnlock() + channel.Send(sender, msg) + return nil } -func (s *Server) handleOutboundMessage(channel, host string, msg core.Message) error { - remote, err := s.getRemote(host) +func (s *Server) handleOutboundMessage(sender net.Conn, addr core.Addr, msg core.Message) error { + remote, err := s.getRemote(addr.Host) if err != nil { return fmt.Errorf("cannot access remote server: %w", err) } diff --git a/server/remote.go b/server/remote.go index 2449511..5d86da2 100644 --- a/server/remote.go +++ b/server/remote.go @@ -45,7 +45,7 @@ func (s *Server) handleServerConn(conn net.Conn) { s.serversMu.Unlock() }() - if err := s.readInput(conn); err != nil { + if err := s.readRemoteInput(conn); err != nil { log.Println(err) } } @@ -110,3 +110,24 @@ func (s *Server) initRemoteConn(name string) (net.Conn, error) { return conn, nil } + +func (s *Server) readRemoteInput(conn net.Conn) error { + for { + payload, err := core.Decode(conn) + if err != nil { + return err + } + if err := s.handleRemotePayload(conn, payload); err != nil { + log.Print("handler error: ", err) + } + } +} + +func (s *Server) handleRemotePayload(sender net.Conn, payload any) error { + switch v := payload.(type) { + case core.Message: + return s.handleMessage(sender, v) + default: + return core.ErrUnexpectedPayloadType + } +} diff --git a/server/server.go b/server/server.go index 7573968..ef2bb5f 100644 --- a/server/server.go +++ b/server/server.go @@ -12,10 +12,12 @@ import ( type Server struct { listenAddr string name string - users map[string]User + users map[string]User // TODO: Use full address instead of just name servers map[string]RemoteServer + channels map[string]*Channel usersMu sync.RWMutex serversMu sync.RWMutex + channelsMu sync.RWMutex } func NewServer(listenAddr string, name string) *Server { @@ -24,6 +26,7 @@ func NewServer(listenAddr string, name string) *Server { name: name, users: make(map[string]User), servers: make(map[string]RemoteServer), + channels: make(map[string]*Channel), } } @@ -45,6 +48,13 @@ func (s *Server) Start() error { } } +func (s *Server) AddChannel(name string) { + s.channelsMu.Lock() + s.channelsMu.Unlock() + s.channels[name] = NewChannel(name) + log.Println("created channel", name) +} + func (s *Server) handleConn(conn net.Conn) { defer conn.Close() @@ -87,25 +97,3 @@ func (s *Server) performHandshake(conn net.Conn) (core.ConnType, error) { return clientHs.ConnType, nil } - -func (s *Server) readInput(conn net.Conn) error { - for { - payload, err := core.Decode(conn) - if err != nil { - return err - } - if err := s.handlePayload(payload); err != nil { - log.Print("handler error: ", err) - } - } -} - -func (s *Server) handlePayload(payload any) error { - switch v := payload.(type) { - case core.Message: - return s.handleMessage(v) - - default: - return core.ErrUnexpectedPayloadType - } -} diff --git a/server/user.go b/server/user.go index 5d2731c..4f78b6a 100644 --- a/server/user.go +++ b/server/user.go @@ -8,22 +8,27 @@ import ( ) type User struct { - Name string - Conns map[net.Conn]struct{} + Name string + Conns map[net.Conn]struct{} + Channels map[string]*Channel } func NewUser(conn net.Conn, name string) User { u := User{ - Name: name, - Conns: make(map[net.Conn]struct{}), + Name: name, + Conns: make(map[net.Conn]struct{}), + Channels: make(map[string]*Channel), } u.Conns[conn] = struct{}{} return u } -func (u *User) Send(payload core.Wrapper) error { +func (u *User) Send(senderConn net.Conn, payload core.Wrapper) error { for c := range u.Conns { + if c == senderConn { + continue + } if err := core.Send(c, payload); err != nil { return err } @@ -62,7 +67,7 @@ func (s *Server) handleUserConn(conn net.Conn) { s.usersMu.Unlock() }() - if err := s.readInput(conn); err != nil { + if err := s.readUserInput(&user, conn); err != nil { log.Println(err) } } @@ -94,3 +99,55 @@ func (s *Server) performUserAuth(conn net.Conn) (string, error) { return clientAuth.Name, nil } + +func (s *Server) readUserInput(user *User, conn net.Conn) error { + for { + payload, err := core.Decode(conn) + if err != nil { + return err + } + if err := s.handleUserPayload(user, conn, payload); err != nil { + log.Print("handler error: ", err) + } + } +} + +func (s *Server) handleUserPayload(user *User, sender net.Conn, payload any) error { + switch v := payload.(type) { + case core.Message: + return s.handleMessage(sender, v) + case core.Usermode: + return s.handleUsermode(user, sender, v) + default: + return core.ErrUnexpectedPayloadType + } +} + +func (s *Server) handleUsermode(user *User, conn net.Conn, mode core.Usermode) error { + addr, err := core.ReadAddr(mode.UserAddr) + if err != nil { + return err + } + + if user.Name != addr.Channel { + log.Println("unauthorized") + return user.Send(conn, core.Error{core.ErrorUnauthorized}) + } + + s.channelsMu.RLock() + channel, ok := s.channels[mode.ChannelName] + if !ok { + log.Println("not found", addr.Channel) + return user.Send(conn, core.Error{core.ErrorNotFound}) + } + s.channelsMu.RUnlock() + + switch mode.Mode { + case core.UsermodeNone: + channel.Remove(user) + case core.UsermodeInChannel: + channel.Add(user) + } + + return nil +} -- cgit v1.2.3