Условие
В IT-отдел привезли экспериментальное устройство для автоматизации расчетов на производстве. Оно работает на урезанном интерпретаторе Python: никаких условий, сравнений или встроенных функций — только арифметика и битовые операции. Знаки сравнения (>, < == и другие) использовать не получится, так как интерпретатор их не поймет и выдаст ошибку. Однако без них писать код довольно сложно. Чтобы помочь коллегам, программист решил реализовать базовую логику выбора большего из двух чисел a и b.
Задача
Как бы вы справились задачей на месте программиста?
Есть два числа: a и b. Найдите наибольшее из них, используя только сложение, вычитание, деление и умножение, а также битовые операции.
Нельзя использовать операторы сравнения (>, <, ==, != и т. д.), тернарный оператор, функции вроде max(), min() и прочее.
Решение
Узнать, какое число больше, можно на основе знаков этих чисел:
- если a — положительное, а число b — отрицательное, то число a больше;
- аналогично, если a — отрицательное, а число b — положительное, то b больше.
Если же оба числа имеют одинаковый знак, можно посмотреть на их разность:
- если a – b > 0, значит первое число больше;
- если a – b < 0, значит второе число больше.
Первый случай: когда знаки одинаковые
Для начала рассмотрим пример, когда оба числа имеют одинаковые знаки. По условию нельзя использовать операторы сравнения, но можно посмотреть на старший бит числа, который указывает знак.
Для этого напишем функцию:
def sign(a: int, b: int) -> int:
result = a - b #считаем разность
sign_bit = (result >> 31) & 1 #сдвигаем на 31 бит, чтобы получить знаковый
return sign_bit
int k = sign(a, b); #записываем результат в переменную k
Самая важная строка:
sign_bit = (result >> 31) & 1
Разберем ее подробнее.
- Предполагается, что для хранения int используется 4 байта или 32 бита. Значит, знак разности хранится в 31 бите, так как нумерация начинается с нуля.
- Оператор >> сдвигает биты вправо 31 раз. Например, вместо 10010011 00001111 0100011 00001111 будет 00000000…00000001. Оператор & 1 берет самый младший бит, то есть самый правый.
- В итоге sign_bit (а также переменная k) содержит 1, если a – b < 0, и содержит 0, если a – b > 0.
На основе этого можем определить, какое из чисел больше:
def getMax(a: int, b: int) -> int:
k = sign(a, b) #записываем результат в переменную k
return a * (1 - k) + b * k
В строке a * (1 – k) + b * k используем переменную k как переключатель, чтобы вернуть нужное значение.
Второй случай: когда знаки разные
Теперь рассмотрим вариант, когда a и b имеют разные знаки. Узнать знак числа мы можем с помощью функции sign(), которую написали выше. Нам даже необязательно выяснять знак числа b — достаточно узнать, положительным или отрицательным является a:
sign_a = sign(a)
return a * (1 - sign_a) + b * sign_a
Если sign(a) вернет 1 (то есть число a — отрицательное), значит мы вернем число b:
return a * (1 - 1) + b * 1
Если же sign(a) вернет 0 (то есть число a — положительное), значит мы также вернем число a:
return a * (1 - 0) + b * 0
Объединяем оба случая
Осталось только объединить оба случая. Для этого:
#находим знаки переменных, а также знак разности a - b
sign_a = sign(a)
sign_b = sign(b)
sign_difference = sign(a - b)
#используем оператор XOR, чтобы узнать одинаковые знаки чисел a и b или нет
use_sign_of_a = sign_a ^ sign_b
use_sign_of_diff = 1 ^ use_sign_of_a
k = use_sign_of_a * sign_a + use_sign_of_diff * sign_difference
return a * (1 - k) + b * k
Разберем код:
- Находим знаки всех переменных.
- Чтобы определить, являются ли a и b переменными с одинаковыми знаками, используем XOR.
- Используем use_sign_of_a, чтобы определить значение use_sign_of_diff.
- Вычисляем значение k.
Также нужно переделать функцию sign():
def sign(v: int) -> int:
sign_bit = (v >> 31) & 1
return sign_bit
В итоге получаем такой код:
def sign(v: int) -> int:
sign_bit = (v >> 31) & 1
return sign_bit
def flip(bit: int) -> int:
return 1 ^ bit
def getMax(a: int, b: int) -> int:
sign_a = sign(a)
sign_b = sign(b)
sign_difference = sign(a - b)
use_sign_of_a = sign_a ^ sign_b
use_sign_of_diff = flip(use_sign_of_a)
k = use_sign_of_a * sign_a + use_sign_of_diff * sign_difference
return a * (1 - k) + b * k
print(getMax(10, 5))