tableQA
1.0.0
AI工具用于在表格数据上查询自然语言。使用变压器的QA模型构建。
这项工作在以下论文中描述:
表格:用自然语言查询表格数据,作者:Abhijith Neil Abraham,Fariz Rahman和Damanpreet Kaur。
如果您使用TableQA,请引用纸张。
这是一个详细的博客,可以理解它的工作原理。
表格数据可以是:
。
。
pip install tableqa
git clone https://github.com/abhijithneilabraham/tableQA
cd tableqa
python setup.py install
from tableqa.agent import Agent
agent=Agent(df) #input your dataframe
response=agent.query_db("Your question here")
print(response)
sql=agent.get_query("Your question here")
print(sql) #returns an sql query
{
"name": DATABASE NAME,
"keywords":[DATABASE KEYWORDS],
"columns":
[
{
"name": COLUMN 1 NAME,
"mapping":{
CATEGORY 1: [CATEGORY 1 KEYWORDS],
CATEGORY 2: [CATEGORY 2 KEYWORDS]
}
},
{
"name": COLUMN 2 NAME,
"keywords": [COLUMN 2 KEYWORDS]
},
{
"name": "COLUMN 3 NAME",
"keywords": [COLUMN 3 KEYWORDS],
"summable":"True"
}
]
}
summable ,其值已经是计数表示。例如。 Death Count,Cases等是已经代表计数的值。示例(使用手动模式):
from tableqa.agent import Agent
agent=Agent(df,schema) #pass the dataframe and schema objects
response=agent.query_db("how many people died of stomach cancer in 2011")
print(response)
#Response =[(22,)]
from tableqa.agent import Agent
agent = Agent(df, schema_file, 'postgres', username='username', password='password', database='DBname', host='localhost', port=5432, aws_db=False)
response=agent.query_db("how many people died of stomach cancer in 2011")
print(response)
#Response =[(22,)]
from tableqa.agent import Agent
agent = Agent(df, schema_file, 'mysql', username='username', password='password', database='DBname', host='localhost', port=5432, aws_db=False)
response=agent.query_db("how many people died of stomach cancer in 2011")
print(response)
#Response =[(22,)]
请参阅文档中的步骤1,以在Amazon RDS上创建MySQL DB实例。可以通过在“引擎”选项卡中选择PostgreSQL来遵循相同的步骤来创建PostgreSQL DB实例。从Amazon RDS上的数据库连接详细信息中获取用户名,密码,数据库,端点和端口。
from tableqa.agent import Agent
agent = Agent(df, schema_file, 'postgres', username='Master username', password='Master password', database='DB name', host='Endpoint', port='Port', aws_db=True)
response=agent.query_db("how many people died of stomach cancer in 2011")
print(response)
#Response =[(22,)]
sql=agent.get_query("How many people died of stomach cancer in 2011")
print(sql)
#sql query: SELECT SUM(Death_Count) FROM cancer_death WHERE Cancer_site = "Stomach" AND Year = "2011"
csv_path="/content/tableQA/tableqa/cleaned_data"
schema_path="/content/tableQA/tableqa/schema"
agent=Agent(csv_path,schema_path)
csv_path="s3://{bucket}/cleaned_data"
schema_path="s3://{bucket}/schema"
agent = Agent(csv_path, schema_path, aws_s3=True, access_key_id=access_key_id, secret_access_key=secret_access_key)
加入我们的工作区:懈怠