#!/usr/bin/env python
# coding: utf-8

# In[1]:


# Tutorial 2.7. Spin textures
# ===========================
#
# Physics background
# ------------------
#  - Spin textures
#  - Skyrmions
#
# Kwant features highlighted
# --------------------------
#  - operators
#  - plotting vector fields

from math import sin, cos, tanh, pi
import itertools
import numpy as np
import tinyarray as ta
import matplotlib.pyplot as plt

import kwant

sigma_0 = ta.array([[1, 0], [0, 1]])
sigma_x = ta.array([[0, 1], [1, 0]])
sigma_y = ta.array([[0, -1j], [1j, 0]])
sigma_z = ta.array([[1, 0], [0, -1]])

# vector of Pauli matrices σ_αiβ where greek
# letters denote spinor indices
sigma = np.rollaxis(np.array([sigma_x, sigma_y, sigma_z]), 1)


# In[2]:


import matplotlib
import matplotlib.pyplot
from matplotlib_inline.backend_inline import set_matplotlib_formats

matplotlib.rcParams['figure.figsize'] = matplotlib.pyplot.figaspect(1) * 2
set_matplotlib_formats('svg')


# In[3]:


def field_direction(pos, r0, delta):
    x, y = pos
    r = np.linalg.norm(pos)
    r_tilde = (r - r0) / delta
    theta = (tanh(r_tilde) - 1) * (pi / 2)

    if r == 0:
        m_i = [0, 0, -1]
    else:
        m_i = [
            (x / r) * sin(theta),
            (y / r) * sin(theta),
            cos(theta),
        ]

    return np.array(m_i)


def scattering_onsite(site, r0, delta, J):
    m_i = field_direction(site.pos, r0, delta)
    return J * np.dot(m_i, sigma)


def lead_onsite(site, J):
    return J * sigma_z


# In[4]:


lat = kwant.lattice.square(norbs=2)

def make_system(L=80):

    syst = kwant.Builder()

    def square(pos):
        return all(-L/2 < p < L/2 for p in pos)

    syst[lat.shape(square, (0, 0))] = scattering_onsite
    syst[lat.neighbors()] = -sigma_0

    lead = kwant.Builder(kwant.TranslationalSymmetry((-1, 0)),
                         conservation_law=-sigma_z)

    lead[lat.shape(square, (0, 0))] = lead_onsite
    lead[lat.neighbors()] = -sigma_0

    syst.attach_lead(lead)
    syst.attach_lead(lead.reversed())

    return syst


# In[5]:


def plot_vector_field(syst, params):
    xmin, ymin = min(s.tag for s in syst.sites)
    xmax, ymax = max(s.tag for s in syst.sites)
    x, y = np.meshgrid(np.arange(xmin, xmax+1), np.arange(ymin, ymax+1))

    m_i = [field_direction(p, **params) for p in zip(x.flat, y.flat)]
    m_i = np.reshape(m_i, x.shape + (3,))
    m_i = np.rollaxis(m_i, 2, 0)

    fig, ax = plt.subplots(1, 1)
    im = ax.quiver(x, y, *m_i, pivot='mid', scale=75)
    fig.colorbar(im)
    plt.show()


def plot_densities(syst, densities):
    fig, axes = plt.subplots(1, len(densities), figsize=(13, 10))
    for ax, (title, rho) in zip(axes, densities):
        kwant.plotter.map(syst, rho, ax=ax, a=4)
        ax.set_title(title)
    plt.show()


def plot_currents(syst, currents):
    fig, axes = plt.subplots(1, len(currents), figsize=(13, 10))
    if not hasattr(axes, '__len__'):
        axes = (axes,)
    for ax, (title, current) in zip(axes, currents):
        kwant.plotter.current(syst, current, ax=ax, colorbar=False,
                              fig_size=(13, 10))
        ax.set_title(title)
    plt.show()


# In[6]:


syst = make_system().finalized()


# In[7]:


plot_vector_field(syst, dict(r0=20, delta=10))


