Browse Source

[var_type] added the models, the database base methods, and the tests

Maxime Alves LIRMM@home 3 years ago
parent
commit
4c0576377b

BIN
db/pyheatpump.db View File


+ 45
- 3
pyheatpump/db.py View File

@@ -1,9 +1,51 @@
1 1
 #!/usr/bin/env python3
2 2
 import sqlite3
3
-from ..conf import config
3
+from subprocess import Popen
4
+from .conf import config
5
+import sys
6
+
7
+conn = None
8
+
9
+def connect():
10
+    global conn
11
+    if conn is None:
12
+        print('Will connect to database {}'.format(
13
+            config['heatpump']['database']))
14
+        conn = sqlite3.connect(config['heatpump']['database'])
15
+        conn.row_factory = sqlite3.Row
16
+    return conn
17
+
18
+def initialize(filename):
19
+    p = Popen(
20
+        '/usr/bin/env sqlite3 -init {} {}'.format(filename, config['heatpump']['database']),
21
+        shell=True
22
+    )
23
+    return True if p.wait() == 0 else False
4 24
 
5
-conn = sqlite3.connect(config['heatpump']['database'])
6 25
 
7 26
 def sql(query):
8 27
     global conn
9
-    return conn.execute(query)
28
+    if conn is None:
29
+        connect()
30
+
31
+    cursor = conn.cursor()
32
+
33
+    print(f'Will execute query : \n{query}\n')
34
+    cursor.execute(query)
35
+
36
+    return cursor
37
+
38
+class RowClass(object):
39
+    def __init__(self, **kwargs):
40
+        for key in kwargs.keys():
41
+            if hasattr(self, key):
42
+                setattr(self, key, kwargs[key])
43
+
44
+    def select(self, key, tablename):
45
+        attr = getattr(self, key)
46
+        if type(attr) == str:
47
+            q = f"SELECT * FROM {tablename} WHERE {key} LIKE '{attr}'"
48
+        elif type(attr) == int:
49
+            q = f"SELECT * FROM {tablename} WHERE {key} = {attr}"
50
+
51
+        return sql(q)

+ 0
- 0
pyheatpump/models/__init__.py View File


+ 18
- 0
pyheatpump/models/variable.py View File

@@ -0,0 +1,18 @@
1
+from pyheatpump.db import RowClass
2
+from pyheatpump.db import sql
3
+
4
+class VariableType(RowClass):
5
+    slabel: str = None
6
+    label: str = None
7
+    type: str = None
8
+    start_address: int = None
9
+    end_address: int = None
10
+ 
11
+    def __init__(self, **kwargs):
12
+        super().__init__(**kwargs)
13
+
14
+    @staticmethod
15
+    def getall():
16
+        return dict([
17
+            (row['label'], VariableType(**dict(row)))
18
+            for row in sql('SELECT * FROM var_type') ])

+ 42
- 0
pyheatpump/models/variable_type.py View File

@@ -0,0 +1,42 @@
1
+from pyheatpump.db import RowClass
2
+from pyheatpump.db import sql
3
+
4
+class VariableType(RowClass):
5
+    slabel: str = None
6
+    label: str = None
7
+    type: str = None
8
+    start_address: int = None
9
+    end_address: int = None
10
+ 
11
+    def __init__(self, **kwargs):
12
+        super().__init__(**kwargs)
13
+
14
+        if self.slabel is None and self.label is not None:
15
+            self.slabel = self.label[0]
16
+
17
+
18
+    def select():
19
+        try:
20
+            elt = next(super().select('slabel', 'var_type'))
21
+        except StopIteration:
22
+            print('No element exists')
23
+
24
+    def save(self):
25
+        q = ['UPDATE var_type SET']
26
+        updates = []
27
+        if self.start_address is not None:
28
+            updates.append(f'start_address = {self.start_address}')
29
+        if self.end_address is not None:
30
+            updates.append(f'end_address = {self.end_address}')
31
+        if len(updates) == 0:
32
+            return
33
+        q.append(','.join(updates))
34
+        q.append(f"WHERE slabel LIKE '{self.slabel}'")
35
+
36
+        return sql(' '.join(q))
37
+
38
+    @staticmethod
39
+    def getall():
40
+        return dict([
41
+            (row['label'], VariableType(**dict(row)))
42
+            for row in sql('SELECT * FROM var_type') ])

+ 18
- 0
pyheatpump/models/variable_value.py View File

@@ -0,0 +1,18 @@
1
+from pyheatpump.db import RowClass
2
+from pyheatpump.db import sql
3
+
4
+class VariableType(RowClass):
5
+    slabel: str = None
6
+    label: str = None
7
+    type: str = None
8
+    start_address: int = None
9
+    end_address: int = None
10
+ 
11
+    def __init__(self, **kwargs):
12
+        super().__init__(**kwargs)
13
+
14
+    @staticmethod
15
+    def getall():
16
+        return dict([
17
+            (row['label'], VariableType(**dict(row)))
18
+            for row in sql('SELECT * FROM var_type') ])

