diff --git a/src/aki_prj23_transparenzregister/utils/postgres/connector.py b/src/aki_prj23_transparenzregister/utils/postgres/connector.py index 15bde3e..487a529 100644 --- a/src/aki_prj23_transparenzregister/utils/postgres/connector.py +++ b/src/aki_prj23_transparenzregister/utils/postgres/connector.py @@ -25,11 +25,16 @@ def get_engine(conn_args: PostgreConnectionString): return create_engine(url) -if __name__ == "__main__": - """Main flow creating tables""" +def init_db(): + """Initialize DB with all defined entities.""" config_provider = JsonFileConfigProvider("./secrets.json") engine = get_engine(config_provider.get_postgre_connection_string()) - with engine.connect() as connection: - Base = declarative_base() + with engine.connect(): + base = declarative_base() - Base.metadata.create_all(engine) + base.metadata.create_all(engine) + + +if __name__ == "__main__": + """Main flow creating tables""" + init_db() diff --git a/src/aki_prj23_transparenzregister/utils/postgres/entities.py b/src/aki_prj23_transparenzregister/utils/postgres/entities.py index e249f77..2729ec3 100644 --- a/src/aki_prj23_transparenzregister/utils/postgres/entities.py +++ b/src/aki_prj23_transparenzregister/utils/postgres/entities.py @@ -101,7 +101,9 @@ class Sentiment(Base): # type: ignore company_hr = Column(Integer) company_court = Column(Integer) date = Column(DateTime(), default=datetime.now) - sentiment_type = Column(Enum(SentimentTypeEnum), nullable=False) + sentiment_type = Column( + Enum(SentimentTypeEnum), nullable=False + ) # type: SentimentTypeEnum value = Column(Float(), nullable=False) source = Column(String(100)) @@ -150,7 +152,7 @@ class PersonRelation(Base): # type: ignore person_id = mapped_column(ForeignKey("person.id")) date_from = Column(DateTime(), default=datetime.now) date_to = Column(DateTime(), default=datetime.now) - relation = Column(Enum(RelationTypeEnum), nullable=False) + relation = Column(Enum(RelationTypeEnum), nullable=False) # type: RelationTypeEnum # company = relationship("Company") # person = relationship("Person", foreign_keys=[person_id]) @@ -182,7 +184,9 @@ class CompanyRelation(Base): # type: ignore company2_id = Column(Integer, nullable=False) date_from = Column(DateTime(), default=datetime.now) date_to = Column(DateTime(), default=datetime.now) - relation = Column(Enum(RelationTypeCompanyEnum), nullable=False) + relation = Column( + Enum(RelationTypeCompanyEnum), nullable=False + ) # type: RelationTypeCompanyEnum # company = relationship("Company") diff --git a/tests/utils/postgres/connector_test.py b/tests/utils/postgres/connector_test.py index 50a9c51..1e80985 100644 --- a/tests/utils/postgres/connector_test.py +++ b/tests/utils/postgres/connector_test.py @@ -1,7 +1,7 @@ -from unittest.mock import patch +from unittest.mock import Mock, patch from aki_prj23_transparenzregister.config.config_template import PostgreConnectionString -from aki_prj23_transparenzregister.utils.postgres.connector import get_engine +from aki_prj23_transparenzregister.utils.postgres.connector import get_engine, init_db def test_get_engine(): @@ -12,3 +12,19 @@ def test_get_engine(): result = "someThing" mock_create_engine.return_value = result assert get_engine(conn_args) == result + + +def test_init_db(): + with patch( + "aki_prj23_transparenzregister.utils.postgres.connector.get_engine" + ) as mock_get_engine, patch( + "aki_prj23_transparenzregister.utils.postgres.connector.declarative_base" + ) as mock_declarative_base: + mock_get_engine.connect.return_value = {} + + mock_value = Mock() + mock_value.metadata.create_all.return_value = None + mock_declarative_base.return_value = mock_value + + init_db() + assert True