import yt
import numpy as np

from matplotlib.patches import Rectangle
import matplotlib.gridspec as gridspec
from matplotlib.ticker import NullFormatter
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib import cm
from matplotlib import rc
from matplotlib import rcParams
rc('text', usetex=True)
rcParams['text.latex.preamble'] = [r'\boldmath']
rc('font', family='serif')

import argparse

def _mag(field, data):
        return np.sqrt(data['magx'].v*data['magx'].v + data['magy'].v*data['magy'].v + data['magz'].v*data['magz'].v)
yt.add_field('mag', function = _mag, units="")

def _encrvol(field, data):
        return data['encr'].v*data['dens'].v
yt.add_field('encrvol', function = _encrvol, units="")

parser = argparse.ArgumentParser(description='Plotting script') 
parser.add_argument('-svg', action="store_true")
parser.add_argument('-n', default=2500, type=int)
args = parser.parse_args()

pltnr=str(args.n).zfill(4)


interpol_method = "gaussian"

########################
# prepare combined plot
########################

kpc = 3.085e21

image_size_inch_x = 12
image_size_inch_y = 10
cbar_pad = 0.15

# plot file interval 1 kyr
file1="../../../CR00/SILCC_hdf5_plt_cnt_"+pltnr
file2="../../../CRXX/SILCC_hdf5_plt_cnt_"+pltnr
file3="../../../CR10/SILCC_hdf5_plt_cnt_"+pltnr

pf1 = yt.load(file1)
pf2 = yt.load(file2)
pf3 = yt.load(file3)

field1="dens"
field2="dens"
field3="encrvol"


cmin=1e-27
cmax=1e-21
c2min=1e-27
c2max=1e-21
emin=1e-12
emax=1e-10

axis = 1
axis2 = 2

width1  = pf1.domain_width[0]
height1 = 1.5*pf1.domain_width[1]
width2  = pf1.domain_width[0]
height2 = pf1.domain_width[1]

resh1   = 192
resw1   = 128
resh2   = 128
resw2   = 128

# store corners
le = pf1.domain_left_edge  /kpc#* pf1['kpc']
re = pf1.domain_right_edge /kpc#* pf1['kpc']
print le, re

fig = plt.figure(figsize=(image_size_inch_x, image_size_inch_y))
gs = gridspec.GridSpec(2, 3, width_ratios=(0.89, 0.89, 1), height_ratios=(1.45,1))
ax11 = plt.subplot(gs[0])
ax12 = plt.subplot(gs[1])
ax13 = plt.subplot(gs[2])
ax21 = plt.subplot(gs[3])
ax22 = plt.subplot(gs[4])
ax23 = plt.subplot(gs[5])
#ax31 = plt.subplot(gs[6])
#ax32 = plt.subplot(gs[7])
#ax33 = plt.subplot(gs[8])


center = pf1.domain_center
#pltdata11 = pf1.h.proj(axis, field1, center=center).to_frb(width1, (resh1, resw1), height=height1)
#pltdata12 = pf2.h.proj(axis, field1, center=center).to_frb(width1, (resh1, resw1), height=height1)
#pltdata13 = pf3.h.proj(axis, field1, center=center).to_frb(width1, (resh1, resw1), height=height1)

#pf1.h.slice(axis, center[axis2], field1, center=center).to_frb(width1, (resh1, resw1), height=height1)
pltdata11 = np.array(yt.SlicePlot(pf1, 'y', fields=field1).data_source.to_frb(height1, (resh1, resw1), height=width1)[field1]).T
pltdata12 = np.array(yt.SlicePlot(pf2, 'y', fields=field1).data_source.to_frb(height1, (resh1, resw1), height=width1)[field1]).T
pltdata13 = np.array(yt.SlicePlot(pf3, 'y', fields=field1).data_source.to_frb(height1, (resh1, resw1), height=width1)[field1]).T

#pltdata21 = pf1.h.proj(axis2, field2, center=center).to_frb(width2, (resh2, resw2), height=height2)
#pltdata22 = pf2.h.proj(axis2, field2, center=center).to_frb(width2, (resh2, resw2), height=height2)
#pltdata23 = pf3.h.proj(axis2, field2, center=center).to_frb(width2, (resh2, resw2), height=height2)

pltdata21 = np.array(yt.SlicePlot(pf1, 'z', fields=field2).data_source.to_frb(width2, (resh2, resw2), height=height2)[field2])
pltdata22 = np.array(yt.SlicePlot(pf2, 'z', fields=field2).data_source.to_frb(width2, (resh2, resw2), height=height2)[field2])
pltdata23 = np.array(yt.SlicePlot(pf3, 'z', fields=field2).data_source.to_frb(width2, (resh2, resw2), height=height2)[field2])

pltdata32 = np.array(yt.SlicePlot(pf2, 'z', fields=field3).data_source.to_frb(width2, (resh2, resw2), height=height2)[field3])
pltdata33 = np.array(yt.SlicePlot(pf3, 'z', fields=field3).data_source.to_frb(width2, (resh2, resw2), height=height2)[field3])


