python源码剖析No.5--大整数运算

0x00

上节我们分析了python实现大整数的方式(切分数值的绝对值存放到数组),如此一来对于大整数的运算会相应的变得复杂,本节探讨的就是python大整数运算的处理。

更多大整数实现细节,读者可以查看上节文章 -> Click me


数值运算

先前讨论python对象的相关内容,知道对象的行为由对象的类型决定的。因此,我们也将从整数类型对象中考察这个问题,整数类型对象对应PyLong_Type,定义位置/Objects/longobject.c

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
PyTypeObject PyLong_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
"int", /* tp_name */
offsetof(PyLongObject, ob_digit), /* tp_basicsize */
sizeof(digit), /* tp_itemsize */
long_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_reserved */
long_to_decimal_string, /* tp_repr */
&long_as_number, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
(hashfunc)long_hash, /* tp_hash */
0, /* tp_call */
long_to_decimal_string, /* tp_str */
PyObject_GenericGetAttr, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
Py_TPFLAGS_LONG_SUBCLASS, /* tp_flags */
long_doc, /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
long_richcompare, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
long_methods, /* tp_methods */
0, /* tp_members */
long_getset, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
long_new, /* tp_new */
PyObject_Del, /* tp_free */
};

可以看到数值型操作的PyLong_Type.tp_as_number字段不为空,说明整数对象支持数值型操作。

PyLong_Type.tp_as_sequence字段以及PyLong_Type.tp_as_mapping字段为空,说明整数对象不支持这两个类型的操作。

跟进数值型操作PyLong_Type.tp_as_number字段的long_as_number,同文件下发现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
static PyNumberMethods long_as_number = {
(binaryfunc)long_add, /*nb_add*/
(binaryfunc)long_sub, /*nb_subtract*/
(binaryfunc)long_mul, /*nb_multiply*/
long_mod, /*nb_remainder*/
long_divmod, /*nb_divmod*/
long_pow, /*nb_power*/
(unaryfunc)long_neg, /*nb_negative*/
(unaryfunc)long_long, /*tp_positive*/
(unaryfunc)long_abs, /*tp_absolute*/
(inquiry)long_bool, /*tp_bool*/
(unaryfunc)long_invert, /*nb_invert*/
long_lshift, /*nb_lshift*/
(binaryfunc)long_rshift, /*nb_rshift*/
long_and, /*nb_and*/
long_xor, /*nb_xor*/
long_or, /*nb_or*/
long_long, /*nb_int*/
0, /*nb_reserved*/
long_float, /*nb_float*/
0, /* nb_inplace_add */
0, /* nb_inplace_subtract */
0, /* nb_inplace_multiply */
0, /* nb_inplace_remainder */
0, /* nb_inplace_power */
0, /* nb_inplace_lshift */
0, /* nb_inplace_rshift */
0, /* nb_inplace_and */
0, /* nb_inplace_xor */
0, /* nb_inplace_or */
long_div, /* nb_floor_divide */
long_true_divide, /* nb_true_divide */
0, /* nb_inplace_floor_divide */
0, /* nb_inplace_true_divide */
long_long, /* nb_index */
};

不为空的字段说明整数对象提供了相应的操作函数。例如,nb_addnb_subtract..等,现在我们可以结合前面的指示画出:


数值加法–nb_add