+ 54
- 0
pyheatpump/variable_types.py View File

@@ -0,0 +1,54 @@
1
+#!/usr/bin/env python3
2
+import os
3
+from datetime import datetime
4
+from configparser import ConfigParser
5
+from starlette.routing import Route, Router
6
+from starlette.responses import PlainTextResponse, JSONResponse
7
+from pprint import pprint
8
+import uvicorn
9
+import json
10
+
11
+# pyHeatpump modules
12
+from pyheatpump.db import sql
13
+from pyheatpump.models.variable_type import VariableType
14
+
15
+def variable_types():
16
+    assert type(VariableType.getall()) == list
17
+
18
+async def get_variable_types(request):
19
+    return JSONResponse(dict([
20
+        (key, val.__dict__)
21
+        for key, val in VariableType.getall().items()
22
+    ]))
23
+
24
+async def set_variable_types(request):
25
+
26
+    body = json.loads(await request.json())
27
+
28
+    for var_type_label, var_type_values in body.items():
29
+        vt = VariableType(label=var_type_label)
30
+        for key, val in var_type_values.items():
31
+            if key in [ 'start_address', 'end_address' ]:
32
+                setattr(vt, key, val)
33
+
34
+        vt.save()
35
+
36
+    return PlainTextResponse('OK')
37
+
38
+async def variable_types_routes(request, *args, **kwargs):
39
+    if request['method'] == 'GET':
40
+        return await get_variable_types(request)
41
+    elif request['method'] == 'POST':
42
+        return await set_variable_types(request)
43
+
44
+
45
+ROUTES=[
46
+    Route('/', variable_types_routes, methods=['GET', 'POST'])
47
+]
48
+
49
+app = Router(routes=ROUTES)
50
+
51
+if __name__ == '__main__':
52
+    uvicorn.run('pyHeatpump:conf.app',
53
+        host='127.0.0.1',
54
+        port=8000)

+ 67
- 0
tests/test_variable_types.py View File

@@ -0,0 +1,67 @@
1
+#!/usr/bin/env python3
2
+import pytest
3
+from starlette.authentication import UnauthenticatedUser
4
+from starlette.testclient import TestClient
5
+#from pyheatpump.conf import app, config, default_config, CONFIG_FILES, get_config, set_config, config_route, ROUTES
6
+from unittest.mock import patch, MagicMock
7
+from pprint import pprint
8
+import json
9
+from tempfile import mkstemp
10
+from configparser import ConfigParser
11
+import os
12
+import sys
13
+
14
+from pyheatpump.conf import config
15
+from pyheatpump.db import initialize, connect
16
+from pyheatpump.variable_types import app, get_variable_types, set_variable_types, ROUTES
17
+
18
+@pytest.fixture(scope='module')
19
+def set_test_db():
20
+    _, tmpdb = mkstemp(suffix='.db', dir=os.getcwd(), )
21
+    print(f'Will store database in {tmpdb}')
22
+    config['heatpump']['database'] = tmpdb 
23
+    if not initialize(os.path.join(os.getcwd(), 'db/pyheatpump.sql')):
24
+        sys.exit(-1)
25
+
26
+    yield
27
+
28
+    os.unlink(tmpdb)
29
+
30
+
31
+def test_get_(set_test_db):
32
+    c = TestClient(app)
33
+    r = c.get('/')
34
+    assert r.status_code == 200
35
+
36
+class RequestMock(MagicMock):
37
+    def __get__(self, key):
38
+        if key == 'method':
39
+            return 'GET'
40
+
41
+@pytest.mark.asyncio
42
+async def test_get_variable_types(set_test_db):
43
+    resp = await get_variable_types(RequestMock())
44
+    assert resp.status_code  == 200
45
+    d_resp = json.loads(resp.body.decode())
46
+    assert 'Analog' in d_resp.keys()
47
+    assert type(d_resp['Analog']) == dict
48
+    assert 'Integer' in d_resp.keys()
49
+    assert type(d_resp['Integer']) == dict
50
+    assert 'Digital' in d_resp.keys()
51
+    assert type(d_resp['Digital']) == dict
52
+
53
+def test_set_variable_types(set_test_db):
54
+    c = TestClient(app)
55
+    r = c.post('/', json=json.dumps({
56
+        'Analog': {
57
+            'start_address': 42,
58
+            'end_address': 420,
59
+        }
60
+    }))
61
+
62
+    assert r.status_code == 200
63
+
64
+    r = c.get('/')
65
+    d_resp = json.loads(r.content.decode())
66
+    assert d_resp['Analog']['start_address'] == 42
67
+    assert d_resp['Analog']['end_address'] == 420

Loading…
Cancel
Save