1 /**
2    A reference-counted smart pointer.
3  */
4 module automem.ref_counted;
5 
6 import automem.traits: isAllocator;
7 import automem.test_utils: TestUtils;
8 import automem.unique: Unique;
9 import stdx.allocator: theAllocator, processAllocator;
10 import std.typecons: Flag;
11 
12 version(unittest) {
13     import unit_threaded;
14     import test_allocator: TestAllocator;
15 }
16 
17 mixin TestUtils;
18 
19 /**
20    A reference-counted smart pointer similar to C++'s std::shared_ptr.
21  */
22 struct RefCounted(Type,
23                   Allocator = typeof(theAllocator),
24                   Flag!"supportGC" supportGC = Flag!"supportGC".yes)
25     if(isAllocator!Allocator)
26 {
27 
28     import std.traits: hasMember;
29 
30     enum isSingleton = hasMember!(Allocator, "instance");
31     enum isTheAllocator = is(Allocator == typeof(theAllocator));
32     enum isGlobal = isSingleton || isTheAllocator;
33 
34     static if(isGlobal)
35         /**
36            The allocator is a singleton, so no need to pass it in to the
37            constructor
38         */
39         this(Args...)(auto ref Args args) {
40             this.makeObject!args();
41         }
42     else
43         /**
44            Non-singleton allocator, must be passed in
45         */
46         this(Args...)(Allocator allocator, auto ref Args args) {
47             _allocator = allocator;
48             this.makeObject!args();
49         }
50 
51     static if(isGlobal)
52         /**
53             Factory method so can construct with zero args.
54         */
55         static typeof(this) construct(Args...)(auto ref Args args) {
56             static if (Args.length != 0)
57                 return typeof(return)(args);
58             else {
59                 typeof(return) ret;
60                 ret.makeObject!()();
61                 return ret;
62             }
63         }
64     else
65         /**
66             Factory method. Not necessary with non-global allocator
67             but included for symmetry.
68         */
69         static typeof(this) construct(Args...)(auto ref Allocator allocator, auto ref Args args) {
70             return typeof(return)(allocator, args);
71         }
72 
73     ///
74     this(this) {
75         assert(_impl !is null);
76         inc;
77     }
78 
79     ///
80     ~this() {
81         release;
82     }
83 
84     /**
85        Assign to an lvalue RefCounted
86     */
87     void opAssign(ref RefCounted other) {
88 
89         if (_impl == other._impl) return;
90 
91         if(_impl !is null) release;
92 
93         static if(!isGlobal)
94             _allocator = other._allocator;
95 
96         _impl = other._impl;
97 
98         if(_impl !is null) inc;
99     }
100 
101     /**
102        Assign to an rvalue RefCounted
103      */
104     void opAssign(RefCounted other) {
105         import std.algorithm: swap;
106         swap(_impl, other._impl);
107         static if(!isGlobal)
108             swap(_allocator, other._allocator);
109     }
110 
111     /**
112        Dereference the smart pointer and yield a reference
113        to the contained type.
114      */
115     ref auto opUnary(string s)() inout if (s == "*") {
116         return _impl._get;
117     }
118 
119     /**
120         Prevent opSlice and opIndex from being hidden by Impl*.
121         This comment is deliberately not DDOC.
122     */
123     auto ref opSlice(A...)(auto ref A args)
124     if (__traits(compiles, Type.init.opSlice(args)))
125     {
126         return _impl._get.opSlice(args);
127     }
128     /// ditto
129     auto ref opIndex(A...)(auto ref A args)
130     if (__traits(compiles, Type.init.opIndex(args)))
131     {
132         return _impl._get.opIndex(args);
133     }
134     /// ditto
135     auto ref opIndexAssign(A...)(auto ref A args)
136     if (__traits(compiles, Type.init.opIndexAssign(args)))
137     {
138         return _impl._get.opIndexAssign(args);
139     }
140 
141     alias _impl this;
142 
143 private:
144 
145     static struct Impl {
146 
147         static if(is(Type == class)) {
148 
149             align ((void*).alignof)
150             void[__traits(classInstanceSize, Type)] _rawMemory;
151 
152         } else
153             Type _object;
154 
155         static if(is(Type == shared))
156             shared size_t _count;
157         else
158             size_t _count;
159 
160         static if (is(Type == class)) {
161             inout(Type) _get() inout {
162                 return cast(inout(Type))&_rawMemory[0];
163             }
164 
165             inout(shared(Type)) _get() inout shared {
166                 return cast(inout(shared(Type)))&_rawMemory[0];
167             }
168         } else {
169             ref inout(Type) _get() inout {
170                 return _object;
171             }
172 
173             ref inout(shared(Type)) _get() inout shared {
174                 return _object;
175             }
176         }
177 
178         alias _get this;
179     }
180 
181     static if(isSingleton)
182         alias _allocator = Allocator.instance;
183     else static if(isTheAllocator) {
184         static if (is(Type == shared))
185             // 'processAllocator' should be used for allocating
186             // memory shared across threads
187             alias _allocator = processAllocator;
188         else
189             alias _allocator = theAllocator;
190     }
191     else
192         Allocator _allocator;
193 
194     static if(is(Type == shared))
195         alias ImplType = shared Impl;
196     else
197         alias ImplType = Impl;
198 
199     public ImplType* _impl; // public or alias this doesn't work
200 
201     void allocateImpl() {
202         import stdx.allocator: make;
203         import std.traits: hasIndirections;
204 
205         _impl = cast(typeof(_impl))_allocator.allocate(Impl.sizeof);
206         _impl._count= 1;
207 
208         static if (is(Type == class)) {
209             // class representation:
210             // void* classInfoPtr
211             // void* monitorPtr
212             // []    interfaces
213             // T...  members
214             import core.memory: GC;
215             if (supportGC && !(typeid(Type).m_flags & TypeInfo_Class.ClassFlags.noPointers))
216                 // members have pointers: we have to watch the monitor
217                 // and all members; skip the classInfoPtr
218                 GC.addRange(&_impl._rawMemory[(void*).sizeof],
219                         __traits(classInstanceSize, Type) - (void*).sizeof);
220             else
221                 // representation doesn't have pointers, just watch the
222                 // monitor pointer; skip the classInfoPtr
223                 // need to watch the monitor pointer even if supportGC is false.
224                 GC.addRange(&_impl._rawMemory[(void*).sizeof], (void*).sizeof);
225         } else static if (supportGC && hasIndirections!Type) {
226             import core.memory: GC;
227             GC.addRange(&_impl._object, Type.sizeof);
228         }
229     }
230 
231     void release() {
232         import std.traits : hasIndirections;
233         import core.memory : GC;
234         import automem.utils : destruct;
235         if(_impl is null) return;
236         assert(_impl._count > 0, "Trying to release a RefCounted but ref count is 0 or less");
237 
238         dec;
239 
240         if(_impl._count == 0) {
241             destruct(_impl._get);
242             static if (is(Type == class)) {
243                 // need to watch the monitor pointer even if supportGC is false.
244                 GC.removeRange(&_impl._rawMemory[(void*).sizeof]);
245             } else static if (supportGC && hasIndirections!Type) {
246                 GC.removeRange(&_impl._object);
247             }
248             auto mem = cast(void*)_impl;
249             _allocator.deallocate(() @trusted { return mem[0 .. Impl.sizeof]; }());
250         }
251     }
252 
253     void inc() {
254         static if(is(Type == shared)) {
255             import core.atomic: atomicOp;
256             _impl._count.atomicOp!"+="(1);
257         } else
258             ++_impl._count;
259 
260     }
261 
262     void dec() {
263         static if(is(Type == shared)) {
264             import core.atomic: atomicOp;
265             _impl._count.atomicOp!"-="(1);
266         } else
267             --_impl._count;
268     }
269 
270 }
271 
272 private template makeObject(args...)
273 {
274     void makeObject(Type, A)(ref RefCounted!(Type, A) rc) @trusted {
275         import std.conv: emplace;
276         import std.functional : forward;
277 
278         rc.allocateImpl;
279 
280         static if(is(Type == class))
281             emplace!Type(rc._impl._rawMemory, forward!args);
282         else
283             emplace(&rc._impl._object, forward!args);
284     }
285 }
286 
287 ///
288 @("struct test allocator no copies")
289 @system unittest {
290     auto allocator = TestAllocator();
291     {
292         auto ptr = RefCounted!(Struct, TestAllocator*)(&allocator, 5);
293         Struct.numStructs.shouldEqual(1);
294     }
295     Struct.numStructs.shouldEqual(0);
296 }
297 
298 @("struct test allocator one lvalue assignment")
299 @system unittest {
300     auto allocator = TestAllocator();
301     {
302         auto ptr1 = RefCounted!(Struct, TestAllocator*)(&allocator, 5);
303         Struct.numStructs.shouldEqual(1);
304 
305         RefCounted!(Struct, TestAllocator*) ptr2;
306         ptr2 = ptr1;
307         Struct.numStructs.shouldEqual(1);
308     }
309     Struct.numStructs.shouldEqual(0);
310 }
311 
312 @("struct test allocator one lvalue assignment from T.init")
313 @system unittest {
314 
315     auto allocator = TestAllocator();
316 
317     {
318         RefCounted!(Struct, TestAllocator*) ptr1;
319         Struct.numStructs.shouldEqual(0);
320 
321         auto ptr2 = RefCounted!(Struct, TestAllocator*)(&allocator, 5);
322         Struct.numStructs.shouldEqual(1);
323 
324         ptr2 = ptr1;
325         Struct.numStructs.shouldEqual(0);
326     }
327 
328     Struct.numStructs.shouldEqual(0);
329 }
330 
331 @("struct test allocator one lvalue assignment both non-null")
332 @system unittest {
333 
334     auto allocator = TestAllocator();
335 
336     {
337         auto ptr1 = RefCounted!(Struct, TestAllocator*)(&allocator, 5);
338         Struct.numStructs.shouldEqual(1);
339 
340         auto ptr2 = RefCounted!(Struct, TestAllocator*)(&allocator, 7);
341         Struct.numStructs.shouldEqual(2);
342 
343         ptr2 = ptr1;
344         Struct.numStructs.shouldEqual(1);
345     }
346 
347     Struct.numStructs.shouldEqual(0);
348 }
349 
350 
351 
352 @("struct test allocator one rvalue assignment test allocator")
353 @system unittest {
354     auto allocator = TestAllocator();
355     {
356         RefCounted!(Struct, TestAllocator*) ptr;
357         ptr = RefCounted!(Struct, TestAllocator*)(&allocator, 5);
358         Struct.numStructs.shouldEqual(1);
359     }
360     Struct.numStructs.shouldEqual(0);
361 }
362 
363 @("struct test allocator one rvalue assignment mallocator")
364 @system unittest {
365     import stdx.allocator.mallocator: Mallocator;
366     {
367         RefCounted!(Struct, Mallocator) ptr;
368         ptr = RefCounted!(Struct, Mallocator)(5);
369         Struct.numStructs.shouldEqual(1);
370     }
371     Struct.numStructs.shouldEqual(0);
372 }
373 
374 
375 @("struct test allocator one lvalue copy constructor")
376 @system unittest {
377     auto allocator = TestAllocator();
378     {
379         auto ptr1 = RefCounted!(Struct, TestAllocator*)(&allocator, 5);
380         Struct.numStructs.shouldEqual(1);
381         auto ptr2 = ptr1;
382         Struct.numStructs.shouldEqual(1);
383 
384         ptr1.i.shouldEqual(5);
385         ptr2.i.shouldEqual(5);
386     }
387     Struct.numStructs.shouldEqual(0);
388 }
389 
390 @("struct test allocator one rvalue copy constructor")
391 @system unittest {
392     auto allocator = TestAllocator();
393     {
394         auto ptr = RefCounted!(Struct, TestAllocator*)(&allocator, 5);
395         Struct.numStructs.shouldEqual(1);
396     }
397     Struct.numStructs.shouldEqual(0);
398 }
399 
400 @("many copies made")
401 @system unittest {
402     auto allocator = TestAllocator();
403 
404     // helper function for intrusive testing, in case the implementation
405     // ever changes
406     size_t refCount(T)(ref T ptr) {
407         return ptr._impl._count;
408     }
409 
410     {
411         auto ptr1 = RefCounted!(Struct, TestAllocator*)(&allocator, 5);
412         Struct.numStructs.shouldEqual(1);
413 
414         auto ptr2 = ptr1;
415         Struct.numStructs.shouldEqual(1);
416 
417         {
418             auto ptr3 = ptr2;
419             Struct.numStructs.shouldEqual(1);
420 
421             refCount(ptr1).shouldEqual(3);
422             refCount(ptr2).shouldEqual(3);
423             refCount(ptr3).shouldEqual(3);
424         }
425 
426         Struct.numStructs.shouldEqual(1);
427         refCount(ptr1).shouldEqual(2);
428         refCount(ptr2).shouldEqual(2);
429 
430         auto produce() {
431             return RefCounted!(Struct, TestAllocator*)(&allocator, 3);
432         }
433 
434         ptr1 = produce;
435         Struct.numStructs.shouldEqual(2);
436         refCount(ptr1).shouldEqual(1);
437         refCount(ptr2).shouldEqual(1);
438 
439         ptr1.twice.shouldEqual(6);
440         ptr2.twice.shouldEqual(10);
441     }
442 
443     Struct.numStructs.shouldEqual(0);
444 }
445 
446 @("default allocator")
447 @system unittest {
448     {
449         auto ptr = RefCounted!Struct(5);
450         Struct.numStructs.shouldEqual(1);
451     }
452     Struct.numStructs.shouldEqual(0);
453 }
454 
455 static if (__VERSION__ >= 2079)
456 @("default allocator (shared)")
457 @system unittest {
458     {
459         auto ptr = RefCounted!(shared SharedStruct)(5);
460         SharedStruct.numStructs.shouldEqual(1);
461     }
462     SharedStruct.numStructs.shouldEqual(0);
463 }
464 
465 @("deref")
466 @system unittest {
467     auto allocator = TestAllocator();
468     auto rc1 = RefCounted!(int, TestAllocator*)(&allocator, 5);
469 
470     (*rc1).shouldEqual(5);
471     auto rc2 = rc1;
472     *rc2 = 42;
473     (*rc1).shouldEqual(42);
474 }
475 
476 @("swap")
477 @system unittest {
478     import std.algorithm: swap;
479     RefCounted!(int, TestAllocator*) rc1, rc2;
480     swap(rc1, rc2);
481 }
482 
483 @("phobos bug 6606")
484 @system unittest {
485 
486     union U {
487        size_t i;
488        void* p;
489     }
490 
491     struct S {
492        U u;
493     }
494 
495     alias SRC = RefCounted!(S, TestAllocator*);
496 }
497 
498 @("phobos bug 6436")
499 @system unittest
500 {
501     static struct S {
502         this(ref int val, string file = __FILE__, size_t line = __LINE__) {
503             val.shouldEqual(3, file, line);
504             ++val;
505         }
506     }
507 
508     auto allocator = TestAllocator();
509     int val = 3;
510     auto s = RefCounted!(S, TestAllocator*)(&allocator, val);
511     val.shouldEqual(4);
512 }
513 
514 @("assign from T")
515 @system unittest {
516     import stdx.allocator.mallocator: Mallocator;
517 
518     {
519         auto a = RefCounted!(Struct, Mallocator)(3);
520         Struct.numStructs.shouldEqual(1);
521 
522         *a = Struct(5);
523         Struct.numStructs.shouldEqual(1);
524         (*a).shouldEqual(Struct(5));
525 
526         RefCounted!(Struct, Mallocator) b;
527         b = a;
528         (*b).shouldEqual(Struct(5));
529         Struct.numStructs.shouldEqual(1);
530     }
531 
532     Struct.numStructs.shouldEqual(0);
533 }
534 
535 @("assign self")
536 @system unittest {
537     auto allocator = TestAllocator();
538     {
539         auto a = RefCounted!(Struct, TestAllocator*)(&allocator, 1);
540         a = a;
541         Struct.numStructs.shouldEqual(1);
542     }
543     Struct.numStructs.shouldEqual(0);
544 }
545 
546 static if (__VERSION__ >= 2079)
547 @("SharedStruct")
548 @system unittest {
549     auto allocator = TestAllocator();
550     {
551         auto ptr = RefCounted!(shared SharedStruct, TestAllocator*)(&allocator, 5);
552         SharedStruct.numStructs.shouldEqual(1);
553     }
554     SharedStruct.numStructs.shouldEqual(0);
555 }
556 
557 @("@nogc @safe")
558 @safe @nogc unittest {
559 
560     auto allocator = SafeAllocator();
561 
562     {
563         const ptr = RefCounted!(NoGcStruct, SafeAllocator)(SafeAllocator(), 6);
564         assert(ptr.i == 6);
565         assert(NoGcStruct.numStructs == 1);
566     }
567 
568     assert(NoGcStruct.numStructs == 0);
569 }
570 
571 
572 @("const object")
573 @system unittest {
574     auto allocator = TestAllocator();
575     auto ptr1 = RefCounted!(const Struct, TestAllocator*)(&allocator, 5);
576 }
577 
578 
579 @("theAllocator")
580 @system unittest {
581 
582     with(theTestAllocator) {
583         auto ptr = RefCounted!Struct(42);
584         (*ptr).shouldEqual(Struct(42));
585         Struct.numStructs.shouldEqual(1);
586     }
587 
588     Struct.numStructs.shouldEqual(0);
589 }
590 
591 static if (__VERSION__ >= 2079) {
592 
593     @("threads Mallocator")
594         @system unittest {
595         import stdx.allocator.mallocator: Mallocator;
596         static assert(__traits(compiles, sendRefCounted!Mallocator(7)));
597     }
598 
599     @("threads SafeAllocator by value")
600         @system unittest {
601         // can't even use TestAllocator because it has indirections
602         // can't pass by pointer since it's an indirection
603         auto allocator = SafeAllocator();
604         static assert(__traits(compiles, sendRefCounted!(SafeAllocator)(allocator, 7)));
605     }
606 
607     @("threads SafeAllocator by shared pointer")
608         @system unittest {
609         // can't even use TestAllocator because it has indirections
610         // can't only pass by pointer if shared
611         auto allocator = shared SafeAllocator();
612         static assert(__traits(compiles, sendRefCounted!(shared SafeAllocator*)(&allocator, 7)));
613     }
614 }
615 
616 
617 auto refCounted(Type, Allocator)(Unique!(Type, Allocator) ptr) {
618 
619     RefCounted!(Type, Allocator) ret;
620 
621     static if(!ptr.isGlobal)
622         ret._allocator = ptr.allocator;
623 
624     ret.allocateImpl;
625     *ret = *ptr;
626 
627     return ret;
628 }
629 
630 @("Construct RefCounted from Unique")
631 @system unittest {
632     import automem.unique: Unique;
633     auto allocator = TestAllocator();
634     auto ptr = refCounted(Unique!(int, TestAllocator*)(&allocator, 42));
635     (*ptr).shouldEqual(42);
636 }
637 
638 @("RefCounted with class")
639 @system unittest {
640     auto allocator = TestAllocator();
641     {
642         writelnUt("Creating ptr");
643         auto ptr = RefCounted!(Class, TestAllocator*)(&allocator, 33);
644         (*ptr).i.shouldEqual(33);
645         Class.numClasses.shouldEqual(1);
646     }
647     Class.numClasses.shouldEqual(0);
648 }
649 
650 @("@nogc class destructor")
651 @nogc unittest {
652 
653     auto allocator = SafeAllocator();
654 
655     {
656         const ptr = Unique!(NoGcClass, SafeAllocator)(SafeAllocator(), 6);
657         // shouldEqual isn't @nogc
658         assert(ptr.i == 6);
659         assert(NoGcClass.numClasses == 1);
660     }
661 
662     assert(NoGcClass.numClasses == 0);
663 }
664 
665 @("RefCounted opSlice and opIndex")
666 @system unittest {
667     import std.mmfile: MmFile;
668     auto file = RefCounted!MmFile(null, MmFile.Mode.readWriteNew, 120, null);
669     // The type of file[0] should be ubyte, not Impl.
670     static assert(is(typeof(file[0]) == typeof(MmFile.init[0])));
671     // opSlice should result in void[] not Impl[].
672     static assert(is(typeof(file[0 .. size_t.max]) == typeof(MmFile.init[0 .. size_t.max])));
673     ubyte[] data = cast(ubyte[]) file[0 .. cast(size_t) file.length];
674     immutable ubyte b = file[1];
675     file[1] = cast(ubyte) (b + 1);
676     assert(data[1] == cast(ubyte) (b + 1));
677 }
678 
679 @("Construct RefCounted using global allocator for struct with zero-args ctor")
680 @system unittest {
681     struct S {
682         private ulong zeroArgsCtorTest = 3;
683     }
684     auto s = RefCounted!S.construct();
685     static assert(is(typeof(s) == RefCounted!S));
686     assert(s._impl !is null);
687     assert(s.zeroArgsCtorTest == 3);
688 }
689 
690 version(unittest):
691 
692 void sendRefCounted(Allocator, Args...)(Args args) {
693     import std.concurrency: spawn, send;
694 
695     auto tid = spawn(&threadFunc);
696     auto ptr = RefCounted!(shared SharedStruct, Allocator)(args);
697 
698     tid.send(ptr);
699 }
700 
701 void threadFunc() {
702 
703 }