import json
import pandas as pd
import string
from lemminflect import getLemma
from sklearn.feature_extraction.text import CountVectorizer
#data2=pd.read_csv('rel2.csv')
data=pd.read_csv('data.csv')
ibowd=pd.read_csv('ibow_pre.csv')
bowd=pd.read_csv('bow_pre.csv')

def filter(body,entity_ids):
    o_flag=False
    entity_ids=[entity_ids]
    final_v=[]
    final_k=[]
    all_hashes=[]
    #print(list(data[data['entity_id']==entity_ids[0]]['hash_id']))
    for k in range(0,len(entity_ids)):
        all_hashes.append(list(data[data['entity_id']==entity_ids[k]]['hash_id'])[0])
    ngram_array=[]
    ngrams_array=[]
    relevance_array=[]
    tags_array=[]
    ss_array=[]
    bodystr=' '+' '.join(CountVectorizer(strip_accents='ascii',lowercase=False).build_analyzer()(body))+' '
    body_lower=bodystr.lower()
    
    tokens = (bodystr.replace('\n',' ').replace('\r',' ').replace('  ',' ')).split(' ')

    lembody=" "+" ".join([getLemma(w,'NOUN')[-1] for w in tokens])+" "

    lembody_lower=lembody.lower()

    for j in range(0,len(entity_ids)):
        ngram_flag=False
        ngrams_flag=False
        bow_flag=False
        bows_flag=False
        all_keys=['n_gram','n_gram_s','bow']
        all_values=[]
        industry=list(data[data['entity_id']==entity_ids[j]]['industry'])
        alias=eval(list(data[data['entity_id']==entity_ids[j]]['alias'])[0])

        ibow=list(ibowd[industry[0]])
        bow_keys=list(bowd.keys())

        for i in range(0,len(alias)):
            if(str(alias[i])==''):
                
                break
        alias=alias[:i]
        for i in range(0,len(ibow)):
            if(str(ibow[i])=='nan'):
                
                break
        ibow=ibow[:i]

        trio=[]
        trio_desc=['value','count','total']
        trio.append([])
        trio.append([])
        trio.append(0)
        temp=trio
        
        for i in range(0,len(alias)):
        
            if ' '+alias[i].lower()+' ' in body_lower:
                temp[0].append(alias[i].lower())
                temp[1].append(body_lower.count(' '+alias[i].lower()+' '))
        temp[2]=sum(temp[1])
        if(temp[2]>0):
            ngram_flag=True
            ngram_array.append(all_hashes[j])
            
        all_values.append(dict(zip(trio_desc,temp)))

        trio=[]
        trio_desc=['value','count','total']
        trio.append([])
        trio.append([])
        trio.append(0)
        temp=trio
        
        for i in range(0,len(alias)):
        
            if ' '+alias[i]+' ' in bodystr:
                temp[0].append(alias[i])
                temp[1].append(bodystr.count(' '+alias[i]+' '))
        temp[2]=sum(temp[1])
        if(temp[2]>0):
            ngrams_flag=True
            ngrams_array.append(all_hashes[j])
        all_values.append(dict(zip(trio_desc,temp)))

        trio=[]
        trio_desc=['value','count','total']
        trio.append([])
        trio.append([])
        trio.append(0)
        temp=trio

        
        for i in range(0,len(ibow)):
        
            if ' '+ibow[i]+' ' in lembody:
                temp[0].append(ibow[i])
                temp[1].append(lembody.count(' '+ibow[i]+' '))
        temp[2]=sum(temp[1])
        if(temp[2]>=2 or ngrams_flag==True):
            bow_flag=True
            relevance_array.append(all_hashes[j])
        if(temp[2]>=1):
            bows_flag=True
            
        all_values.append(dict(zip(trio_desc,temp)))
        
        dict1=dict(zip(all_keys,all_values))

        if((ngrams_flag) and bows_flag):
            ss_array.append(all_hashes[j])
        
        if(ngram_flag or ngrams_flag or bow_flag):
            if not o_flag:
                o_flag=True
            final_k.append(all_hashes[j])
            final_v.append(dict1)
    if o_flag:
        
        all_values=[]
        all_keys=[]
        all_bvalues=list(bowd.keys())
        max_length_bow=len(bowd[all_bvalues[0]])
        all_bows=[]
        for i in range(len(all_bvalues)):
            curr=list(bowd[all_bvalues[i]])
            for j in range(0,max_length_bow):
                if(str(curr[j])=="nan"):
                    break
            all_bows.append(curr[:j])
        
        for i in range(0,len(all_bvalues)):
            trio=[]
            trio_desc=['value','count','total']
            trio.append([])
            trio.append([])
            trio.append(0)
            temp=trio
               
            all_keys.append(all_bvalues[i].replace(' ','_'))
            for j in range(0,len(all_bows[i])):
                current=eval(all_bows[i][j])
                if current[1]==1:
                    if(len(current[0].split(' '))>1):
                        if ' '+current[0] in lembody_lower:
                            temp[0].append(current[0])
                            temp[1].append(lembody_lower.count(' '+current[0]))
                    else:
                        if ' '+current[0]+' ' in lembody_lower:
                            temp[0].append(current[0])
                            temp[1].append(lembody_lower.count(' '+current[0]+' '))
                else:
                    if(len(current[0].split(' '))>1):
                        if ' '+current[0]+' ' in body_lower:
                            temp[0].append(current[0])
                            temp[1].append(body_lower.count(' '+current[0]+' '  ))
                    else:
                        if ' '+current[0]+' ' in body_lower:
                            temp[0].append(current[0])
                            temp[1].append(body_lower.count(' '+current[0]+' '))
                    
            temp[2]=sum(temp[1])

            if(temp[2]>0):
                tags_array.append(all_bvalues[i].replace(' ','_'))
            
            #print(dict(zip(trio_desc,temp)))
            
            all_values.append(dict(zip(trio_desc,temp)))
        #t_=dict(zip(final_k,final_v))
        t_=final_v
        t_c=dict(zip(all_keys,all_values))
        """
        print(bodystr)
        print(final_k)
        print(ss_array)
        print(ngram_array)
        print(ngrams_array)
        print(relevance_array)
        printt(tags_array)
        """
        return ([(relevance_array!=[]),(ngram_array!=[]),(ngrams_array!=[]),(ss_array!=[]),tags_array,t_[0],t_c])            
    else:
        return False
    

    
