tableQA
1.0.0
AI工具用於在表格數據上查詢自然語言。使用變壓器的QA模型構建。
這項工作在以下論文中描述:
表格:用自然語言查詢表格數據,作者:Abhijith Neil Abraham,Fariz Rahman和Damanpreet Kaur。
如果您使用TableQA,請引用紙張。
這是一個詳細的博客,可以理解它的工作原理。
表格數據可以是:
。
。
pip install tableqa
git clone https://github.com/abhijithneilabraham/tableQA
cd tableqa
python setup.py install
from tableqa.agent import Agent
agent=Agent(df) #input your dataframe
response=agent.query_db("Your question here")
print(response)
sql=agent.get_query("Your question here")
print(sql) #returns an sql query
{
"name": DATABASE NAME,
"keywords":[DATABASE KEYWORDS],
"columns":
[
{
"name": COLUMN 1 NAME,
"mapping":{
CATEGORY 1: [CATEGORY 1 KEYWORDS],
CATEGORY 2: [CATEGORY 2 KEYWORDS]
}
},
{
"name": COLUMN 2 NAME,
"keywords": [COLUMN 2 KEYWORDS]
},
{
"name": "COLUMN 3 NAME",
"keywords": [COLUMN 3 KEYWORDS],
"summable":"True"
}
]
}
summable ,其值已經是計數表示。例如。 Death Count,Cases等是已經代表計數的值。示例(使用手動模式):
from tableqa.agent import Agent
agent=Agent(df,schema) #pass the dataframe and schema objects
response=agent.query_db("how many people died of stomach cancer in 2011")
print(response)
#Response =[(22,)]
from tableqa.agent import Agent
agent = Agent(df, schema_file, 'postgres', username='username', password='password', database='DBname', host='localhost', port=5432, aws_db=False)
response=agent.query_db("how many people died of stomach cancer in 2011")
print(response)
#Response =[(22,)]
from tableqa.agent import Agent
agent = Agent(df, schema_file, 'mysql', username='username', password='password', database='DBname', host='localhost', port=5432, aws_db=False)
response=agent.query_db("how many people died of stomach cancer in 2011")
print(response)
#Response =[(22,)]
請參閱文檔中的步驟1,以在Amazon RDS上創建MySQL DB實例。可以通過在“引擎”選項卡中選擇PostgreSQL來遵循相同的步驟來創建PostgreSQL DB實例。從Amazon RDS上的數據庫連接詳細信息中獲取用戶名,密碼,數據庫,端點和端口。
from tableqa.agent import Agent
agent = Agent(df, schema_file, 'postgres', username='Master username', password='Master password', database='DB name', host='Endpoint', port='Port', aws_db=True)
response=agent.query_db("how many people died of stomach cancer in 2011")
print(response)
#Response =[(22,)]
sql=agent.get_query("How many people died of stomach cancer in 2011")
print(sql)
#sql query: SELECT SUM(Death_Count) FROM cancer_death WHERE Cancer_site = "Stomach" AND Year = "2011"
csv_path="/content/tableQA/tableqa/cleaned_data"
schema_path="/content/tableQA/tableqa/schema"
agent=Agent(csv_path,schema_path)
csv_path="s3://{bucket}/cleaned_data"
schema_path="s3://{bucket}/schema"
agent = Agent(csv_path, schema_path, aws_s3=True, access_key_id=access_key_id, secret_access_key=secret_access_key)
加入我們的工作區:懈怠