summaryrefslogtreecommitdiffstats
path: root/server
diff options
context:
space:
mode:
Diffstat (limited to 'server')
-rw-r--r--server/channel.go48
-rw-r--r--server/message.go40
-rw-r--r--server/remote.go23
-rw-r--r--server/server.go34
-rw-r--r--server/user.go69
5 files changed, 169 insertions, 45 deletions
diff --git a/server/channel.go b/server/channel.go
new file mode 100644
index 0000000..11969f2
--- /dev/null
+++ b/server/channel.go
@@ -0,0 +1,48 @@
+package server
+
+import (
+ "log"
+ "net"
+ "sync"
+
+ "go.rctt.net/solec/core"
+)
+
+type Channel struct {
+ Name string
+ Users map[string]*User
+ UsersMu sync.RWMutex
+}
+
+func NewChannel(name string) *Channel {
+ return &Channel{
+ Name: name,
+ Users: make(map[string]*User),
+ }
+}
+
+func (c *Channel) Add(u *User) {
+ c.UsersMu.Lock()
+ c.Users[u.Name] = u
+ u.Channels[c.Name] = c
+ c.UsersMu.Unlock()
+
+ log.Println("user joined a channel")
+}
+
+func (c *Channel) Remove(u *User) {
+ c.UsersMu.Lock()
+ delete(c.Users, u.Name)
+ delete(u.Channels, c.Name)
+ c.UsersMu.Unlock()
+
+ log.Println("user left a channel")
+}
+
+func (c *Channel) Send(senderConn net.Conn, msg core.Message) {
+ for _, u := range c.Users {
+ if err := u.Send(senderConn, msg); err != nil {
+ log.Print("cannot send a message to user on channel", err)
+ }
+ }
+}
diff --git a/server/message.go b/server/message.go
index 2487abd..58f2908 100644
--- a/server/message.go
+++ b/server/message.go
@@ -1,9 +1,9 @@
package server
import (
- "errors"
"fmt"
"log"
+ "net"
"time"
"go.rctt.net/solec/core"
@@ -30,34 +30,44 @@ func (s *Server) SendBroadcast(msg string) {
}
}
-func (s *Server) handleMessage(msg core.Message) error {
+func (s *Server) handleMessage(sender net.Conn, msg core.Message) error {
log.Println("message:", msg.Source, "->", msg.Target, msg.Content)
- channel, host, err := core.ReadAddr(msg.Target)
+ addr, err := core.ReadAddr(msg.Target)
if err != nil {
return err
}
- if host == s.name {
- return s.handleLocalMessage(channel, msg)
+ if addr.Host == s.name {
+ return s.handleLocalMessage(sender, addr, msg)
}
- return s.handleOutboundMessage(channel, host, msg)
+ return s.handleOutboundMessage(sender, addr, msg)
}
-func (s *Server) handleLocalMessage(channel string, msg core.Message) error {
- s.usersMu.RLock()
- user, ok := s.users[channel]
- if !ok {
- return errors.New("target not found")
+func (s *Server) handleLocalMessage(sender net.Conn, addr core.Addr, msg core.Message) error {
+ if addr.Type == core.AddrUser {
+ s.usersMu.RLock()
+ user, ok := s.users[addr.Channel]
+ if !ok {
+ return core.Send(sender, core.Error{core.ErrorNotFound})
+ }
+ s.usersMu.RUnlock()
+ return user.Send(sender, msg)
}
- s.usersMu.RUnlock()
- return user.Send(msg)
+ s.channelsMu.RLock()
+ channel, ok := s.channels[addr.Channel]
+ if !ok {
+ return core.Send(sender, core.Error{core.ErrorNotFound})
+ }
+ s.channelsMu.RUnlock()
+ channel.Send(sender, msg)
+ return nil
}
-func (s *Server) handleOutboundMessage(channel, host string, msg core.Message) error {
- remote, err := s.getRemote(host)
+func (s *Server) handleOutboundMessage(sender net.Conn, addr core.Addr, msg core.Message) error {
+ remote, err := s.getRemote(addr.Host)
if err != nil {
return fmt.Errorf("cannot access remote server: %w", err)
}
diff --git a/server/remote.go b/server/remote.go
index 2449511..5d86da2 100644
--- a/server/remote.go
+++ b/server/remote.go
@@ -45,7 +45,7 @@ func (s *Server) handleServerConn(conn net.Conn) {
s.serversMu.Unlock()
}()
- if err := s.readInput(conn); err != nil {
+ if err := s.readRemoteInput(conn); err != nil {
log.Println(err)
}
}
@@ -110,3 +110,24 @@ func (s *Server) initRemoteConn(name string) (net.Conn, error) {
return conn, nil
}
+
+func (s *Server) readRemoteInput(conn net.Conn) error {
+ for {
+ payload, err := core.Decode(conn)
+ if err != nil {
+ return err
+ }
+ if err := s.handleRemotePayload(conn, payload); err != nil {
+ log.Print("handler error: ", err)
+ }
+ }
+}
+
+func (s *Server) handleRemotePayload(sender net.Conn, payload any) error {
+ switch v := payload.(type) {
+ case core.Message:
+ return s.handleMessage(sender, v)
+ default:
+ return core.ErrUnexpectedPayloadType
+ }
+}
diff --git a/server/server.go b/server/server.go
index 7573968..ef2bb5f 100644
--- a/server/server.go
+++ b/server/server.go
@@ -12,10 +12,12 @@ import (
type Server struct {
listenAddr string
name string
- users map[string]User
+ users map[string]User // TODO: Use full address instead of just name
servers map[string]RemoteServer
+ channels map[string]*Channel
usersMu sync.RWMutex
serversMu sync.RWMutex
+ channelsMu sync.RWMutex
}
func NewServer(listenAddr string, name string) *Server {
@@ -24,6 +26,7 @@ func NewServer(listenAddr string, name string) *Server {
name: name,
users: make(map[string]User),
servers: make(map[string]RemoteServer),
+ channels: make(map[string]*Channel),
}
}
@@ -45,6 +48,13 @@ func (s *Server) Start() error {
}
}
+func (s *Server) AddChannel(name string) {
+ s.channelsMu.Lock()
+ s.channelsMu.Unlock()
+ s.channels[name] = NewChannel(name)
+ log.Println("created channel", name)
+}
+
func (s *Server) handleConn(conn net.Conn) {
defer conn.Close()
@@ -87,25 +97,3 @@ func (s *Server) performHandshake(conn net.Conn) (core.ConnType, error) {
return clientHs.ConnType, nil
}
-
-func (s *Server) readInput(conn net.Conn) error {
- for {
- payload, err := core.Decode(conn)
- if err != nil {
- return err
- }
- if err := s.handlePayload(payload); err != nil {
- log.Print("handler error: ", err)
- }
- }
-}
-
-func (s *Server) handlePayload(payload any) error {
- switch v := payload.(type) {
- case core.Message:
- return s.handleMessage(v)
-
- default:
- return core.ErrUnexpectedPayloadType
- }
-}
diff --git a/server/user.go b/server/user.go
index 5d2731c..4f78b6a 100644
--- a/server/user.go
+++ b/server/user.go
@@ -8,22 +8,27 @@ import (
)
type User struct {
- Name string
- Conns map[net.Conn]struct{}
+ Name string
+ Conns map[net.Conn]struct{}
+ Channels map[string]*Channel
}
func NewUser(conn net.Conn, name string) User {
u := User{
- Name: name,
- Conns: make(map[net.Conn]struct{}),
+ Name: name,
+ Conns: make(map[net.Conn]struct{}),
+ Channels: make(map[string]*Channel),
}
u.Conns[conn] = struct{}{}
return u
}
-func (u *User) Send(payload core.Wrapper) error {
+func (u *User) Send(senderConn net.Conn, payload core.Wrapper) error {
for c := range u.Conns {
+ if c == senderConn {
+ continue
+ }
if err := core.Send(c, payload); err != nil {
return err
}
@@ -62,7 +67,7 @@ func (s *Server) handleUserConn(conn net.Conn) {
s.usersMu.Unlock()
}()
- if err := s.readInput(conn); err != nil {
+ if err := s.readUserInput(&user, conn); err != nil {
log.Println(err)
}
}
@@ -94,3 +99,55 @@ func (s *Server) performUserAuth(conn net.Conn) (string, error) {
return clientAuth.Name, nil
}
+
+func (s *Server) readUserInput(user *User, conn net.Conn) error {
+ for {
+ payload, err := core.Decode(conn)
+ if err != nil {
+ return err
+ }
+ if err := s.handleUserPayload(user, conn, payload); err != nil {
+ log.Print("handler error: ", err)
+ }
+ }
+}
+
+func (s *Server) handleUserPayload(user *User, sender net.Conn, payload any) error {
+ switch v := payload.(type) {
+ case core.Message:
+ return s.handleMessage(sender, v)
+ case core.Usermode:
+ return s.handleUsermode(user, sender, v)
+ default:
+ return core.ErrUnexpectedPayloadType
+ }
+}
+
+func (s *Server) handleUsermode(user *User, conn net.Conn, mode core.Usermode) error {
+ addr, err := core.ReadAddr(mode.UserAddr)
+ if err != nil {
+ return err
+ }
+
+ if user.Name != addr.Channel {
+ log.Println("unauthorized")
+ return user.Send(conn, core.Error{core.ErrorUnauthorized})
+ }
+
+ s.channelsMu.RLock()
+ channel, ok := s.channels[mode.ChannelName]
+ if !ok {
+ log.Println("not found", addr.Channel)
+ return user.Send(conn, core.Error{core.ErrorNotFound})
+ }
+ s.channelsMu.RUnlock()
+
+ switch mode.Mode {
+ case core.UsermodeNone:
+ channel.Remove(user)
+ case core.UsermodeInChannel:
+ channel.Add(user)
+ }
+
+ return nil
+}