#include "../include/rational.hpp"
// #include "bigint.cpp"

uint64_t CRational::termPrecision = 4;
CRational::CRational( CBigInt *num, CBigInt *den )
: m_numerator(num), m_denominator(den) {}
CRational::CRational( CBigInt *num )
: m_numerator(num), m_denominator( new CBigInt(1) ) {}
CRational::CRational( const CRational &parent )
: m_numerator(   std::make_unique<CBigInt>( *parent.m_numerator.get()   )),
  m_denominator( std::make_unique<CBigInt>( *parent.m_denominator.get() )) {}
CRational::~CRational() {}

std::string CRational::toString() {
    simplify();
    return /*((m_denominator->m_negative ^ m_numerator->m_negative) ? "-" : "") + */
    (m_numerator->toString() + " / " + m_denominator->toString());
}

CRational *CRational::evaluate() {
    simplify();
    if (m_denominator->isZero())
        throw std::logic_error("division by zero");
    auto ret = new CRational(*this);
    return ret;
}

CRational *CRational::add( CRational *oth ) {
    setSameDenominator(oth);
    CRational *ret = new CRational( m_numerator->add(oth->m_numerator.get()), new CBigInt(*m_denominator.get()) );
    ret->simplify();
    return ret;
}

CRational *CRational::sub( CRational *oth ) {
    setSameDenominator(oth);
    auto tmp = m_numerator->sub(oth->m_numerator.get());
    CRational *ret = new CRational( tmp, new CBigInt(*m_denominator.get()) );
    ret->simplify();
    return ret;
}

CRational *CRational::mul( CRational *oth ) {
    auto ret = new CRational(m_numerator->mul(oth->m_numerator.get()), m_denominator->mul(oth->m_denominator.get()));
    ret->simplify();
    return ret;
}

CRational *CRational::div( CRational *oth ) {
    std::unique_ptr<CRational> inverse( new CRational( new CBigInt(*oth->m_denominator.get()), new CBigInt(*oth->m_numerator.get()) ) );
    auto ret = mul( inverse.get() );
    ret->simplify();
    return ret;
}

CRational *CRational::power( uint64_t pow ) {
    CRational *ret = new CRational( m_numerator->power(pow), m_denominator->power(pow) );
    ret->simplify();
    return ret;
}

CRational *CRational::term( uint64_t pow ) { // term for taylor series
    CRational *numerator = power(pow);
    CRational denominator( new CBigInt(pow), new CBigInt(1) );
    CRational *ret = numerator->div(&denominator);
    delete numerator;
    return ret;
}

CRational *CRational::ln() {
    if (m_numerator->isZero() || isNegative())
        throw "ln of nonpositive number";
    
    if (m_numerator->cmp(m_denominator.get()) < 1) {
        std::unique_ptr<CRational> ret( new CRational( new CBigInt(0) ));
        CRational *tmp = new CRational(new CBigInt(1));
        std::unique_ptr<CRational> numerator( sub( tmp ));
        delete tmp;
        bool sign = false; // false: +, true: -

        for (uint64_t i = 1; i <= termPrecision; ++i) {
            std::unique_ptr<CRational> currentTerm( numerator->power(i) );
            CRational *tmp = new CRational( new CBigInt(i) );
            currentTerm.reset( currentTerm->div( tmp ));
            delete tmp;
            if (sign)
                currentTerm->negate();
            ret.reset( ret->add(currentTerm.get()) );
            ret->simplify();
            sign ^= 1;
        }

        return ret.release();
    } else {
        // Calculate the base for logarithm
        CRational base( m_numerator->mod(m_denominator.get()), new CBigInt( *m_denominator.get() ));
        CRational ln2( new CBigInt(3552463), new CBigInt(5125120) ); //https://scipp.ucsc.edu/~haber/webpage/Log2.pdf

        // find highest numerator bit
        m_numerator->optimize();
        int64_t numeratorBits = ((m_numerator->m_data.size() - 1) * CHUNK_BITS );
        if (m_numerator->m_data.size())
            numeratorBits += ( 64 - __builtin_clzll(m_numerator->m_data[m_numerator->m_data.size() - 1]) );
        CRational expoN( new CBigInt(numeratorBits), new CBigInt(1));

        std::unique_ptr<CRational> ret(new CRational(new CBigInt(0)));
        std::unique_ptr<CRational> numerator(new CRational(base));

            //taylor series for |x-1| <= 1
        bool sign = false; // false: +, true: -

        // Taylor series expansion
        for (uint64_t i = 1; i <= termPrecision; ++i) {
            std::unique_ptr<CRational> currentTerm( numerator->power(i) );
            CRational *tmp = new CRational(new CBigInt(i));
            currentTerm.reset( currentTerm->div(tmp) );
            delete tmp;
            if (sign)
                currentTerm->negate();
            
            ret.reset(ret->add(currentTerm.get()));
            ret->simplify();
            sign ^= 1;
        }

        CRational *tmp = ln2.mul( &expoN );
        ret.reset( ret->add( tmp ));
        delete tmp;
        return ret.release();
    }
}

