#!/usr/bin/env python
# coding: utf-8

# In[1]:


# Comprehensive example: quantum anomalous Hall effect
# ====================================================
#
# Physics background
# ------------------
# + Quantum anomalous Hall effect
#
# Features highlighted
# --------------------
# + Use of `kwant.continuum` to discretize a continuum Hamiltonian
# + Use of `kwant.operator` to compute local current
# + Use of `kwant.plotter.current` to plot local current

import math
import matplotlib.pyplot
import kwant
import kwant.continuum


# In[2]:


import matplotlib
import matplotlib.pyplot
from matplotlib_inline.backend_inline import set_matplotlib_formats

matplotlib.rcParams['figure.figsize'] = matplotlib.pyplot.figaspect(1) * 2
set_matplotlib_formats('svg')


# In[3]:


def make_model(a):
    ham = ("alpha * (k_x * sigma_x - k_y * sigma_y)"
           "+ (m + beta * kk) * sigma_z"
           "+ (gamma * kk + U) * sigma_0")
    subs = {"kk": "k_x**2 + k_y**2"}
    return kwant.continuum.discretize(ham, locals=subs, grid=a)


# In[4]:


def make_system(model, L):
    def lead_shape(site):
        x, y = site.pos / L
        return abs(y) < 0.5

    # QPC shape: a rectangle with 2 gaussians
    # etched out of the top and bottom edge.
    def central_shape(site):
        x, y = site.pos / L
        return abs(x) < 3/5 and abs(y) < 0.5 - 0.4 * math.exp(-40 * x**2)

    lead = kwant.Builder(kwant.TranslationalSymmetry(
        model.lattice.vec((-1, 0))))
    lead.fill(model, lead_shape, (0, 0))

    syst = kwant.Builder()
    syst.fill(model, central_shape, (0, 0))
    syst.attach_lead(lead)
    syst.attach_lead(lead.reversed())

    return syst.finalized()

# Set up our model and system, and define the model parameters.
params = dict(alpha=0.365, beta=0.686, gamma=0.512, m=-0.01, U=0)
model = make_model(1)
syst = make_system(model, 70)

# Calculate the scattering states at energy 'm' coming from the left
# lead, and the associated particle current.
psi = kwant.wave_function(syst, energy=params['m'], params=params)(0)


# In[5]:


J = kwant.operator.Current(syst).bind(params=params)
current = sum(J(p) for p in psi)
kwant.plotter.current(syst, current);