整数对象的数值加法(nb_add)特化到整数对象下是long_add,代码位置/Objects/longobject.c

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
static PyObject *
long_add(PyLongObject *a, PyLongObject *b)
{
PyLongObject *z;

CHECK_BINOP(a, b);

if (Py_ABS(Py_SIZE(a)) <= 1 && Py_ABS(Py_SIZE(b)) <= 1) {
return PyLong_FromLong(MEDIUM_VALUE(a) + MEDIUM_VALUE(b));
}
if (Py_SIZE(a) < 0) {
if (Py_SIZE(b) < 0) {
z = x_add(a, b);
if (z != NULL) {
/* x_add received at least one multiple-digit int,
and thus z must be a multiple-digit int.
That also means z is not an element of
small_ints, so negating it in-place is safe. */
assert(Py_REFCNT(z) == 1);
Py_SIZE(z) = -(Py_SIZE(z));
}
}
else
z = x_sub(b, a);
}
else {
if (Py_SIZE(b) < 0)
z = x_sub(a, b);
else
z = x_add(a, b);
}
return (PyObject *)z;
}
  • 第六行做了操作数的检查(大概是long_add只对整数对象进行加法,其他类型对象不适用)

    1
    2
    3
    4
    5
    #define CHECK_BINOP(v,w)                                \
    do { \
    if (!PyLong_Check(v) || !PyLong_Check(w)) \
    Py_RETURN_NOTIMPLEMENTED; \
    } while(0)
  • 上述代码中频繁出现的Py_Size是个宏定义,位置在同文件下:

    #define Py_SIZE(ob) (((PyVarObject*)(ob))->ob_size) 前面分析过ob_size字段存放的是用于记录digit数组中元素的个数

  • MEDIUN_VALUE的定义也在同文件下:

    1
    2
    3
    4
    5
    /* convert a PyLong of size 1, 0 or -1 to an sdigit */
    #define MEDIUM_VALUE(x) (assert(-1 <= Py_SIZE(x) && Py_SIZE(x) <= 1), \
    Py_SIZE(x) < 0 ? -(sdigit)(x)->ob_digit[0] : \
    (Py_SIZE(x) == 0 ? (sdigit)0 : \
    (sdigit)(x)->ob_digit[0]))
    1
    2
    3
    4
    #if PYLONG_BITS_IN_DIGIT == 30
    typedef int32_t sdigit; /* signed variant of digit */
    #elif PYLONG_BITS_IN_DIGIT == 15
    typedef short sdigit;

    可以看到,sdigit只是C语言下的int32_tshort的别名,具体是哪个得参考PYLONG_BITS_IN_DIGIT的内容

    到这里,就清楚MEDIUM_VALUE完成的是当指示digit数组长度的ob_size为1、0、-1时,此时仅仅使用C语言就足够运算了,因此转为C语言整数进行运算,运算结果使用PyLong_FromLong转为python对象

  • 第11-22行,当a、b皆为负数(ob_size的符号表示正负。)时,使用x_add,进行运算。

  • 第11行+第23-25行,当a为负数,b为正数,使用x_sub进行运算

  • 第11行+第26-28行,当a为正数,b为负数,使用x_sub进行运算

  • 第11行+第29-31行,当a为正数,b为正数,使用x_add进行运算

可以看到,根据a、b符号进行了不同的操作,但最终回到两个关键的函数x_add & x_sub,那么接着探讨这两个函数的处理过程


x_add

