109 lines
3.4 KiB
Python
109 lines
3.4 KiB
Python
from __future__ import annotations
|
|
|
|
import importlib.util
|
|
import pathlib
|
|
import sys
|
|
import types
|
|
import unittest
|
|
|
|
|
|
ROOT_DIR = pathlib.Path(__file__).resolve().parents[2]
|
|
PAGE_IMPORTER_DIR = ROOT_DIR / "Page Importer"
|
|
if str(PAGE_IMPORTER_DIR) not in sys.path:
|
|
sys.path.insert(0, str(PAGE_IMPORTER_DIR))
|
|
|
|
APP_MODULE = None
|
|
|
|
|
|
def load_app_module():
|
|
original_modules = {
|
|
name: sys.modules.get(name)
|
|
for name in (
|
|
"streamlit",
|
|
"page_importer.dates",
|
|
"page_importer.models",
|
|
"page_importer.scraper",
|
|
"page_importer.wxr",
|
|
"page_importer_app_test",
|
|
)
|
|
}
|
|
|
|
try:
|
|
sys.modules["streamlit"] = types.ModuleType("streamlit")
|
|
|
|
dates_module = types.ModuleType("page_importer.dates")
|
|
dates_module.parse_datetime = lambda value: None
|
|
sys.modules["page_importer.dates"] = dates_module
|
|
|
|
models_module = types.ModuleType("page_importer.models")
|
|
|
|
class ScrapeOptions:
|
|
pass
|
|
|
|
class ScrapedPost:
|
|
pass
|
|
|
|
models_module.ScrapeOptions = ScrapeOptions
|
|
models_module.ScrapedPost = ScrapedPost
|
|
sys.modules["page_importer.models"] = models_module
|
|
|
|
scraper_module = types.ModuleType("page_importer.scraper")
|
|
scraper_module.Scraper = object
|
|
sys.modules["page_importer.scraper"] = scraper_module
|
|
|
|
wxr_module = types.ModuleType("page_importer.wxr")
|
|
wxr_module.build_wxr = lambda posts: ""
|
|
sys.modules["page_importer.wxr"] = wxr_module
|
|
|
|
app_path = PAGE_IMPORTER_DIR / "app.py"
|
|
spec = importlib.util.spec_from_file_location("page_importer_app_test", app_path)
|
|
assert spec is not None and spec.loader is not None
|
|
module = importlib.util.module_from_spec(spec)
|
|
sys.modules["page_importer_app_test"] = module
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
finally:
|
|
for name, original in original_modules.items():
|
|
if original is None:
|
|
sys.modules.pop(name, None)
|
|
else:
|
|
sys.modules[name] = original
|
|
|
|
|
|
APP_MODULE = load_app_module()
|
|
|
|
|
|
class UploadStateTests(unittest.TestCase):
|
|
def test_sync_uploaded_file_state_clears_stale_results_for_new_file(self) -> None:
|
|
session_state = {
|
|
"uploaded_csv_fingerprint": "old",
|
|
"results": ["stale"],
|
|
"input_rows": [{"url": "https://example.com"}],
|
|
"input_headers": ["url"],
|
|
"scrape_context": {"url_column": "url"},
|
|
}
|
|
|
|
APP_MODULE.sync_uploaded_file_state(session_state, "new")
|
|
|
|
self.assertEqual(session_state["uploaded_csv_fingerprint"], "new")
|
|
self.assertNotIn("results", session_state)
|
|
self.assertNotIn("input_rows", session_state)
|
|
self.assertNotIn("input_headers", session_state)
|
|
self.assertNotIn("scrape_context", session_state)
|
|
|
|
def test_sync_uploaded_file_state_keeps_results_for_same_file(self) -> None:
|
|
session_state = {
|
|
"uploaded_csv_fingerprint": "same",
|
|
"results": ["keep"],
|
|
"input_rows": [{"url": "https://example.com"}],
|
|
}
|
|
|
|
APP_MODULE.sync_uploaded_file_state(session_state, "same")
|
|
|
|
self.assertEqual(session_state["results"], ["keep"])
|
|
self.assertEqual(session_state["input_rows"], [{"url": "https://example.com"}])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|