From 518e6ca05dbbab70d897d1805074c66ad6d5fa0c Mon Sep 17 00:00:00 2001
From: thopri <thopri@noc.ac.uk>
Date: Wed, 8 Apr 2020 10:46:58 +0100
Subject: [PATCH] multiple HC added to FES

---
 pynemo/tide/fes_extract_HC.py | 66 +++++++++++++++++++----------------
 1 file changed, 35 insertions(+), 31 deletions(-)

diff --git a/pynemo/tide/fes_extract_HC.py b/pynemo/tide/fes_extract_HC.py
index d8cc3fb..6aaebf6 100644
--- a/pynemo/tide/fes_extract_HC.py
+++ b/pynemo/tide/fes_extract_HC.py
@@ -43,6 +43,7 @@ class HcExtract(object):
             constituents = ['M2','S2']
 
             self.cons = constituents
+            self.mask_dataset = {}
 
             # extract lon and lat z data
             lon_z = np.array(Dataset(settings['tide_fes']+constituents[0]+'_Z.nc').variables['lon'])
@@ -50,12 +51,6 @@ class HcExtract(object):
             lon_resolution = lon_z[1] - lon_z[0]
             data_in_km = 0 # added to maintain the reference to matlab tmd code
 
-            # create empty dictionaries to store harmonics in.
-            self.height_dataset = {}
-            self.Uvelocity_dataset = {}
-            self.Vvelocity_dataset = {}
-            self.mask_dataset = {}
-
             # extract example amplitude grid for Z, U and V and change NaNs to 0 (for land) and other values to 1 (for water)
             mask_z = np.array(np.rot90(Dataset(settings['tide_fes']+constituents[0]+'_Z.nc').variables['amplitude'][:]))
             where_are_NaNs = np.isnan(mask_z)
@@ -78,35 +73,47 @@ class HcExtract(object):
             mask_v[where_are_NaNs] = 0
             self.mask_dataset[mv_name] = mask_v
 
+
             #read and convert the height_dataset file to complex and store in dicts
+            hRe = []
+            hIm = []
+            lat_z = np.array(Dataset(settings['tide_fes'] + constituents[0] + '_Z.nc').variables['lat'][:])
+            lon_z = np.array(Dataset(settings['tide_fes'] + constituents[0] + '_Z.nc').variables['lon'][:])
             for ncon in range(len(constituents)):
                 amp = np.array(np.rot90(Dataset(settings['tide_fes']+str(constituents[ncon])+'_Z.nc').variables['amplitude'][:]))
                 phase = np.array(np.rot90(Dataset(settings['tide_fes']+constituents[ncon]+'_Z.nc').variables['phase'][:]))
-                lat_z = np.array(Dataset(settings['tide_fes']+constituents[ncon]+'_Z.nc').variables['lat'][:])
-                lon_z = np.array(Dataset(settings['tide_fes']+constituents[ncon]+'_Z.nc').variables['lon'][:])
-                hRe = amp*np.sin(phase)
-                hIm = amp*np.cos(phase)
-                self.height_dataset[constituents[ncon]] = {'lat_z':lat_z,'lon_z':lon_z,'hRe':hRe,'hIm':hIm}
+                hRe.append(amp*np.sin(phase))
+                hIm.append(amp*np.cos(phase))
+            hRe = np.stack(hRe)
+            hIm = np.stack(hIm)
+            self.height_dataset = [lon_z,lat_z,hRe,hIm]
 
             #read and convert the velocity_dataset files to complex
+            URe = []
+            UIm = []
+            lat_u = np.array(Dataset(settings['tide_fes'] + constituents[0] + '_U.nc').variables['lat'][:])
+            lon_u = np.array(Dataset(settings['tide_fes'] + constituents[0] + '_U.nc').variables['lon'][:])
             for ncon in range(len(constituents)):
                 amp = np.array(np.rot90(Dataset(settings['tide_fes']+constituents[ncon]+'_U.nc').variables['Ua'][:]))
                 phase = np.array(np.rot90(Dataset(settings['tide_fes']+constituents[ncon]+'_U.nc').variables['Ug'][:]))