ax12.yaxis.set_major_formatter( NullFormatter() )
ax22.yaxis.set_major_formatter( NullFormatter() )
#ax32.yaxis.set_major_formatter( NullFormatter() )
ax13.yaxis.set_major_formatter( NullFormatter() )
ax23.yaxis.set_major_formatter( NullFormatter() )
#ax33.yaxis.set_major_formatter( NullFormatter() )
ax11.xaxis.set_major_formatter( NullFormatter() )
ax12.xaxis.set_major_formatter( NullFormatter() )
ax13.xaxis.set_major_formatter( NullFormatter() )
#ax21.xaxis.set_major_formatter( NullFormatter() )
#ax22.xaxis.set_major_formatter( NullFormatter() )
#ax23.xaxis.set_major_formatter( NullFormatter() )

time    = pf1.current_time/3.1536e13
timestr = "Time: %3.1f Myr" % np.around(time, decimals=1)
ax11.text(-0.92, 1.3, timestr, fontsize=14, color='white')

ax11.set_title("SNe: only thermal ($10^{51}$ erg)")
ax12.set_title("SNe: only CR ($10^{50}$ erg)")
ax13.set_title("SNe: thermal + CR")


im11 = ax11.imshow(pltdata11, origin = 'upper', norm = LogNorm(), extent = [le[0], re[0], 1.5*le[1], 1.5*re[1]], vmin=cmin, vmax=cmax, cmap = "jet", interpolation=interpol_method)
im12 = ax12.imshow(pltdata12, origin = 'upper', norm = LogNorm(), extent = [le[0], re[0], 1.5*le[1], 1.5*re[1]], vmin=cmin, vmax=cmax, cmap = "jet", interpolation=interpol_method)
im13 = ax13.imshow(pltdata13, origin = 'upper', norm = LogNorm(), extent = [le[0], re[0], 1.5*le[1], 1.5*re[1]], vmin=cmin, vmax=cmax, cmap = "jet", interpolation=interpol_method)
divider = make_axes_locatable(ax13)
cbar_ax = divider.append_axes("right", size=0.15, pad=0.2)
cbar = fig.colorbar(im13, ax=ax13, cax=cbar_ax, orientation = 'vertical')
cbar.set_label("density~~(${\\rm{g~cm}}^{-3}$)")

im21 = ax21.imshow(pltdata21, origin = 'upper', norm = LogNorm(), extent = [le[0], re[0], le[1], re[1]], vmin=c2min, vmax=c2max, cmap = "jet", interpolation=interpol_method)
im22 = ax22.imshow(pltdata22, origin = 'upper', norm = LogNorm(), extent = [le[0], re[0], le[1], re[1]], vmin=c2min, vmax=c2max, cmap = "jet", interpolation=interpol_method)
im23 = ax23.imshow(pltdata23, origin = 'upper', norm = LogNorm(), extent = [le[0], re[0], le[1], re[1]], vmin=c2min, vmax=c2max, cmap = "jet", interpolation=interpol_method)
divider = make_axes_locatable(ax23)
cbar_ax = divider.append_axes("right", size=0.15, pad=0.2)
cbar = fig.colorbar(im23, ax=ax23, cax=cbar_ax, orientation = 'vertical')
cbar.set_label("density~~(${\\rm{g~cm}}^{-3}$)")

#im31 = ax31.imshow(pltdata31[field3], origin = 'upper', norm = LogNorm(), extent = [le[0], re[0], le[1], re[1]], vmin=emin, vmax=emax, cmap = "RdBu")
#im32 = ax32.imshow(pltdata32, origin = 'upper', norm = LogNorm(), extent = [le[0], re[0], le[1], re[1]], vmin=emin, vmax=emax, cmap = "seismic")
#im33 = ax33.imshow(pltdata33, origin = 'upper', norm = LogNorm(), extent = [le[0], re[0], le[1], re[1]], vmin=emin, vmax=emax, cmap = "seismic")
#divider = make_axes_locatable(ax33)
#cbar_ax = divider.append_axes("right", size=0.15, pad=0.2)
#cbar = fig.colorbar(im33, ax=ax33, cax=cbar_ax, orientation = 'vertical')
#cbar.set_label("${E_{\\rm{CR}}}~~(\\rm{erg~cm}^{-3})$")

xlabel = "x (kpc)"
y1label = "z (kpc)"
y2label = "y (kpc)"
y3label = "y (kpc)"

ax21.set_xlabel(xlabel)
ax22.set_xlabel(xlabel)
ax23.set_xlabel(xlabel)
ax11.set_ylabel(y1label)
ax21.set_ylabel(y2label)
#ax32.set_ylabel(y2label)

plt.subplots_adjust(wspace=0.1, hspace = 0.0, bottom=0.15)


name="d-d-"+pltnr

# save plot
#plt.savefig(name+".eps", bbox_inches='tight')
plt.savefig(name+".png", bbox_inches='tight')
plt.savefig(name+".pdf", bbox_inches='tight')
if args.svg:
	plt.savefig(name+".svg", bbox_inches='tight')

