92 lines
3.7 KiB
Python
92 lines
3.7 KiB
Python
from constants import CHUNK_SIZE, DATASET_DIR, TRAINING_DIR, VALIDATION_DIR, TESTING_DIR
|
|
import pyarrow.parquet as pq
|
|
import pyarrow as pa
|
|
import os
|
|
from helper import get_time_boundaries
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
def split_dataset_random(filename:str) -> None:
|
|
pq_file = pq.ParquetFile(f"{DATASET_DIR}/{filename}")
|
|
|
|
training_writer = None
|
|
validation_writer = None
|
|
testing_writer = None
|
|
|
|
for batch in pq_file.iter_batches(batch_size=CHUNK_SIZE):
|
|
table = pa.Table.from_batches([batch])
|
|
|
|
rng = np.random.rand(table.num_rows)
|
|
|
|
training = table.filter(rng < 0.75)
|
|
validation = table.filter((rng >= 0.75) & (rng < 0.85))
|
|
testing = table.filter(rng >= 0.85)
|
|
|
|
if not training_writer and training.num_rows:
|
|
training_writer = pq.ParquetWriter(f"{TRAINING_DIR}/{filename}", training.schema)
|
|
if not validation_writer and validation.num_rows:
|
|
validation_writer = pq.ParquetWriter(f"{VALIDATION_DIR}/{filename}", validation.schema)
|
|
if not testing_writer and testing.num_rows:
|
|
testing_writer = pq.ParquetWriter(f"{TESTING_DIR}/{filename}", testing.schema)
|
|
|
|
if training.num_rows:
|
|
training_writer.write(training)
|
|
if validation.num_rows:
|
|
validation_writer.write(validation)
|
|
if testing.num_rows:
|
|
testing_writer.write(testing)
|
|
|
|
training_writer.close()
|
|
validation_writer.close()
|
|
testing_writer.close()
|
|
|
|
def split_dataset(filename: str) -> None:
|
|
df = pd.read_parquet(f"{DATASET_DIR}/{filename}")
|
|
n = len(df)
|
|
df['scraped_at'] = pd.to_datetime(df['scraped_at'], format='ISO8601', errors='coerce', utc=True)
|
|
df = df.sort_values(by='scraped_at')
|
|
|
|
df.iloc[:int(n * 0.8)].to_parquet(f"{TRAINING_DIR}/{filename}")
|
|
df.iloc[int(n * 0.8):int(n * 0.9)].to_parquet(f"{VALIDATION_DIR}/{filename}")
|
|
df.iloc[int(n * 0.9):].to_parquet(f"{TESTING_DIR}/{filename}")
|
|
return
|
|
|
|
# ── Writers start as None — initialized on first batch ───────────────────
|
|
train_writer = None
|
|
val_writer = None
|
|
test_writer = None
|
|
|
|
try:
|
|
parquet_file = pq.ParquetFile(filepath)
|
|
for batch in parquet_file.iter_batches(batch_size=CHUNK_SIZE): # type: ignore
|
|
chunk = batch.to_pandas() # type: ignore
|
|
chunk['scraped_at'] = pd.to_datetime(chunk['scraped_at'], format='ISO8601', errors='coerce', utc=True)
|
|
|
|
# Initialize writers on first batch AFTER datetime conversion
|
|
if train_writer is None:
|
|
schema= pa.Schema.from_pandas(chunk)
|
|
train_writer = pq.ParquetWriter(os.path.join(TRAINING_DIR, filename), schema)
|
|
val_writer = pq.ParquetWriter(os.path.join(VALIDATION_DIR, filename), schema)
|
|
test_writer = pq.ParquetWriter(os.path.join(TESTING_DIR, filename), schema)
|
|
|
|
# Split the chunk
|
|
train_chunk = chunk[chunk['scraped_at'] <= train_cut]
|
|
val_chunk = chunk[(chunk['scraped_at'] > train_cut) & (chunk['scraped_at'] <= val_cut)]
|
|
test_chunk = chunk[chunk['scraped_at'] > val_cut]
|
|
|
|
# Write each split
|
|
if not train_chunk.empty:
|
|
train_writer.write_table(pa.Table.from_pandas(train_chunk, schema=schema))
|
|
if not val_chunk.empty:
|
|
val_writer.write_table(pa.Table.from_pandas(val_chunk, schema=schema))
|
|
if not test_chunk.empty:
|
|
test_writer.write_table(pa.Table.from_pandas(test_chunk, schema=schema))
|
|
|
|
finally:
|
|
if train_writer:
|
|
train_writer.close()
|
|
if val_writer:
|
|
val_writer.close()
|
|
if test_writer:
|
|
test_writer.close()
|