Source code for xir.decimal_complex
"""Module containing the DecimalComplex class which stores complex numbers with
``decimal.Decimal`` precision."""
from __future__ import annotations
import functools
from decimal import Decimal
from numbers import Complex
from typing import Callable, Union
def convert_input(func: Callable) -> Callable:
"""Converts inputs into valid ``DecimalComplex`` objects."""
@functools.wraps(func)
def convert_wrapper(self, c: Union[Decimal, Complex]) -> DecimalComplex:
if isinstance(c, DecimalComplex):
return func(self, c)
if isinstance(c, Decimal):
c = DecimalComplex(c)
return func(self, c)
if isinstance(c, Complex):
c = DecimalComplex(str(c.real), str(c.imag))
return func(self, c)
raise TypeError("Must be of type numbers.Complex.")
return convert_wrapper
[docs]class DecimalComplex(Complex):
"""Complex numbers represented by precision ``decimal.Decimal`` terms.
Args:
real (str, Decimal): the real part of the complex number. Defaults to
``0.0`` if not input.
imag (str, Decimal): The imaginary part of the complex number. Defaults
to ``0.0`` if not input.
"""
def __init__(
self, real: Union[str, Decimal] = "0.0", imag: Union[str, Decimal] = "0.0"
) -> None:
self._real = Decimal(real)
self._imag = Decimal(imag)
@convert_input
def __add__(self, c: Union[Decimal, Complex]) -> DecimalComplex:
real = self.real + c.real
imag = self.imag + c.imag
return DecimalComplex(real, imag)
def __radd__(self, c: Union[Decimal, Complex]) -> DecimalComplex:
return self + c
@convert_input
def __sub__(self, c: Union[Decimal, Complex]) -> DecimalComplex:
real = self.real - c.real
imag = self.imag - c.imag
return DecimalComplex(real, imag)
def __rsub__(self, c: Union[Decimal, Complex]) -> DecimalComplex:
return -(self - c)
@convert_input
def __mul__(self, c: Union[Decimal, Complex]) -> DecimalComplex:
real = self.real * c.real - self.imag * c.imag
imag = self.imag * c.real + self.real * c.imag
return DecimalComplex(real, imag)
def __rmul__(self, c: Union[Decimal, Complex]) -> DecimalComplex:
return self * c
# pylint: disable=missing-function-docstring
@convert_input
def __div__(self, c: Union[Decimal, Complex]) -> DecimalComplex:
real = (self.real * c.real + self.imag * c.imag) / (c.real**2 + c.imag**2)
imag = (self.imag * c.real - self.real * c.imag) / (c.real**2 + c.imag**2)
return DecimalComplex(real, imag)
# pylint: disable=missing-function-docstring
@convert_input
def __rdiv__(self, c: Union[Decimal, Complex]) -> DecimalComplex:
return c / self
@convert_input
def __truediv__(self, c: Union[Decimal, Complex]) -> DecimalComplex:
return self.__div__(c)
@convert_input
def __rtruediv__(self, c: Union[Decimal, Complex]) -> DecimalComplex:
return c / self
# pylint: disable=missing-function-docstring
def __floordiv__(self, _: Union[Decimal, Complex]) -> DecimalComplex:
raise TypeError("Cannot take floor of DecimalComplex")
def __abs__(self) -> Decimal:
return Decimal(self.real**2 + self.imag**2).sqrt()
def __pos__(self) -> DecimalComplex:
return self
def __neg__(self) -> DecimalComplex:
return DecimalComplex(-self.real, -self.imag)
@convert_input
def __eq__(self, c: DecimalComplex) -> bool:
return self.real == c.real and self.imag == c.imag
@convert_input
def __ne__(self, c: DecimalComplex) -> bool:
return not self.__eq__(c)
def __str__(self) -> str:
return f"{self.real}{self.imag:+}j"
def __repr__(self):
return f"DecimalComplex('{self.__str__()}')"
def __pow__(self, c: Union[Decimal, Complex]) -> DecimalComplex:
# TODO: calculate powers with precision using Decimal
res = complex(self) ** complex(c)
return DecimalComplex(str(res.real), str(res.imag))
def __rpow__(self, c: Union[Decimal, Complex]) -> DecimalComplex:
# TODO: calculate powers with precision using Decimal
res = complex(c) ** complex(self)
return DecimalComplex(str(res.real), str(res.imag))
def __gt__(self, _: Union[Decimal, Complex]):
self._not_implemented(">")
def __lt__(self, _: Union[Decimal, Complex]):
self._not_implemented("<")
def __ge__(self, _: Union[Decimal, Complex]):
self._not_implemented(">=")
def __le__(self, _: Union[Decimal, Complex]):
self._not_implemented("<=")
def __bool__(self) -> bool:
return self.real != 0 or self.imag != 0
def __complex__(self) -> complex:
return float(self.real) + float(self.imag) * 1j
def _not_implemented(self, obs):
"""Raises a TypeError for the given (unimplemented) operation."""
raise TypeError(f"Cannot use {obs} with complex numbers")
def __hash__(self):
return hash((self.real, self.imag))
@property
def real(self):
"""Return the real part of the complex number."""
return self._real
@property
def imag(self):
"""Return the imaginary part of the complex number."""
return self._imag
[docs] def conjugate(self):
"""Return the complex conjugate of its argument."""
return DecimalComplex(self.real, -self.imag)
_modules/xir/decimal_complex
Download Python script
Download Notebook
View on GitHub