import sys
try:
import wx
WX_VERSION = int(wx.version()[0])
hasWx = True
except Exception as e:
hasWx = False
WX_VERSION = 0
if hasWx:
import wx.xrc
from wx.lib.buttons import GenBitmapTextButton
from pubsub import pub
import wx.adv
import os
import time
import math
import random
import numpy
import scipy.stats
import datetime
import base
import pytransit.transit_tools as transit_tools
import pytransit.tnseq_tools as tnseq_tools
import pytransit.norm_tools as norm_tools
import pytransit.stat_tools as stat_tools
#method_name = "hmm"
############# GUI ELEMENTS ##################
short_name = "hmm"
long_name = "HMM"
short_desc = "Analysis of genomic regions using a Hidden Markov Model"
long_desc = """Analysis of essentiality in the entire genome using a Hidden Markov Model. Capable of determining regions with different levels of essentiality representing Essential, Growth-Defect, Non-Essential and Growth-Advantage regions.
Reference: DeJesus et al. (2013; BMC Bioinformatics)
"""
transposons = ["himar1"]
columns_sites = ["Location","Read Count","Probability - ES","Probability - GD","Probability - NE","Probability - GA","State","Gene"]
columns_genes = ["Orf","Name","Description","Total Sites","Num. ES","Num. GD","Num. NE","Num. GA", "Avg. Insertions", "Avg. Reads", "State Call"]
############# Analysis Method ##############
[docs]class HMMAnalysis(base.TransitAnalysis):
def __init__(self):
base.TransitAnalysis.__init__(self, short_name, long_name, short_desc, long_desc, transposons, HMMMethod, HMMGUI, [HMMSitesFile, HMMGenesFile])
################## FILE ###################
[docs]class HMMSitesFile(base.TransitFile):
def __init__(self):
base.TransitFile.__init__(self, "#HMM - Sites", columns_sites)
[docs]class HMMGenesFile(base.TransitFile):
def __init__(self):
base.TransitFile.__init__(self, "#HMM - Genes", columns_genes)
############# GUI ##################
[docs]class HMMGUI(base.AnalysisGUI):
[docs] def definePanel(self, wxobj):
self.wxobj = wxobj
hmmPanel = wx.Panel( self.wxobj.optionsWindow, wx.ID_ANY, wx.DefaultPosition, wx.DefaultSize, wx.TAB_TRAVERSAL )
hmmSection = wx.BoxSizer( wx.VERTICAL )
hmmLabel = wx.StaticText( hmmPanel, wx.ID_ANY, u"HMM Options", wx.DefaultPosition, (110,-1), 0 )
hmmLabel.SetFont( wx.Font( 10, wx.DEFAULT, wx.NORMAL, wx.BOLD) )
hmmSection.Add( hmmLabel, 0, wx.ALL|wx.ALIGN_CENTER_HORIZONTAL, 5 )
hmmSizer1 = wx.BoxSizer( wx.VERTICAL )
#(, , Sizer) = self.defineChoiceBox(hmmPanel, u"", hmmNormChoiceChoices, "")
#hmmSizer1.Add(Sizer, 1, wx.ALIGN_CENTER_HORIZONTAL|wx.EXPAND, 5 )
# NORMALIZATION
hmmNormChoiceChoices = [ u"TTR", u"nzmean", u"totreads", u'zinfnb', u'quantile', u"betageom", u"nonorm" ]
(hmmNormLabel, self.wxobj.hmmNormChoice, normSizer) = self.defineChoiceBox(hmmPanel, u"Normalization:", hmmNormChoiceChoices, "Choice of normalization method. The default choice, 'TTR', normalizes datasets to have the same expected count (while not being sensative to outliers). Read documentation for a description other methods.")
hmmSizer1.Add(normSizer, 1, wx.ALIGN_CENTER_HORIZONTAL|wx.EXPAND, 5 )
# REPLICATE
hmmRepChoiceChoices = [ u"Sum", u"Mean" ]
(hmmRepLabel, self.wxobj.hmmRepChoice, repSizer) = self.defineChoiceBox(hmmPanel, u"Replicates:", hmmRepChoiceChoices, "Determines how to handle replicates, and their read-counts. When using many replicates, using 'Mean' may be recommended over 'Sum'")
hmmSizer1.Add(repSizer, 1, wx.ALIGN_CENTER_HORIZONTAL|wx.EXPAND, 5 )
# LOESS
(self.wxobj.hmmLoessCheck, loessCheckSizer) = self.defineCheckBox(hmmPanel, labelText="Correct for Genome Positional Bias", widgetCheck=False, widgetSize=(-1,-1), tooltipText="Check to correct read-counts for possible regional biase using LOESS. Clicking on the button below will plot a preview, which is helpful to visualize the possible bias in the counts.")
hmmSizer1.Add( loessCheckSizer, 0, wx.EXPAND, 5 )
# LOESS Button
self.wxobj.hmmLoessPrev = wx.Button( hmmPanel, wx.ID_ANY, u"Preview LOESS fit", wx.DefaultPosition, wx.DefaultSize, 0 )
hmmSizer1.Add( self.wxobj.hmmLoessPrev, 0, wx.ALL|wx.CENTER, 5 )
hmmSection.Add( hmmSizer1, 1, wx.EXPAND, 5 )
hmmButton = wx.Button( hmmPanel, wx.ID_ANY, u"Run HMM", wx.DefaultPosition, wx.DefaultSize, 0 )
hmmSection.Add( hmmButton, 0, wx.ALL|wx.ALIGN_CENTER_HORIZONTAL, 5 )
hmmPanel.SetSizer( hmmSection )
hmmPanel.Layout()
hmmSection.Fit( hmmPanel )
#Connect events
hmmButton.Bind( wx.EVT_BUTTON, self.wxobj.RunMethod )
self.wxobj.hmmLoessPrev.Bind(wx.EVT_BUTTON, self.wxobj.LoessPrevFunc)
self.panel = hmmPanel
########## CLASS #######################
[docs]class HMMMethod(base.SingleConditionMethod):
"""
HMM
"""
def __init__(self,
ctrldata,
annotation_path,
output_file,
replicates="Mean",
normalization=None,
LOESS=False,
ignoreCodon=True,
NTerminus=0.0,
CTerminus=0.0, wxobj=None):
base.SingleConditionMethod.__init__(self, short_name, long_name, short_desc, long_desc, ctrldata, annotation_path, output_file, replicates=replicates, normalization=normalization, LOESS=LOESS, NTerminus=NTerminus, CTerminus=CTerminus, wxobj=wxobj)
try:
T = len([1 for line in open(ctrldata[0]).readlines() if not line.startswith("#")])
self.maxiterations = T*4 + 1
except:
self.maxiterations = 100
self.count = 1
[docs] @classmethod
def fromGUI(self, wxobj):
""" """
#Get Annotation file
annotationPath = wxobj.annotation
if not transit_tools.validate_annotation(annotationPath):
return None
#Get selected files
ctrldata = wxobj.ctrlSelected()
if not transit_tools.validate_control_datasets(ctrldata):
return None
#Validate transposon types
if not transit_tools.validate_transposons_used(ctrldata, transposons):
return None
#Read the parameters from the wxPython widgets
replicates = wxobj.hmmRepChoice.GetString(wxobj.hmmRepChoice.GetCurrentSelection())
ignoreCodon = True
NTerminus = float(wxobj.globalNTerminusText.GetValue())
CTerminus = float(wxobj.globalCTerminusText.GetValue())
normalization = wxobj.hmmNormChoice.GetString(wxobj.hmmNormChoice.GetCurrentSelection())
LOESS = False
#Get output path
name = transit_tools.basename(ctrldata[0])
defaultFileName = "hmm_output.dat"
defaultDir = os.getcwd()
output_path = wxobj.SaveFile(defaultDir, defaultFileName)
if not output_path: return None
output_file = open(output_path, "w")
return self(ctrldata,
annotationPath,
output_file,
replicates,
normalization,
LOESS,
ignoreCodon,
NTerminus,
CTerminus, wxobj)
[docs] @classmethod
def fromargs(self, rawargs):
(args, kwargs) = transit_tools.cleanargs(rawargs)
ctrldata = args[0].split(",")
annotationPath = args[1]
outpath = args[2]
output_file = open(outpath, "w")
replicates = kwargs.get("r", "Mean")
normalization = kwargs.get("r", "TTR")
LOESS = kwargs.get("l", False)
ignoreCodon = True
NTerminus = float(kwargs.get("iN", 0.0))
CTerminus = float(kwargs.get("iC", 0.0))
return self(ctrldata,
annotationPath,
output_file,
replicates,
normalization,
LOESS,
ignoreCodon,
NTerminus,
CTerminus)
[docs] def Run(self):
self.transit_message("Starting HMM Method")
start_time = time.time()
#Get data
self.transit_message("Getting Data")
(data, position) = transit_tools.get_validated_data(self.ctrldata, wxobj=self.wxobj)
(K,N) = data.shape
# Normalize data
if self.normalization != "nonorm":
self.transit_message("Normalizing using: %s" % self.normalization)
(data, factors) = norm_tools.normalize_data(data, self.normalization, self.ctrldata, self.annotation_path)
# Do LOESS
if self.LOESS:
self.transit_message("Performing LOESS Correction")
for j in range(K):
data[j] = stat_tools.loess_correction(position, data[j])
hash = transit_tools.get_pos_hash(self.annotation_path)
rv2info = transit_tools.get_gene_info(self.annotation_path)
if len(self.ctrldata) > 1:
self.transit_message("Combining Replicates as '%s'" % self.replicates)
O = tnseq_tools.combine_replicates(data, method=self.replicates) + 1 # Adding 1 to because of shifted geometric in scipy
#Parameters
Nstates = 4
label = {0:"ES", 1:"GD", 2:"NE",3:"GA"}
reads = O-1
reads_nz = sorted(reads[reads !=0 ])
size = len(reads_nz)
mean_r = numpy.average(reads_nz[:int(0.95 * size)])
mu = numpy.array([1/0.99, 0.01 * mean_r + 2, mean_r, mean_r*5.0])
#mu = numpy.array([1/0.99, 0.1 * mean_r + 2, mean_r, mean_r*5.0])
L = 1.0/mu
B = [] # Emission Probability Distributions
for i in range(Nstates):
B.append(scipy.stats.geom(L[i]).pmf)
pins = self.calculate_pins(O-1)
pins_obs = sum([1 for rd in O if rd >=2])/float(len(O))
pnon = 1.0 - pins
pnon_obs = 1.0 - pins_obs
for r in range(100):
if pnon ** r < 0.01: break
A = numpy.zeros((Nstates,Nstates))
a = math.log1p(-B[int(Nstates/2)](1)**r)
b = r*math.log(B[int(Nstates/2)](1)) + math.log(1.0/3) # change to Nstates-1?
for i in range(Nstates):
A[i] = [b]*Nstates
A[i][i] = a
PI = numpy.zeros(Nstates) # Initial state distribution
PI[0] = 0.7; PI[1:] = 0.3/(Nstates-1);
self.progress_range(self.maxiterations)
###############
### VITERBI ###
(Q_opt, delta, Q) = self.viterbi(A, B, PI, O)
###############
##################
### ALPHA PASS ###
(log_Prob_Obs, alpha, C) = self.forward_procedure(numpy.exp(A), B, PI, O)
##################
#################
### BETA PASS ###
beta = self.backward_procedure(numpy.exp(A), B, PI, O, C)
#################
T = len(O); total=0; state2count = dict.fromkeys(range(Nstates),0)
for t in xrange(T):
state = Q_opt[t]
state2count[state] +=1
total+=1
self.output.write("#HMM - Sites\n")
self.output.write("# Tn-HMM\n")
if self.wxobj:
members = sorted([attr for attr in dir(self) if not callable(getattr(self,attr)) and not attr.startswith("__")])
memberstr = ""
for m in members:
memberstr += "%s = %s, " % (m, getattr(self, m))
self.output.write("#GUI with: ctrldata=%s, annotation=%s, output=%s\n" % (",".join(self.ctrldata).encode('utf-8'), self.annotation_path.encode('utf-8'), self.output.name.encode('utf-8')))
else:
self.output.write("#Console: python %s\n" % " ".join(sys.argv))
self.output.write("# \n")
self.output.write("# Mean:\t%2.2f\n" % (numpy.average(reads_nz)))
self.output.write("# Median:\t%2.2f\n" % numpy.median(reads_nz))
self.output.write("# pins (obs):\t%f\n" % pins_obs)
self.output.write("# pins (est):\t%f\n" % pins)
self.output.write("# Run length (r):\t%d\n" % r)
self.output.write("# State means:\n")
self.output.write("# %s\n" % " ".join(["%s: %8.4f" % (label[i], mu[i]) for i in range(Nstates)]))
self.output.write("# Self-Transition Prob:\n")
self.output.write("# %s\n" % " ".join(["%s: %2.4e" % (label[i], A[i][i]) for i in range(Nstates)]))
self.output.write("# State Emission Parameters (theta):\n")
self.output.write("# %s\n" % " ".join(["%s: %1.4f" % (label[i], L[i]) for i in range(Nstates)]))
self.output.write("# State Distributions:")
self.output.write("# %s\n" % " ".join(["%s: %2.2f%%" % (label[i], state2count[i]*100.0/total) for i in range(Nstates)]))
states = [int(Q_opt[t]) for t in range(T)]
last_orf = ""
for t in xrange(T):
s_lab = label.get(states[t], "Unknown State")
gamma_t = (alpha[:,t] * beta[:,t])/numpy.sum(alpha[:,t] * beta[:,t])
genes_at_site = hash.get(position[t], [""])
genestr = ""
if not (len(genes_at_site) == 1 and not genes_at_site[0]):
genestr = ",".join(["%s_(%s)" % (g,rv2info.get(g, "-")[0]) for g in genes_at_site])
self.output.write("%s\t%s\t%s\t%s\t%s\n" % (int(position[t]), int(O[t])-1, "\t".join(["%-9.2e" % g for g in gamma_t]), s_lab, genestr))
self.output.close()
self.transit_message("") # Printing empty line to flush stdout
self.transit_message("Finished HMM - Sites Method")
self.transit_message("Adding File: %s" % (self.output.name))
self.add_file(filetype="HMM - Sites")
#Gene Files
self.transit_message("Creating HMM Genes Level Output")
genes_path = ".".join(self.output.name.split(".")[:-1]) + "_genes." + self.output.name.split(".")[-1]
tempObs = numpy.zeros((1,len(O)))
tempObs[0,:] = O - 1
self.post_process_genes(tempObs, position, states, genes_path)
self.transit_message("Adding File: %s" % (genes_path))
self.add_file(path=genes_path, filetype="HMM - Genes")
self.finish()
self.transit_message("Finished HMM Method")
[docs] @classmethod
def usage_string(self):
return """python %s hmm <comma-separated .wig files> <annotation .prot_table or GFF3> <output file>
Optional Arguments:
-r <string> := How to handle replicates. Sum, Mean. Default: -r Mean
-l := Perform LOESS Correction; Helps remove possible genomic position bias. Default: Off.
-iN <float> := Ignore TAs occuring at given fraction of the N terminus. Default: -iN 0.0
-iC <float> := Ignore TAs occuring at given fraction of the C terminus. Default: -iC 0.0
""" % (sys.argv[0])
[docs] def forward_procedure(self, A, B, PI, O):
T = len(O)
N = len(B)
alpha = numpy.zeros((N, T))
C = numpy.zeros(T)
alpha[:,0] = PI * [B[i](O[0]) for i in range(N)]
C[0] = 1.0/numpy.sum(alpha[:,0])
alpha[:,0] = C[0] * alpha[:,0]
for t in xrange(1, T):
#B[i](O[:,t]) => numpy.prod(B[i](O[:,t]))
#b_o = numpy.array([numpy.prod(B[i](O[:,t])) for i in range(N)])
b_o = [B[i](O[t]) for i in range(N)]
alpha[:,t] = numpy.dot(alpha[:,t-1], A) * b_o
C[t] = numpy.nan_to_num(1.0/numpy.sum(alpha[:,t]))
alpha[:,t] = numpy.nan_to_num(alpha[:,t] * C[t])
if numpy.sum(alpha[:,t]) == 0:
alpha[:,t] = 0.0000000000001
text = "Running HMM Method... %1.1f%%" % (100.0*self.count/self.maxiterations)
self.progress_update(text, self.count)
self.count+=1
#print t, O[:,t], alpha[:,t]
log_Prob_Obs = - (numpy.sum(numpy.log(C)))
return(( log_Prob_Obs, alpha, C ))
[docs] def backward_procedure(self, A, B, PI, O, C=numpy.array([])):
N = len(B)
T = len(O)
beta = numpy.zeros((N,T))
beta[:,T-1] = 1.0
if C.any(): beta[:,T-1] = beta[:,T-1] * C[T-1]
for t in xrange(T-2, -1, -1):
#B[i](O[:,t]) => numpy.prod(B[i](O[:,t]))
#b_o = numpy.array([numpy.prod(B[i](O[:,t])) for i in range(N)])
b_o = [B[i](O[t]) for i in range(N)]
beta[:,t] = numpy.nan_to_num(numpy.dot(A, (b_o * beta[:,t+1] ) ))
if sum(beta[:,t]) == 0:
beta[:,t] = 0.0000000000001
if C.any():
beta[:,t] = beta[:,t] * C[t]
text = "Running HMM Method... %1.1f%%" % (100.0*self.count/self.maxiterations)
self.progress_update(text, self.count)
self.count+=1
return(beta)
[docs] def viterbi(self, A, B, PI, O):
N=len(B)
T = len(O)
delta = numpy.zeros((N, T))
b_o = [B[i](O[0]) for i in range(N)]
delta[:,0] = numpy.log(PI) + numpy.log(b_o)
Q = numpy.zeros((N, T), dtype=int)
numpy.seterr(divide='ignore')
for t in xrange(1, T):
b_o = [B[i](O[t]) for i in range(N)]
#nus = delta[:, t-1] + numpy.log(A)
nus = delta[:, t-1] + A
delta[:,t] = nus.max(1) + numpy.log(b_o)
Q[:,t] = nus.argmax(1)
text = "Running HMM Method... %5.1f%%" % (100.0*self.count/self.maxiterations)
self.progress_update(text, self.count)
self.count+=1
Q_opt = [int(numpy.argmax(delta[:,T-1]))]
for t in xrange(T-2, -1, -1):
Q_opt.insert(0, Q[Q_opt[0],t+1])
text = "Running HMM Method... %5.1f%%" % (100.0*self.count/self.maxiterations)
self.progress_update(text, self.count)
self.count+=1
numpy.seterr(divide='warn')
text = "Running HMM Method... %5.1f%%" % (100.0*self.count/self.maxiterations)
self.progress_update(text, self.count)
return((Q_opt, delta, Q))
[docs] def calculate_pins(self, reads):
non_ess_reads = []
temp = []
for rd in reads:
if rd >=1:
if len(temp) < 10: non_ess_reads.extend(temp)
non_ess_reads.append(rd)
temp = []
else:
temp.append(rd)
return(sum([1 for rd in non_ess_reads if rd >= 1])/float(len(non_ess_reads)) )
[docs] def post_process_genes(self, data, position, states, output_path):
output = open(output_path, "w")
pos2state = dict([(position[t],states[t]) for t in range(len(states))])
theta = numpy.mean(data > 0)
G = tnseq_tools.Genes(self.ctrldata, self.annotation_path, data=data, position=position, ignoreCodon=False)
num2label = {0:"ES", 1:"GD", 2:"NE", 3:"GA"}
output.write("#HMM - Genes\n")
for gene in G:
reads_nz = [c for c in gene.reads.flatten() if c > 0]
avg_read_nz = 0
if len(reads_nz) > 0:
avg_read_nz = numpy.average(reads_nz)
# State
genestates = [pos2state[p] for p in gene.position]
statedist = {}
for st in genestates:
if st not in statedist: statedist[st] = 0
statedist[st] +=1
# State counts
n0 = statedist.get(0, 0); n1 = statedist.get(1, 0);
n2 = statedist.get(2, 0); n3 = statedist.get(3, 0);
if gene.n > 0:
E = tnseq_tools.ExpectedRuns(gene.n, 1.0 - theta)
V = tnseq_tools.VarR(gene.n, 1.0 - theta)
if n0 == gene.n: S = "ES"
elif n0 >= int(E+(3*math.sqrt(V))): S = "ES"
else:
temp = max([(statedist.get(s, 0), s) for s in [0, 1, 2, 3]])[1]
S = num2label[temp]
else:
E = 0.0
V = 0.0
S = "N/A"
output.write("%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%1.4f\t%1.2f\t%s\n" % (gene.orf, gene.name, gene.desc, gene.n, n0, n1, n2, n3, gene.theta(), avg_read_nz, S))
output.close()
if __name__ == "__main__":
(args, kwargs) = transit_tools.cleanargs(sys.argv)
G = HMMMethod.fromargs(sys.argv[1:])
G.console_message("Printing the member variables:")
G.print_members()
print ""
print "Running:"
G.Run()