aboutsummaryrefslogtreecommitdiff
path: root/py/mpz.c
diff options
context:
space:
mode:
authorDoug Currie <github.9.eeeeeee@spamgourmet.com>2016-01-30 22:35:58 -0500
committerDamien George <damien.p.george@gmail.com>2016-02-03 22:13:39 +0000
commit2e2e15cec2f85ece763f3f80152d759aecfad47c (patch)
treefc70fb1001ea72640bdb5f5b49c6951a65410817 /py/mpz.c
parent5f3e005b6791634b104fa6385c8a9bf5ed1af164 (diff)
py/mpz: Complete implementation of mpz_{and,or,xor} for negative args.
For these 3 bitwise operations there are now fast functions for positive-only arguments, and general functions for arbitrary sign arguments (the fast functions are the existing implementation). By default the fast functions are not used (to save space) and instead the general functions are used for all operations. Enable MICROPY_OPT_MPZ_BITWISE to use the fast functions for positive arguments.
Diffstat (limited to 'py/mpz.c')
-rw-r--r--py/mpz.c267
1 files changed, 199 insertions, 68 deletions
diff --git a/py/mpz.c b/py/mpz.c
index b3f8b15b6..f02b75c2b 100644
--- a/py/mpz.c
+++ b/py/mpz.c
@@ -29,9 +29,6 @@
#include "py/mpz.h"
-// this is only needed for mp_not_implemented, which should eventually be removed
-#include "py/runtime.h"
-
#if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ
#define DIG_SIZE (MPZ_DIG_SIZE)
@@ -199,6 +196,14 @@ STATIC mp_uint_t mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
return idig + 1 - oidig;
}
+STATIC mp_uint_t mpn_remove_trailing_zeros(mpz_dig_t *oidig, mpz_dig_t *idig) {
+ for (--idig; idig >= oidig && *idig == 0; --idig) {
+ }
+ return idig + 1 - oidig;
+}
+
+#if MICROPY_OPT_MPZ_BITWISE
+
/* computes i = j & k
returns number of digits in i
assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen (jlen argument not needed)
@@ -211,41 +216,46 @@ STATIC mp_uint_t mpn_and(mpz_dig_t *idig, const mpz_dig_t *jdig, const mpz_dig_t
*idig = *jdig & *kdig;
}
- // remove trailing zeros
- for (--idig; idig >= oidig && *idig == 0; --idig) {
- }
-
- return idig + 1 - oidig;
+ return mpn_remove_trailing_zeros(oidig, idig);
}
-/* computes i = j & -k = j & (~k + 1)
+#endif
+
+/* i = -((-j) & (-k)) = ~((~j + 1) & (~k + 1)) + 1
+ i = (j & (-k)) = (j & (~k + 1)) = ( j & (~k + 1))
+ i = ((-j) & k) = ((~j + 1) & k) = ((~j + 1) & k )
+ computes general form:
+ i = (im ^ (((j ^ jm) + jc) & ((k ^ km) + kc))) + ic where Xm = Xc == 0 ? 0 : DIG_MASK
returns number of digits in i
- assumes enough memory in i; assumes normalised j, k
+ assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
can have i, j, k pointing to same memory
*/
-STATIC mp_uint_t mpn_and_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen) {
+STATIC mp_uint_t mpn_and_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen,
+ mpz_dbl_dig_t carryi, mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
mpz_dig_t *oidig = idig;
- mpz_dbl_dig_t carry = 1;
+ mpz_dig_t imask = (0 == carryi) ? 0 : DIG_MASK;
+ mpz_dig_t jmask = (0 == carryj) ? 0 : DIG_MASK;
+ mpz_dig_t kmask = (0 == carryk) ? 0 : DIG_MASK;
- for (; jlen > 0 && klen > 0; --jlen, --klen, ++idig, ++jdig, ++kdig) {
- carry += *kdig ^ DIG_MASK;
- *idig = (*jdig & carry) & DIG_MASK;
- carry >>= DIG_SIZE;
+ for (; jlen > 0; ++idig, ++jdig) {
+ carryj += *jdig ^ jmask;
+ carryk += (--klen <= --jlen) ? (*kdig++ ^ kmask) : kmask;
+ carryi += ((carryj & carryk) ^ imask) & DIG_MASK;
+ *idig = carryi & DIG_MASK;
+ carryk >>= DIG_SIZE;
+ carryj >>= DIG_SIZE;
+ carryi >>= DIG_SIZE;
}
- for (; jlen > 0; --jlen, ++idig, ++jdig) {
- carry += DIG_MASK;
- *idig = (*jdig & carry) & DIG_MASK;
- carry >>= DIG_SIZE;
- }
-
- // remove trailing zeros
- for (--idig; idig >= oidig && *idig == 0; --idig) {
+ if (0 != carryi) {
+ *idig++ = carryi;
}
- return idig + 1 - oidig;
+ return mpn_remove_trailing_zeros(oidig, idig);
}
+#if MICROPY_OPT_MPZ_BITWISE
+
/* computes i = j | k
returns number of digits in i
assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
@@ -267,6 +277,74 @@ STATIC mp_uint_t mpn_or(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
return idig - oidig;
}
+#endif
+
+/* i = -((-j) | (-k)) = ~((~j + 1) | (~k + 1)) + 1
+ i = -(j | (-k)) = -(j | (~k + 1)) = ~( j | (~k + 1)) + 1
+ i = -((-j) | k) = -((~j + 1) | k) = ~((~j + 1) | k ) + 1
+ computes general form:
+ i = ~(((j ^ jm) + jc) | ((k ^ km) + kc)) + 1 where Xm = Xc == 0 ? 0 : DIG_MASK
+ returns number of digits in i
+ assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
+ can have i, j, k pointing to same memory
+*/
+
+#if MICROPY_OPT_MPZ_BITWISE
+
+STATIC mp_uint_t mpn_or_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen,
+ mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
+ mpz_dig_t *oidig = idig;
+ mpz_dbl_dig_t carryi = 1;
+ mpz_dig_t jmask = (0 == carryj) ? 0 : DIG_MASK;
+ mpz_dig_t kmask = (0 == carryk) ? 0 : DIG_MASK;
+
+ for (; jlen > 0; ++idig, ++jdig) {
+ carryj += *jdig ^ jmask;
+ carryk += (--klen <= --jlen) ? (*kdig++ ^ kmask) : kmask;
+ carryi += ((carryj | carryk) ^ DIG_MASK) & DIG_MASK;
+ *idig = carryi & DIG_MASK;
+ carryk >>= DIG_SIZE;
+ carryj >>= DIG_SIZE;
+ carryi >>= DIG_SIZE;
+ }
+
+ if (0 != carryi) {
+ *idig++ = carryi;
+ }
+
+ return mpn_remove_trailing_zeros(oidig, idig);
+}
+
+#else
+
+STATIC mp_uint_t mpn_or_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen,
+ mpz_dbl_dig_t carryi, mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
+ mpz_dig_t *oidig = idig;
+ mpz_dig_t imask = (0 == carryi) ? 0 : DIG_MASK;
+ mpz_dig_t jmask = (0 == carryj) ? 0 : DIG_MASK;
+ mpz_dig_t kmask = (0 == carryk) ? 0 : DIG_MASK;
+
+ for (; jlen > 0; ++idig, ++jdig) {
+ carryj += *jdig ^ jmask;
+ carryk += (--klen <= --jlen) ? (*kdig++ ^ kmask) : kmask;
+ carryi += ((carryj | carryk) ^ imask) & DIG_MASK;
+ *idig = carryi & DIG_MASK;
+ carryk >>= DIG_SIZE;
+ carryj >>= DIG_SIZE;
+ carryi >>= DIG_SIZE;
+ }
+
+ if (0 != carryi) {
+ *idig++ = carryi;
+ }
+
+ return mpn_remove_trailing_zeros(oidig, idig);
+}
+
+#endif
+
+#if MICROPY_OPT_MPZ_BITWISE
+
/* computes i = j ^ k
returns number of digits in i
assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
@@ -285,11 +363,39 @@ STATIC mp_uint_t mpn_xor(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
*idig = *jdig;
}
- // remove trailing zeros
- for (--idig; idig >= oidig && *idig == 0; --idig) {
+ return mpn_remove_trailing_zeros(oidig, idig);
+}
+
+#endif
+
+/* i = (-j) ^ (-k) = ~(j - 1) ^ ~(k - 1) = (j - 1) ^ (k - 1)
+ i = -(j ^ (-k)) = -(j ^ ~(k - 1)) = ~(j ^ ~(k - 1)) + 1 = (j ^ (k - 1)) + 1
+ i = -((-j) ^ k) = -(~(j - 1) ^ k) = ~(~(j - 1) ^ k) + 1 = ((j - 1) ^ k) + 1
+ computes general form:
+ i = ((j - 1 + jc) ^ (k - 1 + kc)) + ic
+ returns number of digits in i
+ assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
+ can have i, j, k pointing to same memory
+*/
+STATIC mp_uint_t mpn_xor_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen,
+ mpz_dbl_dig_t carryi, mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
+ mpz_dig_t *oidig = idig;
+
+ for (; jlen > 0; ++idig, ++jdig) {
+ carryj += *jdig + DIG_MASK;
+ carryk += (--klen <= --jlen) ? (*kdig++ + DIG_MASK) : DIG_MASK;
+ carryi += (carryj ^ carryk) & DIG_MASK;
+ *idig = carryi & DIG_MASK;
+ carryk >>= DIG_SIZE;
+ carryj >>= DIG_SIZE;
+ carryi >>= DIG_SIZE;
}
- return idig + 1 - oidig;
+ if (0 != carryi) {
+ *idig++ = carryi;
+ }
+
+ return mpn_remove_trailing_zeros(oidig, idig);
}
/* computes i = i * d1 + d2
@@ -1097,81 +1203,106 @@ void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
can have dest, lhs, rhs the same
*/
void mpz_and_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
- if (lhs->neg == rhs->neg) {
- if (lhs->neg == 0) {
- // make sure lhs has the most digits
- if (lhs->len < rhs->len) {
- const mpz_t *temp = lhs;
- lhs = rhs;
- rhs = temp;
- }
- // do the and'ing
- mpz_need_dig(dest, rhs->len);
- dest->len = mpn_and(dest->dig, lhs->dig, rhs->dig, rhs->len);
- dest->neg = 0;
- } else {
- // TODO both args are negative
- mp_not_implemented("bignum and with negative args");
- }
+ // make sure lhs has the most digits
+ if (lhs->len < rhs->len) {
+ const mpz_t *temp = lhs;
+ lhs = rhs;
+ rhs = temp;
+ }
+
+ #if MICROPY_OPT_MPZ_BITWISE
+
+ if ((0 == lhs->neg) && (0 == rhs->neg)) {
+ mpz_need_dig(dest, lhs->len);
+ dest->len = mpn_and(dest->dig, lhs->dig, rhs->dig, rhs->len);
+ dest->neg = 0;
} else {
- // args have different sign
- // make sure lhs is the positive arg
- if (rhs->neg == 0) {
- const mpz_t *temp = lhs;
- lhs = rhs;
- rhs = temp;
- }
mpz_need_dig(dest, lhs->len + 1);
- dest->len = mpn_and_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
- assert(dest->len <= dest->alloc);
- dest->neg = 0;
+ dest->len = mpn_and_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
+ lhs->neg == rhs->neg, 0 != lhs->neg, 0 != rhs->neg);
+ dest->neg = lhs->neg & rhs->neg;
}
+
+ #else
+
+ mpz_need_dig(dest, lhs->len + (lhs->neg || rhs->neg));
+ dest->len = mpn_and_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
+ (lhs->neg == rhs->neg) ? lhs->neg : 0, lhs->neg, rhs->neg);
+ dest->neg = lhs->neg & rhs->neg;
+
+ #endif
}
/* computes dest = lhs | rhs
can have dest, lhs, rhs the same
*/
void mpz_or_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
- if (mpn_cmp(lhs->dig, lhs->len, rhs->dig, rhs->len) < 0) {
+ // make sure lhs has the most digits
+ if (lhs->len < rhs->len) {
const mpz_t *temp = lhs;
lhs = rhs;
rhs = temp;
}
- if (lhs->neg == rhs->neg) {
+ #if MICROPY_OPT_MPZ_BITWISE
+
+ if ((0 == lhs->neg) && (0 == rhs->neg)) {
mpz_need_dig(dest, lhs->len);
dest->len = mpn_or(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
+ dest->neg = 0;
} else {
- mpz_need_dig(dest, lhs->len);
- // TODO
- mp_not_implemented("bignum or with negative args");
-// dest->len = mpn_or_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
+ mpz_need_dig(dest, lhs->len + 1);
+ dest->len = mpn_or_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
+ 0 != lhs->neg, 0 != rhs->neg);
+ dest->neg = 1;
}
- dest->neg = lhs->neg;
+ #else
+
+ mpz_need_dig(dest, lhs->len + (lhs->neg || rhs->neg));
+ dest->len = mpn_or_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
+ (lhs->neg || rhs->neg), lhs->neg, rhs->neg);
+ dest->neg = lhs->neg | rhs->neg;
+
+ #endif
}
/* computes dest = lhs ^ rhs
can have dest, lhs, rhs the same
*/
void mpz_xor_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
- if (mpn_cmp(lhs->dig, lhs->len, rhs->dig, rhs->len) < 0) {
+ // make sure lhs has the most digits
+ if (lhs->len < rhs->len) {
const mpz_t *temp = lhs;
lhs = rhs;
rhs = temp;
}
+ #if MICROPY_OPT_MPZ_BITWISE
+
if (lhs->neg == rhs->neg) {
mpz_need_dig(dest, lhs->len);
- dest->len = mpn_xor(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
+ if (lhs->neg == 0) {
+ dest->len = mpn_xor(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
+ } else {
+ dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len, 0, 0, 0);
+ }
+ dest->neg = 0;
} else {
- mpz_need_dig(dest, lhs->len);
- // TODO
- mp_not_implemented("bignum xor with negative args");
-// dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
+ mpz_need_dig(dest, lhs->len + 1);
+ dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len, 1,
+ 0 == lhs->neg, 0 == rhs->neg);
+ dest->neg = 1;
}
- dest->neg = 0;
+ #else
+
+ mpz_need_dig(dest, lhs->len + (lhs->neg || rhs->neg));
+ dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
+ (lhs->neg != rhs->neg), 0 == lhs->neg, 0 == rhs->neg);
+ dest->neg = lhs->neg ^ rhs->neg;
+
+ #endif
}
/* computes dest = lhs * rhs