x_add定义位置在同文件下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
/* Add the absolute values of two integers. */
static PyLongObject *
x_add(PyLongObject *a, PyLongObject *b)
{
//记录二者长度的绝对值,代表了数值大小
Py_ssize_t size_a = Py_ABS(Py_SIZE(a)), size_b = Py_ABS(Py_SIZE(b));
PyLongObject *z;
Py_ssize_t i;
digit carry = 0;

/* Ensure a is the larger of the two: */
if (size_a < size_b) {
{ PyLongObject *temp = a; a = b; b = temp; }
{ Py_ssize_t size_temp = size_a;
size_a = size_b;
size_b = size_temp; }
}
z = _PyLong_New(size_a+1);
if (z == NULL)
return NULL;
for (i = 0; i < size_b; ++i) {
carry += a->ob_digit[i] + b->ob_digit[i];
z->ob_digit[i] = carry & PyLong_MASK;
carry >>= PyLong_SHIFT;
}
for (; i < size_a; ++i) {
carry += a->ob_digit[i];
z->ob_digit[i] = carry & PyLong_MASK;
carry >>= PyLong_SHIFT;
}
z->ob_digit[i] = carry;
return long_normalize(z);
}
  • 第1行,注释中表明该函数完成的是两个整数对象绝对值相加

  • 第5行,比较a、b整数对象存放数值的digit数组的长度(忽略正负),分别存放到size_a size_b

  • 第11-16行,对比size_a size_b,保证a为二者长度较大的一个,否则交换a、b对象

  • 第17行,_PyLong_New创建新的整数对象z,digit数组长度取a+1(用于潜在进位保留)

  • 第21-24行,ob_digit每一个元素表示的范围2 ** 30当超过则会溢出,因此carry保存的进位信息

    这里需要解决22行中z->ob_digit[i] = carry & PyLong_MASK;,跟踪PyLong_MASK在同文件下找到定义

    1
    2
    3
    4
    5
    #define PyLong_MASK     ((digit)(PyLong_BASE - 1))
    //跟进PyLong_BASE,发现
    #define PyLong_BASE ((digit)1 << PyLong_SHIFT)
    //再跟进PyLong_SHIFT,发现
    #define PyLong_SHIFT 15

    大致完成:第21行carry += a->ob_digit[i] + b->ob_digit[i]并做了相同的移位操作,第22行位置z->ob_digit[i]中保留非进位部分,第23行,将进位carry右移参与下一个digit数组的运算

    注:PyLong_MASK将1进行左移PyLong_SHIFT位(15)得到 1000 0000 0000 0000,将这个值-1,得到PyLong_MASK = 0111 1111 1111 1111

  • 第25-29行,这块处理的是a中比b高出的部分。,行为同第21-24行一致。

  • 第30行,将z对象digit数组最高位保存进位。

  • 第31行,调用long_normalize对z对象进行处理。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    static PyLongObject *
    long_normalize(PyLongObject *v)
    {
    Py_ssize_t j = Py_ABS(Py_SIZE(v));
    Py_ssize_t i = j;

    while (i > 0 && v->ob_digit[i-1] == 0)
    --i;
    if (i != j)
    Py_SIZE(v) = (Py_SIZE(v) < 0) ? -(i) : i;
    return v;
    }

    取出对象高位为0 的部分,并完成ob_size的设置

实例演示过程

x_sub

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
static PyLongObject *
x_sub(PyLongObject *a, PyLongObject *b)
{
Py_ssize_t size_a = Py_ABS(Py_SIZE(a)), size_b = Py_ABS(Py_SIZE(b));
PyLongObject *z;
Py_ssize_t i;
int sign = 1;
digit borrow = 0;

/* Ensure a is the larger of the two: */
if (size_a < size_b) {
sign = -1;
{ PyLongObject *temp = a; a = b; b = temp; }
{ Py_ssize_t size_temp = size_a;
size_a = size_b;
size_b = size_temp; }
}
else if (size_a == size_b) {
/* Find highest digit where a and b differ: */
i = size_a;
while (--i >= 0 && a->ob_digit[i] == b->ob_digit[i])
;
if (i < 0)
return (PyLongObject *)PyLong_FromLong(0);
if (a->ob_digit[i] < b->ob_digit[i]) {
sign = -1;
{ PyLongObject *temp = a; a = b; b = temp; }
}
size_a = size_b = i+1;
}
z = _PyLong_New(size_a);
if (z == NULL)
return NULL;
for (i = 0; i < size_b; ++i) {
/* The following assumes unsigned arithmetic
works module 2**N for some N>PyLong_SHIFT. */
borrow = a->ob_digit[i] - b->ob_digit[i] - borrow;
z->ob_digit[i] = borrow & PyLong_MASK;
borrow >>= PyLong_SHIFT;
borrow &= 1; /* Keep only one sign bit */
}
for (; i < size_a; ++i) {
borrow = a->ob_digit[i] - borrow;
z->ob_digit[i] = borrow & PyLong_MASK;
borrow >>= PyLong_SHIFT;
borrow &= 1; /* Keep only one sign bit */
}
assert(borrow == 0);
if (sign < 0) {
Py_SIZE(z) = -Py_SIZE(z);
}
return long_normalize(z);
}

经过x_add,相信读者已经可以开始自己尝试分析x_sub了。

Author: Victory+
Link: https://cvjark.github.io/2022/05/17/python源码剖析-大整数运算/
Copyright Notice: All articles in this blog are licensed under CC BY-NC-SA 4.0 unless stating additionally.