1 module mpfrd;
2 
3 import deimos.mpfr;
4 import std.traits;
5 
6 struct Mpfr {
7     mpfr_t mpfr;
8     alias mpfr this;
9 
10     @disable this();
11 
12     this(this) {
13         mpfr_t new_mpfr;
14         mpfr_init2(new_mpfr, mpfr_get_prec(mpfr));
15         mpfr_set(new_mpfr, mpfr, mpfr_rnd_t.MPFR_RNDN);
16         mpfr = new_mpfr;
17     }
18 
19     this(T)(const T value, mpfr_prec_t precision = 32) if(isNumeric!T) {
20         mpfr_init2(mpfr, precision);
21         this = value;
22     }
23 
24     ~this() {
25         mpfr_clear(mpfr);
26     }
27 
28     private static template isNumericValue(T) {
29         enum isNumericValue = isNumeric!T || is(T == Mpfr);
30     }
31 
32     private static string getTypeString(T)() {
33         static if (isIntegral!T && isSigned!T) {
34             return "_si";
35         } else static if (isIntegral!T && !isSigned!T) {
36             return "_ui";
37         } else static if (is(T : double)) {
38             return "_d";
39         } else static if (is(T == Mpfr)) {
40             return "";
41         } else {
42             static assert(false, "Unhandled type " ~ T.stringof);
43         }
44     }
45 
46     ////////////////////////////////////////////////////////////////////////////
47     // Comparisons
48     ////////////////////////////////////////////////////////////////////////////
49 
50     int opCmp(T)(const T value) const if(isNumericValue!T) {
51         mixin("return mpfr_cmp" ~ getTypeString!T() ~ "(mpfr, value);");
52     }
53 
54     int opCmp(ref const Mpfr value) {
55         return this is value || mpfr_cmp(mpfr, value);
56     }
57 
58     bool opEquals(T)(const T value) const if(isNumericValue!T) {
59         return opCmp(value) == 0;
60     }
61 
62     bool opEquals(ref const Mpfr value) {
63         return this is value || opCmp(value) == 0;
64     }
65 
66     private static string getOperatorString(string op)() {
67         final switch(op) {
68             case "+": return "_add";
69             case "-": return "_sub";
70             case "*": return "_mul";
71             case "/": return "_div";
72             case "^^": return "_pow";
73         }
74     }
75 
76     private static string getShiftOperatorString(string op)() {
77         final switch(op) {
78             case "<<": return "_mul";
79             case ">>": return "_div";
80         }
81     }
82 
83     private static string getShiftTypeString(T)() {
84         static if (isIntegral!T && isSigned!T) {
85             return "_2si";
86         } else static if (isIntegral!T && !isSigned!T) {
87             return "_2ui";
88         } else {
89             static assert(false, "Unhandled type " ~ T.stringof);
90         }
91     }
92 
93     private static string getFunctionSuffix(string op, T, bool isRight) () {
94         static if(op == "<<" || op == ">>") {
95             static assert(!isRight, "Binary Right Shift not allowed, try using lower level mpfr_ui_pow.");
96             return getShiftOperatorString!op() ~ getShiftTypeString!T();
97         } else {
98             return isRight ?
99                 getTypeString!T() ~ getOperatorString!op():
100                 getOperatorString!op() ~ getTypeString!T();
101         }
102     }
103 
104     private static string getFunction(string op, T, bool isRight) () {
105         return "mpfr" ~ getFunctionSuffix!(op, T, isRight);
106     }
107 
108     ////////////////////////////////////////////////////////////////////////////
109     // Arithmetic
110     ////////////////////////////////////////////////////////////////////////////
111 
112     Mpfr opBinary(string op, T)(const T value) const if(isNumericValue!T) {
113         auto output = Mpfr(0);
114         mixin(getFunction!(op, T, false)() ~ "(output, mpfr, value, mpfr_rnd_t.MPFR_RNDN);");
115         return output;
116     }
117 
118     Mpfr opBinaryRight(string op, T)(const T value) const if(isNumericValue!T) {
119         static if(op == "-" || op == "/" || op == "<<" || op == ">>") {
120             auto output = Mpfr(0);
121             mixin(getFunction!(op, T, true)() ~ "(output, value, mpfr, mpfr_rnd_t.MPFR_RNDN);");
122             return output;
123         } else {
124             return opBinary!op(value);
125         }
126     }
127 
128     Mpfr opUnary(string op)() const if(op == "-") {
129         auto output = Mpfr(0);
130         mpfr_neg(output, mpfr, mpfr_rnd_t.MPFR_RNDN);
131         return output;
132     }
133 
134     ////////////////////////////////////////////////////////////////////////////
135     // Mutation
136     ////////////////////////////////////////////////////////////////////////////
137 
138     ref Mpfr opAssign(T)(const T value) if(isNumericValue!T) {
139         mixin("mpfr_set" ~ getTypeString!T() ~ "(mpfr, value, mpfr_rnd_t.MPFR_RNDN);");
140         return this;
141     }
142 
143     ref Mpfr opAssign(ref const Mpfr value) {
144         mpfr_set(mpfr, value, mpfr_rnd_t.MPFR_RNDN);
145         return this;
146     }
147 
148     ref Mpfr opOpAssign(string op, T)(const T value) if(isNumericValue!T) {
149         static assert(!(op == "^^" && isFloatingPoint!T), "No operator ^^= with floating point.");
150         mixin(getFunction!(op, T, false)() ~ "(mpfr, mpfr, value, mpfr_rnd_t.MPFR_RNDN);");
151         return this;
152     }
153 
154     ref Mpfr opOpAssign(string op)(ref const Mpfr value) {
155         if(value !is this) {
156             mixin(getFunction!(op, T, false)() ~ "(mpfr, mpfr, value, mpfr_rnd_t.MPFR_RNDN);");
157         }
158         return this;
159     }
160 
161     ////////////////////////////////////////////////////////////////////////////
162     // String
163     ////////////////////////////////////////////////////////////////////////////
164 
165     string toString() const {
166         char[1024] buffer;
167         const count = mpfr_snprintf(buffer.ptr, buffer.sizeof, "%Rg".ptr, &mpfr);
168         return buffer[0 .. count].idup;
169     }
170 }
171 
172 version (unittest)
173 {
174     import std.meta;
175     import std.stdio : writeln, writefln;
176     alias AllNumericTypes = AliasSeq!(ubyte, ushort, uint, ulong, float, double, byte, short, int, long, Mpfr);
177     alias AllIntegralTypes = AliasSeq!(ubyte, ushort, uint, ulong, byte, short, int, long, Mpfr);
178     alias AllIntegralNoMpfr = AliasSeq!(ubyte, ushort, uint, ulong, byte, short, int, long);
179 }
180 
181 unittest {
182     // Assign from numeric type or another Mpfr
183     auto value = Mpfr(0);
184     value = 1;
185     foreach(T ; AllNumericTypes) {
186         value = T(1);
187     }
188 }
189 
190 unittest {
191     // Copy
192     auto a = Mpfr(0);
193     auto b = Mpfr(0);
194     a = 2;
195     b = a;
196     assert(b == 2);
197 }
198 
199 unittest {
200     // Comparisons
201     auto value = Mpfr(0);
202     value = 2;
203     foreach(T ; AllNumericTypes) {
204         assert(value == T(2));
205         assert(value <= T(2));
206         assert(value <= T(3));
207         assert(value >= T(2));
208         assert(value >= T(1));
209     }
210 }
211 
212 unittest {
213     // opOpAssign
214     auto value = Mpfr(1);
215     assert(value == 1);
216     value = value;
217     assert(value == 1);
218     foreach(T ; AllNumericTypes) {
219         value = 2;
220         value += T(2);
221         assert(value == 4);
222     }
223     foreach(T ; AllNumericTypes) {
224         value = 2;
225         value -= T(2);
226         assert(value == 0);
227     }
228     foreach(T ; AllNumericTypes) {
229         value = 2;
230         value *= T(3);
231         assert(value == 6);
232     }
233     foreach(T ; AllNumericTypes) {
234         value = 2;
235         value /= T(2);
236         assert(value == 1);
237     }
238     foreach(T ; AllIntegralTypes) {
239         value = 2;
240         value ^^= T(2);
241         assert(value == 4);
242     }
243     foreach(T ; AllIntegralNoMpfr) {
244         value = 2;
245         value <<= 3;
246         assert(value == 16);
247     }
248     foreach(T ; AllIntegralNoMpfr) {
249         value = 16;
250         value >>= 3;
251         assert(value == 2);
252     }
253 }
254 
255 unittest {
256     // opBinary && opRightBinary
257     auto value = Mpfr(0);
258     foreach(T ; AllNumericTypes) {
259         value = 2;
260         assert(value + T(2) == 4);
261         assert(T(2) + value == 4);
262     }
263     foreach(T ; AllNumericTypes) {
264         value = 3;
265         assert(value - T(2) == 1);
266         value = 2;
267         assert(T(3) - value == 1);
268     }
269     foreach(T ; AllNumericTypes) {
270         value = 2;
271         assert(value * T(3) == 6);
272         assert(T(3) * value == 6);
273     }
274     foreach(T ; AllNumericTypes) {
275         value = 4;
276         assert(value / T(2) == 2);
277         value = 2;
278         assert(T(6) / value == 3);
279     }
280     foreach(T ; AllIntegralTypes) {
281         value = 2;
282         assert(value ^^ T(2) == 4);
283         assert(T(2) ^^ value == 4);
284     }
285     foreach(T ; AllIntegralNoMpfr) {
286         value = 2;
287         assert(value << T(3) == 16);
288     }
289     foreach(T ; AllIntegralNoMpfr) {
290         value = 16;
291         assert(value >> T(3) == 2);
292     }
293 }