1 module websocketd.frame;
2 import websocketd.checkpoint;
3 
4 enum Op : ubyte {
5    CONT = 0,
6    TEXT = 1,
7    BINARY = 2,
8    CLOSE = 8,
9    PING = 9,
10    PONG = 10
11 }
12 
13 struct Frame {
14    bool fin;
15    Op op;
16    bool masked;
17    ulong length;
18    ubyte[4] mask;
19    bool done = false;
20    ubyte[] data;
21 
22    ulong remaining() @property {
23       return this.length - this.data.length;
24    }
25 
26    ubyte[] serialize() {
27       ubyte[] result = [];
28 
29       result ~= cast(ubyte)(fin ? 1 << 7 : 0) ^ cast(ubyte)op;
30 
31       ubyte b2 = masked ? 1 << 7 : 0;
32       if (length < 126) {
33          result ~= b2 ^ cast(ubyte)length;
34       } else if ((length >> 16) == 0) {
35          result ~= b2 ^ 126;
36          ubyte[2] lens;
37          lens[1] = cast(ubyte)length & 0b11111111;
38          lens[0] = cast(ubyte)(length >> 8) & 0b11111111;
39          result ~= lens;
40       } else {
41          result ~= b2 ^ 127;
42          ubyte[8] lens;
43          for (size_t i = 0; i < 8; i++)
44             lens[7 - i] = cast(ubyte)(length >> (i * 8)) & 0b11111111;
45          result ~= lens;
46       }
47 
48       if (masked)
49          result ~= mask;
50 
51       if (masked)
52          for (size_t i = 0; i < data.length; i++)
53             result ~= data[i] ^ mask[i % 4];
54       else
55          result ~= data;
56 
57       return result;
58    }
59 }
60 
61 auto next(size_t n = 1)(ref ubyte[] data, size_t m = n) {
62    assert(data.length >= m);
63    static if (n == 1) {
64       ubyte b = data[0];
65       data = data[1 .. $];
66       return b;
67    } else {
68       ubyte[] bs = data[0 .. m];
69       data = data[m .. $];
70       return bs;
71    }
72 }
73 
74 Frame parse(string source, ubyte[] data) {
75    // `Frame` is what we're building
76    // "data" is the name of the variable that contains the data to consume
77    // "frame" is the name of the variable we're building
78    // "source" is the name of a "session identifier"
79    mixin(CheckpointSetup!Frame("data", "frame", "source", // fin_rsv_opcode is the name of the first checkpoint
80          // `data.length >= 1` is the condition to enter this state
81          // (otherwise what's after the mixin gets executed)
82          "fin_rsv_opcode".Checkpoint(q{ data.length >= 1 }, q{
83             frame = Frame.init;
84             ubyte b = data.next; // next() modifies `data` by consuming the first byte
85             frame.fin = cast(bool) (b >>> 7);
86             assert(((b >>> 4) & 0b111) == 0);
87             frame.op = cast(Op) (b & 0b1111);
88         }), "mask_len".Checkpoint(q{ data.length >= 1 }, q{
89             ubyte b = data.next;
90             frame.masked = cast(bool) (b >>> 7);
91             frame.length = cast(ulong) (b & 0b1111111);
92             if (frame.length <= 125) mixin (changeState("maskOn_mask"));
93             if (frame.length == 127) mixin (changeState("len127_ext_len"));
94         }), "len126_ext_len".Checkpoint(q{ data.length >= 2 }, q{
95             frame.length = cast(ulong) data.next;
96             frame.length <<= 8;
97             frame.length += cast(ulong) data.next;
98             mixin (changeState("maskOn_mask")); // edge case: length=127
99         }), "len127_ext_len".Checkpoint(q{ data.length >= 8 }, q{
100             frame.length = cast(ulong) data.next;
101             for (int i=0; i<7; i++) {
102                 frame.length <<= 8;
103                 frame.length += cast(ulong) data.next;
104             }
105         }), "maskOn_mask".Checkpoint(q{ data.length >= (frame.masked ? 4 : 0) }, q{
106             if (frame.masked) {
107                 frame.mask = data.next!4; // next!n when n > 1 returns ubyte[]
108             }
109         }), // we don't want to wait for all the data to arrive at once (if we wanted then the condition
110          // should be `data.length >= frame.length`), we prefer processing as it comes
111          "message_extraction".Checkpoint(q{ (frame.length > 0 && data.length >= 1) || (frame.length == 0) }, q{
112             if (frame.masked) {
113                 size_t i = frame.data.length;
114                 while (frame.remaining > 0 && data.length > 0) {
115                     ubyte b = data.next;
116                     frame.data ~= b ^ frame.mask[i % 4];
117                     i++;
118                 }
119             } else if (data.length >= frame.remaining)
120                 frame.data ~= data.next!2(frame.remaining);
121             else frame.data ~= data.next!2(data.length);
122         }), "done".Checkpoint(q{ true }, q{
123             // to allow for streaming we have this changeState(..) loop
124             if (frame.remaining > 0)
125                 mixin (changeState("message_extraction"));
126             frame.done = true;
127         })));
128 
129    return frame;
130 }
131 
132 unittest { // test multiple frames in one go
133    auto f1 = Frame(true, Op.TEXT, true, 6, [0, 0, 0, 0], true, [0, 1, 2, 3, 4, 5]);
134    auto f2 = Frame(false, Op.BINARY, true, 3, [0, 1, 2, 3], true, [8, 7, 6]);
135    auto f3 = Frame(false, Op.CLOSE, true, 10, [0, 1, 2, 3], true, [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);
136    auto d = f1.serialize ~ f2.serialize ~ f3.serialize;
137    auto f4 = "u0".parse(d);
138    auto f5 = "u0".parse([]);
139    auto f6 = "u0".parse([]);
140    assert(f1 == f4);
141    assert(f2 == f5);
142    assert(f3 == f6);
143 }
144 
145 unittest { // test streaming one byte at a time
146    auto f = Frame(true, Op.TEXT, true, 6, [1, 2, 3, 4], true, [0, 1, 2, 3, 4, 5]);
147    ubyte[] data = f.serialize;
148    foreach (b; data[0 .. $ - 1]) {
149       auto _f = "u1".parse([b]);
150       assert(!_f.done);
151    }
152    auto _f = "u1".parse([data[$ - 1]]);
153    assert(_f.done);
154    assert(f == _f);
155 }
156 
157 unittest { // test some funky streaming
158    ubyte[] data;
159    for (size_t i = 0; i < 1024 * 1024; i++)
160       data ~= cast(ubyte)i;
161    auto f = Frame(false, Op.BINARY, true, data.length, [0, 0, 0, 0], true, data);
162    ubyte[] serialized = f.serialize;
163    size_t i0 = 0, i1 = 0, t = 1;
164    do {
165       i0 = i1;
166       i1 = i0 + (((i0 & i1 | 0b11) ^ t) & 0b111111);
167       if (i1 >= serialized.length)
168          i1 = serialized.length;
169       t++;
170       auto _f = "u2".parse(serialized[i0 .. i1]);
171       if (i1 == serialized.length) {
172          assert(_f.done);
173          assert(_f == f);
174       } else
175          assert(!_f.done);
176    }
177    while (i1 < serialized.length);
178 }
179 
180 unittest { // test edge-case length=127
181    ubyte[] data;
182    for (size_t i = 0; i < 127; i++)
183       data ~= cast(ubyte)i;
184    auto f = Frame(true, Op.BINARY, false, data.length, [0, 0, 0, 0], true, data);
185    auto _f = "u3".parse(f.serialize);
186    assert(f == _f);
187 }
188 
189 unittest { // test edge-case length=0
190    import std.stdio;
191 
192    auto f = Frame(true, Op.CLOSE, false, 0, [0, 0, 0, 0], true, []);
193    auto _f = "u4".parse(f.serialize);
194    assert(f == _f);
195 }