1 module websocketd.server;
2 
3 import std.socket;
4 import std.experimental.logger;
5 
6 import websocketd.request;
7 import websocketd.frame;
8 
9 alias PeerID = size_t;
10 
11 class WebSocketState {
12    Socket socket;
13    bool handshaken;
14    Frame[] frames = [];
15    public immutable PeerID id;
16    public immutable Address address;
17    public string path;
18 
19    @disable this();
20 
21    this(PeerID id, Socket socket) {
22       this.socket = socket;
23       this.handshaken = false;
24       this.id = id;
25       this.address = cast(immutable Address)(socket.remoteAddress);
26    }
27 
28    public void performHandshake(ubyte[] message) {
29       import std.base64 : Base64;
30       import std.digest.sha : sha1Of;
31       import std.conv : to;
32 
33       assert(!handshaken);
34       enum MAGIC = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
35       enum KEY = "Sec-WebSocket-Key";
36       Request request = Request.parse(message);
37       if (!request.done || KEY !in request.headers)
38          return;
39       this.path = request.path;
40       string accept = Base64.encode(sha1Of(request.headers[KEY] ~ MAGIC)).to!string;
41       assert(socket.isAlive);
42       socket.send(
43             "HTTP/1.1 101 Switching Protocol\r\n" ~ "Upgrade: websocket\r\n" ~ "Connection: Upgrade\r\n"
44             ~ "Sec-WebSocket-Accept: " ~ accept ~ "\r\n\r\n");
45       handshaken = true;
46    }
47 }
48 
49 abstract class WebSocketServer {
50 
51    private WebSocketState[PeerID] sockets;
52    private Socket listener;
53    private size_t maxConnections;
54 
55    abstract void onOpen(PeerID s, string path);
56    abstract void onTextMessage(PeerID s, string s);
57    abstract void onBinaryMessage(PeerID s, ubyte[] o);
58    abstract void onClose(PeerID s);
59 
60    private static PeerID counter = 0;
61 
62    this() {
63       listener = new TcpSocket();
64    }
65 
66    private void add(Socket socket) {
67       if (sockets.length >= maxConnections) {
68          infof("Maximum number of connections reached (%d)", maxConnections);
69          socket.close();
70          return;
71       }
72       auto s = new WebSocketState(counter++, socket);
73       infof("Acception connection from %s (id=%s)", socket.remoteAddress, s.id);
74       sockets[s.id] = s;
75    }
76 
77    private void remove(WebSocketState socket) {
78       sockets.remove(socket.id);
79       infof("Closing connection with client id %s", socket.id);
80       if (socket.socket.isAlive)
81          socket.socket.close();
82       onClose(socket.id);
83    }
84 
85    private void handle(WebSocketState socket, ubyte[] message) {
86       import std.conv : to;
87       import std.algorithm : swap;
88 
89       string processId = typeof(this).stringof ~ socket.id.to!string;
90       if (socket.handshaken) {
91          Frame prevFrame = processId.parse(message);
92          Frame newFrame, temp;
93          do {
94             handleFrame(socket, prevFrame);
95             newFrame = processId.parse([]);
96             swap(newFrame, prevFrame);
97          }
98          while (newFrame != prevFrame);
99       } else {
100          socket.performHandshake(message);
101          if (socket.handshaken)
102             infof("Handshake with %s done (path=%s)", socket.id, socket.path);
103          onOpen(socket.id, socket.path);
104       }
105    }
106 
107    private void handleFrame(WebSocketState socket, Frame frame) {
108       tracef("From client %s received frame: done=%s; fin=%s; op=%s; length=%d", socket.id, frame.done, frame.fin, frame.op, frame.length);
109       if (!frame.done)
110          return;
111       final switch (frame.op) {
112          case Op.CONT:
113             return handleCont(socket, frame);
114          case Op.TEXT:
115             return handleText(socket, frame);
116          case Op.BINARY:
117             return handleBinary(socket, frame);
118          case Op.CLOSE:
119             return handleClose(socket, frame);
120          case Op.PING:
121             return handlePing(socket, frame);
122          case Op.PONG:
123             return handlePong(socket, frame);
124       }
125    }
126 
127    private void handleCont(WebSocketState socket, Frame frame) {
128       assert(socket.frames.length > 0);
129       if (frame.fin) {
130          Op originalOp = socket.frames[0].op;
131          ubyte[] data = [];
132          for (size_t i = 0; i < socket.frames.length; i++)
133             data ~= socket.frames[i].data;
134          data ~= frame.data;
135          socket.frames = [];
136          if (originalOp == Op.TEXT)
137             onTextMessage(socket.id, cast(string)data);
138          else if (originalOp == Op.BINARY)
139             onBinaryMessage(socket.id, data);
140       } else
141          socket.frames ~= frame;
142    }
143 
144    private void handleText(WebSocketState socket, Frame frame) {
145       assert(socket.frames.length == 0);
146       if (frame.fin)
147          onTextMessage(socket.id, cast(string)frame.data);
148       else
149          socket.frames ~= frame;
150    }
151 
152    private void handleBinary(WebSocketState socket, Frame frame) {
153       assert(socket.frames.length == 0);
154       if (frame.fin)
155          onBinaryMessage(socket.id, frame.data);
156       else
157          socket.frames ~= frame;
158    }
159 
160    private void handleClose(WebSocketState socket, Frame frame) {
161       remove(socket);
162    }
163 
164    private void handlePing(WebSocketState socket, Frame frame) {
165       socket.socket.send(Frame(true, Op.PONG, false, 0, [0, 0, 0, 0], true, []).serialize);
166    }
167 
168    private void handlePong(WebSocketState socket, Frame frame) {
169       tracef("Received pong from %s", socket.id);
170    }
171 
172    public void sendText(PeerID dest, string message) {
173       if (dest !in sockets) {
174          warningf("Tried to send a message to %s which is not connected", dest);
175          return;
176       }
177       import std.string : representation;
178 
179       auto bytes = message.representation.dup;
180       auto frame = Frame(true, Op.TEXT, false, message.length, [0, 0, 0, 0], true, bytes);
181       auto serial = frame.serialize;
182       tracef("Sending %d bytes to %s in one frame of %d bytes long", bytes.length, dest, serial.length);
183       sockets[dest].socket.send(serial);
184    }
185 
186    public void sendBinary(PeerID dest, ubyte[] message) {
187       if (dest !in sockets) {
188          warningf("Tried to send a message to %s which is not connected", dest);
189          return;
190       }
191       auto frame = Frame(true, Op.BINARY, false, message.length, [0, 0, 0, 0], true, message);
192       auto serial = frame.serialize;
193       tracef("Sending %d bytes to %s in one frame of %d bytes long", message.length, dest, serial.length);
194       sockets[dest].socket.send(serial);
195    }
196 
197    public void run(ushort port, size_t maxConnections, size_t bufferSize = 1024)() {
198       this.maxConnections = maxConnections;
199 
200       listener.blocking = false;
201       listener.bind(new InternetAddress(port));
202       listener.listen(10);
203 
204       infof("Listening on port: %d", port);
205       infof("Maximum allowed connections: %d", maxConnections);
206 
207       auto set = new SocketSet(maxConnections + 1);
208       while (true) {
209          set.add(listener);
210          foreach (id, s; sockets)
211             set.add(s.socket);
212          Socket.select(set, null, null);
213 
214          foreach (id, socket; sockets) {
215             if (!set.isSet(socket.socket))
216                continue;
217             ubyte[bufferSize] buffer;
218             long receivedLength = socket.socket.receive(buffer[]);
219             tracef("Received %d bytes from %s", receivedLength, socket.id);
220             if (receivedLength > 0) {
221                handle(socket, buffer[0 .. receivedLength]);
222                continue;
223             }
224             remove(socket);
225          }
226 
227          if (set.isSet(listener)) {
228             add(listener.accept());
229          }
230 
231          set.reset();
232       }
233    }
234 }