-                lat_u = np.array(Dataset(settings['tide_fes']+constituents[ncon]+'_U.nc').variables['lat'][:])
-                lon_u = np.array(Dataset(settings['tide_fes']+constituents[ncon]+'_U.nc').variables['lon'][:])
-                URe = amp*np.sin(phase)
-                UIm = amp*np.cos(phase)
-                self.Uvelocity_dataset[constituents[ncon]] = {'lat_u':lat_u,'lon_u':lon_u,'URe':URe,'UIm':UIm}
-
+                URe.append(amp*np.sin(phase))
+                UIm.append(amp*np.cos(phase))
+            URe = np.stack(URe)
+            UIm = np.stack(UIm)
+            self.Uvelocity_dataset = [lon_u,lat_u,URe,UIm]
+
+            VRe = []
+            VIm = []
+            lat_v = np.array(Dataset(settings['tide_fes'] + constituents[ncon] + '_V.nc').variables['lat'][:])
+            lon_v = np.array(Dataset(settings['tide_fes'] + constituents[ncon] + '_V.nc').variables['lon'][:])
             for ncon in range(len(constituents)):
                 amp = np.array(np.rot90(Dataset(settings['tide_fes']+constituents[ncon]+'_V.nc').variables['Va'][:]))
                 phase = np.array(np.rot90(Dataset(settings['tide_fes']+constituents[ncon]+'_V.nc').variables['Vg'][:]))
-                lat_v = np.array(Dataset(settings['tide_fes']+constituents[ncon]+'_V.nc').variables['lat'][:])
-                lon_v = np.array(Dataset(settings['tide_fes']+constituents[ncon]+'_V.nc').variables['lon'][:])
-                VRe = amp*np.sin(phase)
-                VIm = amp*np.cos(phase)
-                self.Vvelocity_dataset[constituents[ncon]] = {'lat_v':lat_v,'lon_v':lon_v,'VRe':VRe,'VIm':VIm}
-
+                VRe.append(amp*np.sin(phase))
+                VIm.append(amp*np.cos(phase))
+            VRe = np.stack(VRe)
+            VIm = np.stack(VIm)
+            self.Vvelocity_dataset = [lon_v,lat_v,VRe,VIm]
 
             # open grid variables these are resampled TPXO parameters so may not work correctly.
             self.grid = Dataset(settings['tide_fes']+'grid_fes.nc')
@@ -188,18 +195,15 @@ class HcExtract(object):
         amp = np.zeros((len(nc_dataset), lon.shape[0]))
         gph = np.zeros((len(nc_dataset), lon.shape[0]))
 
-        # TODO: need to sort multiple HC, at the momemnt it uses M2 for every harmonic
-
-        data = np.array((nc_dataset['M2'][real_var_name]), dtype=complex)
-        data.imag = np.array((nc_dataset['M2'][img_var_name]))
-        # add extra dim to be compatable with adapted code that expects a list of HC
-        data = np.expand_dims(data,axis=0)
+        data = np.array(np.ravel(nc_dataset[2]), dtype=complex)
+        data.imag = np.array(np.ravel(nc_dataset[3]))
+        data = data.reshape(nc_dataset[2].shape)
         #data = data.reshape(1,nc_dataset['M2'][real_var_name].shape)
         # data[data==0] = np.NaN
 
         # Lat Lon values
-        x_values = nc_dataset['M2'][lon_var_name]
-        y_values = nc_dataset['M2'][lat_var_name]
+        x_values = nc_dataset[0]
+        y_values = nc_dataset[1]
         x_resolution = x_values[1] - x_values[0]
         glob = 0
         if x_values[-1]-x_values[0] == 360-x_resolution:
-- 
GitLab