1 module mpfrd;
2 
3 import deimos.mpfr;
4 import std.traits;
5 
6 struct Mpfr {
7     mpfr_t mpfr;
8     alias mpfr this;
9 
10     @disable this();
11 
12     this(this) {
13         mpfr_t new_mpfr;
14         mpfr_init2(new_mpfr, mpfr_get_prec(mpfr));
15         mpfr_set(new_mpfr, mpfr, mpfr_rnd_t.MPFR_RNDN);
16         mpfr = new_mpfr;
17     }
18 
19     this(T)(const T value, mpfr_prec_t p = 32) if(isNumeric!T) {
20         mpfr_init2(mpfr, p);
21         this = value;
22     }
23 
24     ~this() {
25         mpfr_clear(mpfr);
26     }
27 
28     private static template isNumericValue(T) {
29         enum isNumericValue = isNumeric!T || is(T == Mpfr);
30     }
31 
32     private static string getTypeString(T)() {
33         static if (isIntegral!T && isSigned!T) {
34             return "_si";
35         } else static if (isIntegral!T && !isSigned!T) {
36             return "_ui";
37         } else static if (is(T : double)) {
38             return "_d";
39         } else static if (is(T == Mpfr)) {
40             return "";
41         } else {
42             static assert(false, "Unhandled type " ~ T.stringof);
43         }
44     }
45 
46     ////////////////////////////////////////////////////////////////////////////
47     // properties
48     ////////////////////////////////////////////////////////////////////////////
49 
50     @property void precision(mpfr_prec_t p) {
51         mpfr_set_prec(mpfr, p);
52     }
53 
54     @property mpfr_prec_t precision() const {
55         return mpfr_get_prec(mpfr);
56     }
57 
58     ////////////////////////////////////////////////////////////////////////////
59     // Comparisons
60     ////////////////////////////////////////////////////////////////////////////
61 
62     int opCmp(T)(const T value) const if(isNumericValue!T) {
63         mixin("return mpfr_cmp" ~ getTypeString!T() ~ "(mpfr, value);");
64     }
65 
66     int opCmp(ref const Mpfr value) {
67         return this is value || mpfr_cmp(mpfr, value);
68     }
69 
70     bool opEquals(T)(const T value) const if(isNumericValue!T) {
71         return opCmp(value) == 0;
72     }
73 
74     bool opEquals(ref const Mpfr value) {
75         return this is value || opCmp(value) == 0;
76     }
77 
78     private static string getOperatorString(string op)() {
79         final switch(op) {
80             case "+": return "_add";
81             case "-": return "_sub";
82             case "*": return "_mul";
83             case "/": return "_div";
84             case "^^": return "_pow";
85         }
86     }
87 
88     private static string getShiftOperatorString(string op)() {
89         final switch(op) {
90             case "<<": return "_mul";
91             case ">>": return "_div";
92         }
93     }
94 
95     private static string getShiftTypeString(T)() {
96         static if (isIntegral!T && isSigned!T) {
97             return "_2si";
98         } else static if (isIntegral!T && !isSigned!T) {
99             return "_2ui";
100         } else {
101             static assert(false, "Unhandled type " ~ T.stringof);
102         }
103     }
104 
105     private static string getFunctionSuffix(string op, T, bool isRight) () {
106         static if(op == "<<" || op == ">>") {
107             static assert(!isRight, "Binary Right Shift not allowed, try using lower level mpfr_ui_pow.");
108             return getShiftOperatorString!op() ~ getShiftTypeString!T();
109         } else {
110             return isRight ?
111                 getTypeString!T() ~ getOperatorString!op():
112                 getOperatorString!op() ~ getTypeString!T();
113         }
114     }
115 
116     private static string getFunction(string op, T, bool isRight) () {
117         return "mpfr" ~ getFunctionSuffix!(op, T, isRight);
118     }
119 
120     ////////////////////////////////////////////////////////////////////////////
121     // Arithmetic
122     ////////////////////////////////////////////////////////////////////////////
123 
124     Mpfr opBinary(string op, T)(const T value) const if(isNumericValue!T) {
125         auto output = Mpfr(0, mpfr_get_prec(mpfr));
126         mixin(getFunction!(op, T, false)() ~ "(output, mpfr, value, mpfr_rnd_t.MPFR_RNDN);");
127         return output;
128     }
129 
130     Mpfr opBinaryRight(string op, T)(const T value) const if(isNumericValue!T) {
131         static if(op == "-" || op == "/" || op == "<<" || op == ">>") {
132             auto output = Mpfr(0, mpfr_get_prec(mpfr));
133             mixin(getFunction!(op, T, true)() ~ "(output, value, mpfr, mpfr_rnd_t.MPFR_RNDN);");
134             return output;
135         } else {
136             return opBinary!op(value);
137         }
138     }
139 
140     Mpfr opUnary(string op)() const if(op == "-") {
141         auto output = Mpfr(0, mpfr_get_prec(mpfr));
142         mpfr_neg(output, mpfr, mpfr_rnd_t.MPFR_RNDN);
143         return output;
144     }
145 
146     ////////////////////////////////////////////////////////////////////////////
147     // Mutation
148     ////////////////////////////////////////////////////////////////////////////
149 
150     ref Mpfr opAssign(T)(const T value) if(isNumericValue!T) {
151         mixin("mpfr_set" ~ getTypeString!T() ~ "(mpfr, value, mpfr_rnd_t.MPFR_RNDN);");
152         return this;
153     }
154 
155     ref Mpfr opAssign(ref const Mpfr value) {
156         mpfr_set(mpfr, value, mpfr_rnd_t.MPFR_RNDN);
157         return this;
158     }
159 
160     ref Mpfr opOpAssign(string op, T)(const T value) if(isNumericValue!T) {
161         static assert(!(op == "^^" && isFloatingPoint!T), "No operator ^^= with floating point.");
162         mixin(getFunction!(op, T, false)() ~ "(mpfr, mpfr, value, mpfr_rnd_t.MPFR_RNDN);");
163         return this;
164     }
165 
166     ref Mpfr opOpAssign(string op)(ref const Mpfr value) {
167         if(value !is this) {
168             mixin(getFunction!(op, T, false)() ~ "(mpfr, mpfr, value, mpfr_rnd_t.MPFR_RNDN);");
169         }
170         return this;
171     }
172 
173     ////////////////////////////////////////////////////////////////////////////
174     // String
175     ////////////////////////////////////////////////////////////////////////////
176 
177     string toString() const {
178         char[1024] buffer;
179         const count = mpfr_snprintf(buffer.ptr, buffer.sizeof, "%Rg".ptr, &mpfr);
180         return buffer[0 .. count].idup;
181     }
182 }
183 
184 version (unittest)
185 {
186     import std.meta;
187     import std.stdio : writeln, writefln;
188     alias AllNumericTypes = AliasSeq!(ubyte, ushort, uint, ulong, float, double, byte, short, int, long, Mpfr);
189     alias AllIntegralTypes = AliasSeq!(ubyte, ushort, uint, ulong, byte, short, int, long, Mpfr);
190     alias AllIntegralNoMpfr = AliasSeq!(ubyte, ushort, uint, ulong, byte, short, int, long);
191 }
192 
193 unittest {
194     // Assign from numeric type or another Mpfr
195     auto value = Mpfr(0);
196     value = 1;
197     foreach(T ; AllNumericTypes) {
198         value = T(1);
199     }
200 }
201 
202 unittest {
203     // Copy
204     auto a = Mpfr(0);
205     auto b = Mpfr(0);
206     a = 2;
207     b = a;
208     assert(b == 2);
209 }
210 
211 unittest {
212     // Comparisons
213     auto value = Mpfr(0);
214     value = 2;
215     foreach(T ; AllNumericTypes) {
216         assert(value == T(2));
217         assert(value <= T(2));
218         assert(value <= T(3));
219         assert(value >= T(2));
220         assert(value >= T(1));
221     }
222 }
223 
224 unittest {
225     // opOpAssign
226     auto value = Mpfr(1);
227     assert(value == 1);
228     value = value;
229     assert(value == 1);
230     foreach(T ; AllNumericTypes) {
231         value = 2;
232         value += T(2);
233         assert(value == 4);
234     }
235     foreach(T ; AllNumericTypes) {
236         value = 2;
237         value -= T(2);
238         assert(value == 0);
239     }
240     foreach(T ; AllNumericTypes) {
241         value = 2;
242         value *= T(3);
243         assert(value == 6);
244     }
245     foreach(T ; AllNumericTypes) {
246         value = 2;
247         value /= T(2);
248         assert(value == 1);
249     }
250     foreach(T ; AllIntegralTypes) {
251         value = 2;
252         value ^^= T(2);
253         assert(value == 4);
254     }
255     foreach(T ; AllIntegralNoMpfr) {
256         value = 2;
257         value <<= 3;
258         assert(value == 16);
259     }
260     foreach(T ; AllIntegralNoMpfr) {
261         value = 16;
262         value >>= 3;
263         assert(value == 2);
264     }
265 }
266 
267 unittest {
268     // opBinary && opRightBinary
269     auto value = Mpfr(0);
270     foreach(T ; AllNumericTypes) {
271         value = 2;
272         assert(value + T(2) == 4);
273         assert(T(2) + value == 4);
274     }
275     foreach(T ; AllNumericTypes) {
276         value = 3;
277         assert(value - T(2) == 1);
278         value = 2;
279         assert(T(3) - value == 1);
280     }
281     foreach(T ; AllNumericTypes) {
282         value = 2;
283         assert(value * T(3) == 6);
284         assert(T(3) * value == 6);
285     }
286     foreach(T ; AllNumericTypes) {
287         value = 4;
288         assert(value / T(2) == 2);
289         value = 2;
290         assert(T(6) / value == 3);
291     }
292     foreach(T ; AllIntegralTypes) {
293         value = 2;
294         assert(value ^^ T(2) == 4);
295         assert(T(2) ^^ value == 4);
296     }
297     foreach(T ; AllIntegralNoMpfr) {
298         value = 2;
299         assert(value << T(3) == 16);
300     }
301     foreach(T ; AllIntegralNoMpfr) {
302         value = 16;
303         assert(value >> T(3) == 2);
304     }
305 }
306 
307 unittest {
308     // precision
309     auto value = Mpfr(0);
310     assert(value.precision == 32);
311     value.precision = 128;
312     assert(value.precision == 128);
313     assert((value + 1).precision == 128);
314 }