# In[8]:


params = dict(r0=20, delta=10, J=1)
wf = kwant.wave_function(syst, energy=-1, params=params)
psi = wf(0)[0]


# In[9]:


# even (odd) indices correspond to spin up (down)
up, down = psi[::2], psi[1::2]
density = np.abs(up)**2 + np.abs(down)**2


# In[10]:


# spin down components have a minus sign
spin_z = np.abs(up)**2 - np.abs(down)**2


# In[11]:


# spin down components have a minus sign
spin_y = 1j * (down.conjugate() * up - up.conjugate() * down)


# In[12]:


rho = kwant.operator.Density(syst)
rho_sz = kwant.operator.Density(syst, sigma_z)
rho_sy = kwant.operator.Density(syst, sigma_y)

# calculate the expectation values of the operators with 'psi'
density = rho(psi)
spin_z = rho_sz(psi)
spin_y = rho_sy(psi)


# In[13]:


plot_densities(syst, [
    ('$σ_0$', density),
    ('$σ_z$', spin_z),
    ('$σ_y$', spin_y),
])


# In[14]:


J_0 = kwant.operator.Current(syst)
J_z = kwant.operator.Current(syst, sigma_z)
J_y = kwant.operator.Current(syst, sigma_y)

# calculate the expectation values of the operators with 'psi'
current = J_0(psi)
spin_z_current = J_z(psi)
spin_y_current = J_y(psi)


# In[15]:


plot_currents(syst, [
    ('$J_{σ_0}$', current),
    ('$J_{σ_z}$', spin_z_current),
    ('$J_{σ_y}$', spin_y_current),
])


# In[16]:


def following_m_i(site, r0, delta):
    m_i = field_direction(site.pos, r0, delta)
    return np.dot(m_i, sigma)

J_m = kwant.operator.Current(syst, following_m_i)

# evaluate the operator
m_current = J_m(psi, params=dict(r0=25, delta=10))


# In[17]:


plot_currents(syst, [
    (r'$J_{\mathbf{m}_i}$', m_current),
    ('$J_{σ_z}$', spin_z_current),
])


# In[18]:


def circle(site):
    return np.linalg.norm(site.pos) < 20

rho_circle = kwant.operator.Density(syst, where=circle, sum=True)

all_states = np.vstack((wf(0), wf(1)))
dos_in_circle = sum(rho_circle(p) for p in all_states) / (2 * pi)
print('density of states in circle:', dos_in_circle)


# In[19]:


def left_cut(site_to, site_from):
    return site_from.pos[0] <= -39 and site_to.pos[0] > -39

def right_cut(site_to, site_from):
    return site_from.pos[0] < 39 and site_to.pos[0] >= 39

J_left = kwant.operator.Current(syst, where=left_cut, sum=True)
J_right = kwant.operator.Current(syst, where=right_cut, sum=True)

Jz_left = kwant.operator.Current(syst, sigma_z, where=left_cut, sum=True)
Jz_right = kwant.operator.Current(syst, sigma_z, where=right_cut, sum=True)

print('J_left:', J_left(psi), ' J_right:', J_right(psi))
print('Jz_left:', Jz_left(psi), ' Jz_right:', Jz_right(psi))


# In[20]:


J_m = kwant.operator.Current(syst, following_m_i)
J_z = kwant.operator.Current(syst, sigma_z)

J_m_bound = J_m.bind(params=dict(r0=25, delta=10, J=1))
J_z_bound = J_z.bind(params=dict(r0=25, delta=10, J=1))

# Sum current local from all scattering states on the left at energy=-1
wf_left = wf(0)
J_m_left = sum(J_m_bound(p) for p in wf_left)
J_z_left = sum(J_z_bound(p) for p in wf_left)


# In[21]:


plot_currents(syst, [
    (r'$J_{\mathbf{m}_i}$ (from left)', J_m_left),
    (r'$J_{σ_z}$ (from left)', J_z_left),
])

