from sage.rings.power_series_ring import PowerSeriesRing

class EtaProduct(SageObject):
	
	def __init__(self, N, rdict):
		#print "Checking Ligozat criteria"
		sumR = sumDR = sumNoverDr = 0
		prod = 1

		for d in rdict.keys():
			if N % d:
				raise Exception, "%s does not divide %s" % (d, N)

		for d in divisors(N):
			if not rdict.has_key(d):
				rdict[d] = 0 
			if rdict[d] == 0: continue
			sumR += rdict[d]
			sumDR += rdict[d]*d
			sumNoverDr += rdict[d]*N/d
			prod *= (N/d)**rdict[d]

		#print "sum r_d = ", sumR
		#print "sum d r_d = ", sumDR
		#print "sum (N/d) r_d = ", sumNoverDr
		#print "prod (N/d)^(r_d) = ", prod

		if sumR != 0:
			raise Exception, "sum r_d is not 0"
		if (sumDR % 24) != 0:
			raise Exception, "sum d r_d is not 0 mod 24"		
		if (sumNoverDr % 24) != 0:
			raise Exception, "sum (N/d) r_d is not 0 mod 24"
		if not is_square(prod):
			raise Exception, "product (N/d)^(r_d) is not a square"
		
		#print "...passed"
		#print "Level is", N
		#print "Minimal level is", lcm(rdict.keys())
		self.N = N
		self.sumDR = sumDR
		self.rdict = rdict

	def __repr__(self):
		s = "Eta product of level %s : " % self.N 
		for d in self.rdict.keys():
			if self.rdict[d] != 0:
				s += ("(eta_%s)^%s * " % (d, str(self.rdict[d])))
		if len(self.rdict.keys()) != 0:
			s = s[:-3] 
		return s

	def qexp(self, n):
		R,q = PowerSeriesRing(QQ, 'q').objgen()
		pr = R(1)
		eta = qexp_eta(R, n)
		for d in self.rdict.keys():
			if self.rdict[d] != 0:
				pr *= eta(q**d)**self.rdict[d]
		return pr*q**(self.sumDR / ZZ(24))
		
	def order_at_cusp(self, cusp):
		if not isinstance(cusp, CuspFamily):
			raise ValueError, "Argument (=%s) should be a CuspFamily" % cusp
		if cusp.N != self.N:
			raise ValueError, "Cusp not on right curve!"
		s = ZZ(0)
		return 1/ZZ(24)/gcd(cusp.width, self.N/cusp.width) * sum( [ell*self.rdict[ell]/cusp.width * (gcd(cusp.width, self.N/ell))**2  for ell in divisors(self.N)] )

	def divisor(self):
		return FormalSum([ (self.order_at_cusp(c), c) for c in AllCusps(self.N)])	
	def degree(self):
		return sum( [self.order_at_cusp(c) for c in AllCusps(self.N) if self.order_at_cusp(c) > 0])

	def L2Norm(self):
		r""" Return the L_2-norm, which is the sum of the squares of the orders of the divisors of f."""
		return sum( [self.order_at_cusp(c)**2 for c in AllCusps(self.N)] )

def BasisEtaProducts(N):
	# Attempt to produce all eta products of small degree.
	
	divs = divisors(N)[:-1]
	s = len(divs)
	primedivs = prime_divisors(N)

	rows = []
	for i in xrange(s):
		# generate a row of relation matrix
		row = [ Mod(divs[i], 24) - Mod(N, 24), Mod(N/divs[i], 24) - Mod(1, 24)]
		for p in primedivs:
			row.append( Mod(12*valuation(N/divs[i], p), 24))
		rows.append(row)
	M = matrix(rows)
	Mlift = M.change_ring(Integers())
	# now we compute elementary factors of Mlift
	S,U,V = Mlift.smith_form()
	good_vects = []
	for vect in U.rows():
		nf = sum(vect*Mlift*V) # has only one nonzero entry, but hard to predict
					# which one it is!
		good_vects.append((vect * 24/gcd(nf, 24)).list())
	for v in good_vects:
		v.append(-sum([r for r in v]))
	dicts = []
	for v in good_vects:
		dicts.append({})
		for i in xrange(s):
			dicts[-1][divs[i]] = v[i]
		dicts[-1][N] = v[-1]
	return [EtaProduct(N, d) for d in dicts]


def DivisorMatrix(etas):
	cusps = AllCusps(etas[0].N)
	return Matrix([[ZZ(et.order_at_cusp(c)) for c in cusps] for et in etas])
	
def EtaLatticeReduce(EtaProducts):
	r = DivisorMatrix(EtaProducts)
	N = EtaProducts[0].N
	V = FreeModule(ZZ, r.ncols())
	A = V.submodule_with_basis([vector(rw) for rw in r.rows()])
	rred = r.LLL()
	short_etas = []
	for shortvect in rred.rows():
		bv = A.coordinates(shortvect)
		dict = {}
		for d in divisors(N):
			dict[d] = sum( [bv[i]*EtaProducts[i].rdict[d] for i in xrange(r.nrows())])
		short_etas.append(EtaProduct(N, dict))
	return short_etas
	
def CuspsOfWidth(N, d):
	r""" Return the number of cusps of width d."""
	assert ((N % d) == 0)
	return euler_phi(gcd(d, N/d))

letters = "abcdefghijklmnopqrstuvwxyz"

def AllCusps(N):
	c = []
	for d in divisors(N):
		n = CuspsOfWidth(N, d)
		if n == 1:
			c.append(CuspFamily(N, d))
		elif n > 1:
			for i in xrange(n):
				c.append(CuspFamily(N, d, label=letters[i]))
	return c

class CuspFamily(SageObject):
	r""" A family of elliptic curves parametrising a region of
	$X_0(N)$."""
	
	def __init__(self, N, width, label = None):
		r""" Create the cusp of width d on X_0(N) corresponding to the
		family $(\mathbb{C}_p/q^d, \langle \zeta q\rangle)$. Here $\zeta$ is a
		primitive root of unity of order $r$ with $\mathrm{lcm}(r,d) = N$. The
		cusp doesn't need to store zeta.""" 
		self.N = N
		self.width = width
		self.zeta = zeta
		if (N % self.width):
			raise Exception, "Bad width"
		self.label = label

	def CuspObject(self):
		""" Return a SAGE representation of the corresponding cusp. """
		raise NotImplementedError
			
	def __repr__(self):
		#return "Cusp %s%s of width %s on X_0(%s)" % (self.width, (self.label or ""), self.width, self.N)
		if self.width == 1:
			return "(Inf)"
		elif self.width == self.N:
			return "(0)"
		else:
			return "(c_{%s%s})" % (self.width, (self.label or ""))

	
def qexp_eta(ps_ring, n):
	r""" Return the q-expansion of $\eta(q) / q^{1/24}$, where $\eta(q)$ is Dedekind's function
	$$\eta(q) = \prod_{i=1}^\infty (1-q^n)$, as an element of ps_ring, to precision n.
	Completely horrible naive algorithm."""
	q = ps_ring.gen()
	t = ps_ring(1)
	for i in xrange(1,n):
		t = t*ps_ring( 1 - q**i + O(q**n))
	return t

