main.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import sqlite3
  2. import json
  3. import os
  4. import paho.mqtt.client as mqtt
  5. import logging
  6. DB_FILE = os.getenv('DB_FILE', 'agent-summary.db3')
  7. MQTT_HOST = os.getenv('MQTT_HOST', 'iris.dgtlu.net')
  8. MQTT_TOPIC = os.getenv('MQTT_TOPIC','agent/summary')
  9. MQTT_USER = os.getenv('MQTT_USER','agent')
  10. MQTT_PASS = os.getenv('MQTT_PASS','agent')
  11. logger = logging.get("__name__")
  12. debug = os.getenv("DEBUG", "")
  13. if debug:
  14. logger.setLevel(logging.DEBUG)
  15. else
  16. logger.setLevel(logging.INFO)
  17. logger.addHandler(logging.StreamHandler())
  18. def db_init():
  19. conn = sqlite3.connect(DB_FILE)
  20. c = conn.cursor()
  21. c.execute('''CREATE TABLE IF NOT EXISTS AgentSummary (
  22. camera_id int,
  23. dt datetime default current_timestamp,
  24. file char(500),
  25. model char(100),
  26. summary text,
  27. primary key (dt, file)
  28. )''')
  29. conn.commit()
  30. return conn
  31. def db_save_input(conn: sqlite3.Connection, json_string):
  32. dt = json.loads(json_string)
  33. c = conn.cursor()
  34. c.execute('INSERT INTO AgentSummary (camera_id, file, model, summary) VALUES (?, ?, ?, ?)',
  35. (dt["id"], dt["file"], dt["data"]["model"], dt["data"]["choices"][0]["message"]["content"].strip()))
  36. conn.commit()
  37. def mqtt_init() -> mqtt:
  38. def on_connect(mqt_client, userdata, flags, rc, properties):
  39. if rc == 0:
  40. logger.info("Connected to MQTT Broker!")
  41. else:
  42. logger.error("Failed to connect, return code %d\n", rc)
  43. client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2)
  44. client.on_connect = on_connect
  45. client.username_pw_set(username=MQTT_USER, password=MQTT_PASS)
  46. client.connect(host=MQTT_HOST, keepalive=60)
  47. return client
  48. def mqtt_subscribe(client: mqtt, db_conn: sqlite3.Connection):
  49. def on_message(mqt_client, userdata, msg):
  50. logger.debug(f"Received `{msg.payload.decode()}` from `{msg.topic}` topic")
  51. db_save_input(db_conn, msg.payload.decode())
  52. client.subscribe(MQTT_TOPIC)
  53. client.on_message = on_message
  54. def main():
  55. db_conn = db_init()
  56. mqtt_client = mqtt_init()
  57. mqtt_subscribe(mqtt_client, db_conn)
  58. mqtt_client.loop_forever()
  59. db_conn.close()
  60. if __name__ == '__main__':
  61. main()