CRational *CRational::exp() {
    // e^(a/b) = lim n->inf (1+(a/b)/n)^n = lim n->inf ((bn*a)/bn)^n
    CBigInt *num = m_denominator->mul( new CBigInt(termPrecision) );
    CBigInt *den = new CBigInt(*num);
    num->addDestructive(m_numerator.get());

    std::unique_ptr<CRational> ret (new CRational( num, den ));
    ret->simplify();

    ret.reset( ret->power(termPrecision) );
    return ret.release();
}

std::string CRational::round( uint64_t decimals ) {
    std::string ret;
    if (m_denominator->m_data.size() <= 2 &&
        m_numerator->m_data.size() <= 2) {
        m_numerator->m_data.push_back(0);
        m_denominator->m_data.push_back(0);
        ret = std::to_string( *((double *)m_numerator->m_data.data())
                             / *((double *)m_denominator->m_data.data()));
        size_t dotPosition = ret.find('.');
        if (dotPosition != std::string::npos)
            ret = ret.substr(0, dotPosition + decimals + 1);
        m_numerator->m_data.pop_back();
        m_denominator->m_data.pop_back();
        return ret;
    }
    throw std::length_error("too long to round");

    CRational tmp(*this);
    CBigInt newDen(std::pow(10, decimals));
    int dif = m_denominator->cmp(&newDen);
    
    if (dif > 0) {
        CBigInt *factor = tmp.m_denominator->div(&newDen);
        tmp.m_numerator.reset( tmp.m_numerator->div(factor) );
        delete factor;
    }
    if (dif < 0) {
        CBigInt *factor = newDen.div(tmp.m_denominator.get());
        tmp.m_numerator.reset( tmp.m_numerator->mul(factor) );
        delete factor;
    }
    ret = tmp.m_numerator->toString();
    if ((ret.size() - decimals) > 0)
        ret.insert( ret.size() - decimals, 1, ',' );
    return ret;
}

bool CRational::setSameDenominator( CRational *oth ) {
    if (m_denominator->cmp(oth->m_denominator.get()) == 0)
        return true;
    
    if ( m_numerator == 0 ) {
        m_denominator.reset( new CBigInt( *oth->m_denominator.get() ));
        return true;
    }
    if ( oth->m_numerator == 0 ) {
        oth->m_denominator.reset( new CBigInt( *m_denominator.get() ));
        return true;
    }
    std::unique_ptr<CBigInt> lcm( m_denominator->lcm(oth->m_denominator.get()) );
    std::unique_ptr<CBigInt> thisCoef( lcm->div(m_denominator.get()));
    std::unique_ptr<CBigInt> othCoef( lcm->div(oth->m_denominator.get()));

    m_denominator.reset( m_denominator->mul(thisCoef.get()) );
    m_numerator.reset(   m_numerator  ->mul(thisCoef.get()) );
    

    oth->m_denominator.reset( oth->m_denominator->mul(othCoef.get()) );
    oth->m_numerator.reset(   oth->m_numerator  ->mul(othCoef.get()) );
    return true;
}

void CRational::simplify() {
    if (m_numerator == nullptr)
        return;

    std::unique_ptr<CBigInt> gcd( m_numerator->gcd(m_denominator.get()) );
    m_numerator.reset( m_numerator->div(gcd.get()) );
    m_denominator.reset( m_denominator->div(gcd.get()) );

    m_numerator->m_negative = isNegative();
    m_denominator->m_negative = false;
    return;
}

inline int64_t CRational::intfactorial( uint64_t i ) {
    int64_t ret = 1;
    for ( uint64_t j = 0; j < i; ret *= ++j);
    return ret;
}

bool CRational::isNegative() const {
    return m_numerator->m_negative ^ m_denominator->m_negative;
}

void CRational::negate() {
     m_numerator->m_negative = !m_numerator->m_negative;
}

double CRational::debugDump() {
    return (double) m_numerator->debugDump() / (double) m_denominator->debugDump();
}